예제 #1
0
def test_checkpoint():
    def train_func():
        for i in range(2):
            save_checkpoint(epoch=i)

    def validate_zero(expected):
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data["epoch"] == expected

    init_session(training_func=train_func, world_rank=0, local_rank=0)
    session = get_session()
    session.start()
    validate_zero(0)
    validate_zero(1)
    session.finish()
    shutdown_session()

    def validate_nonzero():
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data == {}

    init_session(training_func=train_func, world_rank=1, local_rank=1)
    session = get_session()
    session.start()
    validate_nonzero()
    validate_nonzero()
    session.finish()
    shutdown_session()

    with pytest.raises(ValueError):
        save_checkpoint(epoch=2)
예제 #2
0
def test_warn_once():
    """Checks if session misuse warning is only shown once per function."""

    with pytest.warns(UserWarning) as record:
        assert not load_checkpoint()
        assert not load_checkpoint()
        assert not save_checkpoint(x=2)
        assert not report(x=2)
        assert not report(x=3)
        assert not get_dataset_shard()

    # Should only warn once.
    assert len(record) == 4
예제 #3
0
 def train_func():
     for i in range(2):
         save_checkpoint(epoch=i)
         checkpoint = load_checkpoint()
         assert checkpoint["epoch"] == i
예제 #4
0
 def train_func():
     save_checkpoint(epoch=0)
     report(epoch=1)
예제 #5
0
 def train_func():
     for i in range(2):
         save_checkpoint(epoch=i)