def on_stage_start(self, state: _State): """On stage start event""" optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = utils.get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def on_stage_start(self, state: RunnerState): optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def _update_optimizer(self, optimizer): new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) new_momentum = self.calc_momentum() if new_momentum is not None: self._update_momentum(optimizer, new_momentum) else: new_momentum = get_optimizer_momentum(optimizer) return new_lr, new_momentum
def _scheduler_step( scheduler, valid_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(valid_metric) lr = safitty.get(scheduler.optimizer.param_groups, 0, "lr") else: scheduler.step() lr = scheduler.get_lr()[0] momentum = get_optimizer_momentum(scheduler.optimizer) return lr, momentum
def _scheduler_step( scheduler, reduced_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(reduced_metric) lr = scheduler.optimizer.param_groups[0]["lr"] else: scheduler.step() lr = scheduler.get_lr()[0] momentum = utils.get_optimizer_momentum(scheduler.optimizer) return lr, momentum
def on_epoch_end(self, state: State): """On epoch end event""" if self.decouple_weight_decay: for i, wd in enumerate(self._optimizer_wd): self._optimizer.param_groups[i]["weight_decay"] = wd lr = self._optimizer.param_groups[0]["lr"] lr_name = (f"lr/{self.optimizer_key}" if self.optimizer_key is not None else "lr") state.epoch_metrics[lr_name] = lr momentum = utils.get_optimizer_momentum(self._optimizer) if momentum is not None: momentum_name = (f"momentum/{self.optimizer_key}" if self.optimizer_key is not None else "momentum") state.epoch_metrics[momentum_name] = momentum
def __init__( self, optimizer: Optimizer, num_steps: int, lr_range=(1.0, 0.005), init_lr: float = None, warmup_steps: int = 0, warmup_fraction: float = None, decay_steps: int = 0, decay_fraction: float = None, momentum_range=(0.8, 0.99, 0.999), init_momentum: float = None, ): """ Args: optimizer: PyTorch optimizer num_steps (int): total number of steps lr_range: tuple with two or three elements (max_lr, min_lr, [final_lr]) init_lr (float, optional): initial lr warmup_steps (int): count of steps for warm-up stage warmup_fraction (float, optional): fraction in [0; 1) to calculate number of warmup steps. Cannot be set together with ``warmup_steps`` decay_steps (int): count of steps for lr decay stage decay_fraction (float, optional): fraction in [0; 1) to calculate number of decay steps. Cannot be set together with ``decay_steps`` momentum_range: tuple with two or three elements (min_momentum, max_momentum, [final_momentum]) init_momentum (float, optional): initial momentum """ if len(lr_range) == 2: max_lr, min_lr = lr_range final_lr = min_lr elif len(lr_range) == 3: max_lr, min_lr, final_lr = lr_range if len(momentum_range) == 2: min_momentum, max_momentum = momentum_range final_momentum = max_momentum elif len(momentum_range) == 3: min_momentum, max_momentum, final_momentum = momentum_range if init_lr is None: init_lr = optimizer.defaults["lr"] if init_momentum is None: init_momentum = get_optimizer_momentum(optimizer) warmup_steps = self._calculate_warmup(num_steps, warmup_steps, warmup_fraction) decay_steps = self._calculate_decay(num_steps, decay_steps, decay_fraction) lr_annealing_steps = num_steps - (warmup_steps + decay_steps) self.warmup_steps = warmup_steps self.lr_annealing_steps = lr_annealing_steps self.decay_steps = decay_steps self.num_steps = warmup_steps + lr_annealing_steps + decay_steps self.lr_range = init_lr, max_lr, min_lr, final_lr self.momentum_range = \ init_momentum, min_momentum, max_momentum, final_momentum self._calculate_lr_momentum(warmup_steps, lr_annealing_steps, decay_steps) self.total_groups = len(optimizer.param_groups) super().__init__(optimizer)