Shortcuts

determined.TrialContext

The TrialContext provides useful methods for writing Trial subclasses. All Trial subclasses receive a TrialContext object as an argument to their __init__() method:

determined.TrialContext

class determined.TrialContext(env: determined._env_context.EnvContext, hvd_config: determined.horovod.HorovodContext)

A base class that all TrialContexts will inherit from. The context passed to the UserTrial.__init__() when we instantiate the user’s Trial must inherit from this class.

TrialContext always has an instance of DistributedContext accessible as the distributed attribute for information related to distributed training.

get_data_config() → Dict[str, Any]

Return the data configuration.

get_experiment_config() → Dict[str, Any]

Return the experiment configuration.

get_experiment_id() → int

Return the experiment ID of the current trial.

get_global_batch_size() → int

Return the global batch size.

get_hparam(name: str) → Any

Return the current value of the hyperparameter with the given name.

get_hparams() → Dict[str, Any]

Return a dictionary of hyperparameter names to values.

get_per_slot_batch_size() → int

Return the per-slot batch size. When a model is trained with a single GPU, this is equal to the global batch size. When multi-GPU training is used, this is equal to the global batch size divided by the number of GPUs used to train the model.

get_trial_id() → int

Return the trial ID of the current trial.

determined.TrialContext.distributed

class determined._train_context.DistributedContext(env: determined._env_context.EnvContext, hvd_config: determined.horovod.HorovodContext)

DistributedContext extends all TrialContexts and NativeContexts under the context.distributed namespace. It provides useful methods for effective multi-slot (parallel and distributed) training.

get_rank() → int

Return the rank of the process in the trial.

get_local_rank() → int

Return the rank of the process on the agent.

get_size() → int

Return the number of slots this trial is running on.

get_num_agents() → int

Return the number of agents this trial is running on.