Ejemplo n.º 1
0
def _configure_schedulers_manual_opt(
        schedulers: list) -> List[LRSchedulerConfig]:
    """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual
    optimization."""
    lr_scheduler_configs = []
    for scheduler in schedulers:
        if isinstance(scheduler, dict):
            invalid_keys = {
                "interval", "frequency", "reduce_on_plateau", "monitor",
                "strict"
            }
            keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

            if keys_to_warn:
                rank_zero_warn(
                    f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
                    " You need to call `lr_scheduler.step()` manually in manual optimization.",
                    category=RuntimeWarning,
                )

            config = LRSchedulerConfig(
                **{
                    key: scheduler[key]
                    for key in scheduler if key not in invalid_keys
                })
        else:
            config = LRSchedulerConfig(scheduler)
        lr_scheduler_configs.append(config)
    return lr_scheduler_configs
Ejemplo n.º 2
0
    def _initialize_deepspeed_train(self, model):
        optimizer, scheduler = None, None
        if "optimizer" in self.config:
            rank_zero_info(
                "You have specified an optimizer and/or scheduler within the DeepSpeed config."
                " It is recommended to define it in `LightningModule.configure_optimizers`."
            )
            lr_scheduler = None
        else:
            optimizer, lr_scheduler, _ = self._init_optimizers()
            if lr_scheduler is not None:
                scheduler = lr_scheduler.scheduler

        model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
        self._set_deepspeed_activation_checkpointing()

        # although we set these here, deepspeed manages the specific optimizer logic
        self.optimizers = [deepspeed_optimizer]

        deepspeed_scheduler = model.lr_scheduler
        if deepspeed_scheduler is not None:
            # disable deepspeed lr scheduling as lightning manages scheduling
            model.lr_scheduler = None
            if lr_scheduler is None:
                lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step", opt_idx=0)
            else:
                lr_scheduler.scheduler = deepspeed_scheduler
            self.lr_scheduler_configs = [lr_scheduler]
        self.model = model
Ejemplo n.º 3
0
def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
    """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic
    optimization."""
    lr_scheduler_configs = []
    for scheduler in schedulers:
        if isinstance(scheduler, dict):
            # check provided keys
            supported_keys = {field.name for field in fields(LRSchedulerConfig)}
            extra_keys = scheduler.keys() - supported_keys
            if extra_keys:
                rank_zero_warn(
                    f"Found unsupported keys in the lr scheduler dict: {extra_keys}."
                    " HINT: remove them from the output of `configure_optimizers`.",
                    category=RuntimeWarning,
                )
                scheduler = {k: v for k, v in scheduler.items() if k in supported_keys}
            if "scheduler" not in scheduler:
                raise MisconfigurationException(
                    'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
                )
            if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
                raise MisconfigurationException(
                    'The "interval" key in lr scheduler dict must be "step" or "epoch"'
                    f' but is "{scheduler["interval"]}"'
                )
            scheduler["reduce_on_plateau"] = isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau)
            if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
                raise MisconfigurationException(
                    "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
                    ' For example: {"optimizer": optimizer, "lr_scheduler":'
                    ' {"scheduler": scheduler, "monitor": "your_loss"}}'
                )
            is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
            if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
                rank_zero_warn(
                    "A `OneCycleLR` scheduler is using 'interval': 'epoch'."
                    " Are you sure you didn't mean 'interval': 'step'?",
                    category=RuntimeWarning,
                )
            config = LRSchedulerConfig(**scheduler)
        elif isinstance(scheduler, ReduceLROnPlateau):
            if monitor is None:
                raise MisconfigurationException(
                    "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
                    " scheduler is used. For example:"
                    ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
                )
            config = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
        else:
            config = LRSchedulerConfig(scheduler)
        lr_scheduler_configs.append(config)
    return lr_scheduler_configs
Ejemplo n.º 4
0
def test_reducelronplateau_scheduling(tmpdir):
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            self.log("foo", batch_idx)
            return super().training_step(batch, batch_idx)

        def configure_optimizers(self):
            optimizer = optim.Adam(self.parameters())
            return {
                "optimizer": optimizer,
                "lr_scheduler":
                optim.lr_scheduler.ReduceLROnPlateau(optimizer),
                "monitor": "foo",
            }

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    lr_scheduler = trainer.lr_scheduler_configs[0]
    assert lr_scheduler == LRSchedulerConfig(
        scheduler=lr_scheduler.scheduler,
        monitor="foo",
        interval="epoch",
        frequency=1,
        reduce_on_plateau=True,
        strict=True,
        opt_idx=0,
        name=None,
    )
Ejemplo n.º 5
0
        def func(trainer: "pl.Trainer") -> None:
            # Decide the structure of the output from _init_optimizers_and_lr_schedulers
            optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)

            if len(optimizers) != 1:
                raise MisconfigurationException(
                    f"`model.configure_optimizers()` returned {len(optimizers)}, but"
                    " learning rate finder only works with single optimizer"
                )

            optimizer = optimizers[0]

            new_lrs = [self.lr_min] * len(optimizer.param_groups)
            for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
                param_group["lr"] = new_lr
                param_group["initial_lr"] = new_lr

            args = (optimizer, self.lr_max, self.num_training)
            scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
            scheduler = cast(pl.utilities.types._LRScheduler, scheduler)

            trainer.strategy.optimizers = [optimizer]
            trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
            trainer.strategy.optimizer_frequencies = []
            _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
Ejemplo n.º 6
0
    def on_train_epoch_start(self, trainer: "pl.Trainer",
                             pl_module: "pl.LightningModule"):
        if trainer.current_epoch == self.swa_start:
            # move average model to request device.
            self._average_model = self._average_model.to(self._device
                                                         or pl_module.device)

            optimizer = trainer.optimizers[0]
            if self._swa_lrs is None:
                self._swa_lrs = [
                    param_group["lr"] for param_group in optimizer.param_groups
                ]
            if isinstance(self._swa_lrs, float):
                self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

            for lr, group in zip(self._swa_lrs, optimizer.param_groups):
                group["initial_lr"] = lr

            self._swa_scheduler = SWALR(
                optimizer,
                swa_lr=self._swa_lrs,
                anneal_epochs=self._annealing_epochs,
                anneal_strategy=self._annealing_strategy,
                last_epoch=trainer.max_epochs
                if self._annealing_strategy == "cos" else -1,
            )
            default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler)
            assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

            if trainer.lr_scheduler_configs:
                scheduler_cfg = trainer.lr_scheduler_configs[0]
                if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
                    rank_zero_warn(
                        f"SWA is currently only supported every epoch. Found {scheduler_cfg}"
                    )
                rank_zero_info(
                    f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
                    f" for `{self._swa_scheduler.__class__.__name__}`")
                trainer.lr_scheduler_configs[0] = default_scheduler_cfg
            else:
                trainer.lr_scheduler_configs.append(default_scheduler_cfg)

            self.n_averaged = torch.tensor(0,
                                           dtype=torch.long,
                                           device=pl_module.device)

        if self.swa_start <= trainer.current_epoch <= self.swa_end:
            self.update_parameters(self._average_model, pl_module,
                                   self.n_averaged, self.avg_fn)

        # Note: No > here in case the callback is saved with the model and training continues
        if trainer.current_epoch == self.swa_end + 1:

            # Transfer weights from average model to pl_module
            self.transfer_weights(self._average_model, pl_module)

            # Reset BatchNorm for update
            self.reset_batch_norm_and_save_state(pl_module)

            # There is no need to perform either backward or optimizer.step as we are
            # performing only one pass over the train data-loader to compute activation statistics
            # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
            trainer.num_training_batches += 1
            trainer.fit_loop._skip_backward = True
            self._accumulate_grad_batches = trainer.accumulate_grad_batches

            trainer.accumulate_grad_batches = trainer.num_training_batches
Ejemplo n.º 7
0
def test_optimizer_return_options(tmpdir):
    trainer = Trainer(default_root_dir=tmpdir)
    model = BoringModel()
    trainer.strategy.connect(model)
    trainer.lightning_module.trainer = trainer

    # single optimizer
    opt_a = optim.Adam(model.parameters(), lr=0.002)
    opt_b = optim.SGD(model.parameters(), lr=0.002)
    scheduler_a = optim.lr_scheduler.StepLR(opt_a, 10)
    scheduler_b = optim.lr_scheduler.StepLR(opt_b, 10)

    # single optimizer
    model.configure_optimizers = lambda: opt_a
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == 1 and len(lr_sched) == len(freq) == 0

    # opt tuple
    model.configure_optimizers = lambda: (opt_a, opt_b)
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert opt == [opt_a, opt_b]
    assert len(lr_sched) == len(freq) == 0

    # opt list
    model.configure_optimizers = lambda: [opt_a, opt_b]
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert opt == [opt_a, opt_b]
    assert len(lr_sched) == len(freq) == 0

    ref_lr_sched = LRSchedulerConfig(
        scheduler=scheduler_a,
        interval="epoch",
        frequency=1,
        reduce_on_plateau=False,
        monitor=None,
        strict=True,
        name=None,
        opt_idx=0,
    )

    # opt tuple of 2 lists
    model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt tuple of 1 list
    model.configure_optimizers = lambda: ([opt_a], scheduler_a)
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt single dictionary
    model.configure_optimizers = lambda: {
        "optimizer": opt_a,
        "lr_scheduler": scheduler_a
    }
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt multiple dictionaries with frequencies
    model.configure_optimizers = lambda: (
        {
            "optimizer": opt_a,
            "lr_scheduler": scheduler_a,
            "frequency": 1
        },
        {
            "optimizer": opt_b,
            "lr_scheduler": scheduler_b,
            "frequency": 5
        },
    )
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == len(freq) == 2
    assert opt[0] == opt_a
    ref_lr_sched.opt_idx = 0
    assert lr_sched[0] == ref_lr_sched
    ref_lr_sched.scheduler = scheduler_b
    ref_lr_sched.opt_idx = 1
    assert lr_sched[1] == ref_lr_sched
    assert freq == [1, 5]