def load_native_implementation_controller( env: det.EnvContext, workloads: workload.Stream, load_path: Optional[pathlib.Path], rendezvous_info: det.RendezvousInfo, hvd_config: horovod.HorovodContext, ) -> det.TrialController: check.true( env.experiment_config.native_enabled(), "Experiment configuration does not have an internal.native " f"configuration: {env.experiment_config}", ) context, trial_class, controller_class = load.load_native_implementation( env, hvd_config) if trial_class is not None: return load_controller_from_trial( trial_class=trial_class, env=env, workloads=workloads, load_path=load_path, rendezvous_info=rendezvous_info, hvd_config=hvd_config, ) else: # Framework-specific native implementation. check.is_not_none( controller_class, "The class attribute `trial_controller_class` is " "None; please set it the correct subclass of `det.TrialController`", ) check.is_subclass( controller_class, det.TrialController, "The class attribute `trial_controller_class` is " "not a valid subclass of `det.TrialController`", ) logging.info( f"Creating {controller_class.__name__} with {type(context).__name__}." ) return cast(det.TrialController, controller_class).from_native( context=cast(det.NativeContext, context), env=env, workloads=workloads, load_path=load_path, rendezvous_info=rendezvous_info, hvd_config=hvd_config, )
def load_controller_from_trial( trial_class: Type[det.Trial], env: det.EnvContext, workloads: workload.Stream, load_path: Optional[pathlib.Path], rendezvous_info: det.RendezvousInfo, hvd_config: horovod.HorovodContext, ) -> det.TrialController: # Step 1: Validate model definition. controller_class = trial_class.trial_controller_class check.is_not_none( controller_class, f"The class attribute `trial_controller_class` of {trial_class.__name__} is " "None; please set it the correct subclass of `det.TrialController`", ) check.is_subclass( controller_class, det.TrialController, f"The class attribute `trial_controller_class` of {trial_class.__name__} is " "not a valid subclass of `det.TrialController`", ) controller_class = cast(Type[det.TrialController], controller_class) # Step 2: Initialize framework-specific details (horovod, random seeds, etc). controller_class.pre_execute_hook(env, hvd_config) trial_context = trial_class.trial_context_class(env, hvd_config, rendezvous_info) # Step 3: Instantiate the user's Trial. trial_inst = trial_class(trial_context) # Step 4: Return the TrialController. logging.info( f"Creating {controller_class.__name__} with {trial_class.__name__}.") return controller_class.from_trial( trial_inst=trial_inst, context=trial_context, env=env, workloads=workloads, load_path=load_path, rendezvous_info=rendezvous_info, hvd_config=hvd_config, )