Beispiel #1
0
    def on_epoch_end(self, state: State) -> None:
        """On epoch end event.

        Args:
            state (State): current state
        """
        if self.decouple_weight_decay:
            for i, wd in enumerate(self._optimizer_wd):
                self._optimizer.param_groups[i]["weight_decay"] = wd

        lr = self._optimizer.param_groups[0]["lr"]
        lr_name = (
            f"lr/{self.optimizer_key}"
            if self.optimizer_key is not None
            else "lr"
        )
        state.epoch_metrics[lr_name] = lr

        momentum = utils.get_optimizer_momentum(self._optimizer)
        if momentum is not None:
            momentum_name = (
                f"momentum/{self.optimizer_key}"
                if self.optimizer_key is not None
                else "momentum"
            )
            state.epoch_metrics[momentum_name] = momentum
Beispiel #2
0
def _load_checkpoint(*, filename, state: State):
    if os.path.isfile(filename):
        print(f"=> loading checkpoint {filename}")
        checkpoint = utils.load_checkpoint(filename)

        if not state.stage_name.startswith("infer"):
            state.stage_name = checkpoint["stage_name"]
            state.epoch = checkpoint["epoch"]
            state.global_epoch = checkpoint["global_epoch"]
            # @TODO: should we also load,
            # checkpoint_data, main_metric, minimize_metric, valid_loader ?
            # epoch_metrics, valid_metrics ?

        utils.unpack_checkpoint(checkpoint,
                                model=state.model,
                                criterion=state.criterion,
                                optimizer=state.optimizer,
                                scheduler=state.scheduler)

        print(f"loaded checkpoint {filename} "
              f"(global epoch {checkpoint['global_epoch']}, "
              f"epoch {checkpoint['epoch']}, "
              f"stage {checkpoint['stage_name']})")
    else:
        raise Exception(f"No checkpoint found at {filename}")
Beispiel #3
0
    def on_epoch_end(self, state: State) -> None:
        """Epoch end hook.

        Args:
            state (State): current state
        """
        if state.stage_name.startswith("infer"):
            return

        state.valid_metrics = {
            k.replace(f"{state.valid_loader}_", ""): v
            for k, v in state.epoch_metrics.items()
            if k.startswith(state.valid_loader)
        }
        assert (
            state.main_metric in state.valid_metrics
        ), f"{state.main_metric} value is not available by the epoch end"

        current_valid_metric = state.valid_metrics[state.main_metric]
        if state.minimize_metric:
            best_valid_metric = state.best_valid_metrics.get(
                state.main_metric, float("+inf"))
            is_best = current_valid_metric < best_valid_metric
        else:
            best_valid_metric = state.best_valid_metrics.get(
                state.main_metric, float("-inf"))
            is_best = current_valid_metric > best_valid_metric

        if is_best:
            state.is_best_valid = True
            state.best_valid_metrics = state.valid_metrics.copy()
Beispiel #4
0
    def on_epoch_start(self, state: State) -> None:
        """Epoch start hook.

        Args:
            state (State): current state
        """
        state.valid_metrics = defaultdict(None)
        state.is_best_valid = False
Beispiel #5
0
    def update_optimizer(self, state: State):
        lr, momentum = self._update_optimizer(optimizer=self._optimizer)

        if self.optimizer_key is not None:
            state.batch_metrics[f"lr_{self.optimizer_key}"] = lr
            state.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            state.batch_metrics["momentum"] = momentum
Beispiel #6
0
def _load_checkpoint(*,
                     filename,
                     state: State,
                     load_full: bool = True) -> None:
    """
    Load checkpoint from a file.

    Arguments:
        filename (str): path to checkpoint
        state (State): training state
        load_full (bool): if true (default) then will be performed
            loading states for criterion, optimizer and scheduler.
            File should contain keys required for
            loading model (``'model_state_dict'``),
            criterion (``'criterion_state_dict'``) (only for full load),
            optimizer (``'optimizer_state_dict'``),
            scheduler (``'scheduler_state_dict'``).

    Raises:
        FileNotFoundError: when file specified in ``filename``
            is not exist.
    """
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"No checkpoint found at {filename}!")

    print(f"=> Loading checkpoint {filename}")
    checkpoint = utils.load_checkpoint(filename)

    if not state.stage_name.startswith("infer") and load_full:
        state.stage_name = checkpoint["stage_name"]
        state.epoch = checkpoint["epoch"]
        state.global_epoch = checkpoint["global_epoch"]
        # @TODO: should we also load,
        # checkpoint_data, main_metric, minimize_metric, valid_loader ?
        # epoch_metrics, valid_metrics ?

    if load_full:
        utils.unpack_checkpoint(
            checkpoint,
            model=state.model,
            criterion=state.criterion,
            optimizer=state.optimizer,
            scheduler=state.scheduler,
        )

        print(f"loaded state checkpoint {filename} "
              f"(global epoch {checkpoint['global_epoch']}, "
              f"epoch {checkpoint['epoch']}, "
              f"stage {checkpoint['stage_name']})")
    else:
        utils.unpack_checkpoint(
            checkpoint,
            model=state.model,
        )

        print(f"loaded model checkpoint {filename}")
Beispiel #7
0
    def step_batch(self, state: State):
        lr, momentum = self._scheduler_step(scheduler=self._scheduler)

        if self.scheduler_key is not None:
            state.batch_metrics[f"lr/{self.scheduler_key}"] = lr
            if momentum is not None:
                state.batch_metrics[f"momentum/{self.scheduler_key}"] = \
                    momentum
        else:
            state.batch_metrics["lr"] = lr
            if momentum is not None:
                state.batch_metrics["momentum"] = momentum
Beispiel #8
0
    def step_epoch(self, state: State):
        reduced_metric = state.valid_metrics[self.reduced_metric]
        lr, momentum = self._scheduler_step(scheduler=self._scheduler,
                                            reduced_metric=reduced_metric)

        if self.scheduler_key is not None:
            state.epoch_metrics[f"lr/{self.scheduler_key}"] = lr
            if momentum is not None:
                state.epoch_metrics[
                    f"momentum/{self.scheduler_key}"] = momentum
        else:
            state.epoch_metrics["lr"] = lr
            if momentum is not None:
                state.epoch_metrics["momentum"] = momentum
Beispiel #9
0
    def update_optimizer(self, state: State) -> None:
        """@TODO: Docs. Contribution is welcome.

        Args:
            state (State): current state
        """
        lr, momentum = self._update_optimizer(optimizer=self._optimizer)

        if self.optimizer_key is not None:
            state.batch_metrics[f"lr_{self.optimizer_key}"] = lr
            state.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            state.batch_metrics["momentum"] = momentum
Beispiel #10
0
 def on_stage_start(self, state: State):
     """
     Checks that the current stage has correct optimizer
     """
     self._optimizer = state.get_attr(key="optimizer",
                                      inner_key=self.optimizer_key)
     assert self._optimizer is not None
Beispiel #11
0
    def on_batch_start(self, state: State) -> None:
        """Batch start hook.

        Args:
            state (State): current state
        """
        state.batch_metrics = defaultdict(None)
Beispiel #12
0
    def on_batch_start(self, state: State):
        """Batch start hook.

        Args:
            state (State): current state
        """
        state.phase = self.phase_manager.get_phase_name(state)
Beispiel #13
0
    def on_stage_start(self, state: State) -> None:
        """Stage start hook.

        Args:
            state (State): current state
        """
        self.reduced_metric = self.reduced_metric or state.main_metric

        scheduler = state.get_attr(
            key="scheduler", inner_key=self.scheduler_key
        )
        assert scheduler is not None
        self._scheduler = scheduler

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if (
            isinstance(scheduler, OneCycleLRWithWarmup)
            and self.mode == "batch"
        ):
            scheduler.reset()
        assert self.mode is not None
Beispiel #14
0
 def on_stage_start(self, state: State):
     optimizer = state.get_attr(
         key="optimizer", inner_key=self.optimizer_key
     )
     assert optimizer is not None
     self._optimizer = optimizer
     self.init_lr = optimizer.defaults["lr"]
Beispiel #15
0
    def step_batch(self, state: State) -> None:
        """@TODO: Docs. Contribution is welcome.

        Args:
            state (State): current state
        """
        lr, momentum = self._scheduler_step(scheduler=self._scheduler)

        if self.scheduler_key is not None:
            state.batch_metrics[f"lr/{self.scheduler_key}"] = lr
            if momentum is not None:
                state.batch_metrics[
                    f"momentum/{self.scheduler_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            if momentum is not None:
                state.batch_metrics["momentum"] = momentum
Beispiel #16
0
 def on_stage_start(self, state: State):
     """
     Checks that the current stage has correct criterion
     """
     criterion = state.get_attr(key="criterion",
                                inner_key=self.criterion_key)
     assert criterion is not None
     self._criterion = criterion
Beispiel #17
0
    def on_loader_start(self, state: State) -> None:
        """Loader start hook.

        Args:
            state (State): current state
        """
        state.loader_metrics = defaultdict(None)
        self.meters = defaultdict(meters.AverageValueMeter)
Beispiel #18
0
    def on_batch_end(self, state: State):
        metrics_ = self._compute_metric(state)

        for arg, metric in zip(self.list_args, metrics_):
            if isinstance(arg, int):
                key = f"{self.prefix}{arg:02}"
            else:
                key = f"{self.prefix}_{arg}"
            state.batch_metrics[key] = metric * self.multiplier
Beispiel #19
0
    def on_batch_end(self, state: State) -> None:
        """Batch end hook.

        Args:
            state (State): current state
        """
        state.batch_metrics = self._process_metrics(state.batch_metrics)
        for key, value in state.batch_metrics.items():
            self.meters[key].add(value)
Beispiel #20
0
    def on_exception(self, state: State):
        """Called if an Exception was raised"""
        exception = state.exception
        if not utils.is_exception(exception):
            return

        if isinstance(exception, KeyboardInterrupt):
            self.tqdm.write("Early exiting")
            state.need_exception_reraise = False
Beispiel #21
0
    def step_epoch(self, state: State) -> None:
        """@TODO: Docs. Contribution is welcome.

        Args:
            state (State): current state
        """
        reduced_metric = state.valid_metrics[self.reduced_metric]
        lr, momentum = self._scheduler_step(scheduler=self._scheduler,
                                            reduced_metric=reduced_metric)

        if self.scheduler_key is not None:
            state.epoch_metrics[f"lr/{self.scheduler_key}"] = lr
            if momentum is not None:
                state.epoch_metrics[
                    f"momentum/{self.scheduler_key}"] = momentum
        else:
            state.epoch_metrics["lr"] = lr
            if momentum is not None:
                state.epoch_metrics["momentum"] = momentum
Beispiel #22
0
    def on_stage_start(self, state: State) -> None:
        """Stage start hook.

        Args:
            state (State): current state
        """
        optimizer = state.get_attr(key="optimizer",
                                   inner_key=self.optimizer_key)
        assert optimizer is not None
        self._optimizer = optimizer
        self.init_lr = optimizer.defaults["lr"]
Beispiel #23
0
    def on_loader_end(self, state: State) -> None:
        """Loader end hook.

        Args:
            state (State): current state
        """
        for key, value in self.meters.items():
            value = value.mean
            state.loader_metrics[key] = value
        for key, value in state.loader_metrics.items():
            state.epoch_metrics[f"{state.loader_name}_{key}"] = value
Beispiel #24
0
    def on_loader_end(self, state: State):
        with torch.no_grad():
            metrics_ = self._compute_metric(state)
            if isinstance(metrics_, torch.Tensor):
                metrics_ = metrics_.detach().cpu().numpy()

        for arg, metric in zip(self.suffixes, metrics_):
            if isinstance(arg, int):
                key = f"{self.prefix}{arg:02}"
            else:
                key = f"{self.prefix}_{arg}"
            state.loader_metrics[key] = metric * self.multiplier
Beispiel #25
0
    def on_batch_end(self, state: State):
        self.timer.stop("_timer/model_time")
        self.timer.stop("_timer/batch_time")

        # @TODO: just a trick
        self.timer.elapsed["_timer/_fps"] = (
            state.batch_size / self.timer.elapsed["_timer/batch_time"])
        for key, value in self.timer.elapsed.items():
            state.batch_metrics[key] = value

        self.timer.reset()
        self.timer.start("_timer/batch_time")
        self.timer.start("_timer/data_time")
Beispiel #26
0
    def on_loader_end(self, state: State):
        input = self._get_input_key(state.memory, self.transform_in_key)
        output = self.transform_fn(**input)
        if isinstance(output, torch.Tensor):
            state.memory[self.transform_out_key] = output
        elif isinstance(output, (list, tuple)):
            for key, value in zip(self.suffixes, output):
                assert isinstance(value, torch.Tensor)

                if self.transform_out_key is not None:
                    key = f"{self.transform_out_key}_{key}"
                state.memory[key] = value
        else:
            raise NotImplementedError()
Beispiel #27
0
    def on_batch_end(self, state: State) -> None:
        """Batch end hook.

        Args:
            state (State): current state
        """
        metrics_ = self._compute_metric(state)

        for arg, metric in zip(self.list_args, metrics_):
            if isinstance(arg, int):
                key = f"{self.prefix}{arg:02}"
            else:
                key = f"{self.prefix}_{arg}"
            state.batch_metrics[key] = metric * self.multiplier
Beispiel #28
0
    def on_epoch_end(self, state: State) -> None:
        if state.stage_name.startswith("infer"):
            return

        score = state.valid_metrics[self.metric]
        if self.best_score is None:
            self.best_score = score
        if self.is_better(score, self.best_score):
            self.num_bad_epochs = 0
            self.best_score = score
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print(f"Early stop at {state.epoch} epoch")
            state.need_early_stop = True
Beispiel #29
0
    def on_batch_end(self, state: State) -> None:
        """Batch end hook.

        Args:
            state (State): current state
        """
        self.timer.stop("_timer/model_time")
        self.timer.stop("_timer/batch_time")

        # @TODO: just a trick
        self.timer.elapsed["_timer/_fps"] = (
            state.batch_size / self.timer.elapsed["_timer/batch_time"])
        for key, value in self.timer.elapsed.items():
            state.batch_metrics[key] = value

        self.timer.reset()
        self.timer.start("_timer/batch_time")
        self.timer.start("_timer/data_time")
Beispiel #30
0
 def on_batch_end(self, state: State):
     if state.loader_step >= self.num_batch_steps:
         state.need_early_stop = True