def on_stage_start(self, state: RunnerState): optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def on_stage_start(self, state: RunnerState): self.fp16 = isinstance(state.model, Fp16Wrap) optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def _update_optimizer(self, optimizer): new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) new_momentum = self.calc_momentum() if new_momentum is not None: self._update_momentum(optimizer, new_momentum) else: new_momentum = get_optimizer_momentum(optimizer) return new_lr, new_momentum
def _scheduler_step( scheduler, valid_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(valid_metric) lr = safitty.get(scheduler.optimizer.param_groups, 0, "lr") else: scheduler.step() lr = scheduler.get_lr()[0] momentum = get_optimizer_momentum(scheduler.optimizer) return lr, momentum
def __init__( self, optimizer: Optimizer, num_steps: int, lr_range=(1.0, 0.005), init_lr: float = None, warmup_steps: int = 0, warmup_fraction: float = None, decay_steps: int = 0, decay_fraction: float = None, momentum_range=(0.8, 0.99, 0.999), init_momentum: float = None, ): """ Args: optimizer: PyTorch optimizer num_steps (int): total number of steps lr_range: tuple with two or three elements (max_lr, min_lr, [final_lr]) init_lr (float, optional): initial lr warmup_steps (int): count of steps for warm-up stage warmup_fraction (float, optional): fraction in [0; 1) to calculate number of warmup steps. Cannot be set together with ``warmup_steps`` decay_steps (int): count of steps for lr decay stage decay_fraction (float, optional): fraction in [0; 1) to calculate number of decay steps. Cannot be set together with ``decay_steps`` momentum_range: tuple with two or three elements (min_momentum, max_momentum, [final_momentum]) init_momentum (float, optional): initial momentum """ if len(lr_range) == 2: max_lr, min_lr = lr_range final_lr = min_lr elif len(lr_range) == 3: max_lr, min_lr, final_lr = lr_range if len(momentum_range) == 2: min_momentum, max_momentum = momentum_range final_momentum = max_momentum elif len(momentum_range) == 3: min_momentum, max_momentum, final_momentum = momentum_range if init_lr is None: init_lr = optimizer.defaults["lr"] if init_momentum is None: init_momentum = get_optimizer_momentum(optimizer) warmup_steps = self._calculate_warmup( num_steps, warmup_steps, warmup_fraction ) decay_steps = self._calculate_decay( num_steps, decay_steps, decay_fraction ) lr_annealing_steps = num_steps - (warmup_steps + decay_steps) self.warmup_steps = warmup_steps self.lr_annealing_steps = lr_annealing_steps self.decay_steps = decay_steps self.num_steps = warmup_steps + lr_annealing_steps + decay_steps self.lr_range = init_lr, max_lr, min_lr, final_lr self.momentum_range = \ init_momentum, min_momentum, max_momentum, final_momentum self._calculate_lr_momentum( warmup_steps, lr_annealing_steps, decay_steps ) self.total_groups = len(optimizer.param_groups) super().__init__(optimizer)