def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: response_func( util.wrap_metrics( self._train_for_step(w.step_id, w.num_batches, w.total_batches_processed), self.context.get_stop_requested(), ) ) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( util.wrap_metrics( self._compute_validation_metrics(), self.context.get_stop_requested() ) ) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.eq(len(args), 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save(path)) elif w.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError("Unexpected workload: {}".format(w.kind))
def _run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: try: response_func( util.wrap_metrics( self._train_for_step( w.step_id, w.num_batches, w.total_batches_processed, ), self.context.get_stop_requested(), invalid_hp=False, init_invalid_hp=False, ) ) except det.InvalidHP as e: logging.info( "Invalid hyperparameter exception in trial train step: {}".format(e) ) response_func( util.wrap_metrics( {}, self.context.get_stop_requested(), invalid_hp=True, init_invalid_hp=False, ) ) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: try: response_func( util.wrap_metrics( self._compute_validation_metrics(), self.context.get_stop_requested(), invalid_hp=False, init_invalid_hp=False, ) ) except det.InvalidHP as e: logging.info( "Invalid hyperparameter exception in trial validation step: {}".format(e) ) response_func( util.wrap_metrics( {}, self.context.get_stop_requested(), invalid_hp=True, init_invalid_hp=False, ) ) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.eq(len(args), 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save(path)) elif w.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError("Unexpected workload: {}".format(w.kind))
def _catch_init_invalid_hp(workloads: Iterator[Any]) -> Any: try: yield except InvalidHP as e: logging.info("Invalid hyperparameter exception in trial __init__: {}".format(e)) wkld, args, response_func = next(workloads) response_func(util.wrap_metrics({}, stop_requested=False, invalid_hp=True)) raise
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Configure the state for a training step. self.train_response_func = response_func self.train_workload_batches = 0 self.train_workload_metrics = [] self.train_workload_len = wkld.num_batches self.multiplexer.set_batches_requested(wkld.num_batches) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: try: response_func( det.util.wrap_metrics( self._compute_validation_metrics(), self.context.get_stop_requested(), invalid_hp=False, init_invalid_hp=False, )) except det.InvalidHP as e: logging.info( "Invalid hyperparameter exception in trial validation step: {}" .format(e)) response_func( util.wrap_metrics( {}, self.context.get_stop_requested(), invalid_hp=True, init_invalid_hp=False, )) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) self.multiplexer._corrected_train_end() raise det.errors.WorkerFinishedGracefully else: raise AssertionError(f"Unknown workload kind {wkld.kind}.")