예제 #1
0
class TestMerge(object):
    def setUp(self):
        self.streams = (DataStream(IterableDataset(['Hello world!'])),
                        DataStream(IterableDataset(['Bonjour le monde!'])))
        self.transformer = Merge(self.streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        assert_equal(next(self.transformer.get_epoch_iterator()),
                     ('Hello world!', 'Bonjour le monde!'))

    def test_as_dict(self):
        assert_equal(next(self.transformer.get_epoch_iterator(as_dict=True)),
                     ({
                         'english': 'Hello world!',
                         'french': 'Bonjour le monde!'
                     }))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english', ))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream, ),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [
            FlagDataStream(IterableDataset([1, 2, 3])),
            FlagDataStream(IterableDataset([4, 5, 6])),
            FlagDataStream(IterableDataset([7, 8, 9]))
        ]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [
            FlagDataStream(IterableDataset([1, 2, 3])),
            FlagDataStream(IterableDataset([4, 5, 6])),
            FlagDataStream(IterableDataset([7, 8, 9]))
        ]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [
            FlagDataStream(IterableDataset([1, 2, 3])),
            FlagDataStream(IterableDataset([4, 5, 6])),
            FlagDataStream(IterableDataset([7, 8, 9]))
        ]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
class TestMerge(object):
    def setUp(self):
        self.streams = (
            DataStream(IterableDataset(['Hello world!'])),
            DataStream(IterableDataset(['Bonjour le monde!'])))
        self.transformer = Merge(self.streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        assert_equal(next(self.transformer.get_epoch_iterator()),
                     ('Hello world!', 'Bonjour le monde!'))

    def test_as_dict(self):
        assert_equal(
            next(self.transformer.get_epoch_iterator(as_dict=True)),
            ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english',))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream,),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
예제 #3
0
 def test_as_dict(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(next(transformer.get_epoch_iterator(as_dict=True)),
                  ({
                      'english': 'Hello world!',
                      'french': 'Bonjour le monde!'
                  }))
예제 #4
0
def test_merge():
    english = IterableDataset(['Hello world!'])
    french = IterableDataset(['Bonjour le monde!'])
    streams = (english.get_example_stream(), french.get_example_stream())
    merged_stream = Merge(streams, ('english', 'french'))
    assert merged_stream.sources == ('english', 'french')
    assert (next(merged_stream.get_epoch_iterator()) == ('Hello world!',
                                                         'Bonjour le monde!'))
예제 #5
0
def test_merge():
    english = IterableDataset(['Hello world!'])
    french = IterableDataset(['Bonjour le monde!'])
    streams = (english.get_example_stream(),
               french.get_example_stream())
    merged_stream = Merge(streams, ('english', 'french'))
    assert merged_stream.sources == ('english', 'french')
    assert (next(merged_stream.get_epoch_iterator()) ==
            ('Hello world!', 'Bonjour le monde!'))
예제 #6
0
 def test_merge(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(next(transformer.get_epoch_iterator()),
                  ('Hello world!', 'Bonjour le monde!'))
예제 #7
0
class TestMerge(object):
    def setUp(self):
        self.streams = (
            DataStream(IterableDataset(['Hello world!'])),
            DataStream(IterableDataset(['Bonjour le monde!'])))
        self.batch_streams = (
            Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
                  iteration_scheme=ConstantScheme(2)),
            Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
                  iteration_scheme=ConstantScheme(2)))
        self.transformer = Merge(
            self.streams, ('english', 'french'))
        self.batch_transformer = Merge(
            self.batch_streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)

    def test_batch_merge(self):
        it = self.batch_transformer.get_epoch_iterator()
        assert_equal(next(it),
                     (('Hello world!', 'Hi!'),
                      ('Bonjour le monde!', 'Salut!')))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.batch_transformer.get_epoch_iterator()
        assert_equal(next(it),
                     (('Hello world!', 'Hi!'),
                      ('Bonjour le monde!', 'Salut!')))
        assert_raises(StopIteration, next, it)

    def test_merge_batch_streams(self):
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)

    def test_as_dict(self):
        assert_equal(
            next(self.transformer.get_epoch_iterator(as_dict=True)),
            ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english',))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream,),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
예제 #8
0
class TestMerge(object):
    def setUp(self):
        self.streams = (
            DataStream(IterableDataset(['Hello world!'])),
            DataStream(IterableDataset(['Bonjour le monde!'])))
        self.batch_streams = (
            Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
                  iteration_scheme=ConstantScheme(2)),
            Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
                  iteration_scheme=ConstantScheme(2)))
        self.transformer = Merge(
            self.streams, ('english', 'french'))
        self.batch_transformer = Merge(
            self.batch_streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)

    def test_batch_merge(self):
        it = self.batch_transformer.get_epoch_iterator()
        assert_equal(next(it),
                     (('Hello world!', 'Hi!'),
                      ('Bonjour le monde!', 'Salut!')))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.batch_transformer.get_epoch_iterator()
        assert_equal(next(it),
                     (('Hello world!', 'Hi!'),
                      ('Bonjour le monde!', 'Salut!')))
        assert_raises(StopIteration, next, it)

    def test_merge_batch_streams(self):
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)

    def test_as_dict(self):
        assert_equal(
            next(self.transformer.get_epoch_iterator(as_dict=True)),
            ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english',))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream,),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
예제 #9
0
 def test_as_dict(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(
         next(transformer.get_epoch_iterator(as_dict=True)),
         ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))
예제 #10
0
 def test_merge(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(next(transformer.get_epoch_iterator()),
                  ('Hello world!', 'Bonjour le monde!'))
 def test_as_dict(self):
     transformer = Merge(self.streams, ("english", "french"))
     assert_equal(
         next(transformer.get_epoch_iterator(as_dict=True)),
         ({"english": "Hello world!", "french": "Bonjour le monde!"}),
     )
 def test_merge(self):
     transformer = Merge(self.streams, ("english", "french"))
     assert_equal(next(transformer.get_epoch_iterator()), ("Hello world!", "Bonjour le monde!"))