示例#1
0
    def on_epoch_end(self, state: _State):
        if state.stage_name.startswith("infer"):
            return

        state.valid_metrics = {
            k.replace(f"{state.valid_loader}_", ""): v
            for k, v in state.epoch_metrics.items()
            if k.startswith(state.valid_loader)
        }
        assert state.main_metric in state.valid_metrics, \
            f"{state.main_metric} value is not available by the epoch end"

        current_valid_metric = state.valid_metrics[state.main_metric]
        if state.minimize_metric:
            best_valid_metric = \
                state.best_valid_metrics.get(state.main_metric, float("+inf"))
            is_best = current_valid_metric < best_valid_metric
        else:
            best_valid_metric = \
                state.best_valid_metrics.get(state.main_metric, float("-inf"))
            is_best = current_valid_metric > best_valid_metric

        if is_best:
            state.is_best_valid = True
            state.best_valid_metrics = state.valid_metrics.copy()
示例#2
0
def _load_checkpoint(*, filename, state: _State):
    if os.path.isfile(filename):
        print(f"=> loading checkpoint {filename}")
        checkpoint = utils.load_checkpoint(filename)

        if not state.stage_name.startswith("infer"):
            state.stage_name = checkpoint["stage_name"]
            state.epoch = checkpoint["epoch"]
            state.global_epoch = checkpoint["global_epoch"]
            # @TODO: should we also load,
            # checkpoint_data, main_metric, minimize_metric, valid_loader ?
            # epoch_metrics, valid_metrics ?

        utils.unpack_checkpoint(checkpoint,
                                model=state.model,
                                criterion=state.criterion,
                                optimizer=state.optimizer,
                                scheduler=state.scheduler)

        print(f"loaded checkpoint {filename} "
              f"(global epoch {checkpoint['global_epoch']}, "
              f"epoch {checkpoint['epoch']}, "
              f"stage {checkpoint['stage_name']})")
    else:
        raise Exception(f"No checkpoint found at {filename}")
示例#3
0
    def load_checkpoint(*, filename, state: _State):
        if os.path.isfile(filename):
            print(f"=> loading checkpoint {filename}")
            checkpoint = utils.load_checkpoint(filename)

            if not state.stage.startswith("infer"):
                state.epoch = checkpoint["epoch"]
                state.stage_epoch = checkpoint["stage_epoch"]
                state.stage = checkpoint["stage"]

            utils.unpack_checkpoint(
                checkpoint,
                model=state.model,
                criterion=state.criterion,
                optimizer=state.optimizer,
                scheduler=state.scheduler
            )

            print(
                f"loaded checkpoint {filename} "
                f"(epoch {checkpoint['epoch']}, "
                f"stage_epoch {checkpoint['stage_epoch']}, "
                f"stage {checkpoint['stage']})"
            )
        else:
            raise Exception(f"No checkpoint found at {filename}")
示例#4
0
    def update_optimizer(self, state: _State):
        if not state.need_backward:
            return

        optimizer = state.get_key(key="optimizer",
                                  inner_key=self.optimizer_key)
        lr, momentum = self._update_optimizer(optimizer=optimizer)
        state.set_key(lr, key="lr", inner_key=self.optimizer_key)
        state.set_key(momentum, key="momentum", inner_key=self.optimizer_key)
 def on_stage_start(self, state: _State):
     """On stage start event"""
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = utils.get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
示例#6
0
    def step_batch(self, state: _State):
        lr, momentum = self._scheduler_step(scheduler=self._scheduler)

        if self.scheduler_key is not None:
            state.batch_metrics[f"lr_{self.scheduler_key}"] = lr
            state.batch_metrics[f"momentum_{self.scheduler_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            state.batch_metrics["momentum"] = momentum
示例#7
0
    def update_optimizer(self, state: _State):
        lr, momentum = self._update_optimizer(optimizer=self._optimizer)

        if self.optimizer_key is not None:
            state.batch_metrics[f"lr_{self.optimizer_key}"] = lr
            state.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum
        else:
            state.batch_metrics["lr"] = lr
            state.batch_metrics["momentum"] = momentum
示例#8
0
    def step(self, state: _State):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)

        valid_metric = \
            safitty.get(state.metric_manager.valid_values, self.reduce_metric)
        lr, momentum = self._scheduler_step(scheduler=scheduler,
                                            valid_metric=valid_metric)

        state.set_key(lr, key="lr", inner_key=self.scheduler_key)
        state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
示例#9
0
    def step_epoch(self, state: _State):
        reduced_metric = state.valid_metrics[self.reduced_metric]
        lr, momentum = self._scheduler_step(
            scheduler=self._scheduler, reduced_metric=reduced_metric
        )

        if self.scheduler_key is not None:
            state.epoch_metrics[f"lr_{self.scheduler_key}"] = lr
            state.epoch_metrics[f"momentum_{self.scheduler_key}"] = momentum
        else:
            state.epoch_metrics["lr"] = lr
            state.epoch_metrics["momentum"] = momentum
示例#10
0
 def on_stage_start(self, state: _State):
     optimizer = state.get_attr(
         key="optimizer", inner_key=self.optimizer_key
     )
     assert optimizer is not None
     self._optimizer = optimizer
     self.init_lr = optimizer.defaults["lr"]
示例#11
0
def _add_loss_to_state(loss_key: Optional[str], state: _State,
                       loss: torch.Tensor):
    if loss_key is None:
        if state.loss is not None:
            if isinstance(state.loss, list):
                state.loss.append(loss)
            else:
                state.loss = [state.loss, loss]
        else:
            state.loss = loss
    else:
        if state.loss is not None:
            assert isinstance(state.loss, dict)
            state.loss[loss_key] = loss
        else:
            state.loss = {loss_key: loss}
示例#12
0
 def on_stage_start(self, state: _State):
     """
     Checks that the current stage has correct optimizer
     """
     optimizer = state.get_attr(key="optimizer",
                                inner_key=self.optimizer_key)
     assert optimizer is not None
     self._optimizer = optimizer
示例#13
0
 def on_loader_start(self, state: _State):
     scheduler = state.get_key(key="scheduler",
                               inner_key=self.scheduler_key)
     if state.loader_name.startswith("train") and \
             isinstance(scheduler, OneCycleLRWithWarmup) and \
             self.mode == "batch":
         scheduler.recalculate(loader_len=state.loader_len,
                               current_step=state.stage_epoch)
示例#14
0
    def on_batch_end(self, state: _State):
        metrics_ = self._compute_metric(state)

        for arg, metric in zip(self.list_args, metrics_):
            if isinstance(arg, int):
                key = f"{self.prefix}{arg:02}"
            else:
                key = f"{self.prefix}_{arg}"
            state.batch_metrics[key] = metric * self.multiplier
示例#15
0
    def on_exception(self, state: _State):
        """Called if an Exception was raised"""
        exception = state.exception
        if not utils.is_exception(exception):
            return

        if isinstance(exception, KeyboardInterrupt):
            self.tqdm.write("Early exiting")
            state.need_exception_reraise = False
示例#16
0
 def on_stage_start(self, state: _State):
     """
     Checks that the current stage has correct criterion
     """
     criterion = state.get_attr(
         key="criterion", inner_key=self.criterion_key
     )
     assert criterion is not None
     self._criterion = criterion
    def on_batch_end(self, state: _State):
        """On batch end event"""
        if not state.need_backward:
            return

        loss = self._get_loss(state)

        self._accumulation_counter += 1
        model = state.model
        optimizer = state.get_key(key="optimizer",
                                  inner_key=self.optimizer_key)

        need_gradient_step = \
            (self._accumulation_counter + 1) % self.accumulation_steps == 0

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(optimizer, "_amp_stash"):
            from apex import amp
            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(optimizer=optimizer,
                           optimizer_wds=self._optimizer_wd,
                           grad_clip_fn=self.grad_clip_fn)

            if self.save_model_grads:
                for tag, value in model.named_parameters():
                    tag = tag.replace(".", "/")
                    state.model_grads[tag] = value.grad.cpu().numpy()

            utils.maybe_recursive_call(model, "zero_grad")

            self._accumulation_counter = 0
 def on_epoch_end(self, state: _State):
     """On epoch end event"""
     if self.decouple_weight_decay:
         optimizer = state.get_key(key="optimizer",
                                   inner_key=self.optimizer_key)
         for i, wd in enumerate(self._optimizer_wd):
             safitty.set(optimizer.param_groups,
                         i,
                         "weight_decay",
                         value=wd)
示例#19
0
    def on_batch_end(self, state: _State):
        self.timer.stop("_timer/model_time")
        self.timer.stop("_timer/batch_time")

        # @TODO: just a trick
        self.timer.elapsed["_timer/_fps"] = \
            state.batch_size / self.timer.elapsed["_timer/batch_time"]
        for key, value in self.timer.elapsed.items():
            state.batch_metrics[key] = value

        self.timer.reset()
        self.timer.start("_timer/batch_time")
        self.timer.start("_timer/data_time")
示例#20
0
    def on_batch_end(self, state: _State) -> None:
        """
        Computes the loss and add it to the metrics
        """
        loss = state.get_key(key="loss")
        loss = self._preprocess_loss(loss)
        loss = self.loss_fn(loss)

        state.metric_manager.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        _add_loss_to_state(self.prefix, state, loss)
示例#21
0
    def on_stage_start(self, state: _State):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)
        assert scheduler is not None

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if isinstance(scheduler, OneCycleLRWithWarmup) and \
                self.mode == "batch":
            scheduler.reset()
示例#22
0
    def on_batch_end(self, state: _State):
        """
        Computes the loss and add it to the metrics
        """
        criterion = state.get_key(key="criterion",
                                  inner_key=self.criterion_key)

        loss = self._compute_loss(state, criterion) * self.multiplier

        state.metric_manager.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        _add_loss_to_state(self.prefix, state, loss)
 def on_epoch_start(self, state: _State):
     """On epoch start event"""
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     if self.decouple_weight_decay:
         self._optimizer_wd = [
             group.get("weight_decay", 0.0)
             for group in optimizer.param_groups
         ]
         for i in range(len(optimizer.param_groups)):
             safitty.set(optimizer.param_groups,
                         i,
                         "weight_decay",
                         value=0.0)
     else:
         self._optimizer_wd = [0.0] * len(optimizer.param_groups)
示例#24
0
    def on_epoch_end(self, state: _State) -> None:
        if state.stage_name.startswith("infer"):
            return

        score = state.valid_metrics[self.metric]
        if self.best_score is None:
            self.best_score = score
        if self.is_better(score, self.best_score):
            self.num_bad_epochs = 0
            self.best_score = score
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print(f"Early stop at {state.epoch} epoch")
            state.need_early_stop = True
示例#25
0
    def load_checkpoint(*, filename, state: _State):
        if os.path.isfile(filename):
            print(f"=> loading checkpoint {filename}")
            checkpoint = utils.load_checkpoint(filename)

            state.epoch = checkpoint["epoch"]

            utils.unpack_checkpoint(checkpoint,
                                    model=state.model,
                                    criterion=state.criterion,
                                    optimizer=state.optimizer,
                                    scheduler=state.scheduler)

            print(
                f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})")
        else:
            raise Exception(f"No checkpoint found at {filename}")
    def _get_loss(self, state: _State) -> torch.Tensor:
        loss = state.get_key(key="loss", inner_key=self.loss_key)

        if isinstance(loss, list):
            raise ValueError(
                f"Loss is a list. "
                f"Only the last value will be used for `backward`."
                f"To aggregate losses into "
                "one value use `CriterionAggregatorCallback`")
        if isinstance(loss, dict):
            error = f"Loss is a dict: {list(loss.keys())}, " \
                    f"to aggregate losses into " \
                    "one value use `CriterionAggregatorCallback`."
            if self.loss_key is None:
                error = error + " Or try to pass `loss_key` " \
                                "in the OptimizerCallback init"
            raise ValueError(error)
        return loss
示例#27
0
 def on_loader_end(self, state: _State):
     for key, value in self.meters.items():
         value = value.mean
         state.loader_metrics[key] = value
     for key, value in state.loader_metrics.items():
         state.epoch_metrics[f"{state.loader_name}_{key}"] = value
示例#28
0
 def on_batch_start(self, state: _State):
     state.batch_metrics = defaultdict(None)
 def on_batch_start(self, state: _State):
     """On batch start event"""
     state.loss = None
示例#30
0
 def on_batch_end(self, state: _State):
     state.batch_metrics = self._process_metrics(state.batch_metrics)
     for key, value in state.batch_metrics.items():
         self.meters[key].add(value)