def on_epoch_end(self, state: RunnerState): if state.stage.startswith("infer"): return valid_metrics = dict(state.metrics.valid_values) epoch_metrics = dict(state.metrics.epoch_values) checkpoint = utils.pack_checkpoint( model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler, epoch_metrics=epoch_metrics, valid_metrics=valid_metrics, stage=state.stage, epoch=state.epoch_log, checkpoint_data=state.checkpoint_data ) self.process_checkpoint( logdir=state.logdir, checkpoint=checkpoint, is_best=state.metrics.is_best, main_metric=state.main_metric, minimize_metric=state.minimize_metric )
def on_exception(self, state: RunnerState): exception = state.exception if not is_exception(exception): return try: valid_metrics = state.metrics.valid_values epoch_metrics = state.metrics.epoch_values checkpoint = utils.pack_checkpoint( model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler, epoch_metrics=epoch_metrics, valid_metrics=valid_metrics, stage=state.stage, epoch=state.epoch_log, checkpoint_data=state.checkpoint_data ) suffix = self.get_checkpoint_suffix(checkpoint) suffix = f"{suffix}.exception_{exception.__class__.__name__}" utils.save_checkpoint( logdir=f"{state.logdir}/checkpoints/", checkpoint=checkpoint, suffix=suffix, is_best=False, is_last=False ) metrics = self.metrics metrics[suffix] = valid_metrics self.save_metric(state.logdir, metrics) except Exception: pass
def on_batch_end(self, state): self._iteration_counter += 1 if self._iteration_counter % self.num_iters == 0: checkpoint = utils.pack_checkpoint(model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler, epoch_metrics=None, valid_metrics=None, stage=state.stage, epoch=state.epoch_log) self.process_checkpoint(logdir=state.logdir, checkpoint=checkpoint, batch_values=state.metrics.batch_values)
def pack_checkpoint(self, **kwargs): return utils.pack_checkpoint(**kwargs)