zetta_utils.tensor_ops

Handling tensor operations. Each opperation must support both np.ndarray and torch.Tensor input types.

zetta_utils.tensor_ops.convert

zetta_utils.tensor_ops.convert.to_np(data)[source]

Convert the given tensor to numpy.ndarray.

Parameters

data (Union[Tensor, ndarray[Any, dtype[+_ScalarType_co]]]) – Input tensor_ops.

Return type

ndarray[Any, dtype[+_ScalarType_co]]

Returns

Input tensor in numpy.ndarray format.

zetta_utils.tensor_ops.convert.to_torch(data, device=None)[source]

Convert the given tensor to torch.Tensor.

Parameters
  • data (Union[Tensor, ndarray[Any, dtype[+_ScalarType_co]]]) – Input tensor_ops.

  • device (Union[device, str, int, None]) – Device name on which the torch tensor will reside.

Return type

Tensor

Returns

Input tensor in torch.Tensor format.

zetta_utils.tensor_ops.convert.astype(data, reference, cast=False)[source]

Convert the given tensor to np.ndarray or torch.Tensor depending on the type of reference tensor_ops.

Parameters
  • data (Union[Tensor, ndarray[Any, dtype[+_ScalarType_co]]]) – Input tensor_ops.

  • reference (~TensorTypeVar) – Reference type tensor_ops.

  • cast (bool) – If True, cast data to the type of reference.

Return type

~TensorTypeVar

Returns

Input tensor converted to the reference type.

zetta_utils.tensor_ops