Пример #1
0
    def run(self) -> None:
        """
        A basic control loop of the old-style (callback-based) TrialController
        classes.
        """

        for w, args, response_func in self.workloads:
            try:
                if w.kind == workload.Workload.Kind.RUN_STEP:
                    response = self.train_for_step(
                        w.step_id, w.num_batches)  # type: workload.Response
                elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    response = self.compute_validation_metrics(w.step_id)
                elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    check.len_eq(args, 1)
                    check.is_instance(args[0], pathlib.Path)
                    path = cast(pathlib.Path, args[0])
                    self.save(path)
                    response = {}
                elif w.kind == workload.Workload.Kind.TERMINATE:
                    self.terminate()
                    response = workload.Skipped()
                else:
                    raise AssertionError("Unexpected workload: {}".format(
                        w.kind))

            except det.errors.SkipWorkloadException:
                response = workload.Skipped()

            response_func(response)
Пример #2
0
    def control_loop(self) -> None:
        for wkld, args, response_func in self.estimator_trial_controller.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")

            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Store values for the training loop.
                self.num_batches = wkld.num_batches
                self.train_response_func = response_func
                # Break out of the control loop so that the train process
                # re-enters the train_and_evaluate() loop.
                break
            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                response_func(
                    det.util.wrap_metrics(
                        self._compute_validation_metrics(),
                        self.estimator_trial_controller.context.
                        get_stop_requested(),
                    ))
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._checkpoint_model(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                self.estimator_trial_controller.exit_response_func = response_func
                raise det.errors.WorkerFinishedGracefully("Exiting normally.")
            else:
                raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
                exit(1)
Пример #3
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             metrics = det.util.make_metrics(
                 num_inputs=None,
                 batch_metrics=[{
                     "loss": 1
                 } for _ in range(w.num_batches)],
             )
             response_func({"metrics": metrics})
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             check.len_eq(args, 0)
             response_func({
                 "metrics": {
                     "validation_metrics": self.validation_metrics
                 }
             })
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             if not path.exists():
                 path.mkdir(parents=True, exist_ok=True)
             with path.joinpath("a_file").open("w") as f:
                 f.write("yup")
             response_func({})
         elif w.kind == workload.Workload.Kind.TERMINATE:
             raise NotImplementedError()
Пример #4
0
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         logging.debug(f"Received wkld {wkld.kind} with args {args}.")
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Configure the state for a training step.
             self.train_response_func = response_func
             self.train_workload_batches = 0
             self.train_workload_metrics = []
             self.train_workload_len = wkld.num_batches
             self.multiplexer.set_batches_requested(wkld.num_batches)
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(
                 det.util.wrap_metrics(self._compute_validation_metrics(),
                                       self.context.get_stop_requested()))
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             self.multiplexer._corrected_train_end()
             raise det.errors.WorkerFinishedGracefully
         else:
             raise AssertionError(f"Unknown workload kind {wkld.kind}.")
Пример #5
0
def error_rate(predictions: torch.Tensor, labels: torch.Tensor) -> float:
    """Return the error rate based on dense predictions and dense labels."""
    check.equal_lengths(predictions, labels)
    check.len_eq(labels.shape, 1, "Labels must be a column vector")

    return (  # type: ignore
        1.0 - float((predictions.argmax(1) == labels.to(torch.long)).sum()) /
        predictions.shape[0])
Пример #6
0
    def get_model(self) -> torch.nn.Module:
        """
        Get the model associated with the trial. This function should not be
        called from:

            * ``__init__``
            * ``build_model()``
        """
        # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives.
        check.len_eq(self.models, 1)
        return self.models[0]
Пример #7
0
    def get_optimizer(self) -> torch.optim.Optimizer:  # type: ignore
        """
        Get the optimizer associated with the trial. This function should not be
        called from:

            * ``__init__``
            * ``build_model()``
            * ``optimizer()``
        """
        # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives.
        check.len_eq(self.optimizers, 1)
        return self.optimizers[0]
Пример #8
0
def binary_error_rate(predictions: torch.Tensor,
                      labels: torch.Tensor) -> float:
    """Return the classification error rate for binary classification."""
    check.eq(predictions.shape[0], labels.shape[0])
    check.is_in(len(predictions.shape), [1, 2])
    if len(predictions.shape) == 2:
        check.eq(predictions.shape[1], 1)
    check.len_eq(labels.shape, 1, "Labels must be a column vector")

    if len(predictions.shape) > 1:
        predictions = torch.squeeze(predictions)

    errors = torch.sum(
        labels.to(torch.long) != torch.round(predictions).to(torch.long))
    result = float(errors) / predictions.shape[0]  # type: float
    return result
Пример #9
0
    def run(self) -> None:
        for wkld, args, response_func in self.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")

            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Store the train_response_func for later.
                self.train_response_func = response_func
                self.num_batches = wkld.num_batches

                # There are two possibilities when a RUN_STEP workload is recieved.
                # 1) This is the first training step seen by the trial
                #    container. In this case, enter the tf.keras fit() training loop.
                # 2) This is _not_ the first training step, meaning that the
                #    tf.keras fit() training loop is already active and paused.
                #    break to re-enter the training loop.
                if not self.fit_loop_started:
                    try:
                        self._launch_fit()
                    except det.errors.WorkerFinishedGracefully:
                        pass

                    if not self.expect_terminate:
                        raise AssertionError(
                            "Training loop exited unexpectedly but without throwing any errors. "
                            "This is possibly due to a user callback causing the training loop to "
                            "exit, which is not supported at this time."
                        )
                break

            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                response_func(
                    det.util.wrap_metrics(
                        self.compute_validation_metrics(), self.context.get_stop_requested()
                    )
                )
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._save_checkpoint(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                self.model.stop_training = True
                self.expect_terminate = True
                response_func({} if self.is_chief else workload.Skipped())
                break
            else:
                raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
Пример #10
0
    def get_model(self) -> torch.nn.Module:
        """
        Get the model associated with the trial. This function should not be
        called from:

            * ``__init__``
            * ``build_model()``

        .. warning::
            This is deprecated.
        """
        # TODO(DET-3262): remove this backward compatibility of old interface.
        logging.warning(
            "PyTorchTrialContext.get_model is deprecated. "
            "Please directly use the model wrapped by context.wrap_model().")
        check.len_eq(self.models, 1)
        return self.models[0]
Пример #11
0
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Move on to the next step.
             self.train_response_func = response_func
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(self._compute_validation_metrics())
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self.save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             raise det.errors.WorkerFinishedGracefully("Exiting normally.")
         else:
             raise AssertionError(f"Unknown wkld kind {wkld.kind}")