示例#1
0
    def on_exception(self, state: _State):
        exception = state.exception
        if not utils.is_exception(exception):
            return

        try:
            valid_metrics = state.metric_manager.valid_values
            epoch_metrics = state.metric_manager.epoch_values
            checkpoint = utils.pack_checkpoint(
                model=state.model,
                criterion=state.criterion,
                optimizer=state.optimizer,
                scheduler=state.scheduler,
                epoch_metrics=epoch_metrics,
                valid_metrics=valid_metrics,
                stage=state.stage,
                epoch=state.epoch_log,
                checkpoint_data=state.checkpoint_data)
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(logdir=Path(f"{state.logdir}/checkpoints/"),
                                  checkpoint=checkpoint,
                                  suffix=suffix,
                                  is_best=False,
                                  is_last=False)
            metrics = self.metrics
            metrics[suffix] = valid_metrics
            self.save_metric(state.logdir, metrics)
        except Exception:
            pass
示例#2
0
    def process_checkpoint(self,
                           logdir: Union[str, Path],
                           checkpoint: Dict,
                           is_best: bool,
                           main_metric: str = "loss",
                           minimize_metric: bool = True):
        suffix = self.get_checkpoint_suffix(checkpoint)
        utils.save_checkpoint(logdir=Path(f"{logdir}/checkpoints/"),
                              checkpoint=checkpoint,
                              suffix=f"{suffix}_full",
                              is_best=is_best,
                              is_last=True,
                              special_suffix="_full")

        exclude = ["criterion", "optimizer", "scheduler"]
        checkpoint = {
            key: value
            for key, value in checkpoint.items()
            if all(z not in key for z in exclude)
        }
        filepath = utils.save_checkpoint(checkpoint=checkpoint,
                                         logdir=Path(f"{logdir}/checkpoints/"),
                                         suffix=suffix,
                                         is_best=is_best,
                                         is_last=True)

        valid_metrics = checkpoint["valid_metrics"]
        checkpoint_metric = valid_metrics[main_metric]
        metrics_record = (filepath, checkpoint_metric, valid_metrics)
        self.top_best_metrics.append(metrics_record)
        self.metrics_history.append(metrics_record)
        self.truncate_checkpoints(minimize_metric=minimize_metric)
        metrics = self.process_metrics(valid_metrics)
        self.save_metric(logdir, metrics)
示例#3
0
    def save_checkpoint(self,
                        logdir: str,
                        checkpoint: Dict,
                        save_n_best: int = 3,
                        minimize_metric: bool = False):
        agent_rewards = checkpoint["rewards"]
        agent_metric = self.rewards2metric(agent_rewards)

        is_best = len(self.best_agents) == 0 or \
            agent_metric > self.rewards2metric(self.best_agents[0][1])
        suffix = f"{checkpoint['epoch']}"
        filepath = utils.save_checkpoint(logdir=f"{logdir}/checkpoints/",
                                         checkpoint=checkpoint,
                                         suffix=suffix,
                                         is_best=is_best,
                                         is_last=True)

        self.best_agents.append((filepath, agent_rewards))
        self.best_agents = sorted(self.best_agents,
                                  key=lambda x: x[1],
                                  reverse=not minimize_metric)
        if len(self.best_agents) > save_n_best:
            last_item = self.best_agents.pop(-1)
            last_filepath = last_item[0]
            os.remove(last_filepath)
示例#4
0
    def process_checkpoint(self,
                           logdir: str,
                           checkpoint: Dict,
                           is_best: bool,
                           main_metric: str = "loss",
                           minimize_metric: bool = True):
        suffix = self.get_checkpoint_suffix(checkpoint)

        exclude = ["criterion", "optimizer", "scheduler"]
        checkpoint = {
            key: value
            for key, value in checkpoint.items()
            if all(z not in key for z in exclude)
        }
        filepath = utils.save_checkpoint(
            checkpoint=checkpoint,
            logdir=Path(logdir) / Path(self.checkpoints_dir),
            suffix=suffix,
            is_best=is_best,
            is_last=True,
        )

        valid_metrics = checkpoint["valid_metrics"]
        checkpoint_metric = valid_metrics[main_metric]
        metrics_record = (filepath, checkpoint_metric, valid_metrics)
        self.top_best_metrics.append(metrics_record)
        self.epochs_metrics.append(metrics_record)
        self.truncate_checkpoints(minimize_metric=minimize_metric)
        metrics = self.get_metric(valid_metrics)
        self.save_metric(logdir, metrics)
示例#5
0
 def _save_checkpoint(self):
     if self.epoch % self.save_period == 0:
         checkpoint = self.algorithm.pack_checkpoint()
         checkpoint["epoch"] = self.epoch
         filename = utils.save_checkpoint(logdir=self.logdir,
                                          checkpoint=checkpoint,
                                          suffix=str(self.epoch))
         print(f"Checkpoint saved to: {filename}")
示例#6
0
 def _save(self, runner: "IRunner", obj: Any, logprefix: str) -> str:
     logpath = f"{logprefix}.pth"
     if self.mode == "model":
         if issubclass(obj.__class__, torch.nn.Module):
             runner.engine.wait_for_everyone()
             obj = runner.engine.unwrap_model(obj)
             runner.engine.save(obj.state_dict(), logpath)
         elif isinstance(obj, dict):
             # obj = dict(model=obj)  # noqa: C408
             checkpoint = pack_checkpoint(model=obj)
             save_checkpoint(checkpoint, logpath)
         else:
             raise NotImplementedError()
     else:
         checkpoint = pack_checkpoint(**obj)
         save_checkpoint(checkpoint, logpath)
     return logpath
示例#7
0
    def on_exception(self, state: State):
        exception = state.exception
        if not utils.is_exception(exception):
            return

        try:
            checkpoint = _pack_state(state)
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(logdir=Path(f"{state.logdir}/checkpoints/"),
                                  checkpoint=checkpoint,
                                  suffix=suffix,
                                  is_best=False,
                                  is_last=False)
            metrics = self.metrics
            metrics[suffix] = state.valid_metrics
            self.save_metric(state.logdir, metrics)
        except Exception:
            pass
示例#8
0
    def process_checkpoint(self, logdir: str, checkpoint: Dict,
                           batch_values: Dict[str, float]):
        filepath = utils.save_checkpoint(
            logdir=Path(f"{logdir}/checkpoints/"),
            checkpoint=checkpoint,
            suffix=self.get_checkpoint_suffix(checkpoint),
            is_best=False,
            is_last=False)

        self.last_checkpoints.append((filepath, batch_values))
        self.truncate_checkpoints()

        self.epochs_metrics.append(batch_values)

        metrics = self.get_metric()
        self.save_metric(logdir, metrics)
        print(f"\nSaved checkpoint at {filepath}")
示例#9
0
    def process_checkpoint(
        self,
        logdir: Union[str, Path],
        checkpoint: Dict,
        batch_metrics: Dict[str, float],
    ):
        filepath = utils.save_checkpoint(
            logdir=Path(f"{logdir}/checkpoints/"),
            checkpoint=checkpoint,
            suffix=self.get_checkpoint_suffix(checkpoint),
            is_best=False,
            is_last=False,
        )

        self.last_checkpoints.append((filepath, batch_metrics))
        self.truncate_checkpoints()

        self.metrics_history.append(batch_metrics)

        metrics = self.process_metrics()
        self.save_metric(logdir, metrics)
        print(f"\nSaved checkpoint at {filepath}")