Report Checkpoints¶
By checkpointing periodically during training and reporting those checkpoints to the master, you can stop and restart training in two different ways: either by pausing and reactivating training using the WebUI, or by clicking the Continue Trial button after the experiment completes.
These two types of continuations have different behaviors. While you always want to preserve the value you are incrementing (the “model weight”), you do not always want to preserve the batch index. When you pause and reactivate you want training to continue from the same batch index, but when starting a fresh experiment you want training to start with a fresh batch index. You can save the trial ID in the checkpoint and use it to distinguish the two types of continues.
Create a new
2_checkpoints.py
training script called by copying the1_metrics.py
script from Report Metrics.Write save and load methods for your model:
def save_state(x, steps_completed, trial_id, checkpoint_directory): with checkpoint_directory.joinpath("state").open("w") as f: f.write(f"{x},{steps_completed},{trial_id}")
def load_state(trial_id, checkpoint_directory): checkpoint_directory = pathlib.Path(checkpoint_directory) with checkpoint_directory.joinpath("state").open("r") as f: x, steps_completed, ckpt_trial_id = [int(field) for field in f.read().split(",")] if ckpt_trial_id == trial_id: return x, steps_completed else: # This is a new trial; load the "model weight" but not the batch count. return x, 0
In your
if __name__ == "__main__"
block, use the ClusterInfo API to gather additional information about the task running on the cluster, specifically a checkpoint to load from and the trial ID, which you also pass tomain()
.info = det.get_cluster_info() assert info is not None, "this example only runs on-cluster" latest_checkpoint = info.latest_checkpoint trial_id = info.trial.trial_id with det.core.init() as core_context: main( core_context=core_context, latest_checkpoint=latest_checkpoint, trial_id=trial_id, increment_by=1, )
It is recommended that you always follow this pattern of extracting values from the ClusterInfo API and passing the values to lower layers of your code, instead of accessing the ClusterInfo API directly in the lower layers. In this way the lower layer can be written to run on or off of the Determined cluster.
Within
main()
, add logic to continue from a checkpoint, when a checkpoint is provided:def main(core_context, latest_checkpoint, trial_id, increment_by): x = 0 # NEW: load a checkpoint if one was provided. starting_batch = 0 if latest_checkpoint is not None: with core_context.checkpoint.restore_path(latest_checkpoint) as path: x, starting_batch = load_state(trial_id, path) for batch in range(starting_batch, 100):
You can checkpoint your model as frequently as you like. For this exercise, save a checkpoint after each training report, and check for a preemption signal after each checkpoint:
if steps_completed % 10 == 0: core_context.train.report_training_metrics( steps_completed=steps_completed, metrics={"x": x} ) # NEW: write checkpoints at regular intervals to limit lost progress # in case of a crash during training. checkpoint_metadata = {"steps_completed": steps_completed} with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): save_state(x, steps_completed, trial_id, path) # NEW: check for a preemption signal. This could originate from a # higher-priority task bumping us off the cluster, or for a user pausing # the experiment via the WebUI or CLI. if core_context.preempt.should_preempt(): # At this point, a checkpoint was just saved, so training can exit # immediately and resume when the trial is reactivated. return
Create a
2_checkpoints.yaml
file by copying the0_start.yaml
file and changing the first couple of lines:name: core-api-stage-2 entrypoint: python3 2_checkpoints.py
Run the code using the command:
det e create 2_checkpoints.yaml . -f
You can navigate to the experiment in the WebUI and pause it mid-training. The trial shuts down and stops producing logs. If you reactivate training it resumes where it stopped. After training is completed, click Continue Trial to see that fresh training is started but that the model weight continues from where previous training finished.
The complete 2_checkpoints.py
and 2_checkpoints.yaml
listings used in this example can be
found in the core_api.tgz
download or in the Github repository.