PyTorch Lightning API Reference¶
User Guide |
---|
determined.pytorch.lightning.LightningAdapter
¶
- class determined.pytorch.lightning.LightningAdapter(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal[native], typing_extensions.Literal[apex]] = 'native', amp_level: typing_extensions.Literal[O0, O1, O2, O3] = 'O2')¶
PyTorch Lightning Adapter provides a quick way to train your PyTorch Lightning models with all the Determined features, such as mid-epoch preemption, simple distributed training interface, simple job submission to the Determined cluster, and so on.
- __init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext, lightning_module: pytorch_lightning.LightningModule, precision: Union[typing_extensions.Literal[32], typing_extensions.Literal[16]] = 32, amp_backend: Union[typing_extensions.Literal[native], typing_extensions.Literal[apex]] = 'native', amp_level: typing_extensions.Literal[O0, O1, O2, O3] = 'O2')¶
This performs the necessary initialization steps to:
check the compatibility of the provided
LightningModule
withLightningAdapter
.define a
PyTorchTrial
with models, optimizers, and LR schedulers that are provided byLightningModule
.patch the
LightningModule
methods that depend on aTrainer
.
After inheriting this class, you need to override this function to initialize the adapted
PyTorchTrial
. Within your__init__
, you should instantiate theLightningModule
and callsuper().__init__
.Here is a minimal code example.
def __init__(self, context: PyTorchTrialContext) -> None: lm = mnist.LightningMNISTClassifier(lr=context.get_hparam('learning_rate')) super().__init__(context, lightning_module=lm)
- Parameters
context (PyTorchTrialContext) –
lightning_module (
LightningModule
) – User-defined lightning module.precision (int, default=32) – Precision to use. Accepted values are 16, and 32.
amp_backend (str) – Automatic mixed precision backend to use. Accepted values are “native”, and “mixed”.
amp_level (str, optional, default="O2") – Apex amp optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
- build_callbacks() Dict[str, determined.pytorch._callback.PyTorchCallback] ¶
build_callbacks defines a set of necessary PyTorchTrialCallback to support lightning. Override and merge the output of this build_callbacks with your desired callbacks.
- abstract build_training_data_loader() determined.pytorch._data.DataLoader ¶
Defines the data loader to use during training.
Must return an instance of
determined.pytorch.DataLoader
.If you’re using
LightningDataModule
this could be as simple as:self.dm.setup() dl = self.dm.train_dataloader() return DataLoader(dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers)
- abstract build_validation_data_loader() determined.pytorch._data.DataLoader ¶
Defines the data loader to use during validation.
Must return an instance of
determined.pytorch.DataLoader
.If you’re using
LightningDataModule
this could be as simple as:self.dm.setup() dl = self.dm.val_dataloader() return DataLoader(dl.dataset, batch_size=dl.batch_size, num_workers=dl.num_workers)
- evaluation_reducer() Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]] ¶
Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to
determined.pytorch.Reducer.AVG
.
- get_batch_length(batch: Any) int ¶
Count the number of records in a given batch.
Override this method when you are using custom batch types, as produced when iterating over the DataLoader. For example, when using pytorch_geometric:
# Extra imports: from determined.pytorch import DataLoader from torch_geometric.data.dataloader import Collater # Trial methods: def build_training_data_loader(self): return DataLoader( self.train_subset, batch_size=self.context.get_per_slot_batch_size(), collate_fn=Collater([], []), ) def get_batch_length(self, batch): # `batch` is `torch_geometric.data.batch.Batch`. return batch.num_graphs
- Parameters
batch (Any) – input training or validation data batch object.