def test_early_stopping_squeezes(): early_stopping = EarlyStopping(monitor="foo") trainer = Trainer() trainer.callback_metrics["foo"] = torch.tensor([[[0]]]) with mock.patch( "pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "")) as es_mock: early_stopping._run_early_stopping_check(trainer) es_mock.assert_called_once_with(torch.tensor(0))
def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() early_stop_callback = EarlyStopping(monitor="train_loss") early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], limit_train_batches=4, limit_val_batches=4, max_epochs=expected_count, enable_checkpointing=False, ) trainer.fit(model, datamodule=dm) assert trainer.early_stopping_callback == early_stop_callback assert trainer.early_stopping_callbacks == [early_stop_callback] assert early_stop_callback._run_early_stopping_check.call_count == expected_count