def create_default_env_context( experiment_config: Dict[str, Any]) -> det.EnvContext: det_trial_runner_network_interface = constants.AUTO_DETECT_TRIAL_RUNNER_NETWORK_INTERFACE return det.EnvContext( experiment_config=experiment_config, initial_workload=workload.Workload( workload.Workload.Kind.RUN_STEP, ExperimentID(1), TrialID(1), StepID(1), det.ExperimentConfig(experiment_config).scheduling_unit(), 0, ), master_addr="", master_port=0, use_tls=False, master_cert_file=None, master_cert_name=None, container_id="", hparams={"global_batch_size": 32}, latest_checkpoint=None, use_gpu=False, container_gpus=[], slot_ids=[], debug=False, workload_manager_type="", det_rendezvous_ports="", det_trial_unique_port_offset=0, det_trial_runner_network_interface=det_trial_runner_network_interface, det_trial_id="1", det_experiment_id="1", det_cluster_id="uuid-123", trial_seed=0, )
def make_default_env_context( hparams: Dict[str, Any], experiment_config: Optional[Dict] = None, trial_seed: int = 0 ) -> det.EnvContext: if experiment_config is None: experiment_config = make_default_exp_config(hparams, 1) # TODO(ryan): Fix the parameter passing so that this doesn't read from environment variables, # and we can get rid of the @expose_gpus fixture. use_gpu = distutils.util.strtobool(os.environ.get("DET_USE_GPU", "false")) gpu_uuids = gpu.get_gpu_uuids_and_validate(use_gpu) return det.EnvContext( experiment_config=experiment_config, initial_workload=workload.Workload( workload.Workload.Kind.RUN_STEP, ExperimentID(1), TrialID(1), StepID(1) ), master_addr="", master_port=0, container_id="", hparams=hparams, latest_checkpoint=None, use_gpu=use_gpu, container_gpus=gpu_uuids, slot_ids=[], debug=False, workload_manager_type="", det_rendezvous_ports="", det_trial_runner_network_interface=constants.AUTO_DETECT_TRIAL_RUNNER_NETWORK_INTERFACE, det_trial_id="1", det_experiment_id="1", det_cluster_id="uuid-123", trial_seed=trial_seed, )
def checkpoint_workload(step_id: int = 1, exp_id: int = 1, trial_id: int = 1) -> Workload: return Workload( Workload.Kind.CHECKPOINT_MODEL, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id), )
def validation_workload(step_id: int = 1, exp_id: int = 1, trial_id: int = 1) -> Workload: return Workload( Workload.Kind.COMPUTE_VALIDATION_METRICS, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id), )
def terminate_workload(step_id: int = 1, exp_id: int = 1, trial_id: int = 1) -> Workload: return Workload( Workload.Kind.TERMINATE, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id), 0, 0, )
def checkpoint_workload( step_id: int = 1, exp_id: int = 1, trial_id: int = 1, total_batches_processed: int = 0 ) -> Workload: return Workload( Workload.Kind.CHECKPOINT_MODEL, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id), 0, total_batches_processed, )
def validation_workload( step_id: int = 1, exp_id: int = 1, trial_id: int = 1, total_batches_processed: int = 0, ) -> Workload: return Workload( Workload.Kind.COMPUTE_VALIDATION_METRICS, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id), 0, total_batches_processed, )
def train_workload( step_id: int, exp_id: int = 1, trial_id: int = 1, num_batches: int = 1, total_batches_processed: int = 0, ) -> Workload: return Workload( Workload.Kind.RUN_STEP, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id), num_batches, total_batches_processed, )
def train_workload(step_id: int, exp_id: int = 1, trial_id: int = 1) -> Workload: return Workload(Workload.Kind.RUN_STEP, ExperimentID(exp_id), TrialID(trial_id), StepID(step_id))