def test_invalid_lr_scheduler_with_custom_step_method(override):
    """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler API."""

    class CustomScheduler:
        def __init__(self, optimizer):
            self.optimizer = optimizer

        def step(self, foobar):  # breaks the API, forces user to override `lr_scheduler_step`
            ...

        def state_dict(self):
            ...

        def load_state_dict(self, state_dict):
            ...

    class CustomBoringModel(BoringModel):
        def configure_optimizers(self):
            opt = torch.optim.SGD(self.parameters(), lr=1e-2)
            lr_scheduler = CustomScheduler(opt)
            return {"optimizer": opt, "lr_scheduler": lr_scheduler}

    model = CustomBoringModel()
    model.trainer = Trainer()
    if override:

        def lr_scheduler_step(*_):
            ...

        # the user did override the hook, no error
        model.lr_scheduler_step = lr_scheduler_step
        _init_optimizers_and_lr_schedulers(model)
    else:
        with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"):
            _init_optimizers_and_lr_schedulers(model)
        def func(trainer: "pl.Trainer") -> None:
            # Decide the structure of the output from _init_optimizers_and_lr_schedulers
            optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)

            if len(optimizers) != 1:
                raise MisconfigurationException(
                    f"`model.configure_optimizers()` returned {len(optimizers)}, but"
                    " learning rate finder only works with single optimizer"
                )

            optimizer = optimizers[0]

            new_lrs = [self.lr_min] * len(optimizer.param_groups)
            for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
                param_group["lr"] = new_lr
                param_group["initial_lr"] = new_lr

            args = (optimizer, self.lr_max, self.num_training)
            scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
            scheduler = cast(pl.utilities.types._LRScheduler, scheduler)

            trainer.strategy.optimizers = [optimizer]
            trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
            trainer.strategy.optimizer_frequencies = []
            _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
Exemple #3
0
 def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
     r"""
     .. deprecated:: v1.6
         `TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8.
     """
     rank_zero_deprecation(
         "`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8."
     )
     pl_module = self.lightning_module or model
     return _init_optimizers_and_lr_schedulers(pl_module)
Exemple #4
0
    def setup_optimizers(self, trainer: "pl.Trainer") -> None:
        """Creates optimizers and schedulers.

        Args:
            trainer: the Trainer, these optimizers should be connected to
        """
        if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
            return
        self.optimizers, self.lr_schedulers, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
            self.lightning_module)
 def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Optional[int]]:
     optimizers, lr_schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
     if len(optimizers) > 1 or len(lr_schedulers) > 1:
         raise MisconfigurationException(
             "DeepSpeed currently only supports single optimizer, single optional scheduler."
         )
     return (
         optimizers[0],
         lr_schedulers[0] if lr_schedulers else None,
         optimizer_frequencies[0] if optimizer_frequencies else None,
     )
def test_invalid_scheduler_missing_state_dict():
    """Test that custom lr scheduler raises an error if it's missing the state dict."""

    class CustomScheduler:
        def __init__(self, optimizer):
            self.optimizer = optimizer

        def step(self):
            ...

    class CustomBoringModel(BoringModel):
        def configure_optimizers(self):
            opt = torch.optim.SGD(self.parameters(), lr=1e-2)
            lr_scheduler = CustomScheduler(opt)
            return {"optimizer": opt, "lr_scheduler": lr_scheduler}

    model = CustomBoringModel()
    model.trainer = Trainer()
    with pytest.raises(TypeError, match="provided lr scheduler `CustomScheduler` is invalid"):
        _init_optimizers_and_lr_schedulers(model)
Exemple #7
0
def test_optimizer_return_options(tmpdir):
    trainer = Trainer(default_root_dir=tmpdir)
    model = BoringModel()
    trainer.strategy.connect(model)
    trainer.lightning_module.trainer = trainer

    # single optimizer
    opt_a = optim.Adam(model.parameters(), lr=0.002)
    opt_b = optim.SGD(model.parameters(), lr=0.002)
    scheduler_a = optim.lr_scheduler.StepLR(opt_a, 10)
    scheduler_b = optim.lr_scheduler.StepLR(opt_b, 10)

    # single optimizer
    model.configure_optimizers = lambda: opt_a
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == 1 and len(lr_sched) == len(freq) == 0

    # opt tuple
    model.configure_optimizers = lambda: (opt_a, opt_b)
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert opt == [opt_a, opt_b]
    assert len(lr_sched) == len(freq) == 0

    # opt list
    model.configure_optimizers = lambda: [opt_a, opt_b]
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert opt == [opt_a, opt_b]
    assert len(lr_sched) == len(freq) == 0

    ref_lr_sched = LRSchedulerConfig(
        scheduler=scheduler_a,
        interval="epoch",
        frequency=1,
        reduce_on_plateau=False,
        monitor=None,
        strict=True,
        name=None,
        opt_idx=0,
    )

    # opt tuple of 2 lists
    model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt tuple of 1 list
    model.configure_optimizers = lambda: ([opt_a], scheduler_a)
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt single dictionary
    model.configure_optimizers = lambda: {
        "optimizer": opt_a,
        "lr_scheduler": scheduler_a
    }
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt multiple dictionaries with frequencies
    model.configure_optimizers = lambda: (
        {
            "optimizer": opt_a,
            "lr_scheduler": scheduler_a,
            "frequency": 1
        },
        {
            "optimizer": opt_b,
            "lr_scheduler": scheduler_b,
            "frequency": 5
        },
    )
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == len(freq) == 2
    assert opt[0] == opt_a
    ref_lr_sched.opt_idx = 0
    assert lr_sched[0] == ref_lr_sched
    ref_lr_sched.scheduler = scheduler_b
    ref_lr_sched.opt_idx = 1
    assert lr_sched[1] == ref_lr_sched
    assert freq == [1, 5]