Using Trained Checkpoints¶
Determined provides APIs for downloading trained checkpoints and loading them into memory in a Python process.
This guide discusses:
Querying trained model checkpoints from trials and experiments.
Loading models into memory in a Python process.
Using the Determined CLI to download checkpoints to disk.
Querying Checkpoints¶
The ExperimentReference
class is a
reference to an experiment. It is retrieved via the
Determined
class. The reference
contains the
top_checkpoint()
method. Without arguments, the method will check the experiment
configuration searcher field for the metric
and
smaller_is_better
values. These values are used to sort the
experiment’s checkpoints by validation performance. The searcher
settings in the following snippet from an experiment configuration file
will result in checkpoints being sorted by the loss metric in ascending
order.
searcher:
metric: "loss"
smaller_is_better: true
The following snippet can be run with a Python interpreter after the
specified experiment has generated a checkpoint. It returns an
instance of Checkpoint
representing the checkpoint from an experiment with the best validation metric.
from determined.experimental import Determined
checkpoint = Determined().get_experiment(id).top_checkpoint()
Checkpoints can be sorted by any metric using the sort_by
keyword
argument, which defines which metric to use, and smaller_is_better
,
which defines whether to sort the checkpoints in ascending or descending
order with respect to the specified metric.
from determined.experimental import Determined
checkpoint = Determined().get_experiment(id).top_checkpoint(sort_by="accuracy", smaller_is_better=False)
You may also query multiple checkpoints at the same time using the
top_n_checkpoints()
method. Only the single best checkpoint from each trial is considered;
out of those, the checkpoints with the best validation metric values are
returned in sorted order, with the best one first. For example, the
following snippet returns the top five checkpoints from distinct trials
of a specified experiment.
from determined.experimental import Determined
checkpoints = Determined.get_experiment(id).top_n_checkpoints(5)
This method also accepts sort_by
and smaller_is_better
arguments.
TrialReference
is used for
fine-grained control over checkpoint selection within a trial. It
contains a
top_checkpoint()
method,
which mirrors
top_checkpoint()
for
an experiment. It also contains
select_checkpoint()
, which
offers three ways to query checkpoints:
best
: Returns the best checkpoint based on validation metrics as discussed above. When usingbest
,smaller_is_better
andsort_by
are also accepted.latest
: Returns the most recent checkpoint for the trial.uuid
: Returns the checkpoint with the specified UUID.
The following snippet showcases how to use the different modes for selecting checkpoints.
from determined.experimental import Determined
trial = Determined.get_trial(id)
best_checkpoint = trial.top_checkpoint()
most_accurate_checkpoint = trial.select_checkpoint(
best=True,
sort_by="accuracy",
smaller_is_better=False
)
most_recent_checkpoint = trial.select_checkpoint(latest=True)
specific_checkpoint = trial.select_checkpoint(uuid="uuid-for-checkpoint")
Using the Checkpoint class¶
The Checkpoint
class can both
download the checkpoint from persistent storage and load it into memory in a
Python process.
The download()
method
downloads a checkpoint from persistent storage to the local file system.
By default, the method downloads the checkpoint to
checkpoints/<checkpoint-uuid>/
. The method accepts path
as a
parameter, which changes the checkpoint download location.
from determined.experimental import Determined
checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint_path = checkpoint.download()
specific_path = checkpoint.download(path="specific-checkpoint-path")
The load()
method downloads
the checkpoint, if it does not exist locally, and then loads it into
memory in a Python process, as shown in the following snippet.
from determined.experimental import Determined
checkpoint = Determined.get_experiment(id).top_checkpoint()
model = checkpoint.load()
TensorFlow checkpoints are saved in the saved_model
format and are loaded
as trackable objects (see documentation for tf.compat.v1.saved_model.load_v2
for details).
PyTorch checkpoints are saved using pickle and loaded as
nn.Module
objects (see the PyTorch documentation for
details).
Download Checkpoints via the CLI¶
Determined offers the following CLI commands for downloading checkpoints locally:
det checkpoint download
det trial download
det experiment download
The det checkpoint download
command downloads a checkpoint for the given
UUID as shown below:
# Download a specific checkpoint.
det checkpoint download 46985143-af68-4d48-ab91-a6447052ca49
The command should display output resembling the following upon successfully downloading the checkpoint.
Local checkpoint path:
checkpoints/46985143-af68-4d48-ab91-a6447052ca49
Batch | Checkpoint UUID | Validation Metrics
-----------+--------------------------------------+---------------------------------------------
1000 | 46985143-af68-4d48-ab91-a6447052ca49 | {
| | "num_inputs": 0,
| | "validation_metrics": {
| | "loss": 7.906739711761475,
| | "accuracy": 0.9646000266075134,
| | "global_step": 1000,
| | "average_loss": 0.12492649257183075
| | }
| | }
The det trial download
command downloads checkpoints for a specified trial.
Similar to the TrialReference
above, the
det trial download
command accepts --best
, --latest
, and --uuid
options.
# Download best checkpoint.
det trial download <trial_id> --best
# Download best checkpoint to a particular directory.
det trial download <trial_id> --best --output-dir local_checkpoint
The command should display output resembling the following upon successfully downloading the checkpoint.
Local checkpoint path:
checkpoints/46985143-af68-4d48-ab91-a6447052ca49
Batch | Checkpoint UUID | Validation Metrics
-----------+--------------------------------------+---------------------------------------------
1000 | 46985143-af68-4d48-ab91-a6447052ca49 | {
| | "num_inputs": 0,
| | "validation_metrics": {
| | "loss": 7.906739711761475,
| | "accuracy": 0.9646000266075134,
| | "global_step": 1000,
| | "average_loss": 0.12492649257183075
| | }
| | }
The --latest
and --uuid
options are used as follows:
# Download the most recent checkpoint.
det trial download <trial_id> --latest
# Download a specific checkpoint.
det trial download <trial_id> --uuid <uuid-for-checkpoint>
Finally, the det experiment download
command provides a similar experience
to using the above ExperimentReference
.
# Download the best checkpoint for a given experiment.
det experiment download <experiment_id>
# Download the best 3 checkpoints for a given experiment.
det experiment download <experiment_id> --top-n 3
The command should display output resembling the following upon successfully downloading the checkpoint(s).
Local checkpoint path:
checkpoints/8d45f621-8652-4268-8445-6ae9a735e453
Batch | Checkpoint UUID | Validation Metrics
-----------+--------------------------------------+------------------------------------------
400 | 8d45f621-8652-4268-8445-6ae9a735e453 | {
| | "num_inputs": 56,
| | "validation_metrics": {
| | "val_loss": 0.26509127765893936,
| | "val_categorical_accuracy": 1
| | }
| | }
Local checkpoint path:
checkpoints/62131ba1-983c-49a8-98ef-36207611d71f
Batch | Checkpoint UUID | Validation Metrics
-----------+--------------------------------------+------------------------------------------
1600 | 62131ba1-983c-49a8-98ef-36207611d71f | {
| | "num_inputs": 50,
| | "validation_metrics": {
| | "val_loss": 0.04411194706335664,
| | "val_categorical_accuracy": 1
| | }
| | }
Local checkpoint path:
checkpoints/a36d2a61-a384-44f7-a84b-8b30b09cb618
Batch | Checkpoint UUID | Validation Metrics
-----------+--------------------------------------+------------------------------------------
400 | a36d2a61-a384-44f7-a84b-8b30b09cb618 | {
| | "num_inputs": 46,
| | "validation_metrics": {
| | "val_loss": 0.07265569269657135,
| | "val_categorical_accuracy": 1
| | }
| | }
Loading From a Local Path¶
Checkpoint
contains a static method,
load_from_path()
, for loading
a checkpoint from a local path.
Suppose a checkpoint is downloaded using a command like this:
det trial download <trial_id> --best --output-dir local_checkpoint
Then it can be loaded in Python with this code:
from determined.experimental import Checkpoint
model = Checkpoint.load_from_path("local_checkpoint")
Next Steps¶
determined.experimental: The reference documentation for this API.