def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): class OldSignature(Callback): def on_save_checkpoint(self, trainer, pl_module): # noqa ... model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, "checkpoint_callback": False, "max_epochs": 1, } filepath = tmpdir / "test.ckpt" trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.save_checkpoint(filepath) class NewSignature(Callback): def on_save_checkpoint(self, trainer, pl_module, checkpoint): ... class ValidSignature1(Callback): def on_save_checkpoint(self, trainer, *args): ... class ValidSignature2(Callback): def on_save_checkpoint(self, *args): ... trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] with no_warning_call(DeprecationWarning): trainer.save_checkpoint(filepath)
def test_resume_incomplete_callbacks_list_warning(tmpdir): model = BoringModel() callback0 = ModelCheckpoint(monitor="epoch") callback1 = ModelCheckpoint(monitor="global_step") trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback0, callback1], ) trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback1], # one callback is missing! ) with pytest.warns( UserWarning, match=escape( f"Please add the following callbacks: [{repr(callback0.state_key)}]" )): trainer.fit(model, ckpt_path=ckpt_path) trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback1, callback0], # all callbacks here, order switched ) with no_warning_call(UserWarning, match="Please add the following callbacks:"): trainer.fit(model, ckpt_path=ckpt_path)
def test_with_datamodule_no_overridden(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name)
def test_trainer_reset_correctly(tmpdir): """Check that all trainer parameters are reset correctly after scaling batch size.""" tutils.reset_seed() model = BatchSizeModel(batch_size=2) # logger file to get meta trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) changed_attributes = [ "callbacks", "checkpoint_callback", "current_epoch", "limit_train_batches", "logger", "max_steps", "global_step", ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} with no_warning_call(UserWarning, match="Please add the following callbacks"): trainer.tuner.scale_batch_size(model, max_trials=5) actual = {ca: getattr(trainer, ca) for ca in changed_attributes} assert actual == expected
def test_nn_modules_warning_when_saved_as_hparams(): class TorchModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(4, 5) class CustomBoringModelWarn(BoringModel): def __init__(self, encoder, decoder, other_hparam=7): super().__init__() self.save_hyperparameters() with pytest.warns( UserWarning, match="is an instance of `nn.Module` and is already saved"): model = CustomBoringModelWarn(encoder=TorchModule(), decoder=TorchModule()) assert list(model.hparams) == ["encoder", "decoder", "other_hparam"] class CustomBoringModelNoWarn(BoringModel): def __init__(self, encoder, decoder, other_hparam=7): super().__init__() self.save_hyperparameters("other_hparam") with no_warning_call( UserWarning, match="is an instance of `nn.Module` and is already saved"): model = CustomBoringModelNoWarn(encoder=TorchModule(), decoder=TorchModule()) assert list(model.hparams) == ["other_hparam"]
def test_train_step_no_return(tmpdir): """Tests that only training_step raises a warning when nothing is returned in case of automatic_optimization.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): self.training_step_called = True loss = self.step(batch[0]) self.log("a", loss, on_step=True, on_epoch=True) def training_epoch_end(self, outputs) -> None: assert len(outputs) == 0 def validation_step(self, batch, batch_idx): self.validation_step_called = True def validation_epoch_end(self, outputs): assert len(outputs) == 0 model = TestModel() trainer_args = dict(default_root_dir=tmpdir, fast_dev_run=2) trainer = Trainer(**trainer_args) with pytest.warns(UserWarning, match=r"training_step returned None.*"): trainer.fit(model) assert model.training_step_called assert model.validation_step_called model = TestModel() model.automatic_optimization = False trainer = Trainer(**trainer_args) with no_warning_call(UserWarning, match=r"training_step returned None.*"): trainer.fit(model)
def test_deprecated_class(): with pytest.deprecated_call( match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor' ' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.' ): past = PastCls(2) assert past.my_c == 2 assert past.my_d == "efg" # check that the warning is raised only once per function with no_warning_call(DeprecationWarning): assert PastCls(c=2, d="")
def test_deprecated_func(): with pytest.deprecated_call( match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor' ' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.' ): assert dep_sum(2) == 7 # check that the warning is raised only once per function with no_warning_call(DeprecationWarning): assert dep_sum(3) == 8 # and does not affect other functions with pytest.deprecated_call( match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor' ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' ): assert dep3_sum(2, 1) == 3
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx): """Test that a warning is raised if training is restarted from mid-epoch.""" class CustomModel(BoringModel): def __init__(self, stop_batch_idx): super().__init__() self.stop_batch_idx = stop_batch_idx def training_step(self, batch, batch_idx): if (batch_idx + 1) == self.stop_batch_idx: self.trainer.should_stop = True return super().training_step(batch, batch_idx) limit_train_batches = 7 trainer_kwargs = { "default_root_dir": tmpdir, "limit_train_batches": limit_train_batches, "enable_progress_bar": False, "enable_model_summary": False, } trainer = Trainer(max_epochs=1, **trainer_kwargs) model = CustomModel(stop_batch_idx) trainer.fit(model) ckpt_path = str(tmpdir / "resume.ckpt") trainer.save_checkpoint(ckpt_path) trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs) warning_raised = limit_train_batches != stop_batch_idx context_manager = pytest.warns if warning_raised else no_warning_call with context_manager( UserWarning, match="resuming from a checkpoint that ended mid-epoch"): trainer.fit(model, ckpt_path=ckpt_path) if warning_raised: with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs) with no_warning_call( UserWarning, match="resuming from a checkpoint that ended mid-epoch"): trainer.fit(model, ckpt_path=ckpt_path)
def test_deprecated_func_incomplete(): # missing required argument with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"): dep2_sum(2) # check that the warning is raised only once per function with no_warning_call(DeprecationWarning): assert dep2_sum(2, 1) == 3 # reset the warning dep2_sum.warned = False # does not affect other functions with pytest.deprecated_call( match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor' ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' ): assert dep2_sum(b=2, a=1) == 3
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle, model_cls): """Test that a warning is raised if training is restarted from mid-epoch.""" limit_train_batches = 8 trainer_kwargs = { "default_root_dir": tmpdir, "limit_train_batches": limit_train_batches, "limit_val_batches": 0, "enable_progress_bar": False, "enable_model_summary": False, } trainer = Trainer(max_epochs=1, **trainer_kwargs) model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1) if stop_in_the_middle: with pytest.raises(CustomException): trainer.fit(model) else: trainer.fit(model) ckpt_path = str(tmpdir / "resume.ckpt") trainer.save_checkpoint(ckpt_path) trainer = Trainer(max_epochs=2, **trainer_kwargs) model.stop_batch_idx = -1 context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call with context_manager(UserWarning, match="resuming from a checkpoint that ended"): trainer.fit(model, ckpt_path=ckpt_path) if stop_in_the_middle: with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): trainer = Trainer(max_epochs=2, **trainer_kwargs) with tutils.no_warning_call( UserWarning, match="resuming from a checkpoint that ended"): trainer.fit(model, ckpt_path=ckpt_path)
def no_deprecated_call(match: Optional[str] = None): with no_warning_call(expected_warning=DeprecationWarning, match=match): yield
def test_v1_5_0_model_checkpoint_period(tmpdir): with no_warning_call(DeprecationWarning): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call( match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1)
def _check_warning_not_raised(data, expected): with no_warning_call(match="Trying to infer the `batch_size`"): assert extract_batch_size(data) == expected
def test_wandb_logger_init(wandb, monkeypatch): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. """ import pytorch_lightning.loggers.wandb as imports # test wandb.init called when there is no W&B run wandb.run = None logger = WandbLogger(name="test_name", save_dir="test_save_dir", version="test_id", project="test_project", resume="never") logger.log_metrics({"acc": 1.0}) wandb.init.assert_called_once_with(name="test_name", dir="test_save_dir", id="test_id", project="test_project", resume="never", anonymous=None) wandb.init().log.assert_called_once_with({"acc": 1.0}) # test wandb.init and setting logger experiment externally wandb.run = None run = wandb.init() logger = WandbLogger(experiment=run) assert logger.experiment # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() wandb.init.reset_mock() wandb.run = wandb.init() monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_12_10", True) with pytest.warns(UserWarning, match="There is a wandb run already in progress"): logger = WandbLogger() # check that no new run is created with no_warning_call(UserWarning, match="There is a wandb run already in progress"): _ = logger.experiment # verify default resume value assert logger._wandb_init["resume"] == "allow" logger.log_metrics({"acc": 1.0}, step=3) wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({ "acc": 1.0, "trainer/global_step": 3 }) # continue training on same W&B run and offset step logger.finalize("success") logger.log_metrics({"acc": 1.0}, step=6) wandb.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6}) # log hyper parameters logger.log_hyperparams({"test": None, "nested": {"a": 1}, "b": [2, 3, 4]}) wandb.init().config.update.assert_called_once_with( { "test": "None", "nested/a": 1, "b": [2, 3, 4] }, allow_val_change=True) # watch a model logger.watch("model", "log", 10, False) wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False) assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id