コード例 #1
0
def test_early_stopping_squeezes():
    early_stopping = EarlyStopping(monitor="foo")
    trainer = Trainer()
    trainer.callback_metrics["foo"] = torch.tensor([[[0]]])

    with mock.patch(
            "pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria",
            return_value=(False, "")) as es_mock:
        early_stopping._run_early_stopping_check(trainer)

    es_mock.assert_called_once_with(torch.tensor(0))
コード例 #2
0
def test_early_stopping_no_extraneous_invocations(tmpdir):
    """Test to ensure that callback methods aren't being invoked outside of the callback handler."""
    model = ClassificationModel()
    dm = ClassifDataModule()
    early_stop_callback = EarlyStopping(monitor="train_loss")
    early_stop_callback._run_early_stopping_check = Mock()
    expected_count = 4
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback],
        limit_train_batches=4,
        limit_val_batches=4,
        max_epochs=expected_count,
        enable_checkpointing=False,
    )
    trainer.fit(model, datamodule=dm)

    assert trainer.early_stopping_callback == early_stop_callback
    assert trainer.early_stopping_callbacks == [early_stop_callback]
    assert early_stop_callback._run_early_stopping_check.call_count == expected_count