def on_stage_start(self, state: RunnerState): optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def on_stage_start(self, state: RunnerState): self.fp16 = isinstance(state.model, Fp16Wrap) optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def step(self, state: RunnerState): scheduler = state.get_key(key="scheduler", inner_key=self.scheduler_key) valid_metric = \ safitty.get(state.metrics.valid_values, self.reduce_metric) lr, momentum = self._scheduler_step(scheduler=scheduler, valid_metric=valid_metric) state.set_key(lr, key="lr", inner_key=self.scheduler_key) state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
def _init_state(self, *, mode: str, stage: str = None, **kwargs) -> RunnerState: """ Inner method for children's classes for state specific initialization. :return: RunnerState with all necessary parameters. """ additional_kwargs = {} # transfer previous counters from old state if self.state is not None: additional_kwargs = { "step": self.state.step, "epoch": self.state.epoch + 1, "best_metrics": self.state.best_metrics } return RunnerState(device=self.device, model=self.model, stage=self.stage, criterion=self.criterion, optimizer=self.optimizer, scheduler=self.scheduler, **kwargs, **additional_kwargs)
def on_loader_start(self, state: RunnerState): scheduler = state.get_key(key="scheduler", inner_key=self.scheduler_key) if state.loader_name.startswith("train") and \ isinstance(scheduler, OneCycleLR) and self.mode == "batch": scheduler.recalculate(loader_len=state.loader_len, current_step=state.stage_epoch)
def on_stage_start(self, state: RunnerState): scheduler = state.get_key(key="scheduler", inner_key=self.scheduler_key) assert scheduler is not None if self.mode is None: if isinstance(scheduler, BatchScheduler): self.mode = "batch" else: self.mode = "epoch" if isinstance(scheduler, OneCycleLR) and self.mode == "batch": scheduler.reset()
def on_epoch_end(self, state: RunnerState) -> None: if state.stage.startswith("infer"): return score = state.metrics.valid_values[self.metric] if self.best_score is None: self.best_score = score if self.is_better(score, self.best_score): self.num_bad_epochs = 0 self.best_score = score else: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: print(f"Early stop at {state.epoch} epoch") state.early_stop = True
def batch_handler(self, *, dct: Dict, model: nn.Module, state: RunnerState = None) -> Dict: """ Batch handler wrapper with main statistics and device management. :param dct: key-value storage with input tensors :param model: model to predict with :param state: runner state :return: key-value storage with model predictions """ dct = {key: value.to(self.device) for key, value in dct.items()} if state is not None: state.input = dct output = self._batch_handler(dct=dct, model=model) return output
def _prepare_state(self, stage: str): migrating_params = {} if self.state is not None: migrating_params.update({ "step": self.state.step, "epoch": self.state.epoch + 1 }) self.model, criterion, optimizer, scheduler, self.device = \ self._get_experiment_components(stage) self.state = RunnerState(stage=stage, model=self.model, device=self.device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, **self.experiment.get_state_params(stage), **migrating_params)
def batch_handler(self, *, dct: Dict, model: nn.Module, state: RunnerState = None) -> Dict: """ Batch handler wrapper with main statistics and device management. :param dct: key-value storage with input tensors :param model: model to predict with :param state: runner state :return: key-value storage with model predictions """ if isinstance(dct, (tuple, list)): assert len(dct) == 2 dct = {"features": dct[0], "targets": dct[1]} dct = {key: value.to(state.device) for key, value in dct.items()} if state is not None: state.input = dct logits = model(dct["features"]) output = {"logits": logits} return output