def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: metrics = det.util.make_metrics( num_inputs=None, batch_metrics=[{ "loss": 1 } for _ in range(w.num_batches)], ) response_func({"metrics": metrics}) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: check.len_eq(args, 0) response_func({ "metrics": { "validation_metrics": self.validation_metrics } }) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) if not path.exists(): path.mkdir(parents=True, exist_ok=True) with path.joinpath("a_file").open("w") as f: f.write("yup") response_func({}) elif w.kind == workload.Workload.Kind.TERMINATE: raise NotImplementedError()
def run(self) -> None: """ A basic control loop of the old-style (callback-based) TrialController classes. """ for w, args, response_func in self.workloads: try: if w.kind == workload.Workload.Kind.RUN_STEP: response = self.train_for_step( w.step_id, w.num_batches) # type: workload.Response elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response = self.compute_validation_metrics(w.step_id) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) self.save(path) response = {} elif w.kind == workload.Workload.Kind.TERMINATE: self.terminate() response = workload.Skipped() else: raise AssertionError("Unexpected workload: {}".format( w.kind)) except det.errors.SkipWorkloadException: response = workload.Skipped() response_func(response)
def error_rate(predictions: torch.Tensor, labels: torch.Tensor) -> float: """Return the error rate based on dense predictions and dense labels.""" check.equal_lengths(predictions, labels) check.len_eq(labels.shape, 1, "Labels must be a column vector") return (1.0 - float( (predictions.argmax(1) == labels.to(torch.long)).sum()) / predictions.shape[0])
def binary_error_rate(predictions: torch.Tensor, labels: torch.Tensor) -> float: """Return the classification error rate for binary classification.""" check.eq(predictions.shape[0], labels.shape[0]) check.is_in(len(predictions.shape), [1, 2]) if len(predictions.shape) == 2: check.eq(predictions.shape[1], 1) check.len_eq(labels.shape, 1, "Labels must be a column vector") if len(predictions.shape) > 1: predictions = torch.squeeze(predictions) errors = torch.sum( labels.to(torch.long) != torch.round(predictions).to(torch.long)) result = float(errors) / predictions.shape[0] # type: float return result
def control_loop(self) -> None: for wkld, args, response_func in self.estimator_trial_controller.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store values for the training loop. self.num_batches = wkld.num_batches self.train_response_func = response_func # Break out of the control loop so that the train process # re-enters the train_and_evaluate() loop. break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: try: response_func( det.util.wrap_metrics( self._compute_validation_metrics(), self.estimator_trial_controller.context. get_stop_requested(), invalid_hp=False, init_invalid_hp=False, )) except det.InvalidHP as e: logging.info( "Invalid hyperparameter exception in trial validation step: {}" .format(e)) response_func( det.util.wrap_metrics( {}, self.estimator_trial_controller.context. get_stop_requested(), invalid_hp=True, init_invalid_hp=False, )) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._checkpoint_model(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: self.estimator_trial_controller.exit_response_func = response_func raise det.errors.WorkerFinishedGracefully("Exiting normally.") else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Configure the state for a training step. self.train_response_func = response_func self.train_workload_batches = 0 self.train_workload_metrics = [] self.train_workload_len = wkld.num_batches self.multiplexer.set_batches_requested(wkld.num_batches) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: try: response_func( det.util.wrap_metrics( self._compute_validation_metrics(), self.context.get_stop_requested(), invalid_hp=False, init_invalid_hp=False, )) except det.InvalidHP as e: logging.info( "Invalid hyperparameter exception in trial validation step: {}" .format(e)) response_func( util.wrap_metrics( {}, self.context.get_stop_requested(), invalid_hp=True, init_invalid_hp=False, )) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) self.multiplexer._corrected_train_end() raise det.errors.WorkerFinishedGracefully else: raise AssertionError(f"Unknown workload kind {wkld.kind}.")