Exemple #1
0
def _init_cluster_mode(
    trial_def: Optional[Type[det.Trial]] = None,
    controller_cls: Optional[Type[det.TrialController]] = None,
    native_context_cls: Optional[Type[det.NativeContext]] = None,
    config: Optional[Dict[str, Any]] = None,
    test: bool = False,
    context_dir: str = "",
    command: Optional[List[str]] = None,
    master_url: Optional[str] = None,
) -> Any:
    if controller_cls is not None and native_context_cls is not None:
        # Case 1: initialize Native implementation.
        if load.RunpyGlobals.is_initialized():
            controller_cls.pre_execute_hook(
                env=load.RunpyGlobals.get_instance().env,
                hvd_config=load.RunpyGlobals.get_instance().hvd_config,
            )
            context = native_context_cls(
                env=load.RunpyGlobals.get_instance().env,
                hvd_config=load.RunpyGlobals.get_instance().hvd_config,
                rendezvous_info=load.RunpyGlobals.get_instance().
                rendezvous_info,
            )
            load.RunpyGlobals.set_runpy_native_result(context, controller_cls)
            context._set_train_fn(_stop_loading_implementation)
            return context

        else:
            _submit_experiment(config=config,
                               context_dir=context_dir,
                               command=command,
                               master_url=master_url)
            logging.info(
                "Exiting the program after submitting the experiment.")
            sys.exit(0)

    elif trial_def is not None:
        # Case 2: initialize Trial implementation.
        if load.RunpyGlobals.is_initialized():
            load.RunpyGlobals.set_runpy_trial_result(
                trial_def,
                cast(Type[det.TrialController],
                     trial_def.trial_controller_class))
            _stop_loading_implementation()

        else:
            _submit_experiment(
                config=config,
                test=test,
                context_dir=context_dir,
                command=command,
                master_url=master_url,
            )

    else:
        raise errors.InternalException(
            "Must provide a trial_def if using Trial API or "
            "a controller_cls and a native_context_cls if using Native API.")
Exemple #2
0
def test_one_batch(
    controller_cls: Optional[Type[det.TrialController]] = None,
    native_context_cls: Optional[Type[det.NativeContext]] = None,
    trial_class: Optional[Type[det.Trial]] = None,
    config: Optional[Dict[str, Any]] = None,
) -> Any:
    # Override the scheduling_unit value to 1.
    config = {**(config or {}), "scheduling_unit": 1}

    logging.info("Running a minimal test experiment locally")
    checkpoint_dir = tempfile.TemporaryDirectory()
    env, rendezvous_info, hvd_config = det._make_local_execution_env(
        managed_training=True, test_mode=True, config=config, limit_gpus=1)
    workloads = _make_test_workloads(
        pathlib.Path(checkpoint_dir.name).joinpath("checkpoint"),
        env.experiment_config)
    logging.info(f"Using hyperparameters: {env.hparams}.")
    logging.debug(f"Using a test experiment config: {env.experiment_config}.")

    if native_context_cls is not None and controller_cls is not None:
        # Case 1: test one batch for Native implementation.
        controller_cls.pre_execute_hook(env=env, hvd_config=hvd_config)
        context = native_context_cls(
            env=env,
            hvd_config=hvd_config,
            rendezvous_info=rendezvous_info,
        )

        def train_fn() -> None:
            controller = cast(Type[det.TrialController],
                              controller_cls).from_native(
                                  context=context,
                                  env=env,
                                  workloads=workloads,
                                  load_path=None,
                                  rendezvous_info=rendezvous_info,
                                  hvd_config=hvd_config,
                              )
            controller.run()
            checkpoint_dir.cleanup()

        context._set_train_fn(train_fn)
        logging.info(
            "Note: to submit an experiment to the cluster, change local parameter to False"
        )
        return context

    elif trial_class is not None:
        # Case 2: test one batch for Trial implementation.
        controller = load.load_controller_from_trial(
            trial_class=trial_class,
            env=env,
            workloads=workloads,
            load_path=None,
            rendezvous_info=rendezvous_info,
            hvd_config=hvd_config,
        )
        controller.run()
        checkpoint_dir.cleanup()
        logging.info(
            "Note: to submit an experiment to the cluster, change local parameter to False"
        )

    else:
        raise errors.InternalException(
            "Must provide a trial_def if using Trial API or "
            "a controller_cls and a native_context_cls if using Native API.")