def _initialize_deepspeed_train(self, model):
        optimizer, scheduler = None, None
        if "optimizer" in self.config:
            rank_zero_info(
                "You have specified an optimizer and/or scheduler within the DeepSpeed config."
                " It is recommended to define it in `LightningModule.configure_optimizers`."
            )
            lr_scheduler = None
        else:
            optimizer, lr_scheduler, _ = self._init_optimizers()
            if lr_scheduler is not None:
                scheduler = lr_scheduler.scheduler

        model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
        self._set_deepspeed_activation_checkpointing()

        # although we set these here, deepspeed manages the specific optimizer logic
        self.optimizers = [deepspeed_optimizer]

        deepspeed_scheduler = model.lr_scheduler
        if deepspeed_scheduler is not None:
            # disable deepspeed lr scheduling as lightning manages scheduling
            model.lr_scheduler = None
            if lr_scheduler is None:
                lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step", opt_idx=0)
            else:
                lr_scheduler.scheduler = deepspeed_scheduler
            self.lr_scheduler_configs = [lr_scheduler]
        self.model = model
Exemple #2
0
def test_optimizer_return_options(tmpdir):
    trainer = Trainer(default_root_dir=tmpdir)
    model = BoringModel()
    trainer.strategy.connect(model)
    trainer.lightning_module.trainer = trainer

    # single optimizer
    opt_a = optim.Adam(model.parameters(), lr=0.002)
    opt_b = optim.SGD(model.parameters(), lr=0.002)
    scheduler_a = optim.lr_scheduler.StepLR(opt_a, 10)
    scheduler_b = optim.lr_scheduler.StepLR(opt_b, 10)

    # single optimizer
    model.configure_optimizers = lambda: opt_a
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == 1 and len(lr_sched) == len(freq) == 0

    # opt tuple
    model.configure_optimizers = lambda: (opt_a, opt_b)
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert opt == [opt_a, opt_b]
    assert len(lr_sched) == len(freq) == 0

    # opt list
    model.configure_optimizers = lambda: [opt_a, opt_b]
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert opt == [opt_a, opt_b]
    assert len(lr_sched) == len(freq) == 0

    ref_lr_sched = LRSchedulerConfig(
        scheduler=scheduler_a,
        interval="epoch",
        frequency=1,
        reduce_on_plateau=False,
        monitor=None,
        strict=True,
        name=None,
        opt_idx=0,
    )

    # opt tuple of 2 lists
    model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt tuple of 1 list
    model.configure_optimizers = lambda: ([opt_a], scheduler_a)
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt single dictionary
    model.configure_optimizers = lambda: {
        "optimizer": opt_a,
        "lr_scheduler": scheduler_a
    }
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == 1
    assert len(freq) == 0
    assert opt[0] == opt_a
    assert lr_sched[0] == ref_lr_sched

    # opt multiple dictionaries with frequencies
    model.configure_optimizers = lambda: (
        {
            "optimizer": opt_a,
            "lr_scheduler": scheduler_a,
            "frequency": 1
        },
        {
            "optimizer": opt_b,
            "lr_scheduler": scheduler_b,
            "frequency": 5
        },
    )
    opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
    assert len(opt) == len(lr_sched) == len(freq) == 2
    assert opt[0] == opt_a
    ref_lr_sched.opt_idx = 0
    assert lr_sched[0] == ref_lr_sched
    ref_lr_sched.scheduler = scheduler_b
    ref_lr_sched.opt_idx = 1
    assert lr_sched[1] == ref_lr_sched
    assert freq == [1, 5]