def run(self) -> None: """ A basic control loop of the old-style (callback-based) TrialController classes. """ for w, args, response_func in self.workloads: try: if w.kind == workload.Workload.Kind.RUN_STEP: response = self.train_for_step( w.step_id, w.num_batches) # type: workload.Response elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response = self.compute_validation_metrics(w.step_id) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) self.save(path) response = {} elif w.kind == workload.Workload.Kind.TERMINATE: self.terminate() response = workload.Skipped() else: raise AssertionError("Unexpected workload: {}".format( w.kind)) except det.errors.SkipWorkloadException: response = workload.Skipped() response_func(response)
def control_loop(self) -> None: for wkld, args, response_func in self.estimator_trial_controller.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store values for the training loop. self.num_batches = wkld.num_batches self.train_response_func = response_func # Break out of the control loop so that the train process # re-enters the train_and_evaluate() loop. break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics( self._compute_validation_metrics(), self.estimator_trial_controller.context. get_stop_requested(), )) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._checkpoint_model(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: self.estimator_trial_controller.exit_response_func = response_func raise det.errors.WorkerFinishedGracefully("Exiting normally.") else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.") exit(1)
def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: metrics = det.util.make_metrics( num_inputs=None, batch_metrics=[{ "loss": 1 } for _ in range(w.num_batches)], ) response_func({"metrics": metrics}) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: check.len_eq(args, 0) response_func({ "metrics": { "validation_metrics": self.validation_metrics } }) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) if not path.exists(): path.mkdir(parents=True, exist_ok=True) with path.joinpath("a_file").open("w") as f: f.write("yup") response_func({}) elif w.kind == workload.Workload.Kind.TERMINATE: raise NotImplementedError()
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Configure the state for a training step. self.train_response_func = response_func self.train_workload_batches = 0 self.train_workload_metrics = [] self.train_workload_len = wkld.num_batches self.multiplexer.set_batches_requested(wkld.num_batches) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics(self._compute_validation_metrics(), self.context.get_stop_requested())) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) self.multiplexer._corrected_train_end() raise det.errors.WorkerFinishedGracefully else: raise AssertionError(f"Unknown workload kind {wkld.kind}.")
def error_rate(predictions: torch.Tensor, labels: torch.Tensor) -> float: """Return the error rate based on dense predictions and dense labels.""" check.equal_lengths(predictions, labels) check.len_eq(labels.shape, 1, "Labels must be a column vector") return ( # type: ignore 1.0 - float((predictions.argmax(1) == labels.to(torch.long)).sum()) / predictions.shape[0])
def get_model(self) -> torch.nn.Module: """ Get the model associated with the trial. This function should not be called from: * ``__init__`` * ``build_model()`` """ # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. check.len_eq(self.models, 1) return self.models[0]
def get_optimizer(self) -> torch.optim.Optimizer: # type: ignore """ Get the optimizer associated with the trial. This function should not be called from: * ``__init__`` * ``build_model()`` * ``optimizer()`` """ # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. check.len_eq(self.optimizers, 1) return self.optimizers[0]
def binary_error_rate(predictions: torch.Tensor, labels: torch.Tensor) -> float: """Return the classification error rate for binary classification.""" check.eq(predictions.shape[0], labels.shape[0]) check.is_in(len(predictions.shape), [1, 2]) if len(predictions.shape) == 2: check.eq(predictions.shape[1], 1) check.len_eq(labels.shape, 1, "Labels must be a column vector") if len(predictions.shape) > 1: predictions = torch.squeeze(predictions) errors = torch.sum( labels.to(torch.long) != torch.round(predictions).to(torch.long)) result = float(errors) / predictions.shape[0] # type: float return result
def run(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store the train_response_func for later. self.train_response_func = response_func self.num_batches = wkld.num_batches # There are two possibilities when a RUN_STEP workload is recieved. # 1) This is the first training step seen by the trial # container. In this case, enter the tf.keras fit() training loop. # 2) This is _not_ the first training step, meaning that the # tf.keras fit() training loop is already active and paused. # break to re-enter the training loop. if not self.fit_loop_started: try: self._launch_fit() except det.errors.WorkerFinishedGracefully: pass if not self.expect_terminate: raise AssertionError( "Training loop exited unexpectedly but without throwing any errors. " "This is possibly due to a user callback causing the training loop to " "exit, which is not supported at this time." ) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics( self.compute_validation_metrics(), self.context.get_stop_requested() ) ) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: self.model.stop_training = True self.expect_terminate = True response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
def get_model(self) -> torch.nn.Module: """ Get the model associated with the trial. This function should not be called from: * ``__init__`` * ``build_model()`` .. warning:: This is deprecated. """ # TODO(DET-3262): remove this backward compatibility of old interface. logging.warning( "PyTorchTrialContext.get_model is deprecated. " "Please directly use the model wrapped by context.wrap_model().") check.len_eq(self.models, 1) return self.models[0]
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: if wkld.kind == workload.Workload.Kind.RUN_STEP: # Move on to the next step. self.train_response_func = response_func break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func(self._compute_validation_metrics()) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self.save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: raise det.errors.WorkerFinishedGracefully("Exiting normally.") else: raise AssertionError(f"Unknown wkld kind {wkld.kind}")