def test_warn_once(): """Checks if session misuse warning is only shown once per function.""" with pytest.warns(UserWarning) as record: assert not load_checkpoint() assert not load_checkpoint() assert not save_checkpoint(x=2) assert not report(x=2) assert not report(x=3) assert not get_dataset_shard() # Should only warn once. assert len(record) == 4
def train_func(): for i in range(2): save_checkpoint(epoch=i) checkpoint = load_checkpoint() assert checkpoint["epoch"] == i