zetta_utils.tensor_ops

Handling tensor operations. Each opperation must support both np.ndarray and torch.Tensor input types.

zetta_utils.tensor_ops.convert

zetta_utils.tensor_ops.convert.to_np(data)[source]

Convert the given tensor to numpy.ndarray.

Parameters:

data (Union[Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]]) – Input tensor_ops.

Return type:

ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]

Returns:

Input tensor in numpy.ndarray format.

zetta_utils.tensor_ops.convert.to_torch(data, device=None)[source]

Convert the given tensor to torch.Tensor.

Parameters:
  • data (Union[Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]]) – Input tensor_ops.

  • device (Union[device, str, int, None]) – Device name on which the torch tensor will reside.

Return type:

Tensor

Returns:

Input tensor in torch.Tensor format.

zetta_utils.tensor_ops.convert.astype(data, reference, cast=False)[source]

Convert the given tensor to np.ndarray or torch.Tensor depending on the type of reference tensor_ops.

Parameters:
  • data (Union[Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]]) – Input tensor_ops.

  • reference (TypeVar(TensorTypeVar, Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]], Union[Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]])) – Reference type tensor_ops.

  • cast (bool) – If True, cast data to the type of reference.

Return type:

TypeVar(TensorTypeVar, Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]], Union[Tensor, ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]])

Returns:

Input tensor converted to the reference type.

zetta_utils.tensor_ops

zetta_utils.tensor_ops.interpolate(data, *args, targets=None, **kwargs)[source]