def test_invalid_lr_scheduler_with_custom_step_method(override): """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler API.""" class CustomScheduler: def __init__(self, optimizer): self.optimizer = optimizer def step(self, foobar): # breaks the API, forces user to override `lr_scheduler_step` ... def state_dict(self): ... def load_state_dict(self, state_dict): ... class CustomBoringModel(BoringModel): def configure_optimizers(self): opt = torch.optim.SGD(self.parameters(), lr=1e-2) lr_scheduler = CustomScheduler(opt) return {"optimizer": opt, "lr_scheduler": lr_scheduler} model = CustomBoringModel() model.trainer = Trainer() if override: def lr_scheduler_step(*_): ... # the user did override the hook, no error model.lr_scheduler_step = lr_scheduler_step _init_optimizers_and_lr_schedulers(model) else: with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"): _init_optimizers_and_lr_schedulers(model)
def func(trainer: "pl.Trainer") -> None: # Decide the structure of the output from _init_optimizers_and_lr_schedulers optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module) if len(optimizers) != 1: raise MisconfigurationException( f"`model.configure_optimizers()` returned {len(optimizers)}, but" " learning rate finder only works with single optimizer" ) optimizer = optimizers[0] new_lrs = [self.lr_min] * len(optimizer.param_groups) for param_group, new_lr in zip(optimizer.param_groups, new_lrs): param_group["lr"] = new_lr param_group["initial_lr"] = new_lr args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) scheduler = cast(pl.utilities.types._LRScheduler, scheduler) trainer.strategy.optimizers = [optimizer] trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]: r""" .. deprecated:: v1.6 `TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8. """ rank_zero_deprecation( "`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8." ) pl_module = self.lightning_module or model return _init_optimizers_and_lr_schedulers(pl_module)
def setup_optimizers(self, trainer: "pl.Trainer") -> None: """Creates optimizers and schedulers. Args: trainer: the Trainer, these optimizers should be connected to """ if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): return self.optimizers, self.lr_schedulers, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers( self.lightning_module)
def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Optional[int]]: optimizers, lr_schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) return ( optimizers[0], lr_schedulers[0] if lr_schedulers else None, optimizer_frequencies[0] if optimizer_frequencies else None, )
def test_invalid_scheduler_missing_state_dict(): """Test that custom lr scheduler raises an error if it's missing the state dict.""" class CustomScheduler: def __init__(self, optimizer): self.optimizer = optimizer def step(self): ... class CustomBoringModel(BoringModel): def configure_optimizers(self): opt = torch.optim.SGD(self.parameters(), lr=1e-2) lr_scheduler = CustomScheduler(opt) return {"optimizer": opt, "lr_scheduler": lr_scheduler} model = CustomBoringModel() model.trainer = Trainer() with pytest.raises(TypeError, match="provided lr scheduler `CustomScheduler` is invalid"): _init_optimizers_and_lr_schedulers(model)
def test_optimizer_return_options(tmpdir): trainer = Trainer(default_root_dir=tmpdir) model = BoringModel() trainer.strategy.connect(model) trainer.lightning_module.trainer = trainer # single optimizer opt_a = optim.Adam(model.parameters(), lr=0.002) opt_b = optim.SGD(model.parameters(), lr=0.002) scheduler_a = optim.lr_scheduler.StepLR(opt_a, 10) scheduler_b = optim.lr_scheduler.StepLR(opt_b, 10) # single optimizer model.configure_optimizers = lambda: opt_a opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == 1 and len(lr_sched) == len(freq) == 0 # opt tuple model.configure_optimizers = lambda: (opt_a, opt_b) opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 # opt list model.configure_optimizers = lambda: [opt_a, opt_b] opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 ref_lr_sched = LRSchedulerConfig( scheduler=scheduler_a, interval="epoch", frequency=1, reduce_on_plateau=False, monitor=None, strict=True, name=None, opt_idx=0, ) # opt tuple of 2 lists model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt tuple of 1 list model.configure_optimizers = lambda: ([opt_a], scheduler_a) opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt single dictionary model.configure_optimizers = lambda: { "optimizer": opt_a, "lr_scheduler": scheduler_a } opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt multiple dictionaries with frequencies model.configure_optimizers = lambda: ( { "optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1 }, { "optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5 }, ) opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == len(freq) == 2 assert opt[0] == opt_a ref_lr_sched.opt_idx = 0 assert lr_sched[0] == ref_lr_sched ref_lr_sched.scheduler = scheduler_b ref_lr_sched.opt_idx = 1 assert lr_sched[1] == ref_lr_sched assert freq == [1, 5]