def test_data_driven_epochs(): class TestDataset(IterableDataset): sources = ('data',) def __init__(self): self.axis_labels = None self.data = [[1, 2, 3, 4], [5, 6, 7, 8]] def open(self): epoch_iter = iter(self.data) data_iter = iter(next(epoch_iter)) return (epoch_iter, data_iter) def next_epoch(self, state): try: data_iter = iter(next(state[0])) return (state[0], data_iter) except StopIteration: return self.open() def get_data(self, state, request): data = [] for i in range(request): data.append(next(state[1])) return (data,) epochs = [] epochs.append([([1],), ([2],), ([3],), ([4],)]) epochs.append([([5],), ([6],), ([7],), ([8],)]) stream = DataStream(TestDataset(), iteration_scheme=ConstantScheme(1)) assert list(stream.get_epoch_iterator()) == epochs[0] assert list(stream.get_epoch_iterator()) == epochs[1] assert list(stream.get_epoch_iterator()) == epochs[0] stream.reset() for i, epoch in zip(range(2), stream.iterate_epochs()): assert list(epoch) == epochs[i] # test scheme resetting between epochs class TestScheme(BatchSizeScheme): def get_request_iterator(self): return iter([1, 2, 1, 3]) epochs = [] epochs.append([([1],), ([2, 3],), ([4],)]) epochs.append([([5],), ([6, 7],), ([8],)]) stream = DataStream(TestDataset(), iteration_scheme=TestScheme()) for i, epoch in zip(range(2), stream.iterate_epochs()): assert list(epoch) == epochs[i]