Beispiel #1
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             metrics = det.util.make_metrics(
                 num_inputs=None,
                 batch_metrics=[{
                     "loss": 1
                 } for _ in range(w.num_batches)],
             )
             response_func({"metrics": metrics})
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             check.len_eq(args, 0)
             response_func({
                 "metrics": {
                     "validation_metrics": self.validation_metrics
                 }
             })
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             if not path.exists():
                 path.mkdir(parents=True, exist_ok=True)
             with path.joinpath("a_file").open("w") as f:
                 f.write("yup")
             response_func({})
         elif w.kind == workload.Workload.Kind.TERMINATE:
             raise NotImplementedError()
Beispiel #2
0
    def run(self) -> None:
        """
        A basic control loop of the old-style (callback-based) TrialController
        classes.
        """

        for w, args, response_func in self.workloads:
            try:
                if w.kind == workload.Workload.Kind.RUN_STEP:
                    response = self.train_for_step(
                        w.step_id, w.num_batches)  # type: workload.Response
                elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    response = self.compute_validation_metrics(w.step_id)
                elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    check.len_eq(args, 1)
                    check.is_instance(args[0], pathlib.Path)
                    path = cast(pathlib.Path, args[0])
                    self.save(path)
                    response = {}
                elif w.kind == workload.Workload.Kind.TERMINATE:
                    self.terminate()
                    response = workload.Skipped()
                else:
                    raise AssertionError("Unexpected workload: {}".format(
                        w.kind))

            except det.errors.SkipWorkloadException:
                response = workload.Skipped()

            response_func(response)
def error_rate(predictions: torch.Tensor, labels: torch.Tensor) -> float:
    """Return the error rate based on dense predictions and dense labels."""
    check.equal_lengths(predictions, labels)
    check.len_eq(labels.shape, 1, "Labels must be a column vector")

    return (1.0 - float(
        (predictions.argmax(1) == labels.to(torch.long)).sum()) /
            predictions.shape[0])
def binary_error_rate(predictions: torch.Tensor,
                      labels: torch.Tensor) -> float:
    """Return the classification error rate for binary classification."""
    check.eq(predictions.shape[0], labels.shape[0])
    check.is_in(len(predictions.shape), [1, 2])
    if len(predictions.shape) == 2:
        check.eq(predictions.shape[1], 1)
    check.len_eq(labels.shape, 1, "Labels must be a column vector")

    if len(predictions.shape) > 1:
        predictions = torch.squeeze(predictions)

    errors = torch.sum(
        labels.to(torch.long) != torch.round(predictions).to(torch.long))
    result = float(errors) / predictions.shape[0]  # type: float
    return result
    def control_loop(self) -> None:
        for wkld, args, response_func in self.estimator_trial_controller.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")

            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Store values for the training loop.
                self.num_batches = wkld.num_batches
                self.train_response_func = response_func
                # Break out of the control loop so that the train process
                # re-enters the train_and_evaluate() loop.
                break
            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                try:
                    response_func(
                        det.util.wrap_metrics(
                            self._compute_validation_metrics(),
                            self.estimator_trial_controller.context.
                            get_stop_requested(),
                            invalid_hp=False,
                            init_invalid_hp=False,
                        ))
                except det.InvalidHP as e:
                    logging.info(
                        "Invalid hyperparameter exception in trial validation step: {}"
                        .format(e))
                    response_func(
                        det.util.wrap_metrics(
                            {},
                            self.estimator_trial_controller.context.
                            get_stop_requested(),
                            invalid_hp=True,
                            init_invalid_hp=False,
                        ))
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._checkpoint_model(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                self.estimator_trial_controller.exit_response_func = response_func
                raise det.errors.WorkerFinishedGracefully("Exiting normally.")
            else:
                raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         logging.debug(f"Received wkld {wkld.kind} with args {args}.")
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Configure the state for a training step.
             self.train_response_func = response_func
             self.train_workload_batches = 0
             self.train_workload_metrics = []
             self.train_workload_len = wkld.num_batches
             self.multiplexer.set_batches_requested(wkld.num_batches)
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             try:
                 response_func(
                     det.util.wrap_metrics(
                         self._compute_validation_metrics(),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     ))
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial validation step: {}"
                     .format(e))
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     ))
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             self.multiplexer._corrected_train_end()
             raise det.errors.WorkerFinishedGracefully
         else:
             raise AssertionError(f"Unknown workload kind {wkld.kind}.")