コード例 #1
0
ファイル: test_datasets.py プロジェクト: Afrik/fuel
def test_data_driven_epochs():
    class TestDataset(IterableDataset):
        sources = ('data',)

        def __init__(self):
            self.axis_labels = None
            self.data = [[1, 2, 3, 4],
                         [5, 6, 7, 8]]

        def open(self):
            epoch_iter = iter(self.data)
            data_iter = iter(next(epoch_iter))
            return (epoch_iter, data_iter)

        def next_epoch(self, state):
            try:
                data_iter = iter(next(state[0]))
                return (state[0], data_iter)
            except StopIteration:
                return self.open()

        def get_data(self, state, request):
            data = []
            for i in range(request):
                data.append(next(state[1]))
            return (data,)

    epochs = []
    epochs.append([([1],), ([2],), ([3],), ([4],)])
    epochs.append([([5],), ([6],), ([7],), ([8],)])
    stream = DataStream(TestDataset(), iteration_scheme=ConstantScheme(1))
    assert list(stream.get_epoch_iterator()) == epochs[0]
    assert list(stream.get_epoch_iterator()) == epochs[1]
    assert list(stream.get_epoch_iterator()) == epochs[0]

    stream.reset()
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == epochs[i]

    # test scheme resetting between epochs
    class TestScheme(BatchSizeScheme):

        def get_request_iterator(self):
            return iter([1, 2, 1, 3])

    epochs = []
    epochs.append([([1],), ([2, 3],), ([4],)])
    epochs.append([([5],), ([6, 7],), ([8],)])
    stream = DataStream(TestDataset(), iteration_scheme=TestScheme())
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == epochs[i]
コード例 #2
0
ファイル: test_datasets.py プロジェクト: xiaoyexixi/fuel
def test_data_driven_epochs():
    class TestDataset(IterableDataset):
        sources = ('data',)

        def __init__(self):
            self.axis_labels = None
            self.data = [[1, 2, 3, 4],
                         [5, 6, 7, 8]]

        def open(self):
            epoch_iter = iter(self.data)
            data_iter = iter(next(epoch_iter))
            return (epoch_iter, data_iter)

        def next_epoch(self, state):
            try:
                data_iter = iter(next(state[0]))
                return (state[0], data_iter)
            except StopIteration:
                return self.open()

        def get_data(self, state, request):
            data = []
            for i in range(request):
                data.append(next(state[1]))
            return (data,)

    epochs = []
    epochs.append([([1],), ([2],), ([3],), ([4],)])
    epochs.append([([5],), ([6],), ([7],), ([8],)])
    stream = DataStream(TestDataset(), iteration_scheme=ConstantScheme(1))
    assert list(stream.get_epoch_iterator()) == epochs[0]
    assert list(stream.get_epoch_iterator()) == epochs[1]
    assert list(stream.get_epoch_iterator()) == epochs[0]

    stream.reset()
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == epochs[i]

    # test scheme resetting between epochs
    class TestScheme(BatchSizeScheme):

        def get_request_iterator(self):
            return iter([1, 2, 1, 3])

    epochs = []
    epochs.append([([1],), ([2, 3],), ([4],)])
    epochs.append([([5],), ([6, 7],), ([8],)])
    stream = DataStream(TestDataset(), iteration_scheme=TestScheme())
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == epochs[i]
コード例 #3
0
def test_dataset():
    data = [1, 2, 3]
    # The default stream requests an example at a time
    stream = DataStream(IterableDataset(data))
    epoch = stream.get_epoch_iterator()
    assert list(epoch) == list(zip(data))

    # Check if iterating over multiple epochs works
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == list(zip(data))

    # Check whether the returning as a dictionary of sources works
    assert next(stream.get_epoch_iterator(as_dict=True)) == {"data": 1}
コード例 #4
0
ファイル: test_datasets.py プロジェクト: jfsantos/fuel
def test_dataset():
    data = [1, 2, 3]
    # The default stream requests an example at a time
    stream = DataStream(IterableDataset(data))
    epoch = stream.get_epoch_iterator()
    assert list(epoch) == list(zip(data))

    # Check if iterating over multiple epochs works
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == list(zip(data))

    # Check whether the returning as a dictionary of sources works
    assert next(stream.get_epoch_iterator(as_dict=True)) == {"data": 1}
コード例 #5
0
ファイル: test_datasets.py プロジェクト: zhoujian1210/fuel
class TestDataset(object):
    def setUp(self):
        self.data = [1, 2, 3]
        self.stream = DataStream(IterableDataset(self.data))

    def test_one_example_at_a_time(self):
        assert_equal(list(self.stream.get_epoch_iterator()),
                     list(zip(self.data)))

    def test_multiple_epochs(self):
        for i, epoch in zip(range(2), self.stream.iterate_epochs()):
            assert list(epoch) == list(zip(self.data))

    def test_as_dict(self):
        assert_equal(next(self.stream.get_epoch_iterator(as_dict=True)),
                     {"data": 1})

    def test_value_error_on_no_provided_sources(self):
        class FaultyDataset(Dataset):
            def get_data(self, state=None, request=None):
                pass

        assert_raises(ValueError, FaultyDataset, self.data)

    def test_value_error_on_nonexistent_sources(self):
        def instantiate_dataset():
            return IterableDataset(self.data, sources=('dummy', ))

        assert_raises(ValueError, instantiate_dataset)

    def test_default_transformer(self):
        class DoublingDataset(IterableDataset):
            def apply_default_transformer(self, stream):
                return Mapping(stream,
                               lambda sources: tuple(2 * s for s in sources))

        dataset = DoublingDataset(self.data)
        stream = dataset.apply_default_transformer(DataStream(dataset))
        assert_equal(list(stream.get_epoch_iterator()), [(2, ), (4, ), (6, )])

    def test_no_axis_labels(self):
        assert IterableDataset(self.data).axis_labels is None

    def test_axis_labels(self):
        axis_labels = {'data': ('batch', )}
        dataset = IterableDataset(self.data, axis_labels=axis_labels)
        assert dataset.axis_labels == axis_labels

    def test_attribute_error_on_no_example_iteration_scheme(self):
        class FaultyDataset(Dataset):
            provides_sources = ('data', )

            def get_data(self, state=None, request=None):
                pass

        def get_example_iteration_scheme():
            return FaultyDataset().example_iteration_scheme

        assert_raises(AttributeError, get_example_iteration_scheme)

    def test_example_iteration_scheme(self):
        scheme = ConstantScheme(2)

        class MinimalDataset(Dataset):
            provides_sources = ('data', )
            _example_iteration_scheme = scheme

            def get_data(self, state=None, request=None):
                pass

        assert MinimalDataset().example_iteration_scheme is scheme

    def test_filter_sources(self):
        dataset = IterableDataset(OrderedDict([('1', [1, 2]), ('2', [3, 4])]),
                                  sources=('1', ))
        assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2], ))
コード例 #6
0
ファイル: test_datasets.py プロジェクト: Afrik/fuel
class TestDataset(object):
    def setUp(self):
        self.data = [1, 2, 3]
        self.stream = DataStream(IterableDataset(self.data))

    def test_one_example_at_a_time(self):
        assert_equal(
            list(self.stream.get_epoch_iterator()), list(zip(self.data)))

    def test_multiple_epochs(self):
        for i, epoch in zip(range(2), self.stream.iterate_epochs()):
            assert list(epoch) == list(zip(self.data))

    def test_as_dict(self):
        assert_equal(
            next(self.stream.get_epoch_iterator(as_dict=True)), {"data": 1})

    def test_value_error_on_no_provided_sources(self):
        class FaultyDataset(Dataset):
            def get_data(self, state=None, request=None):
                pass
        assert_raises(ValueError, FaultyDataset, self.data)

    def test_value_error_on_nonexistent_sources(self):
        def instantiate_dataset():
            return IterableDataset(self.data, sources=('dummy',))
        assert_raises(ValueError, instantiate_dataset)

    def test_default_transformer(self):
        class DoublingDataset(IterableDataset):
            def apply_default_transformer(self, stream):
                return Mapping(
                    stream, lambda sources: tuple(2 * s for s in sources))
        dataset = DoublingDataset(self.data)
        stream = dataset.apply_default_transformer(DataStream(dataset))
        assert_equal(list(stream.get_epoch_iterator()), [(2,), (4,), (6,)])

    def test_no_axis_labels(self):
        assert IterableDataset(self.data).axis_labels is None

    def test_axis_labels(self):
        axis_labels = {'data': ('batch',)}
        dataset = IterableDataset(self.data, axis_labels=axis_labels)
        assert dataset.axis_labels == axis_labels

    def test_attribute_error_on_no_example_iteration_scheme(self):
        class FaultyDataset(Dataset):
            provides_sources = ('data',)

            def get_data(self, state=None, request=None):
                pass

        def get_example_iteration_scheme():
            return FaultyDataset().example_iteration_scheme

        assert_raises(AttributeError, get_example_iteration_scheme)

    def test_example_iteration_scheme(self):
        scheme = ConstantScheme(2)

        class MinimalDataset(Dataset):
            provides_sources = ('data',)
            _example_iteration_scheme = scheme

            def get_data(self, state=None, request=None):
                pass

        assert MinimalDataset().example_iteration_scheme is scheme

    def test_filter_sources(self):
        dataset = IterableDataset(
            OrderedDict([('1', [1, 2]), ('2', [3, 4])]), sources=('1',))
        assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2],))