PyTorch Lightning API Reference

determined.pytorch.lightning.LightningAdapter

class determined.pytorch.lightning.LightningAdapter(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal[native], typing_extensions.Literal[apex]] = 'native', amp_level: typing_extensions.Literal[O0, O1, O2, O3] = 'O2')

PyTorch Lightning Adapter provides a quick way to train your PyTorch Lightning models with all the Determined features, such as mid-epoch preemption, simple distributed training interface, simple job submission to the Determined cluster, and so on.

__init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal[native], typing_extensions.Literal[apex]] = 'native', amp_level: typing_extensions.Literal[O0, O1, O2, O3] = 'O2')

This performs the necessary initialization steps to:

  1. check the compatibility of the provided LightningModule with LightningAdapter.

  2. define a PyTorchTrial with models, optimizers, and LR schedulers that are provided by LightningModule.

  3. patch the LightningModule methods that depend on a Trainer.

After inheriting this class, you need to override this function to initialize the adapted PyTorchTrial. Within your __init__ , you should instantiate the LightningModule and call super().__init__.

Here is a minimal code example.

def __init__(self, context: PyTorchTrialContext) -> None:
    lm = mnist.LightningMNISTClassifier(lr=context.get_hparam('learning_rate'))
    super().__init__(context, lightning_module=lm)
Parameters
  • context (PyTorchTrialContext) –

  • lightning_module (LightningModule) – User-defined lightning module.

  • precision (int, default=32) – Precision to use. Accepted values are 16, and 32.

  • amp_backend (str) – Automatic mixed precision backend to use. Accepted values are “native”, and “mixed”.

  • amp_level (str, optional, default="O2") – Apex amp optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. https://nvidia.github.io/apex/amp.html#opt-levels-and-properties

build_callbacks() Dict[str, determined.pytorch._callback.PyTorchCallback]

build_callbacks defines a set of necessary PyTorchTrialCallback to support lightning. Override and merge the output of this build_callbacks with your desired callbacks.

abstract build_training_data_loader() determined.pytorch._data.DataLoader

Defines the data loader to use during training.

Must return an instance of determined.pytorch.DataLoader.

If you’re using LightningDataModule this could be as simple as:

self.dm.setup()
dl = self.dm.train_dataloader()
return DataLoader(dl.dataset, batch_size=dl.batch_size,
                 num_workers=dl.num_workers)
abstract build_validation_data_loader() determined.pytorch._data.DataLoader

Defines the data loader to use during validation.

Must return an instance of determined.pytorch.DataLoader.

If you’re using LightningDataModule this could be as simple as:

self.dm.setup()
dl = self.dm.val_dataloader()
return DataLoader(dl.dataset, batch_size=dl.batch_size,
                 num_workers=dl.num_workers)
evaluation_reducer() Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]]

Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to determined.pytorch.Reducer.AVG.

get_batch_length(batch: Any) int

Count the number of records in a given batch.

Override this method when you are using custom batch types, as produced when iterating over the DataLoader. For example, when using pytorch_geometric:

# Extra imports:
from determined.pytorch import DataLoader
from torch_geometric.data.dataloader import Collater

# Trial methods:
def build_training_data_loader(self):
    return DataLoader(
        self.train_subset,
        batch_size=self.context.get_per_slot_batch_size(),
        collate_fn=Collater([], []),
    )

def get_batch_length(self, batch):
    # `batch` is `torch_geometric.data.batch.Batch`.
    return batch.num_graphs
Parameters

batch (Any) – input training or validation data batch object.