zetta_utils.tensor_ops¶
Handling tensor operations. Each opperation must support both np.ndarray and torch.Tensor input types.
zetta_utils.tensor_ops.convert¶
- zetta_utils.tensor_ops.convert.to_torch(data, device=None)[source]¶
Convert the given tensor to
torch.Tensor.
- zetta_utils.tensor_ops.convert.astype(data, reference, cast=False)[source]¶
Convert the given tensor to
np.ndarrayortorch.Tensordepending on the type of reference tensor_ops.- Parameters:
data (
Union[Tensor,ndarray[tuple[Any,...],dtype[TypeVar(_ScalarT, bound=generic)]]]) – Input tensor_ops.reference (
TypeVar(TensorTypeVar,Tensor,ndarray[tuple[Any,...],dtype[TypeVar(_ScalarT, bound=generic)]],Union[Tensor,ndarray[tuple[Any,...],dtype[TypeVar(_ScalarT, bound=generic)]]])) – Reference type tensor_ops.cast (
bool) – IfTrue, castdatato the type ofreference.
- Return type:
TypeVar(TensorTypeVar,Tensor,ndarray[tuple[Any,...],dtype[TypeVar(_ScalarT, bound=generic)]],Union[Tensor,ndarray[tuple[Any,...],dtype[TypeVar(_ScalarT, bound=generic)]]])- Returns:
Input tensor converted to the reference type.