def test_cache(): dataset = IterableDataset(range(100)) stream = DataStream(dataset) batched_stream = Batch(stream, ConstantScheme(11)) cached_stream = Cache(batched_stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() # Make sure that cache is filled as expected for (features, ), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert len(cached_stream.cache[0]) == cache_size # Make sure that the epoch finishes correctly for (features, ) in cached_stream.get_epoch_iterator(): pass assert len(features) == 100 % 7 assert not cached_stream.cache[0] # Ensure that the epoch transition is correct cached_stream = Cache(batched_stream, ConstantScheme(7, times=3)) for _, epoch in zip(range(2), cached_stream.iterate_epochs()): cache_sizes = [4, 8, 1] for i, (features, ) in enumerate(epoch): assert len(cached_stream.cache[0]) == cache_sizes[i] assert len(features) == 7 assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features) assert i == 2
def test_cache(): dataset = IterableDataset(range(100)) stream = DataStream(dataset) batched_stream = Batch(stream, ConstantScheme(11)) cached_stream = Cache(batched_stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() # Make sure that cache is filled as expected for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert len(cached_stream.cache[0]) == cache_size # Make sure that the epoch finishes correctly for (features,) in cached_stream.get_epoch_iterator(): pass assert len(features) == 100 % 7 assert not cached_stream.cache[0] # Ensure that the epoch transition is correct cached_stream = Cache(batched_stream, ConstantScheme(7, times=3)) for _, epoch in zip(range(2), cached_stream.iterate_epochs()): cache_sizes = [4, 8, 1] for i, (features,) in enumerate(epoch): assert len(cached_stream.cache[0]) == cache_sizes[i] assert len(features) == 7 assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features) assert i == 2
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0] stream = Batch(DataStream(IterableDataset(range(3000))), ConstantScheme(3200)) cached_stream = Cache(stream, ConstantScheme(64)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 3000 % 64) assert not cached_stream.cache[0]
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0] stream = Batch(DataStream(IterableDataset(range(3000))), ConstantScheme(3200)) cached_stream = Cache(stream, ConstantScheme(64)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 3000 % 64) assert not cached_stream.cache[0]
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0]
def test_cache_fills_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert_equal(len(cached_stream.cache[0]), cache_size)
def test_value_error_on_none_request(self): cached_stream = Cache(self.stream, ConstantScheme(7)) cached_stream.get_epoch_iterator() assert_raises(ValueError, cached_stream.get_data, None)
def test_value_error_on_none_request(self): cached_stream = Cache(self.stream, ConstantScheme(7)) cached_stream.get_epoch_iterator() assert_raises(ValueError, cached_stream.get_data, None)
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0]
def test_cache_fills_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert_equal(len(cached_stream.cache[0]), cache_size)