Ejemplo n.º 1
0
 def on_loader_start(self, state: RunnerState):
     scheduler = state.get_key(key="scheduler",
                               inner_key=self.scheduler_key)
     if state.loader_name.startswith("train") and \
             isinstance(scheduler, OneCycleLR) and self.mode == "batch":
         scheduler.recalculate(loader_len=state.loader_len,
                               current_step=state.stage_epoch)
Ejemplo n.º 2
0
 def on_stage_start(self, state: RunnerState):
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Ejemplo n.º 3
0
 def on_stage_start(self, state: RunnerState):
     self.fp16 = isinstance(state.model, Fp16Wrap)
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Ejemplo n.º 4
0
    def step(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)

        valid_metric = \
            safitty.get(state.metrics.valid_values, self.reduce_metric)
        lr, momentum = self._scheduler_step(scheduler=scheduler,
                                            valid_metric=valid_metric)

        state.set_key(lr, key="lr", inner_key=self.scheduler_key)
        state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
Ejemplo n.º 5
0
    def on_stage_start(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)
        assert scheduler is not None

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if isinstance(scheduler, OneCycleLR) and self.mode == "batch":
            scheduler.reset()