def test_v1_5_0_datamodule_setter(): model = BoringModel() datamodule = BoringDataModule() with no_deprecated_call(match="The `LightningModule.datamodule`"): model.datamodule = datamodule with pytest.deprecated_call(match="The `LightningModule.datamodule`"): _ = model.datamodule
def test_v1_5_0_old_on_validation_epoch_end(tmpdir): callback_warning_cache.clear() class OldSignature(Callback): def on_validation_epoch_end(self, trainer, pl_module): # noqa ... model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) class OldSignatureModel(BoringModel): def on_validation_epoch_end(self): # noqa ... model = OldSignatureModel() with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) callback_warning_cache.clear() class NewSignature(Callback): def on_validation_epoch_end(self, trainer, pl_module, outputs): ... trainer.callbacks = [NewSignature()] with no_deprecated_call( match= "`Callback.on_validation_epoch_end` signature has changed in v1.3." ): trainer.fit(model) class NewSignatureModel(BoringModel): def on_validation_epoch_end(self, outputs): ... model = NewSignatureModel() with no_deprecated_call( match= "`ModelHooks.on_validation_epoch_end` signature has changed in v1.3." ): trainer.fit(model)
def test_v1_5_0_datamodule_setter(): model = BoringModel() datamodule = BoringDataModule() with no_deprecated_call(match="The `LightningModule.datamodule`"): model.datamodule = datamodule from pytorch_lightning.core.lightning import warning_cache warning_cache.clear() _ = model.datamodule assert any("The `LightningModule.datamodule`" in w for w in warning_cache)
def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir): model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, "max_steps": 1, } chk = ModelCheckpoint(save_last=True) trainer = Trainer(**trainer_kwargs, callbacks=[OldSignatureOnLoadCheckpoint(), chk]) trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer_kwargs["max_steps"] = 2 cb = OldSignatureOnLoadCheckpoint() trainer = Trainer(**trainer_kwargs, callbacks=cb, resume_from_checkpoint=chk.last_model_path) trainer.fit(model) assert cb.on_load_checkpoint_called class ValidSignature1(BaseSignatureOnLoadCheckpoint): def on_load_checkpoint(self, trainer, *args): assert len(args) == 2 self.on_load_checkpoint_called = True model = BoringModel() chk = ModelCheckpoint(save_last=True) trainer = Trainer(**trainer_kwargs, callbacks=[ NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2OnLoadCheckpoint(), chk, ]) with no_deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) trainer = Trainer(**trainer_kwargs, resume_from_checkpoint=chk.last_model_path) with no_deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model)
def test_deprecated_epoch_outputs_format(tmpdir): class DeprecationModel(BoringModel): def __init__(self): super().__init__() self.truncated_bptt_steps = 1 def training_step(self, batch, batch_idx, optimizer_idx, hiddens): output = super().training_step(batch, batch_idx) output["hiddens"] = hiddens return output def tbptt_split_batch(self, batch, split_size): return [batch, batch] def training_epoch_end(self, outputs): ... def on_train_batch_end(self, outputs, batch, batch_idx) -> None: ... def configure_optimizers(self): return [ torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters()) ] trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) model = DeprecationModel() batch_match = r"on_train_batch_end.*will change in version v1.8 to \(tbptt_steps, n_optimizers\)" with pytest.deprecated_call(match=batch_match): trainer.fit(model) class DeprecationModel2(DeprecationModel): def on_train_batch_end(self, *args, new_format=True): ... trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) model = DeprecationModel() epoch_match = r"training_epoch_end.*will change in version v1.8 to \(n_batches, tbptt_steps, n_optimizers\)" with pytest.deprecated_call(match=epoch_match): trainer.fit(model) class NoDeprecationModel(DeprecationModel2): def training_epoch_end(self, outputs, new_format=True): ... trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) model = NoDeprecationModel() with no_deprecated_call( match="will change in version v1.8.*new_format=True"): trainer.fit(model)