Exemplo n.º 1
0
def test_v1_5_0_datamodule_setter():
    model = BoringModel()
    datamodule = BoringDataModule()
    with no_deprecated_call(match="The `LightningModule.datamodule`"):
        model.datamodule = datamodule
    with pytest.deprecated_call(match="The `LightningModule.datamodule`"):
        _ = model.datamodule
def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
    callback_warning_cache.clear()

    class OldSignature(Callback):
        def on_validation_epoch_end(self, trainer, pl_module):  # noqa
            ...

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      callbacks=OldSignature())

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.fit(model)

    class OldSignatureModel(BoringModel):
        def on_validation_epoch_end(self):  # noqa
            ...

    model = OldSignatureModel()

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.fit(model)

    callback_warning_cache.clear()

    class NewSignature(Callback):
        def on_validation_epoch_end(self, trainer, pl_module, outputs):
            ...

    trainer.callbacks = [NewSignature()]
    with no_deprecated_call(
            match=
            "`Callback.on_validation_epoch_end` signature has changed in v1.3."
    ):
        trainer.fit(model)

    class NewSignatureModel(BoringModel):
        def on_validation_epoch_end(self, outputs):
            ...

    model = NewSignatureModel()
    with no_deprecated_call(
            match=
            "`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."
    ):
        trainer.fit(model)
Exemplo n.º 3
0
def test_v1_5_0_datamodule_setter():
    model = BoringModel()
    datamodule = BoringDataModule()
    with no_deprecated_call(match="The `LightningModule.datamodule`"):
        model.datamodule = datamodule
    from pytorch_lightning.core.lightning import warning_cache
    warning_cache.clear()
    _ = model.datamodule
    assert any("The `LightningModule.datamodule`" in w for w in warning_cache)
Exemplo n.º 4
0
def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir):

    model = BoringModel()
    trainer_kwargs = {
        "default_root_dir": tmpdir,
        "max_steps": 1,
    }
    chk = ModelCheckpoint(save_last=True)
    trainer = Trainer(**trainer_kwargs,
                      callbacks=[OldSignatureOnLoadCheckpoint(), chk])
    trainer.fit(model)

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer_kwargs["max_steps"] = 2
        cb = OldSignatureOnLoadCheckpoint()
        trainer = Trainer(**trainer_kwargs,
                          callbacks=cb,
                          resume_from_checkpoint=chk.last_model_path)
        trainer.fit(model)
        assert cb.on_load_checkpoint_called

    class ValidSignature1(BaseSignatureOnLoadCheckpoint):
        def on_load_checkpoint(self, trainer, *args):
            assert len(args) == 2
            self.on_load_checkpoint_called = True

    model = BoringModel()
    chk = ModelCheckpoint(save_last=True)
    trainer = Trainer(**trainer_kwargs,
                      callbacks=[
                          NewSignatureOnLoadCheckpoint(),
                          ValidSignature1(),
                          ValidSignature2OnLoadCheckpoint(),
                          chk,
                      ])
    with no_deprecated_call(match="old signature will be removed in v1.5"):
        trainer.fit(model)

    trainer = Trainer(**trainer_kwargs,
                      resume_from_checkpoint=chk.last_model_path)
    with no_deprecated_call(match="old signature will be removed in v1.5"):
        trainer.fit(model)
Exemplo n.º 5
0
def test_deprecated_epoch_outputs_format(tmpdir):
    class DeprecationModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.truncated_bptt_steps = 1

        def training_step(self, batch, batch_idx, optimizer_idx, hiddens):
            output = super().training_step(batch, batch_idx)
            output["hiddens"] = hiddens
            return output

        def tbptt_split_batch(self, batch, split_size):
            return [batch, batch]

        def training_epoch_end(self, outputs):
            ...

        def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
            ...

        def configure_optimizers(self):
            return [
                torch.optim.Adam(self.parameters()),
                torch.optim.Adam(self.parameters())
            ]

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    model = DeprecationModel()
    batch_match = r"on_train_batch_end.*will change in version v1.8 to \(tbptt_steps, n_optimizers\)"
    with pytest.deprecated_call(match=batch_match):
        trainer.fit(model)

    class DeprecationModel2(DeprecationModel):
        def on_train_batch_end(self, *args, new_format=True):
            ...

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    model = DeprecationModel()
    epoch_match = r"training_epoch_end.*will change in version v1.8 to \(n_batches, tbptt_steps, n_optimizers\)"
    with pytest.deprecated_call(match=epoch_match):
        trainer.fit(model)

    class NoDeprecationModel(DeprecationModel2):
        def training_epoch_end(self, outputs, new_format=True):
            ...

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    model = NoDeprecationModel()
    with no_deprecated_call(
            match="will change in version v1.8.*new_format=True"):
        trainer.fit(model)