コード例 #1
0
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
    class OldSignature(Callback):
        def on_save_checkpoint(self, trainer, pl_module):  # noqa
            ...

    model = BoringModel()
    trainer_kwargs = {
        "default_root_dir": tmpdir,
        "checkpoint_callback": False,
        "max_epochs": 1,
    }
    filepath = tmpdir / "test.ckpt"

    trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()])
    trainer.fit(model)

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.save_checkpoint(filepath)

    class NewSignature(Callback):
        def on_save_checkpoint(self, trainer, pl_module, checkpoint):
            ...

    class ValidSignature1(Callback):
        def on_save_checkpoint(self, trainer, *args):
            ...

    class ValidSignature2(Callback):
        def on_save_checkpoint(self, *args):
            ...

    trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()]
    with no_warning_call(DeprecationWarning):
        trainer.save_checkpoint(filepath)
コード例 #2
0
def test_resume_incomplete_callbacks_list_warning(tmpdir):
    model = BoringModel()
    callback0 = ModelCheckpoint(monitor="epoch")
    callback1 = ModelCheckpoint(monitor="global_step")
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=1,
        callbacks=[callback0, callback1],
    )
    trainer.fit(model)
    ckpt_path = trainer.checkpoint_callback.best_model_path

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=1,
        callbacks=[callback1],  # one callback is missing!
    )
    with pytest.warns(
            UserWarning,
            match=escape(
                f"Please add the following callbacks: [{repr(callback0.state_key)}]"
            )):
        trainer.fit(model, ckpt_path=ckpt_path)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=1,
        callbacks=[callback1, callback0],  # all callbacks here, order switched
    )
    with no_warning_call(UserWarning,
                         match="Please add the following callbacks:"):
        trainer.fit(model, ckpt_path=ckpt_path)
コード例 #3
0
    def test_with_datamodule_no_overridden(self, hook_name):
        model, dm, trainer = self.reset_instances()
        trainer._data_connector.attach_datamodule(model, datamodule=dm)
        with no_warning_call(match=f"have overridden `{hook_name}` in"):
            hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

        assert hook == getattr(model, hook_name)
コード例 #4
0
def test_trainer_reset_correctly(tmpdir):
    """Check that all trainer parameters are reset correctly after scaling batch size."""
    tutils.reset_seed()

    model = BatchSizeModel(batch_size=2)

    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    changed_attributes = [
        "callbacks",
        "checkpoint_callback",
        "current_epoch",
        "limit_train_batches",
        "logger",
        "max_steps",
        "global_step",
    ]
    expected = {ca: getattr(trainer, ca) for ca in changed_attributes}

    with no_warning_call(UserWarning,
                         match="Please add the following callbacks"):
        trainer.tuner.scale_batch_size(model, max_trials=5)

    actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
    assert actual == expected
コード例 #5
0
def test_nn_modules_warning_when_saved_as_hparams():
    class TorchModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.l1 = torch.nn.Linear(4, 5)

    class CustomBoringModelWarn(BoringModel):
        def __init__(self, encoder, decoder, other_hparam=7):
            super().__init__()
            self.save_hyperparameters()

    with pytest.warns(
            UserWarning,
            match="is an instance of `nn.Module` and is already saved"):
        model = CustomBoringModelWarn(encoder=TorchModule(),
                                      decoder=TorchModule())

    assert list(model.hparams) == ["encoder", "decoder", "other_hparam"]

    class CustomBoringModelNoWarn(BoringModel):
        def __init__(self, encoder, decoder, other_hparam=7):
            super().__init__()
            self.save_hyperparameters("other_hparam")

    with no_warning_call(
            UserWarning,
            match="is an instance of `nn.Module` and is already saved"):
        model = CustomBoringModelNoWarn(encoder=TorchModule(),
                                        decoder=TorchModule())

    assert list(model.hparams) == ["other_hparam"]
コード例 #6
0
def test_train_step_no_return(tmpdir):
    """Tests that only training_step raises a warning when nothing is returned in case of automatic_optimization."""
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            self.training_step_called = True
            loss = self.step(batch[0])
            self.log("a", loss, on_step=True, on_epoch=True)

        def training_epoch_end(self, outputs) -> None:
            assert len(outputs) == 0

        def validation_step(self, batch, batch_idx):
            self.validation_step_called = True

        def validation_epoch_end(self, outputs):
            assert len(outputs) == 0

    model = TestModel()
    trainer_args = dict(default_root_dir=tmpdir, fast_dev_run=2)

    trainer = Trainer(**trainer_args)

    with pytest.warns(UserWarning, match=r"training_step returned None.*"):
        trainer.fit(model)

    assert model.training_step_called
    assert model.validation_step_called

    model = TestModel()
    model.automatic_optimization = False
    trainer = Trainer(**trainer_args)

    with no_warning_call(UserWarning, match=r"training_step returned None.*"):
        trainer.fit(model)
コード例 #7
0
def test_deprecated_class():
    with pytest.deprecated_call(
        match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor'
              ' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.'
    ):
        past = PastCls(2)
    assert past.my_c == 2
    assert past.my_d == "efg"

    # check that the warning is raised only once per function
    with no_warning_call(DeprecationWarning):
        assert PastCls(c=2, d="")
コード例 #8
0
def test_deprecated_func():
    with pytest.deprecated_call(
        match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor'
              ' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.'
    ):
        assert dep_sum(2) == 7

    # check that the warning is raised only once per function
    with no_warning_call(DeprecationWarning):
        assert dep_sum(3) == 8

    # and does not affect other functions
    with pytest.deprecated_call(
        match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor'
              ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
    ):
        assert dep3_sum(2, 1) == 3
コード例 #9
0
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx):
    """Test that a warning is raised if training is restarted from mid-epoch."""
    class CustomModel(BoringModel):
        def __init__(self, stop_batch_idx):
            super().__init__()
            self.stop_batch_idx = stop_batch_idx

        def training_step(self, batch, batch_idx):
            if (batch_idx + 1) == self.stop_batch_idx:
                self.trainer.should_stop = True

            return super().training_step(batch, batch_idx)

    limit_train_batches = 7
    trainer_kwargs = {
        "default_root_dir": tmpdir,
        "limit_train_batches": limit_train_batches,
        "enable_progress_bar": False,
        "enable_model_summary": False,
    }
    trainer = Trainer(max_epochs=1, **trainer_kwargs)
    model = CustomModel(stop_batch_idx)
    trainer.fit(model)

    ckpt_path = str(tmpdir / "resume.ckpt")
    trainer.save_checkpoint(ckpt_path)

    trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)

    warning_raised = limit_train_batches != stop_batch_idx
    context_manager = pytest.warns if warning_raised else no_warning_call
    with context_manager(
            UserWarning,
            match="resuming from a checkpoint that ended mid-epoch"):
        trainer.fit(model, ckpt_path=ckpt_path)

    if warning_raised:
        with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}):
            trainer = Trainer(max_epochs=2,
                              limit_val_batches=0,
                              **trainer_kwargs)
            with no_warning_call(
                    UserWarning,
                    match="resuming from a checkpoint that ended mid-epoch"):
                trainer.fit(model, ckpt_path=ckpt_path)
コード例 #10
0
def test_deprecated_func_incomplete():

    # missing required argument
    with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"):
        dep2_sum(2)

    # check that the warning is raised only once per function
    with no_warning_call(DeprecationWarning):
        assert dep2_sum(2, 1) == 3

    # reset the warning
    dep2_sum.warned = False
    # does not affect other functions
    with pytest.deprecated_call(
        match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor'
              ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
    ):
        assert dep2_sum(b=2, a=1) == 3
コード例 #11
0
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle,
                                             model_cls):
    """Test that a warning is raised if training is restarted from mid-epoch."""
    limit_train_batches = 8
    trainer_kwargs = {
        "default_root_dir": tmpdir,
        "limit_train_batches": limit_train_batches,
        "limit_val_batches": 0,
        "enable_progress_bar": False,
        "enable_model_summary": False,
    }
    trainer = Trainer(max_epochs=1, **trainer_kwargs)
    model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1)

    if stop_in_the_middle:
        with pytest.raises(CustomException):
            trainer.fit(model)
    else:
        trainer.fit(model)

    ckpt_path = str(tmpdir / "resume.ckpt")
    trainer.save_checkpoint(ckpt_path)

    trainer = Trainer(max_epochs=2, **trainer_kwargs)
    model.stop_batch_idx = -1

    context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call
    with context_manager(UserWarning,
                         match="resuming from a checkpoint that ended"):
        trainer.fit(model, ckpt_path=ckpt_path)

    if stop_in_the_middle:
        with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}):
            trainer = Trainer(max_epochs=2, **trainer_kwargs)
            with tutils.no_warning_call(
                    UserWarning,
                    match="resuming from a checkpoint that ended"):
                trainer.fit(model, ckpt_path=ckpt_path)
コード例 #12
0
def no_deprecated_call(match: Optional[str] = None):
    with no_warning_call(expected_warning=DeprecationWarning, match=match):
        yield
コード例 #13
0
def test_v1_5_0_model_checkpoint_period(tmpdir):
    with no_warning_call(DeprecationWarning):
        ModelCheckpoint(dirpath=tmpdir)
    with pytest.deprecated_call(
            match="is deprecated in v1.3 and will be removed in v1.5"):
        ModelCheckpoint(dirpath=tmpdir, period=1)
コード例 #14
0
 def _check_warning_not_raised(data, expected):
     with no_warning_call(match="Trying to infer the `batch_size`"):
         assert extract_batch_size(data) == expected
コード例 #15
0
def test_wandb_logger_init(wandb, monkeypatch):
    """Verify that basic functionality of wandb logger works.

    Wandb doesn't work well with pytest so we have to mock it out here.
    """
    import pytorch_lightning.loggers.wandb as imports

    # test wandb.init called when there is no W&B run
    wandb.run = None
    logger = WandbLogger(name="test_name",
                         save_dir="test_save_dir",
                         version="test_id",
                         project="test_project",
                         resume="never")
    logger.log_metrics({"acc": 1.0})
    wandb.init.assert_called_once_with(name="test_name",
                                       dir="test_save_dir",
                                       id="test_id",
                                       project="test_project",
                                       resume="never",
                                       anonymous=None)
    wandb.init().log.assert_called_once_with({"acc": 1.0})

    # test wandb.init and setting logger experiment externally
    wandb.run = None
    run = wandb.init()
    logger = WandbLogger(experiment=run)
    assert logger.experiment

    # test wandb.init not called if there is a W&B run
    wandb.init().log.reset_mock()
    wandb.init.reset_mock()
    wandb.run = wandb.init()

    monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_12_10", True)
    with pytest.warns(UserWarning,
                      match="There is a wandb run already in progress"):
        logger = WandbLogger()
    # check that no new run is created
    with no_warning_call(UserWarning,
                         match="There is a wandb run already in progress"):
        _ = logger.experiment

    # verify default resume value
    assert logger._wandb_init["resume"] == "allow"

    logger.log_metrics({"acc": 1.0}, step=3)
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({
        "acc": 1.0,
        "trainer/global_step": 3
    })

    # continue training on same W&B run and offset step
    logger.finalize("success")
    logger.log_metrics({"acc": 1.0}, step=6)
    wandb.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})

    # log hyper parameters
    logger.log_hyperparams({"test": None, "nested": {"a": 1}, "b": [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {
            "test": "None",
            "nested/a": 1,
            "b": [2, 3, 4]
        }, allow_val_change=True)

    # watch a model
    logger.watch("model", "log", 10, False)
    wandb.init().watch.assert_called_once_with("model",
                                               log="log",
                                               log_freq=10,
                                               log_graph=False)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id