def test_cache():
    dataset = IterableDataset(range(100))
    stream = DataStream(dataset)
    batched_stream = Batch(stream, ConstantScheme(11))
    cached_stream = Cache(batched_stream, ConstantScheme(7))
    epoch = cached_stream.get_epoch_iterator()

    # Make sure that cache is filled as expected
    for (features, ), cache_size in zip(epoch,
                                        [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]):
        assert len(cached_stream.cache[0]) == cache_size

    # Make sure that the epoch finishes correctly
    for (features, ) in cached_stream.get_epoch_iterator():
        pass
    assert len(features) == 100 % 7
    assert not cached_stream.cache[0]

    # Ensure that the epoch transition is correct
    cached_stream = Cache(batched_stream, ConstantScheme(7, times=3))
    for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
        cache_sizes = [4, 8, 1]
        for i, (features, ) in enumerate(epoch):
            assert len(cached_stream.cache[0]) == cache_sizes[i]
            assert len(features) == 7
            assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features)
        assert i == 2
示例#2
0
def test_cache():
    dataset = IterableDataset(range(100))
    stream = DataStream(dataset)
    batched_stream = Batch(stream, ConstantScheme(11))
    cached_stream = Cache(batched_stream, ConstantScheme(7))
    epoch = cached_stream.get_epoch_iterator()

    # Make sure that cache is filled as expected
    for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2,
                                               6, 10, 3, 7, 0, 4]):
        assert len(cached_stream.cache[0]) == cache_size

    # Make sure that the epoch finishes correctly
    for (features,) in cached_stream.get_epoch_iterator():
        pass
    assert len(features) == 100 % 7
    assert not cached_stream.cache[0]

    # Ensure that the epoch transition is correct
    cached_stream = Cache(batched_stream, ConstantScheme(7, times=3))
    for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
        cache_sizes = [4, 8, 1]
        for i, (features,) in enumerate(epoch):
            assert len(cached_stream.cache[0]) == cache_sizes[i]
            assert len(features) == 7
            assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features)
        assert i == 2
示例#3
0
 def test_epoch_transition(self):
     cached_stream = Cache(self.stream, ConstantScheme(7, times=3))
     for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
         cache_sizes = [4, 8, 1]
         for i, (features,) in enumerate(epoch):
             assert_equal(len(cached_stream.cache[0]), cache_sizes[i])
             assert_equal(len(features), 7)
             assert_equal(list(range(100))[i * 7:(i + 1) * 7], features)
         assert_equal(i, 2)
示例#4
0
 def test_epoch_transition(self):
     cached_stream = Cache(self.stream, ConstantScheme(7, times=3))
     for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
         cache_sizes = [4, 8, 1]
         for i, (features,) in enumerate(epoch):
             assert_equal(len(cached_stream.cache[0]), cache_sizes[i])
             assert_equal(len(features), 7)
             assert_equal(list(range(100))[i * 7:(i + 1) * 7], features)
         assert_equal(i, 2)