zetta_utils.tensor_ops
¶
Handling tensor operations. Each opperation must support both np.ndarray
and torch.Tensor
input types.
zetta_utils.tensor_ops.convert
¶
- zetta_utils.tensor_ops.convert.to_torch(data, device=None)[source]¶
Convert the given tensor to
torch.Tensor
.- Parameters:
- Return type:
Tensor
- Returns:
Input tensor in
torch.Tensor
format.
- zetta_utils.tensor_ops.convert.astype(data, reference, cast=False)[source]¶
Convert the given tensor to
np.ndarray
ortorch.Tensor
depending on the type of reference tensor_ops.- Parameters:
data (
Union
[Tensor
,ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]]]) – Input tensor_ops.reference (
TypeVar
(TensorTypeVar
,Tensor
,ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]],Union
[Tensor
,ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]]])) – Reference type tensor_ops.cast (
bool
) – IfTrue
, castdata
to the type ofreference
.
- Return type:
TypeVar
(TensorTypeVar
,Tensor
,ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]],Union
[Tensor
,ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]]])- Returns:
Input tensor converted to the reference type.