Exemplo n.º 1
0
 def on_stage_start(self, state: RunnerState):
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Exemplo n.º 2
0
    def step(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)

        valid_metric = \
            safitty.get(state.metrics.valid_values, self.reduce_metric)
        lr, momentum = self._scheduler_step(scheduler=scheduler,
                                            valid_metric=valid_metric)

        state.set_key(lr, key="lr", inner_key=self.scheduler_key)
        state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
Exemplo n.º 3
0
 def _add_loss_to_state(self, state: RunnerState, loss):
     if self.loss_key is None:
         if state.loss is not None:
             if isinstance(state.loss, list):
                 state.loss.append(loss)
             else:
                 state.loss = [state.loss, loss]
         else:
             state.loss = loss
     else:
         if state.loss is not None:
             assert isinstance(state.loss, dict)
             state.loss[self.loss_key] = loss
         else:
             state.loss = {self.loss_key: loss}
Exemplo n.º 4
0
 def on_loader_start(self, state: RunnerState):
     scheduler = state.get_key(key="scheduler",
                               inner_key=self.scheduler_key)
     if state.loader_name.startswith("train") and \
             isinstance(scheduler, OneCycleLR) and self.mode == "batch":
         scheduler.recalculate(loader_len=state.loader_len,
                               current_step=state.stage_epoch)
Exemplo n.º 5
0
def _add_loss_to_state(loss_key: Optional[str], state: RunnerState,
                       loss: torch.Tensor):
    if loss_key is None:
        if state.loss is not None:
            if isinstance(state.loss, list):
                state.loss.append(loss)
            else:
                state.loss = [state.loss, loss]
        else:
            state.loss = loss
    else:
        if state.loss is not None:
            assert isinstance(state.loss, dict)
            state.loss[loss_key] = loss
        else:
            state.loss = {loss_key: loss}
Exemplo n.º 6
0
    def on_exception(self, state: RunnerState):
        exception = state.exception
        if not utils.is_exception(exception):
            return

        if isinstance(exception, KeyboardInterrupt):
            self.tqdm.write("Early exiting")
            state.need_reraise_exception = False
Exemplo n.º 7
0
    def on_batch_end(self, state: RunnerState) -> None:
        loss = state.get_key(key="loss")
        loss = self._preprocess_loss(loss)
        loss = self.loss_fn(loss)

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        _add_loss_to_state(self.prefix, state, loss)
Exemplo n.º 8
0
    def on_batch_end(self, state: RunnerState):
        criterion = state.get_key(key="criterion",
                                  inner_key=self.criterion_key)

        loss = self._compute_loss(state, criterion) * self.multiplier

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        self._add_loss_to_state(state, loss)
Exemplo n.º 9
0
    def on_stage_start(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)
        assert scheduler is not None

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if isinstance(scheduler, OneCycleLR) and self.mode == "batch":
            scheduler.reset()
Exemplo n.º 10
0
    def on_batch_end(self, state: RunnerState):
        if state.loader_name.startswith("train"):
            criterion = state.get_key(key="criterion",
                                      inner_key=self.criterion_key)
        else:
            criterion = nn.CrossEntropyLoss()

        loss = self._compute_loss(state, criterion) * self.multiplier

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        self._add_loss_to_state(state, loss)
Exemplo n.º 11
0
    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.input[self.fields[0]].shape[0])
        self.index.to(state.device)

        for f in self.fields:
            state.input[f] = self.lam * state.input[f] + \
                (1 - self.lam) * state.input[f][self.index]
Exemplo n.º 12
0
    def on_epoch_end(self, state: RunnerState) -> None:
        if state.stage.startswith("infer"):
            return

        score = state.metrics.valid_values[self.metric]
        if self.best_score is None:
            self.best_score = score
        if self.is_better(score, self.best_score):
            self.num_bad_epochs = 0
            self.best_score = score
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print(f"Early stop at {state.stage_epoch} epoch")
            state.early_stop = True
Exemplo n.º 13
0
    def load_checkpoint(*, filename, state: RunnerState):
        if os.path.isfile(filename):
            print(f"=> loading checkpoint {filename}")
            checkpoint = utils.load_checkpoint(filename)

            state.epoch = checkpoint["epoch"]

            utils.unpack_checkpoint(checkpoint,
                                    model=state.model,
                                    criterion=state.criterion,
                                    optimizer=state.optimizer,
                                    scheduler=state.scheduler)

            print(
                f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})")
        else:
            raise Exception(f"No checkpoint found at {filename}")