zetta_utils.training

Overview

zetta_utils.training module provides standardized integrations for neural network training. It is not meant to be a heaviweight integration that dictates the way in which training must be done, but rather is a set independently useful, lightweight components. Some of the advantages of zetta_utils.training are:

  • Bells and whistles on top of Pytorch Lightning (PL). We’ve parametrized PL to include extensive checkpointing, better progress bars, Weights and Biases (wandb) integration, and so on so that you don’t have to.

  • Remote taining management becomes easy for zetta_utils.training users. Unlimited number of concurrent remote traiing runs can be started, monitored, and easily cancelled from the command line. zetta_utils can also handle more complicated cluster setups needed for DDP (WIP).

  • Integration with zetta_utils.builder provides a standard way to provide full traiing parametrization in a single spec file. Having a standard single-spec parametrization format simplifies collaboration between team members, and also gets automatically uploaded to wandb which simplifies experimetn management.

  • It is easy for zetta_utils users to share data augmentations and architectures with each other.

  • Powerful training-inference integration.

All trainings start with a call to zetta_utils.training.lightning.train.lightning_train. Other than training and validation dataloaders and checkpoint paths, this function is given a training regime and a trainer. Regime defines the specifics of how the given network is to be trained. This includes training loss calulation, validation loss calculation, any actions that need to be taken at the beginning or end of each validation epoch, etc. Regimes are usually created by the scientist performing experiments. Trainer defines training loop behavior that is commons for all regimes, such as logging, checkpointing, gradient clipping, etc. Trainer is developed and maintained by the engineering team.

You can find existing reimes in [zetta_utils/training/lightning/regimes]. You can find example training specs that use some of those regimes in [specs/examples/training]. To learn more about the spec file format, refer to zetta_utils.builder documetnation.

API reference

zetta_utils.training.lightning

zetta_utils.training.lightning.train.lightning_train(regime, trainer, train_dataloader, val_dataloader=None, full_state_ckpt_path='last', num_nodes=1, retry_count=3, local_run=True, follow_logs=False, image=None, cluster_name=None, cluster_region=None, cluster_project=None, env_vars=None, resource_limits=None, resource_requests=None, provisioning_model='spot', gpu_accelerator_type=None)[source]

Perform neural net trainig with Zetta’s PytorchLightning integration.

Parameters
  • regime (UnionType[LightningModule, dict[str, Any]]) – Training regime. Defines behavior on training, vallidation steps and epochs. Includes the model being trained as an instance variable.

  • trainer (UnionType[Trainer, dict[str, Any]]) – Pytorch Lightning Trainer object responsible for handling traing loop details that are common for all regimes, such as checkpointing behavior, logging behavior, etc. For Zetta training configuration, use zetta_utils.training.lightning.trainers.build_default_trainer.

  • train_dataloader (UnionType[DataLoader, dict[str, Any]]) – Training dataloader.

  • val_dataloader (Union[DataLoader, dict[str, Any], None]) – Validation dataloader.

  • full_state_ckpt_path (str) – Path to the training checkpoint to resume from. Must be a full training state checkpoint created by PytorchLightning rather than a model checkpoint. If full_state_ckpt_path=="last", the latest checkpoint for the given experiment will be identified and loaded.

  • num_nodes (int) – Number of GPU nodes for distributed training.

  • retry_count (int) – Max retry count for the master train job; excludes failures due to pod distruptions.

  • local_run (bool) – If True run the training locally.

  • follow_logs (bool) – If True, eagerly print logs from the pod. If False, will wait until job completes successfully.

  • image (Optional[str]) – Container image to use.

  • cluster_name (Optional[str]) – Cluster configuration.

  • cluster_region (Optional[str]) – Cluster configuration.

  • cluster_project (Optional[str]) – Cluster configuration.

  • env_vars (Optional[Dict[str, str]]) – Custom env variables to be set on pods.

  • resource_limits (Optional[dict[str, UnionType[int, float, str]]]) – K8s reource limits per pod.

  • resource_requests (Optional[dict[str, UnionType[int, float, str]]]) – K8s resource requests per pod.

  • provisioning_model (Literal[‘standard’, ‘spot’]) – VM provision type to use for worker pods.

  • gpu_accelerator_type (UnionType[str, None]) – Schedule on nodes with given gpu type. Eg., “nvidia-tesla-t4”. gcloud compute accelerator-types list.

Return type

None

zetta_utils.training.datasets

class zetta_utils.training.datasets.LayerDataset(layer, sample_indexer)[source]

PyTorch dataset wrapper around zetta_utils.layer.Layer component.

Parameters
  • layer (Layer) – Layer which will be used as a source of data.

  • sample_indexer (SampleIndexer) – Indexer which will be used to translate integer sample index to a corresponding index understood by the layer backend.

class zetta_utils.training.datasets.JointDataset(mode, datasets)[source]

PyTorch dataset wrapper to allow using multiple torch.utils.data.Dataset datasets simultaneously.

Parameters
  • mode (Literal[‘horizontal’, ‘vertical’]) – String indicating whether the dataset is horizontally or vertically joined. horizontal means that the LayerDatasets will be sampled all at once and returned in a dictionary. vertical means that the LayerDatasets will be sampled one after the other in the order given during initialization.

  • datasets (Dict[str, Any]) – Dictionary containing the datasets that make up the JointDataset.

class zetta_utils.training.datasets.sample_indexers.ChainIndexer(inner_indexer)[source]

Iterates over a sequence of inner indexers.

Parameters

inner_indexer (Sequence[SampleIndexer]) – Sequence of inner indexers.

class zetta_utils.training.datasets.sample_indexers.RandomIndexer(inner_indexer, replacement=False)[source]

Indexer randomizes the order at which inner_indexer samples are pulled.

Parameters
  • indexer – SampleIndexer to be randomized.

  • replacement (bool) – Samples are drawn on-demand with replacement if True.

class zetta_utils.training.datasets.sample_indexers.VolumetricStridedIndexer(bbox, chunk_size, stride, resolution, mode='expand')[source]

SampleIndexer which takes chunks from a volumetric region at uniform intervals.

Parameters
  • bbox (BBox3D) – Bounding cube representing the whole volume to be indexed.

  • resolution (Sequence[float]) – Resolution at which chunk_size is given and which to specify in the resulting `VolumetricIndex`es.

  • chunk_size (Sequence[int]) – Size of a training chunk.

  • stride (Sequence[int]) – Distance between neighboring chunks along each dimension.

  • mode (Literal[‘expand’, ‘shrink’]) – Behavior when bbox cannot be divided evenly.

zetta_utils.training.datasets.sample_indexers.VolumetricStridedIndexer.__call__(self, idx)

Translate a chunk index to a volumetric region in space.

Parameters

idx (int) – chunk index.

Return type

VolumetricIndex

Returns

VolumetricIndex.