예제 #1
0
 def on_stage_start(self, state: RunnerState):
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
예제 #2
0
 def on_stage_start(self, state: RunnerState):
     self.fp16 = isinstance(state.model, Fp16Wrap)
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
예제 #3
0
    def _update_optimizer(self, optimizer):
        new_lr = self.calc_lr()
        if new_lr is not None:
            self._update_lr(optimizer, new_lr)

        new_momentum = self.calc_momentum()
        if new_momentum is not None:
            self._update_momentum(optimizer, new_momentum)
        else:
            new_momentum = get_optimizer_momentum(optimizer)

        return new_lr, new_momentum
예제 #4
0
    def _scheduler_step(
        scheduler,
        valid_metric=None,
    ):
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(valid_metric)
            lr = safitty.get(scheduler.optimizer.param_groups, 0, "lr")
        else:
            scheduler.step()
            lr = scheduler.get_lr()[0]

        momentum = get_optimizer_momentum(scheduler.optimizer)

        return lr, momentum
예제 #5
0
    def __init__(
            self,
            optimizer: Optimizer,
            num_steps: int,
            lr_range=(1.0, 0.005),
            init_lr: float = None,
            warmup_steps: int = 0,
            warmup_fraction: float = None,
            decay_steps: int = 0,
            decay_fraction: float = None,
            momentum_range=(0.8, 0.99, 0.999),
            init_momentum: float = None,
    ):
        """
        Args:
            optimizer: PyTorch optimizer
            num_steps (int): total number of steps
            lr_range: tuple with two or three elements
                (max_lr, min_lr, [final_lr])
            init_lr (float, optional): initial lr
            warmup_steps (int): count of steps for warm-up stage
            warmup_fraction (float, optional): fraction in [0; 1) to calculate
                number of warmup steps.
                Cannot be set together with ``warmup_steps``
            decay_steps (int): count of steps for lr decay stage
            decay_fraction (float, optional): fraction in [0; 1) to calculate
                number of decay steps.
                Cannot be set together with ``decay_steps``
            momentum_range: tuple with two or three elements
                (min_momentum, max_momentum, [final_momentum])
            init_momentum (float, optional): initial momentum
        """
        if len(lr_range) == 2:
            max_lr, min_lr = lr_range
            final_lr = min_lr
        elif len(lr_range) == 3:
            max_lr, min_lr, final_lr = lr_range

        if len(momentum_range) == 2:
            min_momentum, max_momentum = momentum_range
            final_momentum = max_momentum
        elif len(momentum_range) == 3:
            min_momentum, max_momentum, final_momentum = momentum_range

        if init_lr is None:
            init_lr = optimizer.defaults["lr"]
        if init_momentum is None:
            init_momentum = get_optimizer_momentum(optimizer)

        warmup_steps = self._calculate_warmup(
            num_steps, warmup_steps, warmup_fraction
        )

        decay_steps = self._calculate_decay(
            num_steps, decay_steps, decay_fraction
        )

        lr_annealing_steps = num_steps - (warmup_steps + decay_steps)

        self.warmup_steps = warmup_steps
        self.lr_annealing_steps = lr_annealing_steps
        self.decay_steps = decay_steps
        self.num_steps = warmup_steps + lr_annealing_steps + decay_steps

        self.lr_range = init_lr, max_lr, min_lr, final_lr
        self.momentum_range = \
            init_momentum, min_momentum, max_momentum, final_momentum

        self._calculate_lr_momentum(
            warmup_steps,
            lr_annealing_steps,
            decay_steps
        )

        self.total_groups = len(optimizer.param_groups)
        super().__init__(optimizer)