Ejemplo n.º 1
0
def load_native_implementation_controller(
    env: det.EnvContext,
    workloads: workload.Stream,
    load_path: Optional[pathlib.Path],
    rendezvous_info: det.RendezvousInfo,
    hvd_config: horovod.HorovodContext,
) -> det.TrialController:
    check.true(
        env.experiment_config.native_enabled(),
        "Experiment configuration does not have an internal.native "
        f"configuration: {env.experiment_config}",
    )

    context, trial_class, controller_class = load.load_native_implementation(
        env, hvd_config)

    if trial_class is not None:
        return load_controller_from_trial(
            trial_class=trial_class,
            env=env,
            workloads=workloads,
            load_path=load_path,
            rendezvous_info=rendezvous_info,
            hvd_config=hvd_config,
        )

    else:
        # Framework-specific native implementation.
        check.is_not_none(
            controller_class,
            "The class attribute `trial_controller_class` is "
            "None; please set it the correct subclass of `det.TrialController`",
        )
        check.is_subclass(
            controller_class,
            det.TrialController,
            "The class attribute `trial_controller_class` is "
            "not a valid subclass of `det.TrialController`",
        )
        logging.info(
            f"Creating {controller_class.__name__} with {type(context).__name__}."
        )
        return cast(det.TrialController, controller_class).from_native(
            context=cast(det.NativeContext, context),
            env=env,
            workloads=workloads,
            load_path=load_path,
            rendezvous_info=rendezvous_info,
            hvd_config=hvd_config,
        )
Ejemplo n.º 2
0
def load_controller_from_trial(
    trial_class: Type[det.Trial],
    env: det.EnvContext,
    workloads: workload.Stream,
    load_path: Optional[pathlib.Path],
    rendezvous_info: det.RendezvousInfo,
    hvd_config: horovod.HorovodContext,
) -> det.TrialController:
    # Step 1: Validate model definition.
    controller_class = trial_class.trial_controller_class
    check.is_not_none(
        controller_class,
        f"The class attribute `trial_controller_class` of {trial_class.__name__} is "
        "None; please set it the correct subclass of `det.TrialController`",
    )
    check.is_subclass(
        controller_class,
        det.TrialController,
        f"The class attribute `trial_controller_class` of {trial_class.__name__} is "
        "not a valid subclass of `det.TrialController`",
    )
    controller_class = cast(Type[det.TrialController], controller_class)

    # Step 2: Initialize framework-specific details (horovod, random seeds, etc).
    controller_class.pre_execute_hook(env, hvd_config)
    trial_context = trial_class.trial_context_class(env, hvd_config)

    # Step 3: Instantiate the user's Trial.
    trial_inst = trial_class(trial_context)

    # Step 4: Return the TrialController.
    logging.info(
        f"Creating {controller_class.__name__} with {trial_class.__name__}.")
    return controller_class.from_trial(
        trial_inst=trial_inst,
        context=trial_context,
        env=env,
        workloads=workloads,
        load_path=load_path,
        rendezvous_info=rendezvous_info,
        hvd_config=hvd_config,
    )