Python API determined.TrialContext¶
The TrialContext
provides useful methods for writing Trial
subclasses. All Trial
subclasses receive a TrialContext
object as an argument to their __init__()
method:
PyTorchTrial
subclasses receive a plainPyTorchTrialContext
.TFKerasTrial
subclasses receive aTFKerasTrialContext
.EstimatorTrial
subclasses receive anEstimatorTrialContext
.
determined.TrialContext
¶
- class determined.TrialContext(env: determined._env_context.EnvContext, hvd_config: determined.horovod.HorovodContext, rendezvous_info: determined._info.RendezvousInfo)¶
TrialContext is the system-provided API to a Trial class.
TrialContext always has an instance of
DistributedContext
accessible as thedistributed
attribute for information related to distributed training.- classmethod from_config(config: Dict[str, Any]) determined._train_context.TrialContext ¶
Create an context object suitable for debugging outside of Determined.
An example for a subclass of
PyTorchTrial
:config = { ... } context = det.pytorch.PyTorchTrialContext.from_config(config) my_trial = MyPyTorchTrial(context) train_ds = my_trial.build_training_data_loader() for epoch_idx in range(3): for batch_idx, batch in enumerate(train_ds): metrics = my_trial.train_batch(batch, epoch_idx, batch_idx) ...
An example for a subclass of
TFKerasTrial
:config = { ... } context = det.keras.TFKerasTrialContext.from_config(config) my_trial = tf_keras_one_var_model.OneVarTrial(context) model = my_trial.build_model() model.fit(my_trial.build_training_data_loader()) eval_metrics = model.evaluate(my_trial.build_validation_data_loader())
- Parameters
config – An experiment config file, in dictionary form.
- get_experiment_config() Dict[str, Any] ¶
Return the experiment configuration.
- get_data_config() Dict[str, Any] ¶
Return the data configuration.
- get_experiment_id() int ¶
Return the experiment ID of the current trial.
- get_global_batch_size() int ¶
Return the global batch size.
- get_per_slot_batch_size() int ¶
Return the per-slot batch size. When a model is trained with a single GPU, this is equal to the global batch size. When multi-GPU training is used, this is equal to the global batch size divided by the number of GPUs used to train the model.
- get_trial_id() int ¶
Return the trial ID of the current trial.
- get_hparams() Dict[str, Any] ¶
Return a dictionary of hyperparameter names to values.
- get_hparam(name: str) Any ¶
Return the current value of the hyperparameter with the given name.
- get_stop_requested() bool ¶
Return whether a trial stoppage has been requested.
- set_stop_requested(stop_requested: bool) None ¶
Set a flag to request a trial stoppage. When this flag is set to True, we finish the step, checkpoint, then exit.
determined.TrialContext.distributed
¶
- class determined._train_context.DistributedContext(rank_info: determined._train_context.RankInfo, chief_ip: Optional[str] = None, pub_port: int = 12360, pull_port: int = 12376, port_offset: int = 0, force_tcp: bool = False)¶
DistributedContext provides useful methods for effective distributed training.
- get_rank() int ¶
Return the rank of the process in the trial. The rank of a process is a unique ID within the trial; that is, no two processes in the same trial will be assigned the same rank.
- get_local_rank() int ¶
Return the rank of the process on the agent. The local rank of a process is a unique ID within a given agent and trial; that is, no two processes in the same trial that are executing on the same agent will be assigned the same rank.
- get_size() int ¶
Return the number of slots this trial is running on.
- get_num_agents() int ¶
Return the number of agents this trial is running on.