Пример #1
0
 def metrics_result(self) -> Metrics:
     """Identical to result but disallow workload.Skipped responses."""
     check.is_not_none(self._response,
                       "_respond() was not called by the TrialController.")
     check.is_instance(self._response, dict,
                       "unexpected SkippedWorkload response.")
     return cast(Metrics, self._response)
Пример #2
0
    def _compute_validation_metrics(self) -> Any:
        """
        Computes validation metrics using either Evaluator() or CustomInferenceRunner().
        """
        if self.evaluator:
            check.is_none(self.validation_metrics_names)
            metrics = self.evaluator.compute_validation_metrics()
        else:
            check.is_not_none(self.validation_metrics_names)
            # Find our custom Inference callback.
            custom_inference_callback = None  # type: Optional[CustomInferenceRunner]
            for callback in self.trainer._callbacks.cbs:
                if isinstance(callback, CustomInferenceRunner):
                    custom_inference_callback = callback
                    break
            custom_inference_callback = cast(CustomInferenceRunner,
                                             custom_inference_callback)
            self.validation_metrics_names = cast(List[str],
                                                 self.validation_metrics_names)
            metrics = custom_inference_callback.trigger_on_validation_step(
                self.validation_metrics_names)

        if not self.is_chief:
            return workload.Skipped()

        return {"validation_metrics": metrics}
Пример #3
0
 def __getattr__(self, attr: str) -> Any:
     check.is_not_none(
         self._poly_hvd_type,
         "You must call det.horovod.hvd.require_horovod_type() before any other calls.",
     )
     check.is_not_none(self._poly_hvd_module, "Horovod could not be imported in this process.")
     return getattr(self._poly_hvd_module, attr)
Пример #4
0
    def on_train_batch_end(self, _: int, logs: Any = None) -> None:
        check.is_in("loss", logs)

        # Remove default keras metrics we aren't interested in like "batch" and
        # "size".
        self.metrics.append(
            {k: v
             for k, v in logs.items() if k not in {"batch", "size"}})
        self.batches_processed += 1
        if self.batches_processed != self.tf_keras_trial_controller.batches_per_step:
            return

        check.is_not_none(
            self.tf_keras_trial_controller.train_response_func,
            "Callback should avoid calling model.predict() or change model.stop_training "
            "as this will affect Determined training behavior",
        )
        response_func = cast(
            workload.ResponseFunc,
            self.tf_keras_trial_controller.train_response_func)

        # TODO(DET-1278): Average training metrics across GPUs when using Horovod.
        num_inputs = (self.tf_keras_trial_controller.batches_per_step *
                      self.tf_keras_trial_controller.batch_size)

        if self.tf_keras_trial_controller.is_chief:
            response_func(det.util.make_metrics(num_inputs, self.metrics))
        else:
            response_func(workload.Skipped())

        self.tf_keras_trial_controller.train_response_func = None
        self.metrics = []
        self.batches_processed = 0

        self.tf_keras_trial_controller.run()
Пример #5
0
 def result(self) -> Response:
     """Read the WorkloadResponse from the TrialController (only call once per send)."""
     check.is_not_none(self._response,
                       "_respond() was not called by the TrialController.")
     out = self._response
     self._response = None
     return cast(Response, out)
Пример #6
0
    def _save_model(self) -> None:
        # Only save when we have performed training since the last
        # time we saved.
        started_training = self._current_global_step is not None
        checkpoint_exists = self._global_step_of_last_checkpoint is not None
        if not started_training or (checkpoint_exists
                                    and self._global_step_of_last_checkpoint
                                    == self._current_global_step):
            return

        logging.info(
            f"Saving checkpoints for step: {self._current_global_step} "
            f"into {self.estimator_trial_controller.estimator_dir}.")

        check.is_not_none(self._session)
        check.is_not_none(self._current_global_step)
        self._current_global_step = cast(int, self._current_global_step)

        self._get_saver().save(
            self._session,
            str(
                self.estimator_trial_controller.estimator_dir.joinpath(
                    "model.ckpt")),
            global_step=self._current_global_step,
        )
        self._global_step_of_last_checkpoint = self._current_global_step
Пример #7
0
    def _select_optimizers(self) -> None:
        """
        Selects the optimizers that are going to be used. This is done for backwards
        compatibility as previously optimizers were passed in as part of the compile()
        call and are now passed in as part of `self.context.wrap_optimizers()`.
        """
        check.check_len(
            self._optimizers,
            0,
            "context._select_optimizers() called multiple times. Should be only called "
            "once by TFKerasTrialController.",
        )

        if len(self._wrapped_optimizers) > 0:
            logging.debug(
                f"Using wrapped optimizers: {self._wrapped_optimizers}.")
            self._optimizers = self._wrapped_optimizers
            return

        check.is_not_none(
            self._compiled_optimizer,
            "Please use `optimizer = self.context.wrap_optimizer(optimizer)` to wrap your "
            "optimizer. If using multiple optimizer, you should wrap your optimizer "
            "separately (calling wrap_optimizer() once for each optimizer).",
        )

        if self._compiled_optimizer:
            logging.info(
                "Please switch over to using `optimizer = self.context.wrap_optimizer()`."
            )
            logging.debug(
                f"Using compiled optimizer: {self._compiled_optimizer}.")
            self._optimizers = [self._compiled_optimizer]
Пример #8
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Пример #9
0
    def after_run(self, run_context: tf.estimator.SessionRunContext,
                  run_values: tf.estimator.SessionRunValues) -> None:
        # Check for optimizer creation here because when model_fn is passed in as a closure,
        # the optimizer is not initialized until the first training step.
        check.true(
            self.estimator_trial_controller.context.optimizer_initialized,
            "Please pass your optimizer into "
            "`det.estimator.wrap_optimizer(optimizer)` "
            "right after creating it.",
        )
        self._session = run_context.session
        self._current_global_step = run_values.results["global_step"]

        self.num_batches = cast(int, self.num_batches)
        self._collect_batch_metrics(run_values)
        self.batches_processed_in_step += 1
        if self.batches_processed_in_step < self.num_batches:
            return

        # TODO: Average training results across GPUs. This might
        # degrade performance due to an increase in communication.

        # Loss training metric is sometimes called `loss_1` instead of `loss`.
        for step_metrics in self.step_metrics:
            if "loss" not in step_metrics and "loss_1" in step_metrics:
                step_metrics["loss"] = step_metrics["loss_1"]

        # Send the result of the training step back to the main process.
        check.is_not_none(self.train_response_func,
                          "no response_func at end of train_for_step")
        self.train_response_func = cast(workload.ResponseFunc,
                                        self.train_response_func)
        if self.estimator_trial_controller.is_chief:
            response = {
                "metrics":
                det.util.make_metrics(self.batches_processed_in_step,
                                      self.step_metrics),
                "stop_requested":
                self.estimator_trial_controller.context.get_stop_requested(),
                "invalid_hp":
                False,
            }
            self.train_response_func(response)
        else:
            self.train_response_func(workload.Skipped())

        # Reset step counter and clear the step metrics from memory.
        self.train_response_func = None
        self.batches_processed_in_step = 0
        self.step_metrics = []

        estimator._cleanup_after_train_step(
            self.estimator_trial_controller.estimator_dir)

        # Re-enter the control loop (block on receiving the next instruction)
        self.control_loop()
Пример #10
0
    def _get_amp_setting(self) -> str:
        amp_setting = self.env.experiment_config.get("optimizations", {}).get(
            "mixed_precision", None)
        check.is_not_none(amp_setting)
        check.not_in(
            "amp",
            self.env.hparams,
            "Please move `amp` setting from `hyperparameters` "
            "to `optimizations[`mixed_precision`]`.",
        )

        return cast(str, amp_setting)
Пример #11
0
 def _init_device(self) -> None:
     self.n_gpus = len(self.env.container_gpus)
     if self.hvd_config.use:
         check.gt(self.n_gpus, 0)
         # We launch a horovod process per GPU. Each process
         # needs to bind to a unique GPU.
         self.device = torch.device(hvd.local_rank())
         torch.cuda.set_device(self.device)
     elif self.n_gpus > 0:
         self.device = torch.device("cuda", 0)
     else:
         self.device = torch.device("cpu")
     check.is_not_none(self.device)
Пример #12
0
def load_native_implementation_controller(
    env: det.EnvContext,
    workloads: workload.Stream,
    load_path: Optional[pathlib.Path],
    rendezvous_info: det.RendezvousInfo,
    hvd_config: horovod.HorovodContext,
) -> det.TrialController:
    check.true(
        env.experiment_config.native_enabled(),
        "Experiment configuration does not have an internal.native "
        f"configuration: {env.experiment_config}",
    )

    context, trial_class, controller_class = load.load_native_implementation(
        env, hvd_config)

    if trial_class is not None:
        return load_controller_from_trial(
            trial_class=trial_class,
            env=env,
            workloads=workloads,
            load_path=load_path,
            rendezvous_info=rendezvous_info,
            hvd_config=hvd_config,
        )

    else:
        # Framework-specific native implementation.
        check.is_not_none(
            controller_class,
            "The class attribute `trial_controller_class` is "
            "None; please set it the correct subclass of `det.TrialController`",
        )
        check.is_subclass(
            controller_class,
            det.TrialController,
            "The class attribute `trial_controller_class` is "
            "not a valid subclass of `det.TrialController`",
        )
        logging.info(
            f"Creating {controller_class.__name__} with {type(context).__name__}."
        )
        return cast(det.TrialController, controller_class).from_native(
            context=cast(det.NativeContext, context),
            env=env,
            workloads=workloads,
            load_path=load_path,
            rendezvous_info=rendezvous_info,
            hvd_config=hvd_config,
        )
Пример #13
0
    def _launch_python_subprocess(self) -> subprocess.Popen:
        """
        Call training process without using horovodrun. Only used internally when testing.
        """

        check.is_not_none(self._python_subprocess_entrypoint)
        self._python_subprocess_entrypoint = cast(str, self._python_subprocess_entrypoint)

        # Construct the command to launch the non-horovod training subprocess.
        python_cmd = [
            "python",
            "-m",
            self._python_subprocess_entrypoint,
            str(self._worker_process_env_path),
        ]
        return subprocess.Popen(python_cmd)
Пример #14
0
    def on_train_batch_end(self, _: int, logs: Any = None) -> None:
        check.is_in("loss", logs)

        # Remove default keras metrics we aren't interested in like "batch" and
        # "size".
        self.metrics.append(
            {k: v
             for k, v in logs.items() if k not in {"batch", "size"}})
        self.batches_processed += 1
        if self.batches_processed != self.tf_keras_trial_controller.num_batches:
            return

        check.is_not_none(
            self.tf_keras_trial_controller.train_response_func,
            "Callback should avoid calling model.predict(), "
            "as this will affect Determined training behavior",
        )
        response_func = cast(
            workload.ResponseFunc,
            self.tf_keras_trial_controller.train_response_func)

        # TODO(DET-1278): Average training metrics across GPUs when using Horovod.
        num_inputs = (self.tf_keras_trial_controller.num_batches *
                      self.tf_keras_trial_controller.batch_size)

        if self.tf_keras_trial_controller.is_chief:
            response = {
                "metrics":
                det.util.make_metrics(num_inputs, self.metrics),
                "stop_requested":
                self.tf_keras_trial_controller.context.get_stop_requested(),
            }
            response_func(response)
        else:
            response_func(workload.Skipped())

        self.tf_keras_trial_controller.train_response_func = None
        self.metrics = []
        self.batches_processed = 0

        self.tf_keras_trial_controller.run()

        if self.model.stop_training and version.parse(
                tf.__version__) >= version.parse("2.2.0"):
            # Starting with TF 2.2, `model.stop_training` is only checked at the end of epochs.
            raise det.errors.WorkerFinishedGracefully
Пример #15
0
def load_controller_from_trial(
    trial_class: Type[det.Trial],
    env: det.EnvContext,
    workloads: workload.Stream,
    load_path: Optional[pathlib.Path],
    rendezvous_info: det.RendezvousInfo,
    hvd_config: horovod.HorovodContext,
) -> det.TrialController:
    # Step 1: Validate model definition.
    controller_class = trial_class.trial_controller_class
    check.is_not_none(
        controller_class,
        f"The class attribute `trial_controller_class` of {trial_class.__name__} is "
        "None; please set it the correct subclass of `det.TrialController`",
    )
    check.is_subclass(
        controller_class,
        det.TrialController,
        f"The class attribute `trial_controller_class` of {trial_class.__name__} is "
        "not a valid subclass of `det.TrialController`",
    )
    controller_class = cast(Type[det.TrialController], controller_class)

    # Step 2: Initialize framework-specific details (horovod, random seeds, etc).
    controller_class.pre_execute_hook(env, hvd_config)
    trial_context = trial_class.trial_context_class(env, hvd_config)

    # Step 3: Instantiate the user's Trial.
    trial_inst = trial_class(trial_context)

    # Step 4: Return the TrialController.
    logging.info(
        f"Creating {controller_class.__name__} with {trial_class.__name__}.")
    return controller_class.from_trial(
        trial_inst=trial_inst,
        context=trial_context,
        env=env,
        workloads=workloads,
        load_path=load_path,
        rendezvous_info=rendezvous_info,
        hvd_config=hvd_config,
    )
Пример #16
0
    def _trigger_epoch(self) -> None:
        """
        This runs at the end of each training step, sends the metrics back to the main process, and
        decides what to do next.
        """

        check.is_not_none(self.train_response_func,
                          "no response_func at end of train_for_step")
        self.train_response_func = cast(workload.ResponseFunc,
                                        self.train_response_func)

        if self.is_chief:
            self.train_response_func(
                det.util.make_metrics(None, self.batch_metrics))
        else:
            self.train_response_func(workload.Skipped())

        self.train_response_func = None
        self.batch_metrics = []

        self._control_loop()
Пример #17
0
    def __init__(
        self,
        num_connections: Optional[int] = None,
        ports: Optional[List[int]] = None,
        port_range: Optional[Tuple[int, int]] = None,
    ) -> None:
        self.context = zmq.Context()
        self.sockets = []  # type: List[zmq.Socket]
        self.ports = []  # type: List[int]

        if ports:
            check.is_none(port_range)
            self._bind_to_specified_ports(ports=ports)
            check.eq(len(self.ports), len(ports))
        else:
            check.is_not_none(num_connections)
            check.is_not_none(port_range)
            num_connections = cast(int, num_connections)
            port_range = cast(Tuple[int, int], port_range)
            self._bind_to_random_ports(port_range=port_range,
                                       num_connections=num_connections)
            check.eq(len(self.ports), num_connections)
Пример #18
0
    def from_native(
        context: det.NativeContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasNativeContext,
            "TFKerasTrialController needs a TFKerasSprinkleContext",
        )
        context = cast(keras.TFKerasNativeContext, context)

        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        check.is_not_none(
            context.train_config,
            "Please call model.fit(...) or model.fit_generator(...).",
        )

        # For the Native API, we would break the user's model if we changed the session
        # right now, so we have to trust the user did not modify what we set previously.
        #
        # TODO(ryan): Fix this, probably with a function for configuring the backend session.
        session = tf.compat.v1.keras.backend.get_session()

        compile_args = cast(inspect.BoundArguments, context.compile_args)
        train_config = cast(keras.TFKerasTrainConfig, context.train_config)

        (
            context.model,
            compile_args.arguments["optimizer"],
        ) = keras._get_multi_gpu_model_and_optimizer(
            pre_compiled_model=context.model,
            optimizer=compile_args.arguments["optimizer"],
            env=env,
            hvd_config=hvd_config,
            profile_frequency=env.experiment_config.profile_frequency(),
            profile_filename=DeterminedProfiler.OUTPUT_FILENAME,
        )

        context.model.compile(*compile_args.args, **compile_args.kwargs)

        return TFKerasTrialController(
            context.model,
            session,
            train_config,
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Пример #19
0
    def from_native(
        context: det.NativeContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasNativeContext,
            "TFKerasTrialController needs a TFKerasSprinkleContext",
        )
        context = cast(keras.TFKerasNativeContext, context)

        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        check.is_not_none(
            context.train_config,
            "Please call model.fit(...) or model.fit_generator(...).")

        # For the Native API, we would break the user's model if we changed the session
        # right now, so we have to trust the user did not modify what we set previously.
        #
        # TODO(ryan): Fix this, probably with a function for configuring the backend session.
        session = tf.compat.v1.keras.backend.get_session()

        compile_args = cast(inspect.BoundArguments, context.compile_args)
        train_config = cast(keras.TFKerasTrainConfig, context.train_config)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        return TFKerasTrialController(
            context.model,
            session,
            train_config,
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Пример #20
0
 def get_instance(cls) -> "RunpyGlobals":
     check.is_not_none(cls._instance, "Please initialize RunpyGlobals context first.")
     return cast(RunpyGlobals, cls._instance)
Пример #21
0
    def _init_model(self, training_dataflow: Any,
                    validation_dataflow: Any) -> None:
        self._load()

        logging.info(f"Calling build_model")
        if self.hvd_config.use:
            trainer_type = "horovod"
        else:
            trainer_type = "replicated"
        model = self.trial.build_model(trainer_type)
        logging.info(f"Finished build_model")

        determined_ai_tensorpack = is_determined_ai_tensorpack()

        if not determined_ai_tensorpack and self.hvd_config.aggregation_frequency > 1:
            raise AssertionError(
                f"Gradient aggregation is only supported for custom DAI version of tensorpack"
            )

        if self.hvd_config.use:
            self.trainer = tp.HorovodTrainer(
                average=False,
                compression=hvd.compression.Compression.fp16,
                aggregation_frequency=self.hvd_config.aggregation_frequency,
            )
        else:
            num_gpus = len(self.env.container_gpus)
            self.trainer = tp.SyncMultiGPUTrainerReplicated(num_gpus,
                                                            average=False,
                                                            mode="nccl")

        inp = tp.QueueInput(training_dataflow)

        # StagingInput causes deadlocks in some code, so allow it to be disabled.
        # TODO: Figure out why.
        if not self.env.hparams.get("disable_staging_area"):
            inp = tp.StagingInput(inp, 1)

        logging.info(f"Calling setup_graph")
        self.trainer.setup_graph(model.get_input_signature(), inp,
                                 model.build_graph, model.get_optimizer)
        logging.info(f"Finished setup_graph")

        # For validation we support users specifying an Evaluator(), or passing in
        # the validation metrics they want to track. If they pass in validation
        # metrics, we create a custom InferenceRunner() callback. FasterRCNN uses the
        # Evaluator(), while all other Tensorpack example models use InferenceRunner.
        evaluator = None  # type: Optional[Evaluator]
        validation_metrics_names = None  # type: Optional[List[str]]
        inference_runner_callback = None  # type: Optional[CustomInferenceRunner]
        evaluator_or_validation_metrics = self.trial.validation_metrics()
        if isinstance(evaluator_or_validation_metrics, list):
            check.is_not_none(validation_dataflow)
            validation_scalar_stats = CustomScalarStats(
                evaluator_or_validation_metrics, prefix="val")
            validation_metrics_names = validation_scalar_stats.names_with_prefix(
            )
            inference_runner_callback = CustomInferenceRunner(
                self.rendezvous_info.get_rank(), validation_dataflow,
                validation_scalar_stats)
        else:
            evaluator = evaluator_or_validation_metrics

        metrics = ["loss", *self.trial.training_metrics()]

        if self.env.hparams.get("include_summary_metrics"):
            metrics.extend(t.op.inputs[1].name
                           for t in tf.get_collection(tf.GraphKeys.SUMMARIES))

        manager_cb = ManagerCallback(
            metrics,
            evaluator,
            validation_metrics_names,
            self.workloads,
            self.is_chief,
            self.rendezvous_info.get_rank(),
        )

        # TODO: check to make sure users don't pass in InferenceRunner
        # because that will run validation after every RUN_STEP.
        self.cbs = [manager_cb, *self.trial.tensorpack_callbacks()]
        if inference_runner_callback:
            self.cbs.append(inference_runner_callback)
Пример #22
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data_loader = trial.build_training_data_loader()
        validation_data_loader = trial.build_validation_data_loader()

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        training_x, training_y, training_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=training_data_loader)
        training_data = keras._adapt_keras_data(
            x=training_x,
            y=training_y,
            sample_weight=training_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            use_multiprocessing=context._fit_use_multiprocessing,
            workers=context._fit_workers,
            max_queue_size=context._fit_max_queue_size,
            drop_leftovers=True,
        )

        val_x, val_y, val_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=validation_data_loader)
        validation_data = keras._adapt_keras_data(
            x=val_x,
            y=val_y,
            sample_weight=val_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            use_multiprocessing=context._fit_use_multiprocessing,
            workers=context._fit_workers,
            max_queue_size=context._fit_max_queue_size,
            drop_leftovers=False,
        )

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Пример #23
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext",
        )
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_x, training_y, training_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=trial.build_training_data_loader())
        training_data = keras._adapt_keras_data(
            x=training_x,
            y=training_y,
            sample_weight=training_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            drop_leftovers=True,
        )

        val_x, val_y, val_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=trial.build_validation_data_loader())
        validation_data = keras._adapt_keras_data(
            x=val_x,
            y=val_y,
            sample_weight=val_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            drop_leftovers=False,
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        (
            context.model,
            compile_args.arguments["optimizer"],
        ) = keras._get_multi_gpu_model_and_optimizer(
            pre_compiled_model=context.model,
            optimizer=compile_args.arguments["optimizer"],
            env=env,
            hvd_config=hvd_config,
            profile_frequency=env.experiment_config.profile_frequency(),
            profile_filename=DeterminedProfiler.OUTPUT_FILENAME,
        )

        if hvd_config.use and version.parse(
                tf.__version__) >= version.parse("2.0.0"):
            logging.info(
                "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure "
                "TensorFlow calls `optimizer.get_gradients()` to compute gradients."
            )
            context.model.compile(*compile_args.args,
                                  **compile_args.kwargs,
                                  experimental_run_tf_function=False)
        else:
            context.model.compile(*compile_args.args, **compile_args.kwargs)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Пример #24
0
 def get_dataset_length(self) -> int:
     check.is_not_none(self._dataset_length,
                       "Dataset length not yet initialized.")
     return cast(int, self._dataset_length)