Distributed Training

The Core API has special considerations for running distributed training. Some of the more important considerations are:

  • Access to all IP addresses of every node in the Trial (through the ClusterInfo API).

  • Communication primitives such as allgather(), gather(), and broadcast() to give you out-of-the-box coordination between workers.

  • Since many distributed training frameworks expect all workers in training to operate in-step, the should_preempt() call is automatically synchronized across workers, so all workers decide to preempt or continue as a unit.

  1. Create a 4_distributed.py training script by copying the 3_hpsearch.py from Hyperparameter Search.

  2. Add launcher logic to execute one worker subprocess per slot.

    Start with a launcher_main() function that executes one worker subprocess per slot.

    def launcher_main(slots_per_node, num_nodes, cross_rank):
        # Use subprocess to start one worker process per node.
        procs = []
        for local_rank in range(slots_per_node):
            rank = cross_rank * slots_per_node + local_rank
            cmd = [
                # Use the determined.launch.wrap_rank to wrap the worker process.
                # This ensures logs from each worker can be filtered by rank in the WebUI.
                # Re-invoke this script but as a worker.
        # A good launcher normally waits for all workers to finish, but cleans up and exits
        # nonzero immediately if any worker fails to prevent distributed training jobs from
        # hanging.  One way to do this by managing each worker process in a thread and sending
        # exit codes over a Queue as workers complete.
        q = queue.Queue()
        def wait_for_worker(proc):
            worker_exit = proc.wait()
            q.put((proc, worker_exit))
        threads = [threading.Thread(target=wait_for_worker, args=(proc,)) for proc in procs]
        for t in threads:
        first_failed_exit = 0
        for i in range(slots_per_node):
            proc, worker_exit = q.get()
            if worker_exit != 0 and first_failed_exit == 0:
                # When the first worker crashes, preempt the others.
                first_failed_exit = worker_exit
                for proc in procs:
        for t in threads:
        return first_failed_exit

    Typically, you do not have to write your own launcher. Determined provides launchers for Horovod, torch.distributed, and DeepSpeed. Additionally, there are third-party launchers available, such as mpirun. When using a custom or third-party launcher, wrap your worker script in the python -m determined.launcher.wrap_rank wrapper script so the WebUI log viewer can filter logs by rank.

    Also add a worker_main() that will run training on each slot:

    def worker_main(slots_per_node, num_nodes, cross_rank, chief_ip, rank, local_rank):
        # In the absence of a distributed training framework that might define the
        # rank/local_rank/cross_rank, you can derive them from the ClusterInfo API.
        distributed = det.core.DistributedContext(
            size=num_nodes * slots_per_node,
        with det.core.init(distributed=distributed) as core_context:

    Then modify your if __name__ == "__main__" block to invoke the correct *_main() based on command-line arguments:

    slots_per_node = len(info.slot_ids)
    num_nodes = len(info.container_addrs)
    cross_rank = info.container_rank
    chief_ip = info.container_addrs[0]
    # NEW: This script is invoked both as a launcher-of-workers, and again as each worker.
    if sys.argv[1] == "launcher":
        # Usage: SCRIPT launcher
        exitcode = launcher_main(slots_per_node, num_nodes, cross_rank)
    if sys.argv[1] == "worker":
        # Usage: SCRIPT worker $RANK $LOCAL_RANK
        logging.info(f"worker starting")
        rank = int(sys.argv[2])
        local_rank = int(sys.argv[3])
        exitcode = worker_main(
            slots_per_node, num_nodes, cross_rank, chief_ip, rank, local_rank
    raise ValueError(f"unrecognized first argument: {sys.argv[1]}")
  3. In the training code, use the allgather primitive to do a “distributed” increment, to gain experience using the communication primitives:

    all_increment_bys = core_context.distributed.allgather(increment_by)
    x += sum(all_increment_bys)
  4. Frequently, trial logs are easier to read when status is only printed on the chief worker:

    if core_context.distributed.rank == 0:
        logging.info(f"x is now {x}")
  5. Only the chief worker is permitted to report training metrics, report validation metrics, upload checkpoints, or report searcher operations completed. This rule applies to the steps you take periodically during training:

    if steps_completed % 10 == 0:
        # NEW: only the chief may report training metrics and progress,
        # or upload checkpoints.
        if core_context.distributed.rank == 0:
                steps_completed=steps_completed, metrics={"x": x}
            checkpoint_metadata = {"steps_completed": steps_completed}
            with core_context.checkpoint.store_path(
            ) as (checkpoint_directory, uuid):
                save_state(x, steps_completed, trial_id, checkpoint_directory)
            last_checkpoint_batch = steps_completed
        if core_context.preempt.should_preempt():

    The rule also applies to the steps you take after validating:

    if core_context.distributed.rank == 0:
            steps_completed=steps_completed, metrics={"x": x}

    The rule also applies to the conditional save after the main loop completes:

    # NEW: again, only the chief may upload checkpoints.
    if core_context.distributed.rank == 0 and last_checkpoint_batch != steps_completed:
        checkpoint_metadata = {"steps_completed": steps_completed}
        with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid):
            save_state(x, steps_completed, trial_id, path)
  6. Create a 4_distributed.yaml file by copying the 3_distributed.yaml file and changing the first couple of lines:

    name: core-api-stage-4
    entrypoint: ./4_distributed.py launcher

    Set the resources.slots_per_trial field to the number of GPUs you want:

      slots_per_trial: 8

    You can return to using the single searcher instead of an adaptive_asha hyperparameter search:

       name: single
       metric: x
       max_length: 100
  7. Run the code using the Determined CLI with the following command:

    det e create 4_distributed.yaml . -f

The complete 4_distributed.py and 3_hpsearch.yaml listings used in this example can be found in the core_api.tgz download or in the Github repository.