Example #1
0
 def on_loader_start(self, state: RunnerState):
     scheduler = state.get_key(key="scheduler",
                               inner_key=self.scheduler_key)
     if state.loader_name.startswith("train") and \
             isinstance(scheduler, OneCycleLR) and self.mode == "batch":
         scheduler.recalculate(loader_len=state.loader_len,
                               current_step=state.stage_epoch)
Example #2
0
 def on_stage_start(self, state: RunnerState):
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Example #3
0
    def on_batch_end(self, state: RunnerState) -> None:
        loss = state.get_key(key="loss")
        loss = self._preprocess_loss(loss)
        loss = self.loss_fn(loss)

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        _add_loss_to_state(self.prefix, state, loss)
Example #4
0
    def step(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)

        valid_metric = \
            safitty.get(state.metrics.valid_values, self.reduce_metric)
        lr, momentum = self._scheduler_step(scheduler=scheduler,
                                            valid_metric=valid_metric)

        state.set_key(lr, key="lr", inner_key=self.scheduler_key)
        state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
Example #5
0
    def on_batch_end(self, state: RunnerState):
        criterion = state.get_key(key="criterion",
                                  inner_key=self.criterion_key)

        loss = self._compute_loss(state, criterion) * self.multiplier

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        self._add_loss_to_state(state, loss)
Example #6
0
    def on_stage_start(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)
        assert scheduler is not None

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if isinstance(scheduler, OneCycleLR) and self.mode == "batch":
            scheduler.reset()
Example #7
0
    def on_batch_end(self, state: RunnerState):
        if state.loader_name.startswith("train"):
            criterion = state.get_key(key="criterion",
                                      inner_key=self.criterion_key)
        else:
            criterion = nn.CrossEntropyLoss()

        loss = self._compute_loss(state, criterion) * self.multiplier

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        self._add_loss_to_state(state, loss)