def on_epoch_end(self, state: State) -> None: """On epoch end event. Args: state (State): current state """ if self.decouple_weight_decay: for i, wd in enumerate(self._optimizer_wd): self._optimizer.param_groups[i]["weight_decay"] = wd lr = self._optimizer.param_groups[0]["lr"] lr_name = ( f"lr/{self.optimizer_key}" if self.optimizer_key is not None else "lr" ) state.epoch_metrics[lr_name] = lr momentum = utils.get_optimizer_momentum(self._optimizer) if momentum is not None: momentum_name = ( f"momentum/{self.optimizer_key}" if self.optimizer_key is not None else "momentum" ) state.epoch_metrics[momentum_name] = momentum
def _update_optimizer(self, optimizer) -> Tuple[float, float]: new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) new_momentum = self.calc_momentum() if new_momentum is not None: self._update_momentum(optimizer, new_momentum) else: new_momentum = utils.get_optimizer_momentum(optimizer) return new_lr, new_momentum
def _scheduler_step( scheduler, reduced_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(reduced_metric) lr = scheduler.optimizer.param_groups[0]["lr"] else: scheduler.step() lr = scheduler.get_lr()[0] momentum = utils.get_optimizer_momentum(scheduler.optimizer) return lr, momentum
def on_epoch_end(self, runner: IRunner) -> None: """On epoch end event. Args: runner: current runner """ lr = self._optimizer.param_groups[0]["lr"] lr_name = (f"lr/{self.optimizer_key}" if self.optimizer_key is not None else "lr") runner.epoch_metrics[lr_name] = lr momentum = utils.get_optimizer_momentum(self._optimizer) if momentum is not None: momentum_name = (f"momentum/{self.optimizer_key}" if self.optimizer_key is not None else "momentum") runner.epoch_metrics[momentum_name] = momentum