def test_data_driven_epochs(): class TestDataset(IterableDataset): sources = ('data',) def __init__(self): self.axis_labels = None self.data = [[1, 2, 3, 4], [5, 6, 7, 8]] def open(self): epoch_iter = iter(self.data) data_iter = iter(next(epoch_iter)) return (epoch_iter, data_iter) def next_epoch(self, state): try: data_iter = iter(next(state[0])) return (state[0], data_iter) except StopIteration: return self.open() def get_data(self, state, request): data = [] for i in range(request): data.append(next(state[1])) return (data,) epochs = [] epochs.append([([1],), ([2],), ([3],), ([4],)]) epochs.append([([5],), ([6],), ([7],), ([8],)]) stream = DataStream(TestDataset(), iteration_scheme=ConstantScheme(1)) assert list(stream.get_epoch_iterator()) == epochs[0] assert list(stream.get_epoch_iterator()) == epochs[1] assert list(stream.get_epoch_iterator()) == epochs[0] stream.reset() for i, epoch in zip(range(2), stream.iterate_epochs()): assert list(epoch) == epochs[i] # test scheme resetting between epochs class TestScheme(BatchSizeScheme): def get_request_iterator(self): return iter([1, 2, 1, 3]) epochs = [] epochs.append([([1],), ([2, 3],), ([4],)]) epochs.append([([5],), ([6, 7],), ([8],)]) stream = DataStream(TestDataset(), iteration_scheme=TestScheme()) for i, epoch in zip(range(2), stream.iterate_epochs()): assert list(epoch) == epochs[i]
def test_dataset(): data = [1, 2, 3] # The default stream requests an example at a time stream = DataStream(IterableDataset(data)) epoch = stream.get_epoch_iterator() assert list(epoch) == list(zip(data)) # Check if iterating over multiple epochs works for i, epoch in zip(range(2), stream.iterate_epochs()): assert list(epoch) == list(zip(data)) # Check whether the returning as a dictionary of sources works assert next(stream.get_epoch_iterator(as_dict=True)) == {"data": 1}
class TestDataset(object): def setUp(self): self.data = [1, 2, 3] self.stream = DataStream(IterableDataset(self.data)) def test_one_example_at_a_time(self): assert_equal(list(self.stream.get_epoch_iterator()), list(zip(self.data))) def test_multiple_epochs(self): for i, epoch in zip(range(2), self.stream.iterate_epochs()): assert list(epoch) == list(zip(self.data)) def test_as_dict(self): assert_equal(next(self.stream.get_epoch_iterator(as_dict=True)), {"data": 1}) def test_value_error_on_no_provided_sources(self): class FaultyDataset(Dataset): def get_data(self, state=None, request=None): pass assert_raises(ValueError, FaultyDataset, self.data) def test_value_error_on_nonexistent_sources(self): def instantiate_dataset(): return IterableDataset(self.data, sources=('dummy', )) assert_raises(ValueError, instantiate_dataset) def test_default_transformer(self): class DoublingDataset(IterableDataset): def apply_default_transformer(self, stream): return Mapping(stream, lambda sources: tuple(2 * s for s in sources)) dataset = DoublingDataset(self.data) stream = dataset.apply_default_transformer(DataStream(dataset)) assert_equal(list(stream.get_epoch_iterator()), [(2, ), (4, ), (6, )]) def test_no_axis_labels(self): assert IterableDataset(self.data).axis_labels is None def test_axis_labels(self): axis_labels = {'data': ('batch', )} dataset = IterableDataset(self.data, axis_labels=axis_labels) assert dataset.axis_labels == axis_labels def test_attribute_error_on_no_example_iteration_scheme(self): class FaultyDataset(Dataset): provides_sources = ('data', ) def get_data(self, state=None, request=None): pass def get_example_iteration_scheme(): return FaultyDataset().example_iteration_scheme assert_raises(AttributeError, get_example_iteration_scheme) def test_example_iteration_scheme(self): scheme = ConstantScheme(2) class MinimalDataset(Dataset): provides_sources = ('data', ) _example_iteration_scheme = scheme def get_data(self, state=None, request=None): pass assert MinimalDataset().example_iteration_scheme is scheme def test_filter_sources(self): dataset = IterableDataset(OrderedDict([('1', [1, 2]), ('2', [3, 4])]), sources=('1', )) assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2], ))
class TestDataset(object): def setUp(self): self.data = [1, 2, 3] self.stream = DataStream(IterableDataset(self.data)) def test_one_example_at_a_time(self): assert_equal( list(self.stream.get_epoch_iterator()), list(zip(self.data))) def test_multiple_epochs(self): for i, epoch in zip(range(2), self.stream.iterate_epochs()): assert list(epoch) == list(zip(self.data)) def test_as_dict(self): assert_equal( next(self.stream.get_epoch_iterator(as_dict=True)), {"data": 1}) def test_value_error_on_no_provided_sources(self): class FaultyDataset(Dataset): def get_data(self, state=None, request=None): pass assert_raises(ValueError, FaultyDataset, self.data) def test_value_error_on_nonexistent_sources(self): def instantiate_dataset(): return IterableDataset(self.data, sources=('dummy',)) assert_raises(ValueError, instantiate_dataset) def test_default_transformer(self): class DoublingDataset(IterableDataset): def apply_default_transformer(self, stream): return Mapping( stream, lambda sources: tuple(2 * s for s in sources)) dataset = DoublingDataset(self.data) stream = dataset.apply_default_transformer(DataStream(dataset)) assert_equal(list(stream.get_epoch_iterator()), [(2,), (4,), (6,)]) def test_no_axis_labels(self): assert IterableDataset(self.data).axis_labels is None def test_axis_labels(self): axis_labels = {'data': ('batch',)} dataset = IterableDataset(self.data, axis_labels=axis_labels) assert dataset.axis_labels == axis_labels def test_attribute_error_on_no_example_iteration_scheme(self): class FaultyDataset(Dataset): provides_sources = ('data',) def get_data(self, state=None, request=None): pass def get_example_iteration_scheme(): return FaultyDataset().example_iteration_scheme assert_raises(AttributeError, get_example_iteration_scheme) def test_example_iteration_scheme(self): scheme = ConstantScheme(2) class MinimalDataset(Dataset): provides_sources = ('data',) _example_iteration_scheme = scheme def get_data(self, state=None, request=None): pass assert MinimalDataset().example_iteration_scheme is scheme def test_filter_sources(self): dataset = IterableDataset( OrderedDict([('1', [1, 2]), ('2', [3, 4])]), sources=('1',)) assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2],))