Exemplo n.º 1
0
    def on_epoch_end(self, state: RunnerState):
        if state.stage.startswith("infer"):
            return

        valid_metrics = dict(state.metrics.valid_values)
        epoch_metrics = dict(state.metrics.epoch_values)

        checkpoint = utils.pack_checkpoint(
            model=state.model,
            criterion=state.criterion,
            optimizer=state.optimizer,
            scheduler=state.scheduler,
            epoch_metrics=epoch_metrics,
            valid_metrics=valid_metrics,
            stage=state.stage,
            epoch=state.epoch_log,
            checkpoint_data=state.checkpoint_data
        )
        self.process_checkpoint(
            logdir=state.logdir,
            checkpoint=checkpoint,
            is_best=state.metrics.is_best,
            main_metric=state.main_metric,
            minimize_metric=state.minimize_metric
        )
Exemplo n.º 2
0
    def on_exception(self, state: RunnerState):
        exception = state.exception
        if not is_exception(exception):
            return

        try:
            valid_metrics = state.metrics.valid_values
            epoch_metrics = state.metrics.epoch_values
            checkpoint = utils.pack_checkpoint(
                model=state.model,
                criterion=state.criterion,
                optimizer=state.optimizer,
                scheduler=state.scheduler,
                epoch_metrics=epoch_metrics,
                valid_metrics=valid_metrics,
                stage=state.stage,
                epoch=state.epoch_log,
                checkpoint_data=state.checkpoint_data
            )
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(
                logdir=f"{state.logdir}/checkpoints/",
                checkpoint=checkpoint,
                suffix=suffix,
                is_best=False,
                is_last=False
            )
            metrics = self.metrics
            metrics[suffix] = valid_metrics
            self.save_metric(state.logdir, metrics)
        except Exception:
            pass
Exemplo n.º 3
0
 def on_batch_end(self, state):
     self._iteration_counter += 1
     if self._iteration_counter % self.num_iters == 0:
         checkpoint = utils.pack_checkpoint(model=state.model,
                                            criterion=state.criterion,
                                            optimizer=state.optimizer,
                                            scheduler=state.scheduler,
                                            epoch_metrics=None,
                                            valid_metrics=None,
                                            stage=state.stage,
                                            epoch=state.epoch_log)
         self.process_checkpoint(logdir=state.logdir,
                                 checkpoint=checkpoint,
                                 batch_values=state.metrics.batch_values)
Exemplo n.º 4
0
 def pack_checkpoint(self, **kwargs):
     return utils.pack_checkpoint(**kwargs)