コード例 #1
0
def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only,
                                 world_size, global_rank, expected_log):
    """checks if log.info() gets called with expected message when used within EarlyStopping."""

    # set the global_rank and world_size if trainer is not None
    # or else always expect the simple logging message
    if trainer:
        trainer.strategy.global_rank = global_rank
        trainer.strategy.world_size = world_size
    else:
        expected_log = "bar"

    with mock.patch(
            "pytorch_lightning.callbacks.early_stopping.log.info") as log_mock:
        EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)

    # check log.info() was called or not with expected arg
    if expected_log:
        log_mock.assert_called_once_with(expected_log)
    else:
        log_mock.assert_not_called()