PyTorch Lightning API¶
API reference |
---|
This document guides you through training a PyTorch Lightning model in Determined. You need to
implement a trial class that inherits LightningAdapter
and
specify it as the entrypoint in the experiment configuration.
PyTorch Lightning Adapter, defined here as LightningAdapter
, provides a quick way to train your
PyTorch Lightning models with all the Determined features, such as mid-epoch preemption, easy
distributed training, simple job submission to the Determined cluster, and so on.
LightningAdapter is built on top of our PyTorch API, which has a built-in training loop that integrates with the Determined features. However, it only supports LightningModule (v1.2.0). To migrate your code from the Trainer, please read more about PyTorch API and Experiment Configuration Reference.
Port PyTorch Lightning Code¶
Porting your PyTorchLightning
code is often pretty simple:
Bring in your
LightningModule
andLightningDataModule
and initialize themCreate a new trial based on
LightningAdapter
and initialize it.Define the dataloaders.
Here is an example:
from determined.pytorch import PyTorchTrialContext, DataLoader
from determined.pytorch.lightning import LightningAdapter
# bring in your LightningModule and optionally LightningDataModule
from mnist import LightningMNISTClassifier, MNISTDataModule
class MNISTTrial(LightningAdapter):
def __init__(self, context: PyTorchTrialContext) -> None:
# instantiate your LightningModule with hyperparameter from the Determined
# config file or from the searcher for automatic hyperparameter tuning.
lm = LightningMNISTClassifier(lr=context.get_hparam("learning_rate"))
# instantiate your LightningDataModule and make it distributed training ready.
data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
self.dm = MNISTDataModule(context.get_data_config()["url"], data_dir)
# initialize LightningAdapter.
super().__init__(context, lightning_module=lm)
self.dm.prepare_data()
def build_training_data_loader(self) -> DataLoader:
self.dm.setup()
dl = self.dm.train_dataloader()
return DataLoader(
dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers
)
def build_validation_data_loader(self) -> DataLoader:
self.dm.setup()
dl = self.dm.val_dataloader()
return DataLoader(
dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers
)
In this approach, the LightningModule
is not paired with the PyTorch Lightning Trainer
so
that there are some methods and hooks that are not supported. Read about those here:
No separate test-set definition in Determined:
test_step
,test_step_end
,test_epoch_end
,on_test_batch_start
,on_test_batch_end
,on_test_epoch_start
,on_test_epoch_end
,test_dataloader
.No fit or pre-train stage:
setup
,teardown
,on_fit_start
,on_fit_end
,on_pretrain_routine_start
,on_pretrain_routine_end
.Additionally, no:
training_step_end
&validation_step_end
,hiddens
parameter intraining_step
andtbptt_split_batch
,transfer_batch_to_device
,get_progress_bar_dict
,on_train_epoch_end
,manual_backward
,backward
,optimizer_step
,optimizer_zero_grad
In addition, we also patched some LightningModule
methods to make porting your code easier:
log
andlog_dict
are patched to always ship their values to Tensorboard. In the current version only the first two arguments inlog
:key
andvalue
, and the first argument inlog_dict
are supported.
Note
Make sure to return the metric you defined as searcher.metric
in your experiment’s
configuration from your validation_step
.
Note
Determined will automatically log the metrics you return from training_step
and
validation_step
to Tensorboard.
Note
Determined environment images no longer contain PyTorch Lightning. To use PyTorch Lightning, add
a line similar to the following in the startup-hooks.sh
script:
pip install pytorch_lightning==1.5.10 torchmetrics==0.5.1
To learn about this API, start by reading the trial definitions from the following examples:
Load Data¶
Note
Before loading data, read this document Prepare Data to understand how to work with different sources of data.
Loading your dataset when using PyTorch Lightning works the same way as it does with PyTorch API.
If you already have a LightningDataModule
you can bring it in and use it to implement
build_training_data_loader
and build_validation_data_loader
methods easily. For more
information read PyTorchTrial’s section on Data Loading.