zetta_utils.tensor_ops¶
Handling tensor operations. Each opperation must support both np.ndarray and torch.Tensor input types.
zetta_utils.tensor_ops.convert¶
- zetta_utils.tensor_ops.convert.to_torch(data, device=None)[source]¶
Convert the given tensor to
torch.Tensor.- Parameters:
- Return type:
Tensor- Returns:
Input tensor in
torch.Tensorformat.
- zetta_utils.tensor_ops.convert.astype(data, reference, cast=False)[source]¶
Convert the given tensor to
np.ndarrayortorch.Tensordepending on the type of reference tensor_ops.- Parameters:
data (
Union[Tensor,ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]]]) – Input tensor_ops.reference (
TypeVar(TensorTypeVar,Tensor,ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]],Union[Tensor,ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]]])) – Reference type tensor_ops.cast (
bool) – IfTrue, castdatato the type ofreference.
- Return type:
TypeVar(TensorTypeVar,Tensor,ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]],Union[Tensor,ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]]])- Returns:
Input tensor converted to the reference type.