コード例 #1
0
 def run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             response_func(
                 util.wrap_metrics(
                     self._train_for_step(w.step_id, w.num_batches, w.total_batches_processed),
                     self.context.get_stop_requested(),
                 )
             )
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             response_func(
                 util.wrap_metrics(
                     self._compute_validation_metrics(), self.context.get_stop_requested()
                 )
             )
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.eq(len(args), 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save(path))
         elif w.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             break
         else:
             raise AssertionError("Unexpected workload: {}".format(w.kind))
コード例 #2
0
ファイル: _pytorch_trial.py プロジェクト: shiyuann/determined
 def _run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             try:
                 response_func(
                     util.wrap_metrics(
                         self._train_for_step(
                             w.step_id,
                             w.num_batches,
                             w.total_batches_processed,
                         ),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     )
                 )
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial train step: {}".format(e)
                 )
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     )
                 )
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             try:
                 response_func(
                     util.wrap_metrics(
                         self._compute_validation_metrics(),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     )
                 )
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial validation step: {}".format(e)
                 )
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     )
                 )
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.eq(len(args), 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save(path))
         elif w.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             break
         else:
             raise AssertionError("Unexpected workload: {}".format(w.kind))
コード例 #3
0
ファイル: _execution.py プロジェクト: wbwatkinson/determined
def _catch_init_invalid_hp(workloads: Iterator[Any]) -> Any:
    try:
        yield
    except InvalidHP as e:
        logging.info("Invalid hyperparameter exception in trial __init__: {}".format(e))
        wkld, args, response_func = next(workloads)
        response_func(util.wrap_metrics({}, stop_requested=False, invalid_hp=True))
        raise
コード例 #4
0
 def _control_loop(self) -> None:
     for wkld, args, response_func in self.workloads:
         logging.debug(f"Received wkld {wkld.kind} with args {args}.")
         if wkld.kind == workload.Workload.Kind.RUN_STEP:
             # Configure the state for a training step.
             self.train_response_func = response_func
             self.train_workload_batches = 0
             self.train_workload_metrics = []
             self.train_workload_len = wkld.num_batches
             self.multiplexer.set_batches_requested(wkld.num_batches)
             break
         elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             try:
                 response_func(
                     det.util.wrap_metrics(
                         self._compute_validation_metrics(),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     ))
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial validation step: {}"
                     .format(e))
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     ))
         elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.len_eq(args, 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save_checkpoint(path))
         elif wkld.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             self.multiplexer._corrected_train_end()
             raise det.errors.WorkerFinishedGracefully
         else:
             raise AssertionError(f"Unknown workload kind {wkld.kind}.")