Skip to content

PyTorch

This part of the documentation describes how to train a PyTorch model in PEDL.

There are two steps needed to define a PyTorch model in PEDL using a Standard Model Definition:

  1. Define a make_data_loaders() function. See Data Loading for more information.
  2. Implement the PyTorchTrial interface.

Data Loading

Loading data into PyTorchTrial models is done by defining a make_data_loaders() function. This function should return a pair of objects (one for training and one for validation) which implements PyTorch's DataLoader interface.

Each DataLoader is expected to return batches in the form (input, target). input and target are expected to be one of the following types:

  • np.ndarray

    np.array([[0, 0], [0, 0]])
    

  • torch.Tensor

    torch.Tensor([[0, 0], [0, 0]])
    

  • tuple of np.ndarrays or torch.Tensors

    (torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]]))
    

  • list of np.ndarrays or torch.Tensors

    [torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]])]
    

  • dictionary mapping strings to np.ndarrays or torch.Tensors

    {"input1": torch.Tensor([[0, 0], [0, 0]]), "input2": torch.Tensor([[1, 1], [1, 1]])}
    

  • combination of the above

    {
        "input1": [
            {"sub_input1": torch.Tensor([[0, 0], [0, 0]])},
            {"sub_input2": torch.Tensor([0, 0])},
        ],
        "input2": (torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]])),
    }
    

PyTorchTrial Interface

PyTorch trials are created by subclassing the abstract class PyTorchTrial. Users must define the following abstract methods to create the deep learning model associated with a specific trial, and to subsequently train and evaluate it:

  • build_model(self, hparams): Defines the deep learning architecture associated with a trial, which typically depends on the trial's specific hyperparameter settings stored in the hparams dictionary. This method returns the model as an an instance or subclass of nn.Module.

    The input to the model's forward method will be the Batch.data for the Batch that was returned by this experiment's BatchLoader. For simple models, that data will often be a plain tensor, but for multi-input models, it will be in the form of a dictionary of named inputs.

    The output of the model's forward method will be fed directly into the user-defined losses, training_metrics, and validation_metrics methods as predictions.

  • losses(self, predictions, labels): Calculates loss(es) of the model. If the model only returns a single loss, the output of this method can be a scalar tensor which will be used for backpropagation. If the model reports multiple losses, this method must return a dictionary of losses which contains the special key "loss" corresponding to a scalar tensor which will be used for backpropagation. The output of this method is fed directly into the training_metrics and validation_metrics methods.

  • optimizer(self, model): Specifies an instance of torch.optim.Optimizer to be used for training the given model, e.g., torch.optim.SGD(model.parameters(), learning_rate).
  • batch_size(self): Specifies the batch size to use for training.
  • validation_metrics(self, predictions, labels, losses): Calculates and returns a dictionary mapping string names to validation metrics. Metrics may be non-scalar tensors. Results from each batch of validation data will be averaged to compute validation metrics for a given model.

Optional Methods

  • training_metrics(self, predictions, labels, losses): Calculates and returns a dictionary mapping string names to training metrics. If supplied, this method defines a set of metrics to be computed in addition to the training loss. Metrics may be non-scalar tensors.

Examples