def test_logging_manager_no_checkpointing(caplog):
    """Unit test of logging_manager (no checkpointing)"""

    caplog.set_level(logging.INFO)

    emmental.init()
    Meta.update_config(
        config={
            "meta_config": {
                "verbose": False
            },
            "logging_config": {
                "counter_unit": "epoch",
                "evaluation_freq": 1,
                "checkpointing": False,
                "checkpointer_config": {
                    "checkpoint_freq": 2
                },
                "writer_config": {
                    "writer": "json"
                },
            },
        })

    logging_manager = LoggingManager(n_batches_per_epoch=2)

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is False
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is True
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(10)
    assert logging_manager.trigger_evaluation() is False
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is True
    assert logging_manager.trigger_checkpointing() is False

    assert logging_manager.epoch_count == 0

    assert logging_manager.sample_total == 25
    assert logging_manager.batch_total == 4
    assert logging_manager.epoch_total == 2

    model = EmmentalModel()

    logging_manager.close(model)
示例#2
0
def test_logging_manager_sample(caplog):
    """Unit test of logging_manager (sample)"""

    caplog.set_level(logging.INFO)

    Meta.reset()

    emmental.init()
    Meta.update_config(
        config={
            "logging_config": {
                "counter_unit": "sample",
                "evaluation_freq": 10,
                "checkpointing": True,
                "checkpointer_config": {
                    "checkpoint_freq": 2
                },
            }
        })

    logging_manager = LoggingManager(n_batches_per_epoch=10)

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is False
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is True
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(10)
    assert logging_manager.trigger_evaluation() is True
    assert logging_manager.trigger_checkpointing() is True

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is False
    assert logging_manager.trigger_checkpointing() is False

    assert logging_manager.sample_count == 5
    assert logging_manager.sample_total == 25

    assert logging_manager.batch_total == 4
    assert logging_manager.epoch_total == 0.4