zetta_utils.tensor_ops
Handling tensor operations. Each opperation must support both np.ndarray
and torch.Tensor
input types.
zetta_utils.tensor_ops.convert
-
zetta_utils.tensor_ops.convert.to_np(data)[source]
Convert the given tensor to numpy.ndarray
.
- Parameters
data (Union
[Tensor
, ndarray
[Any
, dtype
[+_ScalarType_co]]]) – Input tensor_ops.
- Return type
ndarray
[Any
, dtype
[+_ScalarType_co]]
- Returns
Input tensor in numpy.ndarray
format.
-
zetta_utils.tensor_ops.convert.to_torch(data, device=None)[source]
Convert the given tensor to torch.Tensor
.
- Parameters
data (Union
[Tensor
, ndarray
[Any
, dtype
[+_ScalarType_co]]]) – Input tensor_ops.
device (Union
[device
, str
, int
, None
]) – Device name on which the torch tensor will reside.
- Return type
Tensor
- Returns
Input tensor in torch.Tensor
format.
-
zetta_utils.tensor_ops.convert.astype(data, reference, cast=False)[source]
Convert the given tensor to np.ndarray
or torch.Tensor
depending on the type of reference tensor_ops.
- Parameters
data (Union
[Tensor
, ndarray
[Any
, dtype
[+_ScalarType_co]]]) – Input tensor_ops.
reference (~TensorTypeVar) – Reference type tensor_ops.
cast (bool
) – If True
, cast data
to the type of reference
.
- Return type
~TensorTypeVar
- Returns
Input tensor converted to the reference type.
zetta_utils.tensor_ops