コード例 #1
0
 def metrics_result(self) -> Metrics:
     """Identical to result but disallow workload.Skipped responses."""
     check.is_not_none(self._response,
                       "_respond() was not called by the TrialController.")
     check.is_instance(self._response, dict,
                       "unexpected SkippedWorkload response.")
     return cast(Metrics, self._response)
コード例 #2
0
ファイル: ipc.py プロジェクト: wbwatkinson/determined
    def barrier(self,
                num_connections: int,
                message: Any = None,
                timeout: Optional[int] = None) -> List[Any]:
        """
        This is a one-sided barrier, where the chief blocks until
        all non-chief trial containers have sent a message.
        """
        check.eq(len(self.sockets), 1)
        messages = []  # type: List[Any]
        start_time = time.time()

        for _ in range(num_connections):
            if timeout:
                message_received, barrier_message = self.receive_non_blocking(
                    send_rank=0, deadline=start_time + timeout)

                if not message_received:
                    return messages

            else:
                barrier_message = self.receive_blocking(0)

            check.is_instance(barrier_message, _OneSidedBarrier)
            messages.append(barrier_message.message)
            self.sockets[0].send_pyobj(
                _OneSidedBarrier(message=message))  # type: ignore

        return messages
コード例 #3
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             metrics = det.util.make_metrics(
                 num_inputs=None,
                 batch_metrics=[{
                     "loss": 1
                 } for _ in range(w.num_batches)],
             )
             response_func({"metrics": metrics})
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             check.len_eq(args, 0)
             response_func({
                 "metrics": {
                     "validation_metrics": self.validation_metrics
                 }
             })
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             if not path.exists():
                 path.mkdir(parents=True, exist_ok=True)
             with path.joinpath("a_file").open("w") as f:
                 f.write("yup")
             response_func({})
         elif w.kind == workload.Workload.Kind.TERMINATE:
             raise NotImplementedError()
コード例 #4
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(pytorch.PyTorchTrialContext, self.context)
        self.callbacks = self.trial.build_callbacks()

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with wrap_model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your optimizer with wrap_optimizer()",
        )
        self._check_evaluate_implementation()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()
コード例 #5
0
    def run(self) -> None:
        """
        A basic control loop of the old-style (callback-based) TrialController
        classes.
        """

        for w, args, response_func in self.workloads:
            try:
                if w.kind == workload.Workload.Kind.RUN_STEP:
                    response = self.train_for_step(
                        w.step_id, w.num_batches)  # type: workload.Response
                elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    response = self.compute_validation_metrics(w.step_id)
                elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    check.len_eq(args, 1)
                    check.is_instance(args[0], pathlib.Path)
                    path = cast(pathlib.Path, args[0])
                    self.save(path)
                    response = {}
                elif w.kind == workload.Workload.Kind.TERMINATE:
                    self.terminate()
                    response = workload.Skipped()
                else:
                    raise AssertionError("Unexpected workload: {}".format(
                        w.kind))

            except det.errors.SkipWorkloadException:
                response = workload.Skipped()

            response_func(response)
コード例 #6
0
ファイル: determined.py プロジェクト: wbwatkinson/determined
    def create_experiment(
        self,
        config: Union[str, pathlib.Path, Dict],
        model_dir: str,
    ) -> experiment.ExperimentReference:
        check.is_instance(config, (str, pathlib.Path, dict),
                          "config parameter must be dictionary or path")
        if isinstance(config, str):
            with open(config) as f:
                experiment_config = yaml.safe_load(f)
        elif isinstance(config, pathlib.Path):
            with config.open() as f:
                experiment_config = yaml.safe_load(f)
        elif isinstance(config, Dict):
            experiment_config = config

        model_context = _path_to_files(pathlib.Path(model_dir))

        experiment_request = V1CreateExperimentRequest(
            model_definition=model_context,
            config=yaml.safe_dump(experiment_config),
        )
        experiment_response = self._internal.determined_create_experiment(
            experiment_request)
        return experiment.ExperimentReference(
            experiment_response.experiment.id,
            self._session._master,
            self._experiments,
        )
コード例 #7
0
ファイル: _pytorch_trial.py プロジェクト: shiyuann/determined
 def _run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             try:
                 response_func(
                     util.wrap_metrics(
                         self._train_for_step(
                             w.step_id,
                             w.num_batches,
                             w.total_batches_processed,
                         ),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     )
                 )
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial train step: {}".format(e)
                 )
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     )
                 )
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             try:
                 response_func(
                     util.wrap_metrics(
                         self._compute_validation_metrics(),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     )
                 )
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial validation step: {}".format(e)
                 )
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     )
                 )
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.eq(len(args), 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save(path))
         elif w.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             break
         else:
             raise AssertionError("Unexpected workload: {}".format(w.kind))
コード例 #8
0
    def safe_start(self, health_check: Callable[[], None]) -> None:
        """
        Broadcast Hello messages over and over until all clients response with a Hello message.

        The reason for this is that the only way to be 100% confident that a subscriber has
        connected is for it to actually receive a message over the pub/sub connection.

        After each client sees its first Hello, it will send a single Hello message to the
        server.

        After all connections have been made, the server will broadcast a FinalHello.
        """

        connections_made = 0
        while connections_made < self._num_connections:
            # Send a Hello.
            self._pub_socket.send_pyobj(_HelloMessage())

            # Check for an incoming connection.
            if self._pull_socket.poll(50) == 0:
                health_check()
                continue

            obj = self._pull_socket.recv_pyobj()
            check.is_instance(obj, _HelloMessage, "got non-_HelloMessage in server safe_start")
            connections_made += 1

        self._pub_socket.send_pyobj(_FinalHelloMessage())
コード例 #9
0
 def as_batches(
     self,
     batches: Optional[int] = None,
     records: Optional[int] = None,
     epochs: Optional[int] = None,
 ) -> int:
     if sum((batches is not None, records is not None, epochs
             is not None)) != 1:
         raise ValueError(
             f"invalid length: batches={batches} records={records} epochs={epochs}"
         )
     if batches is not None:
         return batches
     if records is not None:
         check.gt(self.global_batch_size, 0,
                  "global_batch_size must be positive")
         return max(records // self.global_batch_size, 1)
     if epochs is not None:
         check.is_instance(self.records_per_epoch, int,
                           "length must be an integer")
         assert self.records_per_epoch is not None
         check.gt(self.global_batch_size, 0,
                  "global_batch_size must be positive")
         return max(
             (epochs * self.records_per_epoch) // self.global_batch_size, 1)
     # Make mypy happy.
     raise ValueError("invalid length")
コード例 #10
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        *args: Any,
        **kwargs: Any,
    ) -> det.TrialController:
        check.is_instance(
            context,
            estimator.EstimatorTrialContext,
            "EstimatorTrialController needs an EstimatorTrialContext",
        )
        context = cast(estimator.EstimatorTrialContext, context)

        check.is_instance(trial_inst, EstimatorTrial,
                          "EstimatorTrialController needs an EstimatorTrial")
        trial_inst = cast(EstimatorTrial, trial_inst)

        return EstimatorTrialController(
            trial_inst.build_estimator(),
            trial_inst.build_train_spec(),
            trial_inst.build_validation_spec(),
            trial_inst.build_serving_input_receiver_fns(),
            context,
            env,
            *args,
            **kwargs,
        )
コード例 #11
0
    def _calculate_batch_sizes(self) -> Tuple[int, int]:
        if "global_batch_size" not in self.hparams.keys():
            raise AssertionError(
                "Please specify `global_batch_size` under `hyperparameters` "
                "in experiment config.")

        if "batch_size" in self.hparams.keys():
            logging.warning(
                "Use `global_batch_size` not `batch_size` under `hyperparameters` "
                "in experiment config.")

        global_batch_size = self.hparams["global_batch_size"]
        check.is_instance(global_batch_size, int,
                          "`global_batch_size` hparam must be an int.")
        global_batch_size = cast(int, global_batch_size)

        if self.experiment_config.native_parallel_enabled():
            return global_batch_size, global_batch_size

        # Configure batch sizes.
        slots_per_trial = self.experiment_config.slots_per_trial()
        if global_batch_size < slots_per_trial:
            raise AssertionError(
                "Please set the `global_batch_size` hyperparameter to be greater or equal to the "
                f"number of slots. Current batch_size: {global_batch_size}, slots_per_trial: "
                f"{slots_per_trial}.")

        per_gpu_batch_size = global_batch_size // slots_per_trial
        effective_batch_size = per_gpu_batch_size * slots_per_trial
        if effective_batch_size != global_batch_size:
            logging.warning(
                f"`global_batch_size` changed from {global_batch_size} to {effective_batch_size} "
                f"to divide equally across {slots_per_trial} slots.")

        return per_gpu_batch_size, effective_batch_size
コード例 #12
0
ファイル: _tf_keras_trial.py プロジェクト: hkang1/determined
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
コード例 #13
0
ファイル: ipc.py プロジェクト: wbwatkinson/determined
 def barrier(self, message: Any = None) -> Any:
     """
     This is a one-sided barrier, where the chief blocks until
     all non-chief trial containers have sent a message.
     """
     self.socket.send_pyobj(_OneSidedBarrier(message=message))
     barrier_message = self.socket.recv_pyobj()
     check.is_instance(barrier_message, _OneSidedBarrier)
     return barrier_message.message
コード例 #14
0
    def _send_recv_workload(self, wkld: workload.Workload,
                            args: List[Any]) -> workload.Response:
        # Broadcast every workload to every worker on this machine.
        self.broadcast_server.broadcast((wkld, args))

        if wkld.kind == workload.Workload.Kind.TERMINATE:
            # Do not perform health checks once worker have been instructed to terminate.
            self._worker_process_ids = []

        try:
            responses, exception_received = self.broadcast_server.gather_with_polling(
                self._health_check)
        except det.errors.WorkerError:
            if wkld.kind == workload.Workload.Kind.TERMINATE:
                return {}
            raise

        if exception_received:
            raise det.errors.WorkerError("Training process died.")

        # Find the response from the chief worker for the trial (the only non-SkippedWorkload). The
        # chief may report to another container, in which case we will only have SkippedWorkloads.
        chief_worker_response = None  # Optional[workload.Metrics]
        for response in responses:
            if isinstance(response, workload.Skipped):
                continue
            # Any other response must be a Dict[str, Any]-like object.
            check.is_instance(
                response, dict,
                f"Received non-metrics object from worker: {response}")
            # There should only be one chief response.
            # Special case InvalidHP messages
            if chief_worker_response != {
                    "metrics": {},
                    "stop_requested": False,
                    "invalid_hp": True,
                    "init_invalid_hp": False,
            }:
                check.is_none(
                    chief_worker_response,
                    "Received multiple non-SkippedWorkload messages.")
            chief_worker_response = cast(Dict[str, Any], response)

        # Confirm that if we have did not see a chief response then we are not the chief machine.
        if chief_worker_response is None:
            check.gt(
                self.rendezvous_info.get_rank(),
                0,
                "Received SkippedWorkload message from chief worker.",
            )

        return workload.Skipped(
        ) if chief_worker_response is None else chief_worker_response
コード例 #15
0
    def from_trial(
        cls: Type["TFKerasTrialController"],
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: Optional[workload.Stream] = None,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext, "TFKerasTrialController needs a TFKerasTrialContext"
        )
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial, "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        # Keras only supports horovod backend for distributed training
        session = cls._configure_session(
            env, trial.session_config(), use_horovod=context.distributed.size > 1
        )

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args, "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        cls.compile_model(context=context, compile_args=compile_args, env=env)

        tf_keras_callbacks = trial.keras_callbacks()

        return cls(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data, tf_keras_callbacks),
            trial,
            context,
            env,
            workloads,
        )
コード例 #16
0
ファイル: _tf_keras_trial.py プロジェクト: hkang1/determined
    def from_native(
        context: det.NativeContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasNativeContext,
            "TFKerasTrialController needs a TFKerasSprinkleContext",
        )
        context = cast(keras.TFKerasNativeContext, context)

        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        check.is_not_none(
            context.train_config,
            "Please call model.fit(...) or model.fit_generator(...).")

        # For the Native API, we would break the user's model if we changed the session
        # right now, so we have to trust the user did not modify what we set previously.
        #
        # TODO(ryan): Fix this, probably with a function for configuring the backend session.
        session = tf.compat.v1.keras.backend.get_session()

        compile_args = cast(inspect.BoundArguments, context.compile_args)
        train_config = cast(keras.TFKerasTrainConfig, context.train_config)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        return TFKerasTrialController(
            context.model,
            session,
            train_config,
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
コード例 #17
0
    def _do_startup_message_sequence(self) -> None:
        # Wait for a ConnectedMessage from every worker.
        responses, exception_received = self.broadcast_server.gather_with_polling(
            self._health_check)

        if exception_received:
            raise det.errors.WorkerError("Training process died.")

        for response in responses:
            check.is_instance(
                response,
                ipc.ConnectedMessage,
                f"Did not receive ConnectedMessage from worker. Got: {response}",
            )
            response = cast(ipc.ConnectedMessage, response)
            self._worker_process_ids.append(response.process_id)
コード例 #18
0
    def safe_start(self) -> None:
        """
        See ZMQBroadcastServer.safe_start().
        """

        # Get the first HelloMessage to guarantee our SUB socket is connected.
        obj = self._sub_socket.recv_pyobj()
        check.is_instance(obj, _HelloMessage, "got non-HelloMessage in client.safe_start()")

        # Send our own _HelloMessage.
        self._push_socket.send_pyobj(_HelloMessage())

        while True:
            # Discard all further Hellos until the FinalHello.
            obj = self._sub_socket.recv_pyobj()
            if isinstance(obj, _FinalHelloMessage):
                break
            check.is_instance(obj, _HelloMessage, "got non-HelloMessage in client.safe_start()")
コード例 #19
0
    def create_experiment(
        self,
        config: Union[str, pathlib.Path, Dict],
        model_dir: Union[str, pathlib.Path],
    ) -> experiment.ExperimentReference:
        """
        Create an experiment with config parameters and model directory. The function
        returns :class:`~determined.experimental.ExperimentReference` of the experiment.

        Arguments:
            config(string, pathlib.Path, dictionary): experiment config filename (.yaml)
                or a dict.
            model_dir(string): directory containing model definition.
        """
        check.is_instance(
            config, (str, pathlib.Path, dict), "config parameter must be dictionary or path"
        )
        if isinstance(config, str):
            with open(config) as f:
                experiment_config = util.safe_load_yaml_with_exceptions(f)
        elif isinstance(config, pathlib.Path):
            with config.open() as f:
                experiment_config = util.safe_load_yaml_with_exceptions(f)
        elif isinstance(config, Dict):
            experiment_config = config

        if isinstance(model_dir, str):
            model_dir = pathlib.Path(model_dir)

        model_context, _ = context.read_context(model_dir)

        resp = self._session.post(
            "/api/v1/experiments",
            body={
                "config": yaml.safe_dump(experiment_config),
                "model_definition": model_context,
            },
        )

        exp_id = _CreateExperimentResponse(resp.json()).id
        exp = experiment.ExperimentReference(exp_id, self._session)
        exp.activate()

        return exp
コード例 #20
0
    def control_loop(self) -> None:
        for wkld, args, response_func in self.estimator_trial_controller.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")

            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Store values for the training loop.
                self.num_batches = wkld.num_batches
                self.train_response_func = response_func
                # Break out of the control loop so that the train process
                # re-enters the train_and_evaluate() loop.
                break
            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                try:
                    response_func(
                        det.util.wrap_metrics(
                            self._compute_validation_metrics(),
                            self.estimator_trial_controller.context.
                            get_stop_requested(),
                            invalid_hp=False,
                            init_invalid_hp=False,
                        ))
                except det.InvalidHP as e:
                    logging.info(
                        "Invalid hyperparameter exception in trial validation step: {}"
                        .format(e))
                    response_func(
                        det.util.wrap_metrics(
                            {},
                            self.estimator_trial_controller.context.
                            get_stop_requested(),
                            invalid_hp=True,
                            init_invalid_hp=False,
                        ))
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._checkpoint_model(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                self.estimator_trial_controller.exit_response_func = response_func
                raise det.errors.WorkerFinishedGracefully("Exiting normally.")
            else:
                raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
コード例 #21
0
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         logging.debug(f"Received wkld {wkld.kind} with args {args}.")
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Configure the state for a training step.
             self.train_response_func = response_func
             self.train_workload_batches = 0
             self.train_workload_metrics = []
             self.train_workload_len = wkld.num_batches
             self.multiplexer.set_batches_requested(wkld.num_batches)
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             try:
                 response_func(
                     det.util.wrap_metrics(
                         self._compute_validation_metrics(),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     ))
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial validation step: {}"
                     .format(e))
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     ))
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             self.multiplexer._corrected_train_end()
             raise det.errors.WorkerFinishedGracefully
         else:
             raise AssertionError(f"Unknown workload kind {wkld.kind}.")
コード例 #22
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(pytorch.PyTorchTrialContext, self.context)
        self.context.experimental._set_allgather_fn(self.allgather_metrics)
        self.callbacks = self.trial.build_callbacks()

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with wrap_model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your optimizer with wrap_optimizer()",
        )
        self._check_evaluate_implementation()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # We don't want the training_iterator shuffling values after we load state
        self.training_iterator = iter(self.training_loader)

        # If a load path is provided load weights and restore the data location.
        self._load()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context._main_model.state_dict(),
                                     root_rank=0)
            for optimizer in self.context.optimizers:
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)
コード例 #23
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(pytorch.PyTorchTrialContext, self.context)
        self.context._set_determined_profiler(self.prof)
        if torch.cuda.is_available():
            self.prof._set_sync_device(self._sync_device)
        self.callbacks = self.trial.build_callbacks()

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with wrap_model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your optimizer with wrap_optimizer()",
        )
        self._check_evaluate_implementation()

        self.wlsq = None  # type: Optional[layers.WorkloadSequencer]
        if self.workloads is None:
            self.workloads, self.wlsq = layers.make_compatibility_workloads(
                self.context._core,
                self.env,
                self.context.get_global_batch_size(),
            )

        self.steps_completed = self.env.steps_completed

        # Currently only horovod and torch backends are supported for distributed training
        if self.context.distributed.size > 1:
            assert (self.use_horovod or self.use_torch
                    ), "Must use horovod or torch for distributed training"
コード例 #24
0
    def create_experiment(
        self,
        config: Union[str, pathlib.Path, Dict],
        model_dir: str,
    ) -> experiment.ExperimentReference:
        """
        Create an experiment with config parameters and model direcotry. The function
        returns :class:`~determined.experimental.ExperimentReference` of the experiment.

        Arguments:
            config(string, pathlib.Path, dictionary): experiment config filename (.yaml)
                or a dict.
            model_dir(string): directory containing model definition.
        """
        check.is_instance(config, (str, pathlib.Path, dict),
                          "config parameter must be dictionary or path")
        if isinstance(config, str):
            with open(config) as f:
                experiment_config = util.safe_load_yaml_with_exceptions(f)
        elif isinstance(config, pathlib.Path):
            with config.open() as f:
                experiment_config = util.safe_load_yaml_with_exceptions(f)
        elif isinstance(config, Dict):
            experiment_config = config

        model_context = _path_to_files(pathlib.Path(model_dir))

        experiment_request = V1CreateExperimentRequest(
            model_definition=model_context,
            config=yaml.safe_dump(experiment_config),
        )
        experiment_response = self._internal.determined_create_experiment(
            experiment_request)
        return experiment.ExperimentReference(
            experiment_response.experiment.id,
            self._session._master,
            self._experiments,
        )
コード例 #25
0
    def from_native(context: det.NativeContext, *args: Any,
                    **kwargs: Any) -> det.TrialController:
        check.is_instance(
            context,
            estimator.EstimatorNativeContext,
            "EstimatorTrialController needs an EstimatorSprinkleContext",
        )
        context = cast(estimator.EstimatorNativeContext, context)

        check.true(
            hasattr(context, "estimator") and hasattr(context, "train_spec")
            and hasattr(context, "eval_spec"),
            "Please call TFEstimatorExperiment.train_and_evaluate().",
        )

        return EstimatorTrialController(
            context.estimator,
            context.train_spec,
            context.eval_spec,
            context.serving_input_receiver_fns,
            context,
            *args,
            **kwargs,
        )
コード例 #26
0
    def _init_model(self) -> None:
        self._init_train_hooks()
        self._init_val_hooks()
        self._init_paths()

        self.estimator = tf.estimator.Estimator(
            model_fn=self._set_default_session_before_building_model(
                self.estimator._model_fn),
            config=self._init_run_config(self.estimator.config),
            params=self.estimator.params
            if self.estimator.params != {} else None,
            warm_start_from=self.estimator._warm_start_settings,
        )

        check.is_instance(
            self.estimator,
            tf.estimator.Estimator,
            "Please modify your model definition's build_estimator() implementation to return "
            "an instance of `tf.estimator.Estimator`.",
        )
        check.is_instance(
            self.user_train_spec,
            tf.estimator.TrainSpec,
            "Please modify your model definition's build_train_spec() implementation to return "
            "an instance of `tf.estimator.TrainSpec`.",
        )
        check.is_instance(
            self.val_spec,
            tf.estimator.EvalSpec,
            "Please modify your model definition's build_validation_spec() implementation "
            "to return an instance of `tf.estimator.EvalSpec`.",
        )

        # TODO(DET-834): Separate step ID from data loader state.
        #
        # During warm start, we initialize model weights, optimizer state
        # and input state from the checkpoint, and we set the step ID to
        # 1. Trials typically use the step ID as an index into the data
        # sequence, which means there is an inconsistency between the
        # step ID (as data index) and the optimizer state and input state.
        #
        # In the short term, behave like other trials and reset input
        # state if we are warm started. This will create an inconsistency
        # wrt saved optimizer state.

        # Repeat training dataset so we never run out of data.
        repeating_train_fn = self._check_and_repeat_train_input_fn(
            self.user_train_spec.input_fn)

        self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn,
                                                 hooks=self.train_hooks)

        self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn,
                                               hooks=self._init_val_hooks(),
                                               steps=self.val_spec.steps)
コード例 #27
0
 def from_file(path: pathlib.Path) -> "WorkerProcessContext":
     with path.open(mode="rb") as f:
         obj = pickle.load(f)
     check.is_instance(obj, WorkerProcessContext, "did not find WorkerProcessContext in file")
     return cast(WorkerProcessContext, obj)
コード例 #28
0
ファイル: _tf_keras_trial.py プロジェクト: hkang1/determined
    def _launch_evaluate(self) -> Any:
        validation_data = self.validation_data
        steps = None

        if isinstance(validation_data, tf.keras.utils.Sequence):
            # Calculate the length of our validation shard.
            steps = len(validation_data)
            if self.context.distributed.get_size() > 1:
                size = self.context.distributed.get_size()
                rank = self.context.distributed.get_rank()
                steps = steps // size + (1 if steps % size > rank else 0)

            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=validation_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=False,
                shuffle=False,
                shuffle_seed=0,
                prior_batches_trained=0,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            validation_data = enqueuer.data()

        if isinstance(validation_data, tf.data.Dataset):
            # Handle validation_steps, which in Keras only applies to tf.data.Datasets.
            steps = self.context._fit_validation_steps

        # Starting in TF 2.2 users may define custom test_step() that do
        # not use the model metrics.
        use_model_metrics = not (
            version.parse(tf.__version__) >= version.parse("2.2.0")
            and is_tf2_enabled() and tf.executing_eagerly())
        evaluate_kwargs = {} if use_model_metrics else {"return_dict": True}

        if self.env.test_mode:
            steps = 1

        metrics_values = self.model.evaluate(
            validation_data,
            callbacks=self.callback_list,
            steps=steps,
            verbose=0,
            workers=0,
            **evaluate_kwargs,
        )
        logging.debug(
            f"Worker finished model.evaluate() with metrics: {metrics_values}."
        )

        # Clean up the enqueuer if we started one.
        if isinstance(self.validation_data, tf.keras.utils.Sequence):
            enqueuer.stop()
            self.enqueuers.remove(enqueuer)

            # A special side-effect of converting the keras sequence to a generator and passing
            # steps explicitly is that keras will exit our generator after N steps and the
            # Sequence.on_epoch_end() that normally runs after the last yield won't run at all
            # because the fit loop will call next() exactly `steps` times.  So we try to match the
            # exact keras behavior by manually calling on_epoch_end() here.
            self.validation_data.on_epoch_end()

        # If the model was compiled with metrics=None, metrics_value will be a single value.
        if not isinstance(metrics_values, (tuple, list, dict)):
            metrics_values = (metrics_values, )

        if use_model_metrics:
            metrics = make_logs(self.model, {},
                                metrics_values,
                                ModeKeys.TEST,
                                prefix="val_")
        else:
            check.is_instance(metrics_values, dict)
            metrics = {f"val_{k}": v for k, v in metrics_values.items()}

        return metrics
コード例 #29
0
ファイル: _tf_keras_trial.py プロジェクト: hkang1/determined
    def _configure_callbacks(self, user_callbacks: Optional[List]) -> None:
        """
        If we pass a callbacks parameter to model.fit() or model.evaluate() which is a
        pre-constructed CallbackList, Keras will not alter it.  We can use this property to
        configure the exact callback order that we want in our system.

        The implementation is based closely on from the real
        tf.keras.callbacks.configure_callbacks(), with the following differences:

          - We always assume we have the original Callbacks list.
          - We prepend and append additional Determined and Horovod callbacks
          - We create a det.keras.CallbackList instead of the normal tf.keras one.
        """

        callbacks = user_callbacks or []
        check.is_instance(
            callbacks,
            list,
            "the callbacks parameter of model.fit() or model.eval() must be a list of Callbacks",
        )

        if self.env.experiment_config.get_records_per_epoch() is None:
            for cb in callbacks:
                if util.is_overridden(
                        cb.on_epoch_end,
                        tf.keras.callbacks.Callback) and not getattr(
                            cb, "_skip_epoch_end_check", False):
                    if isinstance(cb, keras.callbacks.Callback):
                        # New callbacks must obey the rules.
                        raise AssertionError(
                            "it is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__}) without setting the records_per_epoch value "
                            "in the experiment config")
                    else:
                        # Pre-existing callbacks only get a warning.
                        logging.warning(
                            "It is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__})without setting the records_per_epoch value in "
                            "the experiment config. Training will continue but on_epoch_end will "
                            "never be called.")

        # Standard post-callback from the real configure_callbacks().
        # Note that we are not including BaseLogger since it is only for averaging metrics over an
        # entire epoch, and we don't report any metrics in on_epoch_end at all.
        self.model.history = keras.callbacks._DeterminedHistory()
        callbacks = callbacks + [self.model.history]

        if self.context._fit_verbose:
            # Our implementation of verbose=True.
            callbacks = [keras.callbacks._DeterminedProgress()] + callbacks

        # Calculate batches per epoch.  We can only handle batches per epoch, not records per epoch,
        # because we would have to communicate after every batch to know how many records were in
        # each batch on each worker in order to trigger on_epoch_end callbacks correctly.
        batches_per_epoch = None
        records_per_epoch = self.env.experiment_config.get_records_per_epoch()
        if records_per_epoch is not None:
            batches_per_epoch = records_per_epoch // self.context.get_global_batch_size(
            )

        # We wrap all of the callbacks in a single Multiplexer.
        self.multiplexer = TrialControllerMultiplexer(
            self,
            callbacks,
            self.is_chief,
            self.batch_size,
            batches_per_epoch,
            self.multiplexer_load_state,
        )
        callbacks = [self.multiplexer]

        if self.hvd_config.use:
            # Horovod synchronization of initial variables should happen even before we enter our
            # control loop, in case we have an initial validation requested.
            callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)
                         ] + callbacks

        # The remainder of Determined control logic is done with a custom CallbackList
        self.callback_list = CallbackList(callbacks)

        # Disable timing of callbacks in some versions of keras. This can fail in some corner-cases
        # because CallbackList is not designed to allow some callbacks to call other callbacks, and
        # they can interact very poorly.
        if hasattr(self.callback_list, "_timing"):
            self.callback_list._timing["on_train_batch_begin"] = True
            self.callback_list._timing["on_train_batch_end"] = True
            self.callback_list._timing["on_test_batch_begin"] = True
            self.callback_list._timing["on_test_batch_end"] = True
            self.callback_list._timing["on_predict_batch_begin"] = True
            self.callback_list._timing["on_predict_batch_end"] = True

        # callback_model is the model given to callbacks, where we should be checking for
        # stop_training.  In horovod dtrain or non-dtrain, it should always be self.model.
        callback_model = self.model._get_callback_model()
        self.callback_list.set_model(callback_model)

        # Fill in bogus values for most of these... some of them are very complex to calculate.
        set_callback_parameters(
            self.callback_list,
            self.model,
            do_validation=False,
            batch_size=self.batch_size,
            epochs=None,
            steps_per_epoch=None,
            samples=None,
            verbose=False,
            mode=ModeKeys.TRAIN,
        )

        self.callback_list.model.stop_training = False
コード例 #30
0
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            self._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.hvd_config.use and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}