def make_trial_controller_from_trial_implementation( trial_class: Type[det.Trial], hparams: Dict, workloads: workload.Stream, scheduling_unit: int = 1, load_path: Optional[pathlib.Path] = None, trial_seed: int = 0, exp_config: Optional[Dict] = None, ) -> det.TrialController: if not exp_config: exp_config = make_default_exp_config(hparams, scheduling_unit) env = make_default_env_context(hparams=hparams, experiment_config=exp_config, trial_seed=trial_seed) rendezvous_info = make_default_rendezvous_info() hvd_config = make_default_hvd_config() # TODO(ryan): remove all global APIs that read from environment variables os.environ["DET_HPARAMS"] = json.dumps(hparams) return load.load_controller_from_trial( trial_class=trial_class, env=env, workloads=workloads, load_path=load_path, rendezvous_info=rendezvous_info, hvd_config=hvd_config, )
def make_trial_controller( trial_class: Type[det.Trial], hparams: Dict[str, Any], workloads: workload.Stream, env: Optional[det.EnvContext] = None, load_path: Optional[pathlib.Path] = None, ) -> det.TrialController: """ Create a TrialController for a given Trial class, using the Trial.get_trial_controller_class() static method, as the harness code would. """ if env is None: env = make_default_env_context(hparams=hparams) return load.load_controller_from_trial( trial_class, env=env, workloads=workloads, load_path=load_path, rendezvous_info=make_default_rendezvous_info(), hvd_config=make_default_hvd_config(), )
def test_one_batch( context_path: pathlib.Path, trial_class: Optional[Type[det.Trial]] = None, config: Optional[Dict[str, Any]] = None, ) -> None: # Override the batches_per_step value to 1. # TODO(DET-2931): Make the validation step a single batch as well. config = {**(config or {}), "batches_per_step": 1} print("Running a minimal test experiment locally") checkpoint_dir = tempfile.TemporaryDirectory() env, workloads, rendezvous_info, hvd_config = make_test_experiment_env( checkpoint_dir=pathlib.Path(checkpoint_dir.name), config=config) print(f"Using hyperparameters: {env.hparams}") if util.debug_mode(): print(f"Using a test experiment config: {env.experiment_config}") with local_execution_manager(context_path): if not trial_class: if util.debug_mode(): print("Loading trial class from experiment configuration") trial_class = load.load_trial_implementation( env.experiment_config["entrypoint"]) controller = load.load_controller_from_trial( trial_class=trial_class, env=env, workloads=workloads, load_path=None, rendezvous_info=rendezvous_info, hvd_config=hvd_config, ) controller.run() checkpoint_dir.cleanup() print( "Note: to submit an experiment to the cluster, change mode argument to Mode.CLUSTER" )
def test_one_batch( controller_cls: Optional[Type[det.TrialController]] = None, native_context_cls: Optional[Type[det.NativeContext]] = None, trial_class: Optional[Type[det.Trial]] = None, config: Optional[Dict[str, Any]] = None, ) -> Any: # Override the scheduling_unit value to 1. config = {**(config or {}), "scheduling_unit": 1} logging.info("Running a minimal test experiment locally") checkpoint_dir = tempfile.TemporaryDirectory() env, rendezvous_info, hvd_config = det._make_local_execution_env( managed_training=True, test_mode=True, config=config, limit_gpus=1) workloads = _make_test_workloads( pathlib.Path(checkpoint_dir.name).joinpath("checkpoint"), env.experiment_config) logging.info(f"Using hyperparameters: {env.hparams}.") logging.debug(f"Using a test experiment config: {env.experiment_config}.") if native_context_cls is not None and controller_cls is not None: # Case 1: test one batch for Native implementation. controller_cls.pre_execute_hook(env=env, hvd_config=hvd_config) context = native_context_cls( env=env, hvd_config=hvd_config, rendezvous_info=rendezvous_info, ) def train_fn() -> None: controller = cast(Type[det.TrialController], controller_cls).from_native( context=context, env=env, workloads=workloads, load_path=None, rendezvous_info=rendezvous_info, hvd_config=hvd_config, ) controller.run() checkpoint_dir.cleanup() context._set_train_fn(train_fn) logging.info( "Note: to submit an experiment to the cluster, change local parameter to False" ) return context elif trial_class is not None: # Case 2: test one batch for Trial implementation. controller = load.load_controller_from_trial( trial_class=trial_class, env=env, workloads=workloads, load_path=None, rendezvous_info=rendezvous_info, hvd_config=hvd_config, ) controller.run() checkpoint_dir.cleanup() logging.info( "Note: to submit an experiment to the cluster, change local parameter to False" ) else: raise errors.InternalException( "Must provide a trial_def if using Trial API or " "a controller_cls and a native_context_cls if using Native API.")