예제 #1
0
    def process_checkpoint(self,
                           logdir: str,
                           checkpoint: Dict,
                           is_best: bool,
                           main_metric: str = "loss",
                           minimize_metric: bool = True):
        suffix = self.get_checkpoint_suffix(checkpoint)
        utils.save_checkpoint(logdir=f"{logdir}/checkpoints/",
                              checkpoint=checkpoint,
                              suffix=f"{suffix}_full",
                              is_best=is_best,
                              is_last=True,
                              special_suffix="_full")

        exclude = ["criterion", "optimizer", "scheduler"]
        checkpoint = {
            key: value
            for key, value in checkpoint.items()
            if all(z not in key for z in exclude)
        }
        filepath = utils.save_checkpoint(checkpoint=checkpoint,
                                         logdir=f"{logdir}/checkpoints/",
                                         suffix=suffix,
                                         is_best=is_best,
                                         is_last=True)

        valid_metrics = checkpoint["valid_metrics"]
        checkpoint_metric = valid_metrics[main_metric]
        self.top_best_metrics.append(
            (filepath, checkpoint_metric, valid_metrics))
        self.truncate_checkpoints(minimize_metric=minimize_metric)

        metrics = self.get_metric(valid_metrics)
        self.save_metric(logdir, metrics)
예제 #2
0
    def on_exception(self, state: RunnerState):
        exception = state.exception
        if not is_exception(exception):
            return

        try:
            valid_metrics = state.metrics.valid_values
            epoch_metrics = state.metrics.epoch_values
            checkpoint = utils.pack_checkpoint(
                model=state.model,
                criterion=state.criterion,
                optimizer=state.optimizer,
                scheduler=state.scheduler,
                epoch_metrics=epoch_metrics,
                valid_metrics=valid_metrics,
                stage=state.stage,
                epoch=state.epoch_log,
                checkpoint_data=state.checkpoint_data
            )
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(
                logdir=f"{state.logdir}/checkpoints/",
                checkpoint=checkpoint,
                suffix=suffix,
                is_best=False,
                is_last=False
            )
            metrics = self.metrics
            metrics[suffix] = valid_metrics
            self.save_metric(state.logdir, metrics)
        except Exception:
            pass
예제 #3
0
    def save_checkpoint(
        self,
        logdir: str,
        checkpoint: Dict,
        is_best: bool,
        save_n_best: int = 5,
        main_metric: str = "loss",
        minimize_metric: bool = True
    ):
        suffix = f"{checkpoint['stage']}.{checkpoint['epoch']}"
        filepath = utils.save_checkpoint(
            logdir=f"{logdir}/checkpoints/",
            checkpoint=checkpoint,
            suffix=suffix,
            is_best=is_best,
            is_last=True
        )

        checkpoint_metric = checkpoint["valid_metrics"][main_metric]
        self.top_best_metrics.append((filepath, checkpoint_metric))
        self.top_best_metrics = sorted(
            self.top_best_metrics,
            key=lambda x: x[1],
            reverse=not minimize_metric
        )
        if len(self.top_best_metrics) > save_n_best:
            last_item = self.top_best_metrics.pop(-1)
            last_filepath = last_item[0]
            os.remove(last_filepath)
예제 #4
0
    def save_checkpoint(self,
                        logdir: str,
                        checkpoint: Dict,
                        is_best: bool,
                        save_n_best: int = 5,
                        main_metric: str = "loss",
                        minimize_metric: bool = True):
        suffix = f"{checkpoint['stage']}.{checkpoint['epoch']}"
        filepath = utils.save_checkpoint(logdir=f"{logdir}/checkpoints/",
                                         checkpoint=checkpoint,
                                         suffix=suffix,
                                         is_best=is_best,
                                         is_last=True)

        valid_metrics = checkpoint["valid_metrics"]
        checkpoint_metric = valid_metrics[main_metric]
        self.top_best_metrics.append(
            (filepath, checkpoint_metric, valid_metrics))
        self.top_best_metrics = sorted(self.top_best_metrics,
                                       key=lambda x: x[1],
                                       reverse=not minimize_metric)
        if len(self.top_best_metrics) > save_n_best:
            last_item = self.top_best_metrics.pop(-1)
            last_filepath = last_item[0]
            os.remove(last_filepath)

        checkpoints = [(Path(filepath).stem, valid_metric)
                       for (filepath, _, valid_metric) in self.top_best_metrics
                       ]
        best_valid_metrics = checkpoints[0][1]
        metrics = OrderedDict([("best", best_valid_metrics)] + checkpoints +
                              [("last", valid_metrics)])
        safitty.save(metrics, f"{logdir}/checkpoints/_metrics.json")
예제 #5
0
    def process_checkpoint(self, logdir: str, checkpoint: Dict,
                           batch_values: Dict[str, float]):
        filepath = utils.save_checkpoint(
            logdir=f"{logdir}/checkpoints/",
            checkpoint=checkpoint,
            suffix=self.get_checkpoint_suffix(checkpoint),
            is_best=False,
            is_last=False)

        self.last_checkpoints.append((filepath, batch_values))
        self.truncate_checkpoints()

        metrics = self.get_metric()
        self.save_metric(logdir, metrics)
        print(f"\nSaved checkpoint at {filepath}")
예제 #6
0
    def save_checkpoint(self, logdir, checkpoint, save_n_last):
        suffix = f"{checkpoint['stage']}." \
                 f"epoch.{checkpoint['epoch']}." \
                 f"iter.{self._iteration_counter}"

        filepath = utils.save_checkpoint(logdir=f"{logdir}/checkpoints/",
                                         checkpoint=checkpoint,
                                         suffix=suffix,
                                         is_best=False,
                                         is_last=False)

        self.last_checkpoints.append(filepath)
        if len(self.last_checkpoints) > save_n_last:
            top_filepath = self.last_checkpoints.pop(0)
            os.remove(top_filepath)

        print(f"\nSaved checkpoint at {filepath}")
예제 #7
0
    def process_checkpoint(self, logdir: str, checkpoint: Dict,
                           batch_values: Dict[str, float]):
        suffix = f"{checkpoint['stage']}." \
                 f"epoch.{checkpoint['epoch']}." \
                 f"iter.{self._iteration_counter}"

        filepath = utils.save_checkpoint(logdir=f"{logdir}/checkpoints/",
                                         checkpoint=checkpoint,
                                         suffix=suffix,
                                         is_best=False,
                                         is_last=False)

        self.last_checkpoints.append((filepath, batch_values))
        self.truncate_checkpoints()

        metrics = self.get_metric()
        self.save_metric(logdir, metrics)
        print(f"\nSaved checkpoint at {filepath}")
예제 #8
0
    def process_checkpoint(self,
                           logdir: str,
                           checkpoint: Dict,
                           is_best: bool,
                           main_metric: str = "loss",
                           minimize_metric: bool = True):
        filepath = utils.save_checkpoint(
            logdir=f"{logdir}/checkpoints/",
            checkpoint=checkpoint,
            suffix=self.get_checkpoint_suffix(checkpoint),
            is_best=is_best,
            is_last=True)

        valid_metrics = checkpoint["valid_metrics"]
        checkpoint_metric = valid_metrics[main_metric]
        self.top_best_metrics.append(
            (filepath, checkpoint_metric, valid_metrics))
        self.truncate_checkpoints(minimize_metric=minimize_metric)

        metrics = self.get_metric(valid_metrics)
        self.save_metric(logdir, metrics)