Exemple #1
0
 def on_stage_start(self, state: RunnerState):
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Exemple #2
0
 def on_stage_start(self, state: RunnerState):
     self.fp16 = isinstance(state.model, Fp16Wrap)
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Exemple #3
0
    def step(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)

        valid_metric = \
            safitty.get(state.metrics.valid_values, self.reduce_metric)
        lr, momentum = self._scheduler_step(scheduler=scheduler,
                                            valid_metric=valid_metric)

        state.set_key(lr, key="lr", inner_key=self.scheduler_key)
        state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
Exemple #4
0
 def _init_state(self,
                 *,
                 mode: str,
                 stage: str = None,
                 **kwargs) -> RunnerState:
     """
     Inner method for children's classes for state specific initialization.
     :return: RunnerState with all necessary parameters.
     """
     additional_kwargs = {}
     # transfer previous counters from old state
     if self.state is not None:
         additional_kwargs = {
             "step": self.state.step,
             "epoch": self.state.epoch + 1,
             "best_metrics": self.state.best_metrics
         }
     return RunnerState(device=self.device,
                        model=self.model,
                        stage=self.stage,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        scheduler=self.scheduler,
                        **kwargs,
                        **additional_kwargs)
Exemple #5
0
 def on_loader_start(self, state: RunnerState):
     scheduler = state.get_key(key="scheduler",
                               inner_key=self.scheduler_key)
     if state.loader_name.startswith("train") and \
             isinstance(scheduler, OneCycleLR) and self.mode == "batch":
         scheduler.recalculate(loader_len=state.loader_len,
                               current_step=state.stage_epoch)
Exemple #6
0
    def on_stage_start(self, state: RunnerState):
        scheduler = state.get_key(key="scheduler",
                                  inner_key=self.scheduler_key)
        assert scheduler is not None

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if isinstance(scheduler, OneCycleLR) and self.mode == "batch":
            scheduler.reset()
Exemple #7
0
    def on_epoch_end(self, state: RunnerState) -> None:
        if state.stage.startswith("infer"):
            return

        score = state.metrics.valid_values[self.metric]
        if self.best_score is None:
            self.best_score = score
        if self.is_better(score, self.best_score):
            self.num_bad_epochs = 0
            self.best_score = score
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print(f"Early stop at {state.epoch} epoch")
            state.early_stop = True
Exemple #8
0
    def batch_handler(self,
                      *,
                      dct: Dict,
                      model: nn.Module,
                      state: RunnerState = None) -> Dict:
        """
        Batch handler wrapper with main statistics and device management.

        :param dct: key-value storage with input tensors
        :param model: model to predict with
        :param state: runner state
        :return: key-value storage with model predictions
        """
        dct = {key: value.to(self.device) for key, value in dct.items()}
        if state is not None:
            state.input = dct
        output = self._batch_handler(dct=dct, model=model)
        return output
Exemple #9
0
    def _prepare_state(self, stage: str):
        migrating_params = {}
        if self.state is not None:
            migrating_params.update({
                "step": self.state.step,
                "epoch": self.state.epoch + 1
            })

        self.model, criterion, optimizer, scheduler, self.device = \
            self._get_experiment_components(stage)

        self.state = RunnerState(stage=stage,
                                 model=self.model,
                                 device=self.device,
                                 criterion=criterion,
                                 optimizer=optimizer,
                                 scheduler=scheduler,
                                 **self.experiment.get_state_params(stage),
                                 **migrating_params)
Exemple #10
0
    def batch_handler(self,
                      *,
                      dct: Dict,
                      model: nn.Module,
                      state: RunnerState = None) -> Dict:
        """
        Batch handler wrapper with main statistics and device management.

        :param dct: key-value storage with input tensors
        :param model: model to predict with
        :param state: runner state
        :return: key-value storage with model predictions
        """
        if isinstance(dct, (tuple, list)):
            assert len(dct) == 2
            dct = {"features": dct[0], "targets": dct[1]}
        dct = {key: value.to(state.device) for key, value in dct.items()}
        if state is not None:
            state.input = dct
        logits = model(dct["features"])
        output = {"logits": logits}
        return output