def on_stage_start(self, state: RunnerState): optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def step(self, state: RunnerState): scheduler = state.get_key(key="scheduler", inner_key=self.scheduler_key) valid_metric = \ safitty.get(state.metrics.valid_values, self.reduce_metric) lr, momentum = self._scheduler_step(scheduler=scheduler, valid_metric=valid_metric) state.set_key(lr, key="lr", inner_key=self.scheduler_key) state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
def _add_loss_to_state(self, state: RunnerState, loss): if self.loss_key is None: if state.loss is not None: if isinstance(state.loss, list): state.loss.append(loss) else: state.loss = [state.loss, loss] else: state.loss = loss else: if state.loss is not None: assert isinstance(state.loss, dict) state.loss[self.loss_key] = loss else: state.loss = {self.loss_key: loss}
def on_loader_start(self, state: RunnerState): scheduler = state.get_key(key="scheduler", inner_key=self.scheduler_key) if state.loader_name.startswith("train") and \ isinstance(scheduler, OneCycleLR) and self.mode == "batch": scheduler.recalculate(loader_len=state.loader_len, current_step=state.stage_epoch)
def _add_loss_to_state(loss_key: Optional[str], state: RunnerState, loss: torch.Tensor): if loss_key is None: if state.loss is not None: if isinstance(state.loss, list): state.loss.append(loss) else: state.loss = [state.loss, loss] else: state.loss = loss else: if state.loss is not None: assert isinstance(state.loss, dict) state.loss[loss_key] = loss else: state.loss = {loss_key: loss}
def on_exception(self, state: RunnerState): exception = state.exception if not utils.is_exception(exception): return if isinstance(exception, KeyboardInterrupt): self.tqdm.write("Early exiting") state.need_reraise_exception = False
def on_batch_end(self, state: RunnerState) -> None: loss = state.get_key(key="loss") loss = self._preprocess_loss(loss) loss = self.loss_fn(loss) state.metrics.add_batch_value(metrics_dict={ self.prefix: loss.item(), }) _add_loss_to_state(self.prefix, state, loss)
def on_batch_end(self, state: RunnerState): criterion = state.get_key(key="criterion", inner_key=self.criterion_key) loss = self._compute_loss(state, criterion) * self.multiplier state.metrics.add_batch_value(metrics_dict={ self.prefix: loss.item(), }) self._add_loss_to_state(state, loss)
def on_stage_start(self, state: RunnerState): scheduler = state.get_key(key="scheduler", inner_key=self.scheduler_key) assert scheduler is not None if self.mode is None: if isinstance(scheduler, BatchScheduler): self.mode = "batch" else: self.mode = "epoch" if isinstance(scheduler, OneCycleLR) and self.mode == "batch": scheduler.reset()
def on_batch_end(self, state: RunnerState): if state.loader_name.startswith("train"): criterion = state.get_key(key="criterion", inner_key=self.criterion_key) else: criterion = nn.CrossEntropyLoss() loss = self._compute_loss(state, criterion) * self.multiplier state.metrics.add_batch_value(metrics_dict={ self.prefix: loss.item(), }) self._add_loss_to_state(state, loss)
def on_batch_start(self, state: RunnerState): if not self.is_needed: return if self.alpha > 0: self.lam = np.random.beta(self.alpha, self.alpha) else: self.lam = 1 self.index = torch.randperm(state.input[self.fields[0]].shape[0]) self.index.to(state.device) for f in self.fields: state.input[f] = self.lam * state.input[f] + \ (1 - self.lam) * state.input[f][self.index]
def on_epoch_end(self, state: RunnerState) -> None: if state.stage.startswith("infer"): return score = state.metrics.valid_values[self.metric] if self.best_score is None: self.best_score = score if self.is_better(score, self.best_score): self.num_bad_epochs = 0 self.best_score = score else: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: print(f"Early stop at {state.stage_epoch} epoch") state.early_stop = True
def load_checkpoint(*, filename, state: RunnerState): if os.path.isfile(filename): print(f"=> loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) state.epoch = checkpoint["epoch"] utils.unpack_checkpoint(checkpoint, model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler) print( f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})") else: raise Exception(f"No checkpoint found at {filename}")