determined.TrialContext¶
The TrialContext
provides useful methods for writing Trial
subclasses.
All Trial
subclasses receive a TrialContext
object as an argument to
their __init__()
method:
PyTorchTrial
subclasses receive a plainPyTorchTrialContext
.TFKerasTrial
subclasses receive aTFKerasTrialContext
.EstimatorTrial
subclasses receive anEstimatorTrialContext
.
determined.TrialContext
¶
-
class
determined.
TrialContext
(env: determined._env_context.EnvContext, hvd_config: determined.horovod.HorovodContext)¶ A base class that all TrialContexts will inherit from. The context passed to the UserTrial.__init__() when we instantiate the user’s Trial must inherit from this class.
TrialContext always has an instance of
DistributedContext
accessible as thedistributed
attribute for information related to distributed training.-
get_data_config
() → Dict[str, Any]¶ Return the data configuration.
-
get_experiment_config
() → Dict[str, Any]¶ Return the experiment configuration.
-
get_experiment_id
() → int¶ Return the experiment ID of the current trial.
-
get_global_batch_size
() → int¶ Return the global batch size.
-
get_hparam
(name: str) → Any¶ Return the current value of the hyperparameter with the given name.
-
get_hparams
() → Dict[str, Any]¶ Return a dictionary of hyperparameter names to values.
-
get_per_slot_batch_size
() → int¶ Return the per-slot batch size. When a model is trained with a single GPU, this is equal to the global batch size. When multi-GPU training is used, this is equal to the global batch size divided by the number of GPUs used to train the model.
-
get_stop_requested
() → bool¶ Return whether a trial stoppage has been requested.
-
get_trial_id
() → int¶ Return the trial ID of the current trial.
-
set_stop_requested
(stop_requested: bool) → None¶ Set a flag to request a trial stoppage. When this flag is set to True, we finish the step, checkpoint, then exit.
-
determined.TrialContext.distributed
¶
-
class
determined._train_context.
DistributedContext
(env: determined._env_context.EnvContext, hvd_config: determined.horovod.HorovodContext)¶ DistributedContext extends all TrialContexts and NativeContexts under the
context.distributed
namespace. It provides useful methods for effective distributed training.-
get_rank
() → int¶ Return the rank of the process in the trial. The rank of a process is a unique ID within the trial; that is, no two processes in the same trial will be assigned the same rank.
-
get_local_rank
() → int¶ Return the rank of the process on the agent. The local rank of a process is a unique ID within a given agent and trial; that is, no two processes in the same trial that are executing on the same agent will be assigned the same rank.
-
get_size
() → int¶ Return the number of slots this trial is running on.
-
get_num_agents
() → int¶ Return the number of agents this trial is running on.
-