Use Distributed Training with Sharded Checkpointing#

In this tutorial, we’ll show you how to manage sharded checkpoints using detached mode.

We will guide you through a process that includes setting up PyTorch for distributed training, sharding data between different processes, and saving sharded checkpoints.

For the full script, visit the GitHub repository.


These step-by-step instructions will cover:

  • Initializing communications libraries and distributed context

  • Implementing sharding for batches across processes

  • Reporting training and validation metrics

  • Storing sharded checkpoints

  • Running distributed code with PyTorch and the appropriate cluster topology arguments

By the end of this guide, you’ll:

  • Understand how distributed training functions in detached mode

  • Know how to shard checkpoints effectively

  • Understand how to employ the Core API for managing distributed training sessions



  • A Determined cluster

  • PyTorch library for distributed training


Step 1: Initialize Communications Library and Distributed Context#

Import necessary libraries, initialize the communications library, and set up the distributed context:

import logging
import torch.distributed as dist
import determined
import determined.core
from determined.experimental import core_v2

def main():
    distributed = core_v2.DistributedContext.from_torch_distributed()

Step 2: Shard Batches Across Processes#

Shard the batches between processes and report training metrics:

size = dist.get_world_size()
for i in range(100):
    if i % size == dist.get_rank():
            metrics={"loss": random.random(), "rank": dist.get_rank()},

Step 3: Report Validation Metrics Periodically#

Report validation metrics periodically, adding rank as a metric in addition to loss:

if (i + 1) % 10 == 0:
        metrics={"loss": random.random(), "rank": dist.get_rank()},

Step 4: Store Sharded Checkpoints#

Save the sharded checkpoints:

ckpt_metadata = {"steps_completed": i, f"rank_{dist.get_rank()}": "ok"}
with core_v2.checkpoint.store_path(ckpt_metadata, shard=True) as (path, uuid):
    with (path / f"state_{dist.get_rank()}").open("w") as fout:

Step 5: Retrieve Web Server Address and Close Context#

Get the address of the web server where our metrics will be sent, and close the core context:

if dist.get_rank() == 0:
        "See the experiment at:",

Step 6: Run Code with PyTorch#

Run the code with PyTorch and the appropriate arguments for cluster topology (number of nodes, processes per node, chief worker’s address, port, etc.):

python3 -m --nnodes=1 --nproc_per_node=2 \
  --master_addr --master_port 29400 --max_restarts 0 \

Navigate to <DET_MASTER_IP:PORT> in your web browser to see the experiment.

Next Steps#

Now that you’ve successfully used detached mode for distributed training with sharded checkpointing, you can try more examples using detached mode or learn more about Determined by visiting the tutorials.