Python API determined.TrialContext

The TrialContext provides useful methods for writing Trial subclasses. All Trial subclasses receive a TrialContext object as an argument to their __init__() method:

determined.TrialContext

class determined.TrialContext(generic_context: determined._generic._context.Context, env: determined._env_context.EnvContext, hvd_config: determined.horovod.HorovodContext)

TrialContext is the system-provided API to a Trial class.

TrialContext always has an instance of DistributedContext accessible as the distributed attribute for information related to distributed training.

classmethod from_config(config: Dict[str, Any]) determined._trial_context.TrialContext

Create an context object suitable for debugging outside of Determined.

An example for a subclass of PyTorchTrial:

config = { ... }
context = det.pytorch.PyTorchTrialContext.from_config(config)
my_trial = MyPyTorchTrial(context)

train_ds = my_trial.build_training_data_loader()
for epoch_idx in range(3):
    for batch_idx, batch in enumerate(train_ds):
        metrics = my_trial.train_batch(batch, epoch_idx, batch_idx)
        ...

An example for a subclass of TFKerasTrial:

config = { ... }
context = det.keras.TFKerasTrialContext.from_config(config)
my_trial = tf_keras_one_var_model.OneVarTrial(context)

model = my_trial.build_model()
model.fit(my_trial.build_training_data_loader())
eval_metrics = model.evaluate(my_trial.build_validation_data_loader())
Parameters

config – An experiment config file, in dictionary form.

get_experiment_config() Dict[str, Any]

Return the experiment configuration.

get_data_config() Dict[str, Any]

Return the data configuration.

get_experiment_id() int

Return the experiment ID of the current trial.

get_global_batch_size() int

Return the global batch size.

get_per_slot_batch_size() int

Return the per-slot batch size. When a model is trained with a single GPU, this is equal to the global batch size. When multi-GPU training is used, this is equal to the global batch size divided by the number of GPUs used to train the model.

get_trial_id() int

Return the trial ID of the current trial.

get_hparams() Dict[str, Any]

Return a dictionary of hyperparameter names to values.

get_hparam(name: str) Any

Return the current value of the hyperparameter with the given name.

get_stop_requested() bool

Return whether a trial stoppage has been requested.

set_stop_requested(stop_requested: bool) None

Set a flag to request a trial stoppage. When this flag is set to True, we finish the step, checkpoint, then exit.

determined.TrialContext.distributed

class determined._generic._distributed.DistributedContext(*, rank: int, size: int, local_rank: int, local_size: int, cross_rank: int, cross_size: int, chief_ip: Optional[str] = None, pub_port: int = 12360, pull_port: int = 12376, port_offset: int = 0, force_tcp: bool = False)

DistributedContext provides useful methods for effective distributed training.

A DistributedContext has the following required args:
  • rank: the index of this worker in the entire job

  • size: the number of workers in the entire job

  • local_rank: the index of this worker on this machine

  • local_size: the number of workers on this machine

  • cross_rank: the index of this machine in the entire job

  • cross_size: the number of this machines in the entire job

Additionally, any time that cross_size > 0, you must also provide:
  • chief_ip: the ip address to reach the chief worker (where rank==0)

get_rank() int

Return the rank of the process in the trial. The rank of a process is a unique ID within the trial; that is, no two processes in the same trial will be assigned the same rank.

get_local_rank() int

Return the rank of the process on the agent. The local rank of a process is a unique ID within a given agent and trial; that is, no two processes in the same trial that are executing on the same agent will be assigned the same rank.

get_size() int

Return the number of slots this trial is running on.

get_num_agents() int

Return the number of agents this trial is running on.