Python API determined.pytorch.lightning

determined.pytorch.lightning.LightningAdapter

Pytorch Lightning Adapter, defined here as LightningAdapter, provides a quick way to train your Pytorch Lightning models with all the Determined features, such as mid-epoch preemption, easy distributed training, simple job submission to the Determined cluster, and so on.

LightningAdapter is built on top of our PyTorchTrial API, which has a built-in training loop that integrates with the Determined features. However, it only supports LightningModule (v1.2.0). To migrate your code from the Trainer, please read more about PyTorchTrial and Experiment Configuration.

Porting your PyTorchLightning code is often pretty simple: 1. Bring in your LightningModule and LightningDataModule and initialize them 2. Create a new trial based on LightningAdapter and initialize it. 3. Define the dataloaders.

Here is an example:

from determined.pytorch import PyTorchTrialContext, DataLoader
from determined.pytorch.lightning import LightningAdapter

# bring in your LightningModule and optionally LightningDataModule
from mnist import LightningMNISTClassifier, MNISTDataModule


class MNISTTrial(LightningAdapter):
    def __init__(self, context: PyTorchTrialContext) -> None:
        # instantiate your LightningModule with hyperparameter from the Determined
        # config file or from the searcher for automatic hyperparameter tuning.
        lm = LightningMNISTClassifier(lr=context.get_hparam("learning_rate"))

        # instantiate your LightningDataModule and make it distributed training ready.
        data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
        self.dm = MNISTDataModule(context.get_data_config()["url"], data_dir)

        # initialize LightningAdapter.
        super().__init__(context, lightning_module=lm)
        self.dm.prepare_data()

    def build_training_data_loader(self) -> DataLoader:
        self.dm.setup()
        dl = self.dm.train_dataloader()
        return DataLoader(
            dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers
        )

    def build_validation_data_loader(self) -> DataLoader:
        self.dm.setup()
        dl = self.dm.val_dataloader()
        return DataLoader(
            dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers
        )
class determined.pytorch.lightning.LightningAdapter(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal['native'], typing_extensions.Literal['apex']] = 'native', amp_level: typing_extensions.Literal['O0', 'O1', 'O2', 'O3'] = 'O2')

Pytorch Lightning Adapter provides a quick way to train your Pytorch Lightning models with all the Determined features, such as mid-epoch preemption, simple distributed training interface, simple job submission to the Determined cluster, and so on.

__init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal['native'], typing_extensions.Literal['apex']] = 'native', amp_level: typing_extensions.Literal['O0', 'O1', 'O2', 'O3'] = 'O2')

This performs the necessary initialization steps to:

  1. check the compatibility of the provided LightningModule with LightningAdapter.

  2. define a PytorchTrial with models, optimizers, and LR schedulers that are provided by LightningModule.

  3. patch the LightningModule methods that depend on a Trainer.

After inheriting this class, you need to override this function to initialize the adapted PytorchTrial. Within your __init__ , you should instantiate the LightningModule and call super().__init__.

Here is a minimal code example.

def __init__(self, context: PyTorchTrialContext) -> None:
    lm = mnist.LightningMNISTClassifier(lr=context.get_hparam('learning_rate'))
    super().__init__(context, lightning_module=lm)
Parameters
  • context (PyTorchTrialContext) –

  • lightning_module (LightningModule) – User-defined lightning module.

  • precision (int, default=32) – Precision to use. Accepted values are 16, and 32.

  • amp_backend (str) – Automatic mixed precision backend to use. Accepted values are “native”, and “mixed”.

  • amp_level (str, optional, default="O2") – Apex amp optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. https://nvidia.github.io/apex/amp.html#opt-levels-and-properties

build_callbacks() Dict[str, determined.pytorch._callback.PyTorchCallback]

build_callbacks defines a set of necessary PyTorchTrialCallback to support lightning. Override and merge the output of this build_callbacks with your desired callbacks.

abstract build_training_data_loader() determined.pytorch._data.DataLoader

Defines the data loader to use during training.

Must return an instance of determined.pytorch.DataLoader.

If you’re using LightningDataModule this could be as simple as:

self.dm.setup()
dl = self.dm.train_dataloader()
return DataLoader(dl.dataset, batch_size=dl.batch_size,
                 num_workers=dl.num_workers)
abstract build_validation_data_loader() determined.pytorch._data.DataLoader

Defines the data loader to use during validation.

Must return an instance of determined.pytorch.DataLoader.

If you’re using LightningDataModule this could be as simple as:

self.dm.setup()
dl = self.dm.val_dataloader()
return DataLoader(dl.dataset, batch_size=dl.batch_size,
                 num_workers=dl.num_workers)
evaluation_reducer() Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]]

Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to determined.pytorch.Reducer.AVG.

get_batch_length(batch: Any) int

Count the number of records in a given batch.

Override this method when you are using custom batch types, as produced when iterating over the DataLoader. For example, when using pytorch_geometric:

# Extra imports:
from determined.pytorch import DataLoader
from torch_geometric.data.dataloader import Collater

# Trial methods:
def build_training_data_loader(self):
    return DataLoader(
        self.train_subset,
        batch_size=self.context.get_per_slot_batch_size(),
        collate_fn=Collater([], []),
    )

def get_batch_length(self, batch):
    # `batch` is `torch_geometric.data.batch.Batch`.
    return batch.num_graphs
Parameters

batch (Any) – input training or validation data batch object.

In this approach, the LightningModule is not paired with the PyTorch Lightning Trainer so that there are some methods and hooks that are not supported. Read about those here:

  • No separate test-set definition in Determined: test_step, test_step_end, test_epoch_end, on_test_batch_start, on_test_batch_end, on_test_epoch_start, on_test_epoch_end, test_dataloader.

  • No fit or pre-train stage: setup, teardown, on_fit_start, on_fit_end, on_pretrain_routine_start, on_pretrain_routine_end.

  • Additionally, no: training_step_end & validation_step_end, hiddens parameter in training_step and tbptt_split_batch, transfer_batch_to_device, get_progress_bar_dict, on_train_epoch_end, manual_backward, backward, optimizer_step, optimizer_zero_grad

In addition, we also patched some LightningModule methods to make porting your code easier:

  • log and log_dict are patched to always ship their values to Tensorboard. In the current version only the first two arguments in log: key and value, and the first argument in log_dict are supported.

Note

Make sure to return the metric you defined as searcher.metric in your experiment’s configuration from your validation_step.

Note

Determined will automatically log the metrics you return from training_step and validation_step to Tensorboard.

Data Loading

Loading your dataset when using the LightningAdapter works the same way as it does with PyTorch Trial.

If you already have a LightningDataModule you can bring it in and use it to implement build_training_data_loader and build_validation_data_loader methods easily. For more information read PyTorchTrial’s section on Data Loading.

Debugging

Please see Training: Debug.