示例#1
0
 def update(self):
     if not tune.is_session_enabled():
         self.best_state = {
             "model": copy.deepcopy(self.network.state_dict()),
             "optimizer": copy.deepcopy(self.optimizer.state_dict()),
             "engine": copy.copy(self.trainer.state),
         }
示例#2
0
    def __init__(self,
                 backend_config: BackendConfig,
                 num_workers: int = 1,
                 num_cpus_per_worker: float = 1,
                 num_gpus_per_worker: float = 0,
                 additional_resources_per_worker: Optional[Dict[str,
                                                                float]] = None,
                 max_retries: int = 3):
        self._backend_config = backend_config
        self._backend = self._backend_config.backend_cls()
        self._num_workers = num_workers
        self._num_cpus_per_worker = num_cpus_per_worker
        self._num_gpus_per_worker = num_gpus_per_worker
        self._additional_resources_per_worker = additional_resources_per_worker
        self._max_failures = max_retries
        if self._max_failures < 0:
            self._max_failures = float("inf")
        self._num_failures = 0
        self._initialization_hook = None

        if tune is not None and tune.is_session_enabled():
            self.checkpoint_manager = TuneCheckpointManager()
        else:
            self.checkpoint_manager = CheckpointManager()

        self.worker_group = InactiveWorkerGroup()

        self.checkpoint_manager.on_init()
示例#3
0
 def on_epoch_completed(self, engine, train_loader, tune_loader):
     train_metrics = self.trainer.state.metrics
     print("Metrics Epoch", engine.state.epoch)
     justify = max(len(k) for k in train_metrics) + 2
     for k, v in train_metrics.items():
         if type(v) == float:
             print("train {:<{justify}} {:<5f}".format(k,
                                                       v,
                                                       justify=justify))
             continue
         print("train {:<{justify}} {:<5}".format(k, v, justify=justify))
     self.evaluator.run(tune_loader)
     tune_metrics = self.evaluator.state.metrics
     if tune.is_session_enabled():
         tune.report(mean_loss=tune_metrics["loss"])
     justify = max(len(k) for k in tune_metrics) + 2
     for k, v in tune_metrics.items():
         if type(v) == float:
             print("tune {:<{justify}} {:<5f}".format(k, v,
                                                      justify=justify))
             continue
     if tune_metrics["loss"] < self.best_loss:
         self.best_loss = tune_metrics["loss"]
         self.counter = 0
         self.update()
     else:
         self.counter += 1
     if self.counter == self.patience:
         self.logger.info(
             "Early Stopping: No improvement for {} epochs".format(
                 self.patience))
         engine.terminate()
示例#4
0
def _try_add_tune_callback(kwargs: Dict):
    if TUNE_INSTALLED and tune.is_session_enabled():
        callbacks = kwargs.get("callbacks", [])
        for callback in callbacks:
            if isinstance(callback, RayTuneReportCallback):
                return
        callbacks.append(RayTuneReportCallback())
        kwargs["callbacks"] = callbacks
示例#5
0
def _try_add_tune_callback(kwargs: Dict):
    if TUNE_INSTALLED and is_session_enabled():
        callbacks = kwargs.get("callbacks", []) or []
        new_callbacks = []
        has_tune_callback = False

        REPLACE_MSG = "Replaced `{orig}` with `{target}`. If you want to " \
                      "avoid this warning, pass `{target}` as a callback " \
                      "directly in your calls to `xgboost_ray.train()`."

        for cb in callbacks:
            if isinstance(cb,
                          (TuneReportCallback, TuneReportCheckpointCallback)):
                has_tune_callback = True
                new_callbacks.append(cb)
            elif isinstance(cb, OrigTuneReportCallback):
                replace_cb = TuneReportCallback(metrics=cb._metrics)
                new_callbacks.append(replace_cb)
                logging.warning(
                    REPLACE_MSG.format(
                        orig="ray.tune.integration.xgboost.TuneReportCallback",
                        target="xgboost_ray.tune.TuneReportCallback"))
                has_tune_callback = True
            elif isinstance(cb, OrigTuneReportCheckpointCallback):
                if TUNE_LEGACY:
                    replace_cb = TuneReportCheckpointCallback(
                        metrics=cb._report._metrics,
                        filename=cb._checkpoint._filename)
                else:
                    replace_cb = TuneReportCheckpointCallback(
                        metrics=cb._report._metrics,
                        filename=cb._checkpoint._filename,
                        frequency=cb._checkpoint._frequency)
                new_callbacks.append(replace_cb)
                logging.warning(
                    REPLACE_MSG.format(
                        orig="ray.tune.integration.xgboost."
                        "TuneReportCheckpointCallback",
                        target="xgboost_ray.tune.TuneReportCheckpointCallback")
                )
                has_tune_callback = True
            else:
                new_callbacks.append(cb)

        if not has_tune_callback:
            # Todo: Maybe add checkpointing callback
            new_callbacks.append(TuneReportCallback())

        kwargs["callbacks"] = new_callbacks
        return True
    else:
        return False
示例#6
0
 def on_training_completed(self, engine, loader):
     if not tune.is_session_enabled():
         self.save()
         self.load()
         self.evaluator.run(loader)
         metric_values = self.evaluator.state.metrics
         print("Metrics Epoch", engine.state.epoch)
         justify = max(len(k) for k in metric_values) + 2
         for k, v in metric_values.items():
             if type(v) == float:
                 print("best {:<{justify}} {:<5f}".format(k,
                                                          v,
                                                          justify=justify))
                 continue
示例#7
0
 def load(self):
     if tune.is_session_enabled():
         with tune.checkpoint_dir(
                 step=self.trainer.state.epoch) as checkpoint_dir:
             p = os.path.join(checkpoint_dir, "checkpoint.pt")
     else:
         file_name = "best_checkpoint.pt"
         p = os.path.join(self.job_dir, file_name)
     if not os.path.exists(p):
         self.logger.info(
             "Checkpoint {} does not exist, starting a new engine".format(
                 p))
         return
     self.logger.info("Loading saved checkpoint {}".format(p))
     checkpoint = torch.load(p)
     self.network.load_state_dict(checkpoint["model"])
     self.optimizer.load_state_dict(checkpoint["optimizer"])
     self.trainer.state = checkpoint["engine"]
示例#8
0
 def train(config, checkpoint_dir=None):
     is_active = tune.is_session_enabled()
     result = {"active": is_active}
     if is_active:
         tune.report(**result)
     return result
示例#9
0
 def train(config, checkpoint_dir=None):
     is_active = tune.is_session_enabled()
     if is_active:
         tune.report(active=is_active)
     return is_active
示例#10
0
 def _is_tune_enabled(self):
     """Whether or not this Trainer is part of a Tune session."""
     return TUNE_INSTALLED and tune.is_session_enabled()
示例#11
0
 def _is_tune_enabled(self):
     """Whether or not this Trainer is part of a Tune session."""
     return tune is not None and tune.is_session_enabled()