Shortcuts

Using Checkpoints

Determined provides APIs for downloading checkpoints and loading them into memory in a Python process.

This guide discusses:

  1. Querying model checkpoints from trials and experiments.

  2. Loading model checkpoints in Python.

  3. Storing additional user-defined metadata in a checkpoint.

  4. Using the Determined CLI to download checkpoints to disk.

Querying Checkpoints

The ExperimentReference class is a reference to an experiment. It is retrieved via the Determined class. The reference contains the top_checkpoint() method. Without arguments, the method will check the experiment configuration searcher field for the metric and smaller_is_better values. These values are used to sort the experiment’s checkpoints by validation performance. The searcher settings in the following snippet from an experiment configuration file will result in checkpoints being sorted by the loss metric in ascending order.

searcher:
  metric: "loss"
  smaller_is_better: true

The following snippet of Python code can be run after the specified experiment has generated a checkpoint. It returns an instance of Checkpoint representing the checkpoint that has the best validation metric.

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()

Checkpoints can be sorted by any metric using the sort_by keyword argument, which defines which metric to use, and smaller_is_better, which defines whether to sort the checkpoints in ascending or descending order with respect to the specified metric.

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint(sort_by="accuracy", smaller_is_better=False)

You may also query multiple checkpoints at the same time using the top_n_checkpoints() method. Only the single best checkpoint from each trial is considered; out of those, the checkpoints with the best validation metric values are returned in sorted order, with the best one first. For example, the following snippet returns the top five checkpoints from distinct trials of a specified experiment.

from determined.experimental import Determined

checkpoints = Determined.get_experiment(id).top_n_checkpoints(5)

This method also accepts sort_by and smaller_is_better arguments.

TrialReference is used for fine-grained control over checkpoint selection within a trial. It contains a top_checkpoint() method, which mirrors top_checkpoint() for an experiment. It also contains select_checkpoint(), which offers three ways to query checkpoints:

  1. best: Returns the best checkpoint based on validation metrics as discussed above. When using best, smaller_is_better and sort_by are also accepted.

  2. latest: Returns the most recent checkpoint for the trial.

  3. uuid: Returns the checkpoint with the specified UUID.

The following snippet showcases how to use the different modes for selecting checkpoints.

from determined.experimental import Determined

trial = Determined.get_trial(id)

best_checkpoint = trial.top_checkpoint()

most_accurate_checkpoint = trial.select_checkpoint(
    best=True,
    sort_by="accuracy",
    smaller_is_better=False
)

most_recent_checkpoint = trial.select_checkpoint(latest=True)

specific_checkpoint = trial.select_checkpoint(uuid="uuid-for-checkpoint")

Using the Checkpoint Class

The Checkpoint class can both download the checkpoint from persistent storage and load it into memory in a Python process.

The download() method downloads a checkpoint from persistent storage to a directory on the local file system. By default, checkpoints are downloaded to checkpoints/<checkpoint-uuid>/ (relative to the current working directory). The download() method accepts path as an optional parameter, which changes the checkpoint download location.

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint_path = checkpoint.download()

specific_path = checkpoint.download(path="specific-checkpoint-path")

The load() method downloads the checkpoint, if it does not already exist locally, and then loads it into memory in a Python process, as shown in the following snippet.

from determined.experimental import Determined

checkpoint = Determined.get_experiment(id).top_checkpoint()
model = checkpoint.load()

TensorFlow checkpoints are saved in either the saved_model or h5 formats and are loaded as trackable objects (see documentation for tf.compat.v1.saved_model.load_v2 for details).

PyTorch checkpoints are saved using pickle and loaded as determined.pytorch.PyTorchTrial objects (see the PyTorch documentation for details).

User-Defined Checkpoint Metadata

You can add arbitrary user-defined metadata to a checkpoint via the Python API. This feature is useful for storing post-training metrics, labels, information related to deployment, etc.

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint.add_metadata({"environment": "production"})

# Metadata will be stored in Determined and accessible on the checkpoint object.
print(checkpoint.metadata)

You may store an arbitrarily nested dictionary using the add_metadata() method. If the top level key already exists the entire tree beneath it will be overwritten.

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint.add_metadata({"metrics": {"loss": 0.12}})
checkpoint.add_metadata({"metrics": {"acc": 0.92}})

print(checkpoint.metadata)  # Output: {"metrics": {"acc": 0.92}}

You may remove metadata via the remove_metadata() method. The method accepts a list of top level keys. The entire tree beneath the keys passed will be deleted.

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint.remove_metadata(["metrics"])

Download Checkpoints via the CLI

Determined offers the following CLI commands for downloading checkpoints locally:

  1. det checkpoint download

  2. det trial download

  3. det experiment download

The det checkpoint download command downloads a checkpoint for the given UUID as shown below:

# Download a specific checkpoint.
det checkpoint download 46985143-af68-4d48-ab91-a6447052ca49

The command should display output resembling the following upon successfully downloading the checkpoint.

Local checkpoint path:
checkpoints/46985143-af68-4d48-ab91-a6447052ca49

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+---------------------------------------------
      1000 | 46985143-af68-4d48-ab91-a6447052ca49 | {
           |                                      |     "num_inputs": 0,
           |                                      |     "validation_metrics": {
           |                                      |         "loss": 7.906739711761475,
           |                                      |         "accuracy": 0.9646000266075134,
           |                                      |         "global_step": 1000,
           |                                      |         "average_loss": 0.12492649257183075
           |                                      |     }
           |                                      | }

The det trial download command downloads checkpoints for a specified trial. Similar to the TrialReference API, the det trial download command accepts --best, --latest, and --uuid options.

# Download best checkpoint.
det trial download <trial_id> --best
# Download best checkpoint to a particular directory.
det trial download <trial_id> --best --output-dir local_checkpoint

The command should display output resembling the following upon successfully downloading the checkpoint.

Local checkpoint path:
checkpoints/46985143-af68-4d48-ab91-a6447052ca49

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+---------------------------------------------
      1000 | 46985143-af68-4d48-ab91-a6447052ca49 | {
           |                                      |     "num_inputs": 0,
           |                                      |     "validation_metrics": {
           |                                      |         "loss": 7.906739711761475,
           |                                      |         "accuracy": 0.9646000266075134,
           |                                      |         "global_step": 1000,
           |                                      |         "average_loss": 0.12492649257183075
           |                                      |     }
           |                                      | }

The --latest and --uuid options are used as follows:

# Download the most recent checkpoint.
det trial download <trial_id> --latest

# Download a specific checkpoint.
det trial download <trial_id> --uuid <uuid-for-checkpoint>

Finally, the det experiment download command provides a similar experience to using the ExperimentReference Python API.

# Download the best checkpoint for a given experiment.
det experiment download <experiment_id>

# Download the best 3 checkpoints for a given experiment.
det experiment download <experiment_id> --top-n 3

The command should display output resembling the following upon successfully downloading the checkpoints.

Local checkpoint path:
checkpoints/8d45f621-8652-4268-8445-6ae9a735e453

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+------------------------------------------
       400 | 8d45f621-8652-4268-8445-6ae9a735e453 | {
           |                                      |     "num_inputs": 56,
           |                                      |     "validation_metrics": {
           |                                      |         "val_loss": 0.26509127765893936,
           |                                      |         "val_categorical_accuracy": 1
           |                                      |     }
           |                                      | }

Local checkpoint path:
checkpoints/62131ba1-983c-49a8-98ef-36207611d71f

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+------------------------------------------
      1600 | 62131ba1-983c-49a8-98ef-36207611d71f | {
           |                                      |     "num_inputs": 50,
           |                                      |     "validation_metrics": {
           |                                      |         "val_loss": 0.04411194706335664,
           |                                      |         "val_categorical_accuracy": 1
           |                                      |     }
           |                                      | }

Local checkpoint path:
checkpoints/a36d2a61-a384-44f7-a84b-8b30b09cb618

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+------------------------------------------
       400 | a36d2a61-a384-44f7-a84b-8b30b09cb618 | {
           |                                      |     "num_inputs": 46,
           |                                      |     "validation_metrics": {
           |                                      |         "val_loss": 0.07265569269657135,
           |                                      |         "val_categorical_accuracy": 1
           |                                      |     }
           |                                      | }

Loading From a Local Path

Checkpoint contains a static method, load_from_path(), that loads a checkpoint from a path on the local file system.

Suppose a checkpoint is downloaded using a command like this:

det trial download <trial_id> --best --output-dir local_checkpoint

The checkpoint can then be loaded in Python with this code:

from determined.experimental import Checkpoint

model = Checkpoint.load_from_path("local_checkpoint")