Source code for zetta_utils.tensor_ops.common

# pylint: disable=missing-docstring
from typing import (
    Any,
    Callable,
    Container,
    Generic,
    Literal,
    Mapping,
    Optional,
    Sequence,
    SupportsIndex,
    TypeVar,
    Union,
    overload,
)

import attrs
import einops
import numpy as np
import tinybrain
import torch
from numpy import typing as npt
from typeguard import typechecked
from typing_extensions import Concatenate, ParamSpec

from zetta_utils import builder, tensor_ops
from zetta_utils.tensor_typing import Tensor, TensorTypeVar

P = ParamSpec("P")
T = TypeVar("T")


@attrs.frozen
class DictSupportingTensorOp(Generic[P]):
    fn: Callable  # [Concatenate[TensorTypeVar, P], TensorTypeVar]

    @overload
    def __call__(
        self,
        data: TensorTypeVar,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> TensorTypeVar:
        ...

    @overload
    def __call__(
        self,
        data: Mapping[Any, TensorTypeVar],
        *args: P.args,
        targets: Container[str] | None = None,
        **kwargs: P.kwargs,
    ) -> dict[Any, TensorTypeVar]:
        ...

    @typechecked
    def __call__(
        self, data, *args: P.args, targets: Container[str] | None = None, **kwargs: P.kwargs
    ):
        if isinstance(data, Mapping):
            new_data = {}
            for k in data.keys():
                if targets is None or k in targets:
                    new_data[k] = self.fn(data[k], *args, **kwargs)
                else:
                    new_data[k] = data[k]
            return new_data
        else:
            if targets is not None:
                raise RuntimeError("`data` must be a Mapping when `targets` is specified.")
            return self.fn(data, *args, **kwargs)


@overload
def supports_dict(
    fn: Callable[Concatenate[npt.NDArray, P], npt.NDArray]
) -> DictSupportingTensorOp[P]:
    ...


@overload
def supports_dict(
    fn: Callable[Concatenate[torch.Tensor, P], torch.Tensor]
) -> DictSupportingTensorOp[P]:
    ...


def supports_dict(fn):
    return DictSupportingTensorOp(fn)


@builder.register("rearrange")
@supports_dict
def rearrange(data: TensorTypeVar, pattern: str, **kwargs) -> TensorTypeVar:  # pragma: no cover
    return einops.rearrange(  # type: ignore # bad typing by einops
        tensor=data, pattern=pattern, **kwargs
    )


@builder.register("reduce")
@supports_dict
def reduce(data: TensorTypeVar, **kwargs) -> TensorTypeVar:  # pragma: no cover
    return einops.reduce(tensor=data, **kwargs)  # type: ignore # bad typing by einops


@builder.register("repeat")
@supports_dict
def repeat(data: TensorTypeVar, **kwargs) -> TensorTypeVar:  # pragma: no cover
    return einops.repeat(tensor=data, **kwargs)  # type: ignore # bad typing by einops


@builder.register("multiply")
@supports_dict
def multiply(data: TensorTypeVar, value) -> TensorTypeVar:  # pragma: no cover
    return value * data


@builder.register("add")
@supports_dict
def add(data: TensorTypeVar, value) -> TensorTypeVar:  # pragma: no cover
    return value + data


@builder.register("power")
@supports_dict
def power(data: TensorTypeVar, value) -> TensorTypeVar:  # pragma: no cover
    return data ** value


@builder.register("divide")
@supports_dict
def divide(data: TensorTypeVar, value) -> TensorTypeVar:  # pragma: no cover
    return data / value


@builder.register("int_divide")
@supports_dict
def int_divide(data: TensorTypeVar, value) -> TensorTypeVar:  # pragma: no cover
    return data // value


@builder.register("unsqueeze")
@supports_dict
@typechecked
def unsqueeze(
    data: TensorTypeVar, dim: Union[SupportsIndex, Sequence[SupportsIndex]] = 0
) -> TensorTypeVar:
    """
    Returns a new tensor with new dimensions of size one inserted at the specified positions.

    :param data: the input tensor.
    :param dim: indexes at which to insert new dimensions.
    :return: tensor with added dimensions.
    """
    if isinstance(data, torch.Tensor):
        if isinstance(dim, int):
            result = data.unsqueeze(dim)  # type: TensorTypeVar
        else:
            raise ValueError(f"Cannot use `torch.unsqueeze` with dim of type '{type(dim)}'")
    else:
        assert isinstance(data, np.ndarray), "Type checking failure"
        result = np.expand_dims(data, dim)

    return result


@builder.register("squeeze")
@supports_dict
@typechecked
def squeeze(
    data: TensorTypeVar,
    dim: Optional[Union[SupportsIndex, Sequence[SupportsIndex]]] = None,
) -> TensorTypeVar:
    """
    Returns a tensor with all the dimensions of input of size 1 removed.
    When dim is given, a squeeze operation is done only for the given dimensions.

    :param data: the input tensor.
    :param dim:  if given, the input will be squeezed only in these dimensions.
    :return: tensor with squeezed dimensions.
    """

    if isinstance(data, torch.Tensor):
        if isinstance(dim, int) or dim is None:
            result = data.squeeze(dim)
        else:
            raise ValueError(f"Cannot use `torch.squeeze` with dim of type '{type(dim)}'")
    else:
        assert isinstance(data, np.ndarray)
        result = data.squeeze(axis=dim)  # type: ignore # mypy thinkgs None is not ok, but it is

    return result


TorchInterpolationMode = Literal[
    "nearest",
    "nearest-exact",
    "linear",
    "bilinear",
    "bicubic",
    "trilinear",
    "area",
]
CustomInterpolationMode = Literal[
    "img",
    "field",
    "mask",
    "segmentation",
]
InterpolationMode = Union[TorchInterpolationMode, CustomInterpolationMode]


def _standardize_scale_factor(
    data_ndim: int,
    scale_factor: Optional[Union[float, Sequence[float]]] = None,
) -> Optional[Sequence[float]]:
    if scale_factor is None:
        result = None
    else:
        data_space_ndim = data_ndim - 2  # Batch + Channel
        if isinstance(scale_factor, (float, int)):
            result = (scale_factor,) * data_space_ndim
        else:
            result = tuple(scale_factor)
            while len(result) < data_space_ndim:
                result = (1.0,) + result

    return result


def _get_torch_interp_mode(
    spatial_ndim: int,
    mode: InterpolationMode,
) -> TorchInterpolationMode:
    if mode in ("img", "field"):
        if spatial_ndim == 3:
            torch_interp_mode = "trilinear"  # type: TorchInterpolationMode
        elif spatial_ndim == 2:
            torch_interp_mode = "bilinear"
        else:
            assert spatial_ndim == 1, "Setting validation error."

            torch_interp_mode = "linear"
    elif mode == "mask":
        torch_interp_mode = "area"
    elif mode == "segmentation":
        torch_interp_mode = "nearest-exact"
    else:
        torch_interp_mode = mode  # type: ignore # has to fit at this point

    return torch_interp_mode


def _get_nearest_interp_compatible_view(data: torch.Tensor):
    # Bypass F.interpolate dtype restrictions for Nearest-Neighbor interpolation.
    # This works because NN interpolation is only index magic, values are not used.
    # See https://github.com/pytorch/pytorch/issues/5580
    mapping = {
        torch.bool: torch.uint8,
        torch.int8: torch.uint8,
        torch.int16: torch.float16,
        torch.int32: torch.float32,
        torch.int64: torch.float64,
    }

    if data.dtype in mapping:
        return data.view(dtype=mapping[data.dtype])
    return data


def _validate_interpolation_setting(
    data: Tensor,
    size: Optional[Sequence[int]],
    scale_factor_tuple: Optional[Sequence[float]],
    allow_slice_rounding: bool,
):
    # Torch checks for some of these, but we need to check preemptively
    # as some of our pre-processing code assumes a valid setting.

    if data.ndim > 5:
        raise ValueError(f"float of dimensions must be <= 5. Got: {data.ndim}")

    if scale_factor_tuple is None and size is None:
        raise ValueError("Neither `size` nor `scale_factor` provided to `interpolate()`")
    if scale_factor_tuple is not None and size is not None:
        raise ValueError(
            "Both `size` and `scale_factor` provided to `interpolate()`. "
            "Exactly one of them must be provided."
        )

    spatial_ndim = data.ndim - 2
    if size is not None:
        if len(size) != spatial_ndim:
            raise ValueError(
                "`len(size)` must be equal to `data.ndim - 2`. "
                f"Got `len(size)` == {len(size)},  `data.ndim` == {data.ndim}."
            )

    if scale_factor_tuple is not None:
        if not allow_slice_rounding:
            result_spatial_shape = [
                data.shape[2 + i] * scale_factor_tuple[i] for i in range(spatial_ndim)
            ]
            for i in range(spatial_ndim):
                if round(result_spatial_shape[i]) != result_spatial_shape[i]:
                    raise RuntimeError(
                        f"Interpolation of array with shape {data.shape} and scale "
                        f"factor {scale_factor_tuple} would result in a non-integer shape "
                        f"along spatial dimention {i} "
                        f"({data.shape[2 + i]} -> {result_spatial_shape[i]}) while "
                        "`allow_slice_rounding` == False ."
                    )


@builder.register("unsqueeze_to")
@supports_dict
@typechecked
def unsqueeze_to(
    data: TensorTypeVar,
    ndim: Optional[int],
):
    """
    Returns a new tensor with new dimensions of size one inserted at poisition 0 until
    the tensor reaches the given number of dimensions.

    :param data: the input tensor.
    :param ndim: target number of dimensions.
    :return: tensor with added dimensions.
    """

    if ndim is not None:
        while data.ndim < ndim:
            data = unsqueeze(data, 0)
    return data


@builder.register("squeeze_to")
@supports_dict
@typechecked
def squeeze_to(
    data: TensorTypeVar,
    ndim: Optional[int],
):
    """
    Returns a new tensor with the dimension at position 0 squeezed until
    the tensor reaches the given number of dimensions.

    :param data: the input tensor.
    :param ndim: target number of dimensions.
    :return: tensor with squeezed dimensions.
    """
    if ndim is not None:
        while data.ndim > ndim:
            if data.shape[0] != 1:
                raise RuntimeError(
                    f"Not able to squeeze tensor with shape=={data.shape} to "
                    f"ndim=={ndim}: shape[0] != 1"
                )
            data = squeeze(data, 0)

    return data


[docs]@builder.register("interpolate") @supports_dict @typechecked def interpolate( # pylint: disable=too-many-locals data: TensorTypeVar, size: Optional[Sequence[int]] = None, scale_factor: Optional[Union[float, Sequence[float]]] = None, mode: InterpolationMode = "img", mask_value_thr: float = 0, allow_slice_rounding: bool = False, unsqueeze_input_to: Optional[int] = 5, ) -> TensorTypeVar: """ Interpolate the given tensor to the given ``size`` or by the given ``scale_factor``. :param data: Input tensor. :param size: Desired result shape. :param scale_factor: Interpolation scale factor. When provided as ``float``, applied to all spatial dimensions of the data. :param mode: Algorithm according to which the tensor should be interpolated. :param mask_value_thr: When ``mode == 'mask'``, threshold above which the interpolated value will be considered as ``True``. :param allow_slice_rounding: Whether to allow interpolation with scale factors that result in non-integer tensor shapes. :param unsqueeze_input_to: If provided, the tensor will be unsqueezed to the given number of dimensions before interpolating. New dimensions are alwyas added to the front (dim 0). Result is squeezed back to the original number of dimensions before returning. IMPORTANT: The unsqueezed number of dimension must be in B C X (Y? Z?) format. :return: Interpolated tensor of the same type as the input tensor_ops. """ original_ndim = data.ndim # data is assumed to be unsqueezed to have a batch and channel dimensions data = unsqueeze_to(data, unsqueeze_input_to) scale_factor_tuple = _standardize_scale_factor( data_ndim=data.ndim, scale_factor=scale_factor, ) _validate_interpolation_setting( data=data, size=size, scale_factor_tuple=scale_factor_tuple, allow_slice_rounding=allow_slice_rounding, ) if mode in ("segmentation", "img", "bilinear", "linear", "trilinear") and ( scale_factor_tuple is not None and ( tuple(scale_factor_tuple) in ( [(0.5 ** i, 0.5 ** i) for i in range(1, 5)] # 2D factors of 2 + [(0.5 ** i, 0.5 ** i, 1) for i in range(1, 5)] + [(0.5 ** i, 0.5 ** i, 0.5 ** i) for i in range(1, 5)] # #D factors of 2 ) ) and data.shape[0] == 1 ): # use tinybrain result_raw = _interpolate_with_tinybrain( data=data, scale_factor_tuple=scale_factor_tuple, is_segmentation=(mode == "segmentation"), ) else: result_raw = _interpolate_with_torch( data=data, scale_factor_tuple=scale_factor_tuple, size=size, mode=mode, mask_value_thr=mask_value_thr, ) result_final = squeeze_to(result_raw, original_ndim) return result_final
def _interpolate_with_torch( data: TensorTypeVar, scale_factor_tuple: Sequence[float] | None, size: Optional[Sequence[int]], mode: InterpolationMode, mask_value_thr: float, ) -> TensorTypeVar: torch_interp_mode = _get_torch_interp_mode( spatial_ndim=data.ndim - 2, mode=mode, ) data_torch = tensor_ops.convert.to_torch(data) result_raw: torch.Tensor if torch_interp_mode in ("nearest", "nearest-exact"): data_in = _get_nearest_interp_compatible_view(data_torch) result_raw = torch.nn.functional.interpolate( data_in, size=size, scale_factor=scale_factor_tuple, mode=torch_interp_mode, ).view(dtype=data_torch.dtype) else: data_in = data_torch.float() result_raw = torch.nn.functional.interpolate( data_in, size=size, scale_factor=scale_factor_tuple, mode=torch_interp_mode, ) if mode == "field": if scale_factor_tuple is None: raise NotImplementedError( # pragma: no cover "`size`-based field interpolation is not currently supported." ) field_dim = result_raw.shape[1] if field_dim == 2: if scale_factor_tuple[0] != scale_factor_tuple[1]: raise NotImplementedError( # pragma: no cover f"Non-isotropic 2D field interpolation is not supported. " f"X scale factor: {scale_factor_tuple[0]} " f"y scale factor: {scale_factor_tuple[1]} " ) multiplier = scale_factor_tuple[0] else: raise NotImplementedError("Only 2D field interpolation is currently supported") # All of the field dimensions if all(e == scale_factor_tuple[0] for e in scale_factor_tuple[:field_dim]): multiplier = scale_factor_tuple[0] else: raise NotImplementedError( # pragma: no cover f"Non-isotropic field interpolation (scale_factor={scale_factor_tuple}) " "is not currently supported." ) result_raw *= multiplier elif mode == "mask": result_raw = result_raw > mask_value_thr result_raw = result_raw.to(data_torch.dtype) result = tensor_ops.convert.astype(result_raw, data, cast=True) return result def _interpolate_with_tinybrain( data: TensorTypeVar, scale_factor_tuple: Sequence[float], is_segmentation: bool ) -> TensorTypeVar: """ Interpolate the given segmentation tensor by the given ``scale_factor_tuple`` using the algorithm implemented ``tinybrain``. :param data: Input tensor with batch and channel dimensions (B C X Y Z?). :param scale_factor_tuple: Interpolation scale factors for each spatial dim. """ assert all(e <= 1 for e in scale_factor_tuple) assert data.shape[0] == 1 data_np = tensor_ops.convert.to_np(data) data_np = data_np.squeeze(0) # cut the B dim data_np = np.moveaxis(data_np, 0, -1) # put C dim to the end for tinybrain if is_segmentation: result_raw = tinybrain.downsample_segmentation( img=data_np, factor=[1.0 / e for e in scale_factor_tuple] )[0] else: result_raw = tinybrain.downsample_with_averaging( img=data_np, factor=[1.0 / e for e in scale_factor_tuple] )[0] result_raw = np.moveaxis(result_raw, -1, 0) # put C dim to front again result_raw = result_raw[np.newaxis, ...] # add the B dim result_final = tensor_ops.convert.astype(result_raw, data) return result_final CompareMode = Literal[ "eq", "==", "neq", "!=", "gt", ">", "gte", ">=", "lt", "<", "lte", "<=", ] @builder.register("compare") @supports_dict @typechecked def compare( data: TensorTypeVar, mode: CompareMode, value: float, binarize: bool = True, fill: Optional[float] = None, ) -> TensorTypeVar: """ Compare the given the given tensor to the given value. If `binarize` is set to `True`, return the binary outcome of the comparison. If `fill` is set, return a new tensor in which the values that pass the comparison are replaced by `fill`. Only one of `binarize=True` or `fill` can be set. :param data: the input tensor. :param mode: the mode of comparison. :param value: comparison operand. :param binarize: when set to `True`, will return a binary result of comparison, and `fill` must be `None`. :param fill: when set, will return a new tensor with values that pass the comparison replaced by `fill`, and `binarize` must be `False`. """ if mode in ["eq", "=="]: mask = data == value elif mode in ["neq", "!="]: mask = data != value elif mode in ["gt", ">"]: mask = data > value elif mode in ["gte", ">="]: mask = data >= value elif mode in ["lt", "<"]: mask = data < value elif mode in ["lte", "<="]: mask = data <= value else: assert False, "Type checker failure." # pragma: no cover if binarize: if fill is not None: raise ValueError("`fill` must be set to None when `binarize` == True") result = mask else: if fill is None: raise ValueError( "`fill` must be set to a floating point value when `binarize` == False" ) result = data result[mask] = fill return result @builder.register("crop") @supports_dict @typechecked def crop( data: TensorTypeVar, crop: Sequence[int], # pylint: disable=redefined-outer-name # mode: Literal["center"] = "center", ) -> TensorTypeVar: """ Crop a multidimensional tensor. :param data: Input tensor. :param crop: float from pixels to crop from each side. The last integer will correspond to the last dimension, and count will go from there right to left. """ slices = [slice(0, None) for _ in range(data.ndim - len(crop))] for e in crop: assert e >= 0 if e != 0: slices.append(slice(e, -e)) else: slices.append(slice(0, None)) result = data[tuple(slices)] return result @builder.register("crop_center") @supports_dict @typechecked def crop_center( data: TensorTypeVar, size: Sequence[int], # pylint: disable=redefined-outer-name ) -> TensorTypeVar: """ Crop a multidimensional tensor to the center. :param data: Input tensor. :param size: Size of the crop box. The last integer will correspond to the last dimension, and count will go from there right to left. """ ndim = len(size) slices = [slice(0, None) for _ in range(data.ndim - ndim)] for insz, outsz in zip(data.shape[-ndim:], size): if isinstance(insz, torch.Tensor): # pragma: no cover # only occurs for JIT insz = insz.item() assert insz >= outsz lcrop = (insz - outsz) // 2 rcrop = (insz - outsz) - lcrop if rcrop != 0: slices.append(slice(lcrop, -rcrop)) else: assert lcrop == 0 slices.append(slice(0, None)) result = data[tuple(slices)] return result @builder.register("clone") @typechecked def clone( data: TensorTypeVar, ) -> TensorTypeVar: # pragma: no cover; delegation if isinstance(data, torch.Tensor): return data.clone() else: return data.copy() @builder.register("tensor_op_chain") @supports_dict @typechecked def tensor_op_chain( data: TensorTypeVar, steps: Sequence[Callable[[TensorTypeVar], TensorTypeVar]] ) -> TensorTypeVar: # pragma: no cover result = data for step in steps: result = step(result) return result @builder.register("abs") @supports_dict @typechecked def abs( # pylint: disable=redefined-builtin data: TensorTypeVar, ) -> TensorTypeVar: # pragma: no cover if isinstance(data, torch.Tensor): return data.abs() else: return np.abs(data) @builder.register("pad_center_to") @supports_dict @typechecked def pad_center_to( data: TensorTypeVar, shape: Sequence[int], mode: Literal["constant", "reflect", "replicate", "circular"] = "constant", value: float | None = None, ) -> TensorTypeVar: """ Pad data to the given shape. :param data: Input tensor :param shape: Shape to pad input to. Must be bigger than data.shape. :param mode: torch.nn.functional.pad's padding kwarg. :param value: torch.nn.functional.pad's value kwarg. :return: Padded tensor """ ndim = len(shape) pad = [] for insz, outsz in zip(data.shape[-ndim:], shape): if isinstance(insz, torch.Tensor): # pragma: no cover # only occurs for JIT insz = insz.item() assert outsz >= insz lpad = (outsz - insz) // 2 rpad = (outsz - insz) - lpad pad.extend([lpad, rpad]) # TODO: maybe use numpy instead - torch supports only float data currently data_torch = tensor_ops.convert.to_torch(data) result = torch.nn.functional.pad(data_torch, tuple(reversed(pad)), mode=mode, value=value) return tensor_ops.convert.astype(result, data)