Ejemplo n.º 1
0
    def on_exception(self, runner: IRunner):
        exception = runner.exception
        if not utils.is_exception(exception):
            return

        if runner.device.type == "xla":
            from torch_xla.core.xla_model import save
        else:
            from torch import save

        try:
            checkpoint = _pack_runner(runner)
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(
                logdir=Path(f"{runner.logdir}/checkpoints/"),
                checkpoint=checkpoint,
                suffix=suffix,
                is_best=False,
                is_last=False,
                saver_fn=save,
            )
            metrics = self.metrics
            metrics[suffix] = runner.valid_metrics
            self.save_metric(runner.logdir, metrics)
        except Exception:  # noqa: S110
            pass
    def _save_checkpoint(
        self,
        logdir: Union[str, Path],
        suffix: str,
        checkpoint: Dict,
        is_best: bool,
        is_last: bool,
    ) -> Tuple[str, str]:
        #print("saving checkpoint")
        if self.save_full:
            full_checkpoint_path = utils.save_checkpoint(
                logdir=Path(f"{logdir}/checkpoints/"),
                checkpoint=checkpoint,
                suffix=f"{suffix}_full",
                is_best=is_best,
                is_last=is_last,
                special_suffix="_full",
            )
        else:
            full_checkpoint_path = None

        exclude = ["criterion", "optimizer", "scheduler"]
        checkpoint_path = utils.save_checkpoint(
            checkpoint={
                key: value
                for key, value in checkpoint.items()
                if all(z not in key for z in exclude)
            },
            logdir=Path(f"{logdir}/checkpoints/"),
            suffix=suffix,
            is_best=is_best,
            is_last=is_last,
        )
        return (full_checkpoint_path, checkpoint_path)
Ejemplo n.º 3
0
    def process_checkpoint(
        self,
        logdir: Union[str, Path],
        checkpoint: Dict,
        batch_metrics: Dict[str, float],
    ):
        """
        Save checkpoint and metrics.

        Args:
            logdir (str or Path object): directory for storing checkpoints
            checkpoint (dict): dict with checkpoint data
            batch_metrics (dict): dict with metrics based on a few batches
        """
        filepath = utils.save_checkpoint(
            logdir=Path(f"{logdir}/checkpoints/"),
            checkpoint=checkpoint,
            suffix=self.get_checkpoint_suffix(checkpoint),
            is_best=False,
            is_last=False,
            saver_fn=self._save_fn,
        )

        self.last_checkpoints.append((filepath, batch_metrics))
        self.truncate_checkpoints()

        self.metrics_history.append(batch_metrics)

        metrics = self.process_metrics()
        self.save_metric(logdir, metrics)
        print(f"\nSaved checkpoint at {filepath}")
Ejemplo n.º 4
0
    def _save_checkpoint(
        self,
        logdir: Union[str, Path],
        suffix: str,
        checkpoint: Dict,
        is_best: bool,
        is_last: bool,
    ) -> Tuple[str, str]:
        """
        Save checkpoint (simple and full).

        Args:
            logdir (str or Path object): directory for storing checkpoints
            suffix (str): checkpoint suffix
            checkpoint (dict): dict with checkpoint data
            is_best (bool): indicator to save best checkpoint,
                if true then will be saved two additional checkpoints -
                ``best`` and ``best_full``.
            is_last (bool): indicator to save the last checkpoint,
                if true then will be saved two additional checkpoints -
                ``last`` and ``last_full``.
        """
        full_checkpoint_path = utils.save_checkpoint(
            logdir=Path(f"{logdir}/checkpoints/"),
            checkpoint=checkpoint,
            suffix=f"{suffix}_full",
            is_best=is_best,
            is_last=is_last,
            special_suffix="_full",
            saver_fn=self._save_fn,
        )
        exclude = ["criterion", "optimizer", "scheduler"]
        checkpoint_path = utils.save_checkpoint(
            checkpoint={
                key: value
                for key, value in checkpoint.items()
                if all(z not in key for z in exclude)
            },
            logdir=Path(f"{logdir}/checkpoints/"),
            suffix=suffix,
            is_best=is_best,
            is_last=is_last,
            saver_fn=self._save_fn,
        )
        return (full_checkpoint_path, checkpoint_path)
Ejemplo n.º 5
0
    def process_checkpoint(
        self,
        logdir: Union[str, Path],
        checkpoint: Dict,
        is_best: bool,
        main_metric: str = "loss",
        minimize_metric: bool = True,
    ):
        """@TODO: Docs. Contribution is welcome."""
        suffix = self.get_checkpoint_suffix(checkpoint)
        utils.save_checkpoint(
            logdir=Path(f"{logdir}/checkpoints/"),
            checkpoint=checkpoint,
            suffix=f"{suffix}_full",
            is_best=is_best,
            is_last=True,
            special_suffix="_full",
        )

        exclude = ["criterion", "optimizer", "scheduler"]
        checkpoint = {
            key: value
            for key, value in checkpoint.items()
            if all(z not in key for z in exclude)
        }
        filepath = utils.save_checkpoint(
            checkpoint=checkpoint,
            logdir=Path(f"{logdir}/checkpoints/"),
            suffix=suffix,
            is_best=is_best,
            is_last=True,
        )

        valid_metrics = checkpoint["valid_metrics"]
        checkpoint_metric = valid_metrics[main_metric]
        metrics_record = (filepath, checkpoint_metric, valid_metrics)
        self.top_best_metrics.append(metrics_record)
        self.metrics_history.append(metrics_record)
        self.truncate_checkpoints(minimize_metric=minimize_metric)
        metrics = self.process_metrics(valid_metrics)
        self.save_metric(logdir, metrics)
Ejemplo n.º 6
0
    def on_exception(self, state: State):
        exception = state.exception
        if not utils.is_exception(exception):
            return

        try:
            checkpoint = _pack_state(state)
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(
                logdir=Path(f"{state.logdir}/checkpoints/"),
                checkpoint=checkpoint,
                suffix=suffix,
                is_best=False,
                is_last=False,
            )
            metrics = self.metrics
            metrics[suffix] = state.valid_metrics
            self.save_metric(state.logdir, metrics)
        except Exception:
            pass
Ejemplo n.º 7
0
    def on_exception(self, runner: IRunner):
        exception = runner.exception
        if not utils.is_exception(exception):
            return

        try:
            checkpoint = _pack_runner(runner)
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(
                logdir=Path(f"{runner.logdir}/checkpoints/"),
                checkpoint=checkpoint,
                suffix=suffix,
                is_best=False,
                is_last=False,
            )
            metrics = self.metrics
            metrics[suffix] = runner.valid_metrics
            self.save_metric(runner.logdir, metrics)
        except Exception:  # noqa: S110
            pass
Ejemplo n.º 8
0
    def process_checkpoint(
        self,
        logdir: Union[str, Path],
        checkpoint: Dict,
        batch_metrics: Dict[str, float],
    ):
        """@TODO: Docs. Contribution is welcome."""
        filepath = utils.save_checkpoint(
            logdir=Path(f"{logdir}/checkpoints/"),
            checkpoint=checkpoint,
            suffix=self.get_checkpoint_suffix(checkpoint),
            is_best=False,
            is_last=False,
        )

        self.last_checkpoints.append((filepath, batch_metrics))
        self.truncate_checkpoints()

        self.metrics_history.append(batch_metrics)

        metrics = self.process_metrics()
        self.save_metric(logdir, metrics)
        print(f"\nSaved checkpoint at {filepath}")