PyTorch Lightning API¶
Overview¶
This document guides you through training a PyTorch Lightning model in Determined. You need to
implement a trial class that inherits LightningAdapter
and
specify it as the entrypoint in the experiment configuration.
PyTorch Lightning Adapter, defined here as LightningAdapter
, provides a quick way to train your
PyTorch Lightning models with all the Determined features, such as mid-epoch preemption, easy
distributed training, simple job submission to the Determined cluster, and so on.
LightningAdapter is built on top of our PyTorch API, which has a built-in training loop that integrates with the Determined features. However, it only supports LightningModule (v1.2.0). To migrate your code from the Trainer, please read more about PyTorch API and Experiment Configuration.
Porting your PyTorchLightning
code is often pretty simple:
Bring in your
LightningModule
andLightningDataModule
and initialize themCreate a new trial based on
LightningAdapter
and initialize it.Define the dataloaders.
Here is an example:
from determined.pytorch import PyTorchTrialContext, DataLoader
from determined.pytorch.lightning import LightningAdapter
# bring in your LightningModule and optionally LightningDataModule
from mnist import LightningMNISTClassifier, MNISTDataModule
class MNISTTrial(LightningAdapter):
def __init__(self, context: PyTorchTrialContext) -> None:
# instantiate your LightningModule with hyperparameter from the Determined
# config file or from the searcher for automatic hyperparameter tuning.
lm = LightningMNISTClassifier(lr=context.get_hparam("learning_rate"))
# instantiate your LightningDataModule and make it distributed training ready.
data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
self.dm = MNISTDataModule(context.get_data_config()["url"], data_dir)
# initialize LightningAdapter.
super().__init__(context, lightning_module=lm)
self.dm.prepare_data()
def build_training_data_loader(self) -> DataLoader:
self.dm.setup()
dl = self.dm.train_dataloader()
return DataLoader(
dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers
)
def build_validation_data_loader(self) -> DataLoader:
self.dm.setup()
dl = self.dm.val_dataloader()
return DataLoader(
dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers
)
In this approach, the LightningModule
is not paired with the PyTorch Lightning Trainer
so
that there are some methods and hooks that are not supported. Read about those here:
No separate test-set definition in Determined:
test_step
,test_step_end
,test_epoch_end
,on_test_batch_start
,on_test_batch_end
,on_test_epoch_start
,on_test_epoch_end
,test_dataloader
.No fit or pre-train stage:
setup
,teardown
,on_fit_start
,on_fit_end
,on_pretrain_routine_start
,on_pretrain_routine_end
.Additionally, no:
training_step_end
&validation_step_end
,hiddens
parameter intraining_step
andtbptt_split_batch
,transfer_batch_to_device
,get_progress_bar_dict
,on_train_epoch_end
,manual_backward
,backward
,optimizer_step
,optimizer_zero_grad
In addition, we also patched some LightningModule
methods to make porting your code easier:
log
andlog_dict
are patched to always ship their values to Tensorboard. In the current version only the first two arguments inlog
:key
andvalue
, and the first argument inlog_dict
are supported.
Note
Make sure to return the metric you defined as searcher.metric
in your experiment’s
configuration from your validation_step
.
Note
Determined will automatically log the metrics you return from training_step
and
validation_step
to Tensorboard.
To learn about this API, you can start by reading the trial definitions from the following examples:
Loading Data¶
Note
Before loading data, read this document Prepare Data to understand how to work with different sources of data.
Loading your dataset when using the PyTorch Lightning API works the same way as it does with PyTorch API.
If you already have a LightningDataModule
you can bring it in and use it to implement
build_training_data_loader
and build_validation_data_loader
methods easily. For more
information read PyTorchTrial’s section on Data Loading.
API Reference¶
determined.pytorch.lightning.LightningAdapter
¶
- class determined.pytorch.lightning.LightningAdapter(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal[native], typing_extensions.Literal[apex]] = 'native', amp_level: typing_extensions.Literal[O0, O1, O2, O3] = 'O2')¶
PyTorch Lightning Adapter provides a quick way to train your PyTorch Lightning models with all the Determined features, such as mid-epoch preemption, simple distributed training interface, simple job submission to the Determined cluster, and so on.
- __init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal[native], typing_extensions.Literal[apex]] = 'native', amp_level: typing_extensions.Literal[O0, O1, O2, O3] = 'O2')¶
This performs the necessary initialization steps to:
check the compatibility of the provided
LightningModule
withLightningAdapter
.define a
PyTorchTrial
with models, optimizers, and LR schedulers that are provided byLightningModule
.patch the
LightningModule
methods that depend on aTrainer
.
After inheriting this class, you need to override this function to initialize the adapted
PyTorchTrial
. Within your__init__
, you should instantiate theLightningModule
and callsuper().__init__
.Here is a minimal code example.
def __init__(self, context: PyTorchTrialContext) -> None: lm = mnist.LightningMNISTClassifier(lr=context.get_hparam('learning_rate')) super().__init__(context, lightning_module=lm)
- Parameters
context (PyTorchTrialContext) –
lightning_module (
LightningModule
) – User-defined lightning module.precision (int, default=32) – Precision to use. Accepted values are 16, and 32.
amp_backend (str) – Automatic mixed precision backend to use. Accepted values are “native”, and “mixed”.
amp_level (str, optional, default="O2") – Apex amp optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
- build_callbacks() Dict[str, determined.pytorch._callback.PyTorchCallback] ¶
build_callbacks defines a set of necessary PyTorchTrialCallback to support lightning. Override and merge the output of this build_callbacks with your desired callbacks.
- abstract build_training_data_loader() determined.pytorch._data.DataLoader ¶
Defines the data loader to use during training.
Must return an instance of
determined.pytorch.DataLoader
.If you’re using
LightningDataModule
this could be as simple as:self.dm.setup() dl = self.dm.train_dataloader() return DataLoader(dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers)
- abstract build_validation_data_loader() determined.pytorch._data.DataLoader ¶
Defines the data loader to use during validation.
Must return an instance of
determined.pytorch.DataLoader
.If you’re using
LightningDataModule
this could be as simple as:self.dm.setup() dl = self.dm.val_dataloader() return DataLoader(dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers)
- evaluation_reducer() Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]] ¶
Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to
determined.pytorch.Reducer.AVG
.
- get_batch_length(batch: Any) int ¶
Count the number of records in a given batch.
Override this method when you are using custom batch types, as produced when iterating over the DataLoader. For example, when using pytorch_geometric:
# Extra imports: from determined.pytorch import DataLoader from torch_geometric.data.dataloader import Collater # Trial methods: def build_training_data_loader(self): return DataLoader( self.train_subset, batch_size=self.context.get_per_slot_batch_size(), collate_fn=Collater([], []), ) def get_batch_length(self, batch): # `batch` is `torch_geometric.data.batch.Batch`. return batch.num_graphs
- Parameters
batch (Any) – input training or validation data batch object.