def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): """Test that tuner algorithms are skipped if fast dev run is enabled.""" model = BoringModel() model.lr = 0.1 # avoid no-lr-found exception trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, auto_scale_batch_size=(tuner_alg == "batch size scaler"), auto_lr_find=(tuner_alg == "learning rate finder"), fast_dev_run=True, ) expected_message = f"Skipping {tuner_alg} since fast_dev_run is enabled." with pytest.warns(UserWarning, match=expected_message): trainer.tune(model)
def test_init_optimizers_resets_lightning_optimizers(tmpdir): """Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized.""" def compare_optimizers(): assert trainer.strategy._lightning_optimizers[ 0].optimizer is trainer.optimizers[0] model = BoringModel() model.lr = 0.2 trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_lr_find=True) trainer.tune(model) compare_optimizers() trainer.fit(model) compare_optimizers() trainer.fit_loop.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers()