def test_checkpoint(): def train_func(): for i in range(2): save_checkpoint(epoch=i) def validate_zero(expected): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT assert next.data["epoch"] == expected init_session(training_func=train_func, world_rank=0, local_rank=0) session = get_session() session.start() validate_zero(0) validate_zero(1) session.finish() shutdown_session() def validate_nonzero(): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT assert next.data == {} init_session(training_func=train_func, world_rank=1, local_rank=1) session = get_session() session.start() validate_nonzero() validate_nonzero() session.finish() shutdown_session() with pytest.raises(ValueError): save_checkpoint(epoch=2)
def test_warn_once(): """Checks if session misuse warning is only shown once per function.""" with pytest.warns(UserWarning) as record: assert not load_checkpoint() assert not load_checkpoint() assert not save_checkpoint(x=2) assert not report(x=2) assert not report(x=3) assert not get_dataset_shard() # Should only warn once. assert len(record) == 4
def train_func(): for i in range(2): save_checkpoint(epoch=i) checkpoint = load_checkpoint() assert checkpoint["epoch"] == i
def train_func(): save_checkpoint(epoch=0) report(epoch=1)
def train_func(): for i in range(2): save_checkpoint(epoch=i)