示例#1
0
    def on_epoch_end(self, state: _State):
        if state.stage_name.startswith("infer"):
            return

        state.valid_metrics = {
            k.replace(f"{state.valid_loader}_", ""): v
            for k, v in state.epoch_metrics.items()
            if k.startswith(state.valid_loader)
        }
        assert state.main_metric in state.valid_metrics, \
            f"{state.main_metric} value is not available by the epoch end"

        current_valid_metric = state.valid_metrics[state.main_metric]
        if state.minimize_metric:
            best_valid_metric = \
                state.best_valid_metrics.get(state.main_metric, float("+inf"))
            is_best = current_valid_metric < best_valid_metric
        else:
            best_valid_metric = \
                state.best_valid_metrics.get(state.main_metric, float("-inf"))
            is_best = current_valid_metric > best_valid_metric

        if is_best:
            state.is_best_valid = True
            state.best_valid_metrics = state.valid_metrics.copy()
示例#2
0
def _load_checkpoint(*, filename, state: _State):
    if os.path.isfile(filename):
        print(f"=> loading checkpoint {filename}")
        checkpoint = utils.load_checkpoint(filename)

        if not state.stage_name.startswith("infer"):
            state.stage_name = checkpoint["stage_name"]
            state.epoch = checkpoint["epoch"]
            state.global_epoch = checkpoint["global_epoch"]
            # @TODO: should we also load,
            # checkpoint_data, main_metric, minimize_metric, valid_loader ?
            # epoch_metrics, valid_metrics ?

        utils.unpack_checkpoint(
            checkpoint,
            model=state.model,
            criterion=state.criterion,
            optimizer=state.optimizer,
            scheduler=state.scheduler
        )

        print(
            f"loaded checkpoint {filename} "
            f"(global epoch {checkpoint['global_epoch']}, "
            f"epoch {checkpoint['epoch']}, "
            f"stage {checkpoint['stage_name']})"
        )
    else:
        raise Exception(f"No checkpoint found at {filename}")
示例#3
0
    def step_batch(self, state: _State):
        lr, momentum = self._scheduler_step(scheduler=self._scheduler)

        if self.scheduler_key is not None:
            state.batch_metrics[f"lr_{self.scheduler_key}"] = lr
            state.batch_metrics[f"momentum_{self.scheduler_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            state.batch_metrics["momentum"] = momentum
示例#4
0
    def update_optimizer(self, state: _State):
        lr, momentum = self._update_optimizer(optimizer=self._optimizer)

        if self.optimizer_key is not None:
            state.batch_metrics[f"lr_{self.optimizer_key}"] = lr
            state.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            state.batch_metrics["momentum"] = momentum
示例#5
0
    def step_epoch(self, state: _State):
        reduced_metric = state.valid_metrics[self.reduced_metric]
        lr, momentum = self._scheduler_step(scheduler=self._scheduler,
                                            reduced_metric=reduced_metric)

        if self.scheduler_key is not None:
            state.epoch_metrics[f"lr_{self.scheduler_key}"] = lr
            state.epoch_metrics[f"momentum_{self.scheduler_key}"] = momentum
        else:
            state.epoch_metrics["lr"] = lr
            state.epoch_metrics["momentum"] = momentum
示例#6
0
 def on_stage_start(self, state: _State):
     """
     Checks that the current stage has correct criterion
     """
     criterion = state.get_attr(
         key="criterion", inner_key=self.criterion_key
     )
     assert criterion is not None
     self._criterion = criterion
示例#7
0
    def on_batch_end(self, state: _State):
        metrics_ = self._compute_metric(state)

        for arg, metric in zip(self.list_args, metrics_):
            if isinstance(arg, int):
                key = f"{self.prefix}{arg:02}"
            else:
                key = f"{self.prefix}_{arg}"
            state.batch_metrics[key] = metric * self.multiplier
示例#8
0
    def on_exception(self, state: _State):
        """Called if an Exception was raised"""
        exception = state.exception
        if not utils.is_exception(exception):
            return

        if isinstance(exception, KeyboardInterrupt):
            self.tqdm.write("Early exiting")
            state.need_exception_reraise = False
示例#9
0
 def on_stage_start(self, state: _State):
     """
     Checks that the current stage has correct optimizer
     """
     optimizer = state.get_attr(
         key="optimizer", inner_key=self.optimizer_key
     )
     assert optimizer is not None
     self._optimizer = optimizer
示例#10
0
    def on_stage_start(self, state: _State):
        scheduler = state.get_attr(key="scheduler",
                                   inner_key=self.scheduler_key)
        assert scheduler is not None
        self._scheduler = scheduler

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if isinstance(scheduler, OneCycleLRWithWarmup) and \
                self.mode == "batch":
            scheduler.reset()
示例#11
0
    def on_epoch_end(self, state: _State) -> None:
        if state.stage_name.startswith("infer"):
            return

        score = state.valid_metrics[self.metric]
        if self.best_score is None:
            self.best_score = score
        if self.is_better(score, self.best_score):
            self.num_bad_epochs = 0
            self.best_score = score
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print(f"Early stop at {state.epoch} epoch")
            state.need_early_stop = True
示例#12
0
 def on_loader_end(self, state: _State):
     for key, value in self.meters.items():
         value = value.mean
         state.loader_metrics[key] = value
     for key, value in state.loader_metrics.items():
         state.epoch_metrics[f"{state.loader_name}_{key}"] = value
示例#13
0
 def on_loader_start(self, state: _State):
     state.loader_metrics = defaultdict(None)
     self.meters = defaultdict(meters.AverageValueMeter)
示例#14
0
 def on_epoch_start(self, state: _State):
     state.epoch_metrics = defaultdict(None)
示例#15
0
 def on_batch_start(self, state: _State):
     state.phase = self.phase_manager.get_phase_name(state)
示例#16
0
 def on_epoch_start(self, state: _State):
     state.valid_metrics = defaultdict(None)
     state.is_best_valid = False
示例#17
0
 def on_batch_end(self, state: _State):
     if state.loader_step >= self.num_batch_steps:
         state.need_early_stop = True
示例#18
0
 def on_batch_start(self, state: _State):
     state.batch_metrics = defaultdict(None)
示例#19
0
 def on_stage_start(self, state: _State):
     optimizer = state.get_attr(key="optimizer",
                                inner_key=self.optimizer_key)
     assert optimizer is not None
     self._optimizer = optimizer
     self.init_lr = optimizer.defaults["lr"]
示例#20
0
 def on_batch_end(self, state: _State):
     state.batch_metrics = self._process_metrics(state.batch_metrics)
     for key, value in state.batch_metrics.items():
         self.meters[key].add(value)
示例#21
0
 def on_epoch_end(self, state: _State):
     if state.epoch >= self.num_epoch_steps:
         state.need_early_stop = True