Exemplo n.º 1
0
            def new_train_batch(batch: pytorch.TorchData, epoch_idx: int,
                                batch_idx: int) -> Any:
                tr_metrics = train_batch(
                    batch=batch,
                    model=model,
                    epoch_idx=epoch_idx,
                    batch_idx=batch_idx,
                )
                if isinstance(tr_metrics, torch.Tensor):
                    tr_metrics = {"loss": tr_metrics}
                check.is_instance(
                    tr_metrics,
                    dict,
                    "train_batch() must return a dictionary "
                    f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
                )
                check.is_in("loss", tr_metrics.keys(),
                            'Please include "loss" in you training metrics.')

                def clip_grads(parameters: Iterator) -> None:
                    for callback in self.callbacks.values():
                        callback.on_before_optimizer_step(parameters)

                self.context.backward(tr_metrics["loss"])
                self.context.step_optimizer(self.context.optimizers[0],
                                            clip_grads=clip_grads)

                return tr_metrics
Exemplo n.º 2
0
 def metrics_result(self) -> Metrics:
     """Identical to result but disallow workload.Skipped responses."""
     check.is_not_none(self._response,
                       "_respond() was not called by the TrialController.")
     check.is_instance(self._response, dict,
                       "unexpected SkippedWorkload response.")
     return cast(Metrics, self._response)
Exemplo n.º 3
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             response_func(
                 util.wrap_metrics(
                     self._train_for_step(w.step_id, w.num_batches, w.total_batches_processed),
                     self.context.get_stop_requested(),
                 )
             )
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(
                 util.wrap_metrics(
                     self._compute_validation_metrics(), self.context.get_stop_requested()
                 )
             )
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.eq(len(args), 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save(path))
         elif w.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             break
         else:
             raise AssertionError("Unexpected workload: {}".format(w.kind))
Exemplo n.º 4
0
    def _calculate_batch_sizes(self) -> Tuple[int, int]:
        if "global_batch_size" not in self.hparams.keys():
            raise AssertionError(
                "Please specify `global_batch_size` under `hyperparameters` "
                "in experiment config.")

        if "batch_size" in self.hparams.keys():
            logging.warning(
                "Use `global_batch_size` not `batch_size` under `hyperparameters` "
                "in experiment config.")

        global_batch_size = self.hparams["global_batch_size"]
        check.is_instance(global_batch_size, int,
                          "`global_batch_size` hparam must be an int.")
        global_batch_size = cast(int, global_batch_size)

        if self.experiment_config.native_parallel_enabled():
            return global_batch_size, global_batch_size

        # Configure batch sizes.
        slots_per_trial = self.experiment_config.slots_per_trial()
        if global_batch_size < slots_per_trial:
            raise AssertionError(
                "Please set the `global_batch_size` hyperparameter to be greater or equal to the "
                f"number of slots. Current batch_size: {global_batch_size}, slots_per_trial: "
                f"{slots_per_trial}.")

        per_gpu_batch_size = global_batch_size // slots_per_trial
        effective_batch_size = per_gpu_batch_size * slots_per_trial
        if effective_batch_size != global_batch_size:
            logging.warning(
                f"`global_batch_size` changed from {global_batch_size} to {effective_batch_size} "
                f"to divide equally across {slots_per_trial} slots.")

        return per_gpu_batch_size, effective_batch_size
Exemplo n.º 5
0
    def run(self) -> None:
        """
        A basic control loop of the old-style (callback-based) TrialController
        classes.
        """

        for w, args, response_func in self.workloads:
            try:
                if w.kind == workload.Workload.Kind.RUN_STEP:
                    response = self.train_for_step(
                        w.step_id, w.num_batches)  # type: workload.Response
                elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    response = self.compute_validation_metrics(w.step_id)
                elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    check.len_eq(args, 1)
                    check.is_instance(args[0], pathlib.Path)
                    path = cast(pathlib.Path, args[0])
                    self.save(path)
                    response = {}
                elif w.kind == workload.Workload.Kind.TERMINATE:
                    self.terminate()
                    response = workload.Skipped()
                else:
                    raise AssertionError("Unexpected workload: {}".format(
                        w.kind))

            except det.errors.SkipWorkloadException:
                response = workload.Skipped()

            response_func(response)
Exemplo n.º 6
0
    def barrier(self,
                num_connections: int,
                message: Any = None,
                timeout: Optional[int] = None) -> List[Any]:
        """
        This is a one-sided barrier, where the chief blocks until
        all non-chief trial containers have sent a message.
        """
        check.eq(len(self.sockets), 1)
        messages = []  # type: List[Any]
        start_time = time.time()

        for _ in range(num_connections):
            if timeout:
                message_received, barrier_message = self.receive_non_blocking(
                    send_rank=0, deadline=start_time + timeout)

                if not message_received:
                    return messages

            else:
                barrier_message = self.receive_blocking(0)

            check.is_instance(barrier_message, _OneSidedBarrier)
            messages.append(barrier_message.message)
            self.sockets[0].send_pyobj(_OneSidedBarrier(message=message))

        return messages
Exemplo n.º 7
0
    def __init__(
        self,
        context: Union[keras.TFKerasTrialContext, keras.TFKerasNativeContext],
        train_config: keras.TFKerasTrainConfig,
    ) -> None:
        super().__init__(context=context)

        self._training_cacheable = self._context.experimental.get_train_cacheable(
        )
        self._training_dataset = train_config.training_data

        check.true(
            self._training_cacheable.is_decorator_used(),
            "Please use `@context.experimental.cache_train_dataset(dataset_name, dataset_version)`"
            " for the training dataset.",
        )
        check.false(
            self._context.dataset_initialized,
            "Please do not use: `context.wrap_dataset(dataset)` if using "
            "`@context.experimental.cache_train_dataset()` and "
            "`@context.experimental.cache_validation_dataset()`.",
        )
        check.is_instance(
            train_config.training_data,
            tf.data.Dataset,
            "Pass in a `tf.data.Dataset` object if using "
            "`@context.experimental.cache_train_dataset()`.",
        )
Exemplo n.º 8
0
    def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self._check_evaluate_implementation()

        self._init_model_and_optimizer()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # Track whether a warning logging category has already been issued to the user.
        self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False}

        self.context.lr_scheduler = self.trial.create_lr_scheduler(self.context.optimizer)

        self.callbacks = self.trial.build_callbacks()

        # If a load path is provided load weights and restore the data location.
        self._load()
        self._configure_amp()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.context.optimizer, root_rank=0)

        self.training_iterator = iter(self.training_loader)
Exemplo n.º 9
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             metrics = det.util.make_metrics(
                 num_inputs=None,
                 batch_metrics=[{
                     "loss": 1
                 } for _ in range(w.num_batches)],
             )
             response_func({"metrics": metrics})
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             check.len_eq(args, 0)
             response_func({
                 "metrics": {
                     "validation_metrics": self.validation_metrics
                 }
             })
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             if not path.exists():
                 path.mkdir(parents=True, exist_ok=True)
             with path.joinpath("a_file").open("w") as f:
                 f.write("yup")
             response_func({})
         elif w.kind == workload.Workload.Kind.TERMINATE:
             raise NotImplementedError()
Exemplo n.º 10
0
    def control_loop(self) -> None:
        for wkld, args, response_func in self.estimator_trial_controller.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")

            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Store values for the training loop.
                self.num_batches = wkld.num_batches
                self.train_response_func = response_func
                # Break out of the control loop so that the train process
                # re-enters the train_and_evaluate() loop.
                break
            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                response_func(
                    det.util.wrap_metrics(
                        self._compute_validation_metrics(),
                        self.estimator_trial_controller.context.
                        get_stop_requested(),
                    ))
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._checkpoint_model(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                self.estimator_trial_controller.exit_response_func = response_func
                raise det.errors.WorkerFinishedGracefully("Exiting normally.")
            else:
                raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
                exit(1)
Exemplo n.º 11
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        *args: Any,
        **kwargs: Any,
    ) -> det.TrialController:
        check.is_instance(
            context,
            estimator.EstimatorTrialContext,
            "EstimatorTrialController needs an EstimatorTrialContext",
        )
        context = cast(estimator.EstimatorTrialContext, context)

        check.is_instance(trial_inst, EstimatorTrial,
                          "EstimatorTrialController needs an EstimatorTrial")
        trial_inst = cast(EstimatorTrial, trial_inst)

        return EstimatorTrialController(
            trial_inst.build_estimator(),
            trial_inst.build_train_spec(),
            trial_inst.build_validation_spec(),
            trial_inst.build_serving_input_receiver_fns(),
            context,
            env,
            *args,
            **kwargs,
        )
Exemplo n.º 12
0
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         logging.debug(f"Received wkld {wkld.kind} with args {args}.")
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Configure the state for a training step.
             self.train_response_func = response_func
             self.train_workload_batches = 0
             self.train_workload_metrics = []
             self.train_workload_len = wkld.num_batches
             self.multiplexer.set_batches_requested(wkld.num_batches)
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(
                 det.util.wrap_metrics(self._compute_validation_metrics(),
                                       self.context.get_stop_requested()))
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             self.multiplexer._corrected_train_end()
             raise det.errors.WorkerFinishedGracefully
         else:
             raise AssertionError(f"Unknown workload kind {wkld.kind}.")
Exemplo n.º 13
0
    def _launch_evaluate(self) -> Any:
        validation_data = self.validation_data
        steps = None

        # Support the deprecated SequenceAdapter API.
        if isinstance(validation_data, keras.SequenceAdapter):
            # Ignore these settings and use the same settings as for the fit call.
            validation_data = validation_data.sequence

        if isinstance(validation_data, tf.keras.utils.Sequence):
            # Calculate the length of our validation shard.
            steps = len(validation_data)
            if self.context.distributed.get_size() > 1:
                size = self.context.distributed.get_size()
                rank = self.context.distributed.get_rank()
                steps = steps // size + (1 if steps % size > rank else 0)

            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=validation_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=False,
                shuffle=False,
                shuffle_seed=0,
                prior_batches_trained=0,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            validation_data = enqueuer.data()

        # Starting in TF 2.2 users may define custom test_step() that do
        # not use the model metrics.
        use_model_metrics = version.parse(tf.__version__) < version.parse("2.2.0")
        evaluate_kwargs = {} if use_model_metrics else {"return_dict": True}

        metrics_values = self.model.evaluate(
            validation_data,
            callbacks=self.callback_list,
            steps=steps,
            verbose=0,
            workers=0,
            **evaluate_kwargs,
        )
        logging.debug(f"Worker finished model.evaluate() with metrics: {metrics_values}.")

        # If the model was compiled with metrics=None, metrics_value will be a single value.
        if not isinstance(metrics_values, (tuple, list, dict)):
            metrics_values = (metrics_values,)

        if use_model_metrics:
            metrics = make_logs(self.model, {}, metrics_values, ModeKeys.TEST, prefix="val_")
        else:
            check.is_instance(metrics_values, dict)
            metrics = {f"val_{k}": v for k, v in metrics_values.items()}

        return metrics
Exemplo n.º 14
0
    def from_native(
        context: det.NativeContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasNativeContext,
            "TFKerasTrialController needs a TFKerasSprinkleContext",
        )
        context = cast(keras.TFKerasNativeContext, context)

        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        check.is_not_none(
            context.train_config,
            "Please call model.fit(...) or model.fit_generator(...).",
        )

        # For the Native API, we would break the user's model if we changed the session
        # right now, so we have to trust the user did not modify what we set previously.
        #
        # TODO(ryan): Fix this, probably with a function for configuring the backend session.
        session = tf.compat.v1.keras.backend.get_session()

        compile_args = cast(inspect.BoundArguments, context.compile_args)
        train_config = cast(keras.TFKerasTrainConfig, context.train_config)

        (
            context.model,
            compile_args.arguments["optimizer"],
        ) = keras._get_multi_gpu_model_and_optimizer(
            pre_compiled_model=context.model,
            optimizer=compile_args.arguments["optimizer"],
            env=env,
            hvd_config=hvd_config,
            profile_frequency=env.experiment_config.profile_frequency(),
            profile_filename=DeterminedProfiler.OUTPUT_FILENAME,
        )

        context.model.compile(*compile_args.args, **compile_args.kwargs)

        return TFKerasTrialController(
            context.model,
            session,
            train_config,
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Exemplo n.º 15
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Exemplo n.º 16
0
 def barrier(self, message: Any = None) -> Any:
     """
     This is a one-sided barrier, where the chief blocks until
     all non-chief trial containers have sent a message.
     """
     self.socket.send_pyobj(_OneSidedBarrier(message=message))
     barrier_message = self.socket.recv_pyobj()
     check.is_instance(barrier_message, _OneSidedBarrier)
     return barrier_message.message
Exemplo n.º 17
0
 def _do_startup_message_sequence(self) -> None:
     # Wait for a ReadyMessage from every worker.
     responses = self.broadcast_server.gather_with_polling(
         self._health_check)
     for response in responses:
         check.is_instance(
             response,
             ipc.ReadyMessage,
             f"Did not receive ReadyMessage from worker. Got: {response}",
         )
Exemplo n.º 18
0
    def _do_startup_message_sequence(self) -> None:
        # Wait for a ReadyMessage from every worker.
        responses, exception_received = self.broadcast_server.gather_with_polling(
            self._health_check)

        if exception_received:
            raise det.errors.WorkerError("Training process died.")

        for response in responses:
            check.is_instance(
                response,
                ipc.ReadyMessage,
                f"Did not receive ReadyMessage from worker. Got: {response}",
            )
Exemplo n.º 19
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, TensorpackTrial,
                          "TensorpackTrialController needs a TensorpackTrial")
        self.trial = cast(TensorpackTrial, trial_inst)

        training_dataflow = self.trial.build_training_dataflow()
        validation_dataflow = self.trial.build_validation_dataflow()

        # Set if model is initialized from scratch.
        self.session_init = None  # type: Optional[Any]

        self._init_model(training_dataflow, validation_dataflow)
Exemplo n.º 20
0
    def from_native(
        context: det.NativeContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasNativeContext,
            "TFKerasTrialController needs a TFKerasSprinkleContext",
        )
        context = cast(keras.TFKerasNativeContext, context)

        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        check.is_not_none(
            context.train_config,
            "Please call model.fit(...) or model.fit_generator(...).")

        # For the Native API, we would break the user's model if we changed the session
        # right now, so we have to trust the user did not modify what we set previously.
        #
        # TODO(ryan): Fix this, probably with a function for configuring the backend session.
        session = tf.compat.v1.keras.backend.get_session()

        compile_args = cast(inspect.BoundArguments, context.compile_args)
        train_config = cast(keras.TFKerasTrainConfig, context.train_config)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        return TFKerasTrialController(
            context.model,
            session,
            train_config,
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Exemplo n.º 21
0
    def run(self) -> None:
        for wkld, args, response_func in self.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")

            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Store the train_response_func for later.
                self.train_response_func = response_func
                self.num_batches = wkld.num_batches

                # There are two possibilities when a RUN_STEP workload is recieved.
                # 1) This is the first training step seen by the trial
                #    container. In this case, enter the tf.keras fit() training loop.
                # 2) This is _not_ the first training step, meaning that the
                #    tf.keras fit() training loop is already active and paused.
                #    break to re-enter the training loop.
                if not self.fit_loop_started:
                    try:
                        self._launch_fit()
                    except det.errors.WorkerFinishedGracefully:
                        pass

                    if not self.expect_terminate:
                        raise AssertionError(
                            "Training loop exited unexpectedly but without throwing any errors. "
                            "This is possibly due to a user callback causing the training loop to "
                            "exit, which is not supported at this time."
                        )
                break

            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                response_func(
                    det.util.wrap_metrics(
                        self.compute_validation_metrics(), self.context.get_stop_requested()
                    )
                )
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._save_checkpoint(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                self.model.stop_training = True
                self.expect_terminate = True
                response_func({} if self.is_chief else workload.Skipped())
                break
            else:
                raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
Exemplo n.º 22
0
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Move on to the next step.
             self.train_response_func = response_func
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(self._compute_validation_metrics())
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self.save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             raise det.errors.WorkerFinishedGracefully("Exiting normally.")
         else:
             raise AssertionError(f"Unknown wkld kind {wkld.kind}")
Exemplo n.º 23
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             check.eq(len(args), 1)
             num_batches = cast(int, args[0])
             response_func(self._train_for_step(w.step_id, num_batches))
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(self._compute_validation_metrics())
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.eq(len(args), 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save(path))
         elif w.kind == workload.Workload.Kind.TERMINATE:
             break
         else:
             raise AssertionError("Unexpected workload: {}".format(w.kind))
Exemplo n.º 24
0
    def _send_recv_workload(self, wkld: workload.Workload,
                            args: List[Any]) -> workload.Response:
        # Broadcast every workload to every worker on this machine.
        self.broadcast_server.broadcast((wkld, args))

        if wkld.kind == workload.Workload.Kind.TERMINATE:
            # Do not perform health checks once worker have been instructed to terminate.
            self._worker_process_ids = []

        try:
            responses, exception_received = self.broadcast_server.gather_with_polling(
                self._health_check)
        except det.errors.WorkerError:
            if wkld.kind == workload.Workload.Kind.TERMINATE:
                return {}
            raise

        if exception_received:
            raise det.errors.WorkerError("Training process died.")

        # Find the response from the chief worker for the trial (the only non-SkippedWorkload). The
        # chief may report to another container, in which case we will only have SkippedWorkloads.
        chief_worker_response = None  # Optional[workload.Metrics]
        for response in responses:
            if isinstance(response, workload.Skipped):
                continue
            # Any other response must be a Dict[str, Any]-like object.
            check.is_instance(
                response, dict,
                f"Received non-metrics object from worker: {response}")
            # There should only be one chief response.
            check.is_none(chief_worker_response,
                          "Received multiple non-SkippedWorkload messages.")
            chief_worker_response = cast(Dict[str, Any], response)

        # Confirm that if we have did not see a chief response then we are not the chief machine.
        if chief_worker_response is None:
            check.gt(
                self.rendezvous_info.get_rank(),
                0,
                "Received SkippedWorkload message from chief worker.",
            )

        return workload.Skipped(
        ) if chief_worker_response is None else chief_worker_response
Exemplo n.º 25
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(pytorch.PyTorchTrialContext, self.context)
        self.context.experimental._set_allgather_fn(self.allgather_metrics)
        self.callbacks = self.trial.build_callbacks()

        self._apply_backwards_compatibility()

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with wrap_model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your optimizer with wrap_optimizer()",
        )
        self._check_evaluate_implementation()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # We don't want the training_iterator shuffling values after we load state
        self.training_iterator = iter(self.training_loader)

        # If a load path is provided load weights and restore the data location.
        self._load()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context._main_model.state_dict(),
                                     root_rank=0)
            for optimizer in self.context.optimizers:
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)
Exemplo n.º 26
0
    def _launch_evaluate(self) -> Any:
        (
            validation_data,
            validation_steps,
        ) = self._validation_input_manager.get_validation_input_and_num_batches(
        )

        # Starting in TF 2.2 users may define custom test_step() that do
        # not use the model metrics.
        use_model_metrics = version.parse(
            tf.__version__) < version.parse("2.2.0")
        evaluate_kwargs = {} if use_model_metrics else {"return_dict": True}

        metrics_values = self.model.evaluate(
            validation_data,
            steps=validation_steps,
            verbose=0,
            callbacks=self.callback_list,
            **evaluate_kwargs,
        )
        logging.debug(
            f"Worker finished model.evaluate() with metrics: {metrics_values}."
        )

        # If the model was compiled with metrics=None, metrics_value will be a single value.
        if not isinstance(metrics_values, (tuple, list, dict)):
            metrics_values = (metrics_values, )

        if use_model_metrics:
            metrics = make_logs(self.model, {},
                                metrics_values,
                                ModeKeys.TEST,
                                prefix="val_")
        else:
            check.is_instance(metrics_values, dict)
            metrics = {f"val_{k}": v for k, v in metrics_values.items()}

        _ = self._validation_input_manager.stop_validation_input_and_get_num_inputs(
        )

        return metrics
Exemplo n.º 27
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self._check_evaluate_implementation()

        self.model = self.trial.build_model()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]

        self._set_data_loaders()

        # Track whether a warning logging category has already been issued to the user.
        self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False}

        self._init_model()
Exemplo n.º 28
0
    def _init_model(self) -> None:
        self._init_paths()

        self.estimator = tf.estimator.Estimator(
            model_fn=self.estimator._model_fn,
            config=self._init_run_config(self.estimator.config),
            params=self.estimator.params
            if self.estimator.params != {} else None,
            warm_start_from=self.estimator._warm_start_settings,
        )

        check.is_instance(
            self.estimator,
            tf.estimator.Estimator,
            "Please modify your model definition's build_estimator() implementation to return "
            "an instance of `tf.estimator.Estimator`.",
        )
        check.is_instance(
            self.user_train_spec,
            tf.estimator.TrainSpec,
            "Please modify your model definition's build_train_spec() implementation to return "
            "an instance of `tf.estimator.TrainSpec`.",
        )
        check.is_instance(
            self.val_spec,
            tf.estimator.EvalSpec,
            "Please modify your model definition's build_validation_spec() implementation "
            "to return an instance of `tf.estimator.EvalSpec`.",
        )

        all_hooks = [*self.user_train_spec.hooks]

        if self.hvd_config.use:
            all_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        # It is important that this hook is the final in the list so that if
        # any other hooks need to run _before_ the training step ends they have
        # their chance.
        all_hooks.append(DeterminedControlHook(self))

        # TODO(DET-834): Separate step ID from data loader state.
        #
        # During warm start, we initialize model weights, optimizer state
        # and input state from the checkpoint, and we set the step ID to
        # 1. Trials typically use the step ID as an index into the data
        # sequence, which means there is an inconsistency between the
        # step ID (as data index) and the optimizer state and input state.
        #
        # In the short term, behave like other trials and reset input
        # state if we are warm started. This will create an inconsistency
        # wrt saved optimizer state.

        # Repeat training dataset so we never run out of data.
        repeating_train_fn = self._check_and_repeat_train_input_fn(
            self.user_train_spec.input_fn)

        self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn,
                                                 hooks=all_hooks)
        self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn,
                                               steps=None)
Exemplo n.º 29
0
    def from_native(context: det.NativeContext, *args: Any,
                    **kwargs: Any) -> det.TrialController:
        check.is_instance(
            context,
            estimator.EstimatorNativeContext,
            "EstimatorTrialController needs an EstimatorSprinkleContext",
        )
        context = cast(estimator.EstimatorNativeContext, context)

        check.true(
            hasattr(context, "estimator") and hasattr(context, "train_spec")
            and hasattr(context, "eval_spec"),
            "Please call TFEstimatorExperiment.train_and_evaluate().",
        )

        return EstimatorTrialController(
            context.estimator,
            context.train_spec,
            context.eval_spec,
            context.serving_input_receiver_fns,
            context,
            *args,
            **kwargs,
        )
Exemplo n.º 30
0
    def _init_model(self) -> None:
        self._init_train_hooks()
        self._init_val_hooks()
        self._init_paths()

        self.estimator = tf.estimator.Estimator(
            model_fn=self._set_default_session_before_building_model(
                self.estimator._model_fn),
            config=self._init_run_config(self.estimator.config),
            params=self.estimator.params
            if self.estimator.params != {} else None,
            warm_start_from=self.estimator._warm_start_settings,
        )

        check.is_instance(
            self.estimator,
            tf.estimator.Estimator,
            "Please modify your model definition's build_estimator() implementation to return "
            "an instance of `tf.estimator.Estimator`.",
        )
        check.is_instance(
            self.user_train_spec,
            tf.estimator.TrainSpec,
            "Please modify your model definition's build_train_spec() implementation to return "
            "an instance of `tf.estimator.TrainSpec`.",
        )
        check.is_instance(
            self.val_spec,
            tf.estimator.EvalSpec,
            "Please modify your model definition's build_validation_spec() implementation "
            "to return an instance of `tf.estimator.EvalSpec`.",
        )

        # TODO(DET-834): Separate step ID from data loader state.
        #
        # During warm start, we initialize model weights, optimizer state
        # and input state from the checkpoint, and we set the step ID to
        # 1. Trials typically use the step ID as an index into the data
        # sequence, which means there is an inconsistency between the
        # step ID (as data index) and the optimizer state and input state.
        #
        # In the short term, behave like other trials and reset input
        # state if we are warm started. This will create an inconsistency
        # wrt saved optimizer state.

        # Repeat training dataset so we never run out of data.
        repeating_train_fn = self._check_and_repeat_train_input_fn(
            self.user_train_spec.input_fn)

        self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn,
                                                 hooks=self.train_hooks)

        self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn,
                                               hooks=self._init_val_hooks(),
                                               steps=self.val_spec.steps)