Report Checkpoints

By checkpointing periodically during training and reporting those checkpoints to the master, you can stop and restart training in two different ways: either by pausing and reactivating training using the WebUI, or by clicking the Continue Trial button after the experiment completes.

These two types of continues have different behaviors. While you always want to preserve the value you are incrementing (the “model weight”), you do not always want to preserve the batch index. When you pause and reactivate you want training to continue from the same batch index, but when starting a fresh experiment you want training to start with a fresh batch index. You can save the trial ID in the checkpoint and use it to distinguish the two types of continues.

  1. Create a new 2_checkpoints.py training script called by copying the 1_metrics.py script from Report Metrics.

  2. Write save and load methods for your model:

    def save_state(x, steps_completed, trial_id, checkpoint_directory):
        with checkpoint_directory.joinpath("state").open("w") as f:
            f.write(f"{x},{steps_completed},{trial_id}")
    
    def load_state(trial_id, checkpoint_directory):
        checkpoint_directory = pathlib.Path(checkpoint_directory)
        with checkpoint_directory.joinpath("state").open("r") as f:
            x, steps_completed, ckpt_trial_id = [int(field) for field in f.read().split(",")]
        if ckpt_trial_id == trial_id:
            return x, steps_completed
        else:
            # This is a new trial; load the "model weight" but not the batch count.
            return x, 0
    
  3. In your if __name__ == "__main__" block, use the ClusterInfo API to gather additional information about the task running on the cluster, specifically a checkpoint to load from and the trial ID, which you also pass to main().

        info = det.get_cluster_info()
        assert info is not None, "this example only runs on-cluster"
        latest_checkpoint = info.latest_checkpoint
        trial_id = info.trial.trial_id
    
        with det.core.init() as core_context:
            main(
                core_context=core_context,
                latest_checkpoint=latest_checkpoint,
                trial_id=trial_id,
                increment_by=1
            )
    

    It is recommended that you always follow this pattern of extracting values from the ClusterInfo API and passing the values to lower layers of your code, instead of accessing the ClusterInfo API directly in the lower layers. In this way the lower layer can be written to run on or off of the Determined cluster.

  4. Within main(), add logic to continue from a checkpoint, when a checkpoint is provided:

    def main(core_context, latest_checkpoint, trial_id, increment_by):
        x = 0
    
        # NEW: load a checkpoint if one was provided.
        starting_batch = 0
        if latest_checkpoint is not None:
            with core_context.checkpoint.restore_path(latest_checkpoint) as path:
                x, starting_batch = load_state(trial_id, path)
    
        for batch in range(starting_batch, 100):
    
  5. You can checkpoint your model as frequently as you like. For this exercise, save a checkpoint after each training report, and check for a preemption signal after each checkpoint:

    if steps_completed % 10 == 0:
        core_context.train.report_training_metrics(
            steps_completed=steps_completed, metrics={"x": x}
        )
    
        # NEW: write checkpoints at regular intervals to limit lost progress
        # in case of a crash during training.
        checkpoint_metadata = {"steps_completed": steps_completed}
        with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid):
            save_state(x, steps_completed, trial_id, path)
    
        # NEW: check for a preemption signal.  This could originate from a
        # higher-priority task bumping us off the cluster, or for a user pausing
        # the experiment via the WebUI or CLI.
        if core_context.preempt.should_preempt():
            # At this point, a checkpoint ws just saved, so training can exit
            # immediately and resume when the trial is reactivated.
            return
    
  6. Create a 2_checkpoints.yaml file by copying the 0_start.yaml file and changing the first couple of lines:

    name: core-api-stage-2
    entrypoint: python3 2_checkpoints.py
    
  7. Run the code using the command:

    det e create 2_checkpoints.yaml . -f
    
  8. You can navigate to the experiment in the WebUI and pause it mid-training. The trial shuts down and stop producing logs. If you reactivate training it resumes where it stopped. After training is completed, click Continue Trial to see that a fresh training is started but the model weight continues from where the previous one finished.

The complete 2_checkpoints.py and 2_checkpoints.yaml listings used in this example can be found in the core_api.tgz download or in the Github repository.