def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
    """Test that tuner algorithms are skipped if fast dev run is enabled."""

    model = BoringModel()
    model.lr = 0.1  # avoid no-lr-found exception
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        auto_scale_batch_size=(tuner_alg == "batch size scaler"),
        auto_lr_find=(tuner_alg == "learning rate finder"),
        fast_dev_run=True,
    )
    expected_message = f"Skipping {tuner_alg} since fast_dev_run is enabled."
    with pytest.warns(UserWarning, match=expected_message):
        trainer.tune(model)
Пример #2
0
def test_init_optimizers_resets_lightning_optimizers(tmpdir):
    """Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized."""
    def compare_optimizers():
        assert trainer.strategy._lightning_optimizers[
            0].optimizer is trainer.optimizers[0]

    model = BoringModel()
    model.lr = 0.2
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_lr_find=True)

    trainer.tune(model)
    compare_optimizers()

    trainer.fit(model)
    compare_optimizers()

    trainer.fit_loop.max_epochs = 2  # simulate multiple fit calls
    trainer.fit(model)
    compare_optimizers()