Source code for zetta_utils.training.datasets.sample_indexers.chain_indexer
import bisect
from itertools import accumulate
from typing import Any, Sequence
import attrs
from typeguard import typechecked
from zetta_utils import builder
from .base import SampleIndexer
[docs]@builder.register("ChainIndexer")
@typechecked
@attrs.frozen
class ChainIndexer(SampleIndexer):
"""
Iterates over a sequence of inner indexers.
:param inner_indexer: Sequence of inner indexers.
"""
inner_indexer: Sequence[SampleIndexer]
num_samples: list[int] = attrs.field(init=False)
def __attrs_post_init__(self):
# Use `__setattr__` to keep the object frozen.
num_samples = [0] + list(accumulate(len(indexer) for indexer in self.inner_indexer))
object.__setattr__(self, "num_samples", num_samples)
def __len__(self):
return self.num_samples[-1]
def __call__(self, idx: int) -> Any:
"""Yield a sample index from an indexer given a index.
:param idx: Integer sample index.
:return: Index of the type used by the wrapped inner indexer.
"""
if idx not in range(0, len(self)):
raise ValueError(f"idx expected to be in range [0, {len(self)}), but got {idx}.")
inner_indexer = bisect.bisect_right(self.num_samples, idx) - 1
inner_index = idx - self.num_samples[inner_indexer]
return self.inner_indexer[inner_indexer](inner_index)