Exemplo n.º 1
0
 def test_as_dict(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(next(transformer.get_epoch_iterator(as_dict=True)),
                  ({
                      'english': 'Hello world!',
                      'french': 'Bonjour le monde!'
                  }))
Exemplo n.º 2
0
 def test_reset_calls_reset_on_all_streams(self):
     streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                FlagDataStream(IterableDataset([4, 5, 6])),
                FlagDataStream(IterableDataset([7, 8, 9]))]
     transformer = Merge(streams, ('1', '2', '3'))
     transformer.reset()
     assert all(stream.reset_called for stream in streams)
Exemplo n.º 3
0
 def test_reset_calls_reset_on_all_streams(self):
     streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                FlagDataStream(IterableDataset([4, 5, 6])),
                FlagDataStream(IterableDataset([7, 8, 9]))]
     transformer = Merge(streams, ('1', '2', '3'))
     transformer.reset()
     assert all(stream.reset_called for stream in streams)
Exemplo n.º 4
0
    def getFinalStream(X, Y, sources, sources_k, batch_size=128, embedding_dim=300,
        shuffle=False):
        """Despite horrible variable names, this method
        gives back the final stream for both train or test data

        batch_size:
        embedding_dim: for glove vects
        min_df and max_features: for Tokenizer

        Returns
        -------
        merged stream with sources = sources + sources_k
        """
        trX, trY = (X, Y)
        trX_k, trY_k = (X, Y)

        # Transforms
        trXt=lambda x: floatX(x)
        Yt=lambda y: intX(SeqPadded(vect.transform(sampleCaptions(y)), 'back'))

        # Foxhound Iterators
        # RCL: Write own iterator to sample positive examples/captions, since there are 5 for each image.
        train_iterator = iterators.Linear(
            trXt=trXt, trYt=Yt, size=batch_size, shuffle=shuffle
            )
        train_iterator_k = iterators.Linear(
            trXt=trXt, trYt=Yt, size=batch_size, shuffle=shuffle
            )

        # FoxyDataStreams
        train_stream = FoxyDataStream(
              (trX, trY)
            , sources
            , train_iterator
            , FoxyIterationScheme(len(trX), batch_size)
            )

        train_stream_k = FoxyDataStream(
              (trX_k, trY_k)
            , sources_k
            , train_iterator_k
            , FoxyIterationScheme(len(trX), batch_size)
            )
        glove_version = "glove.6B.%sd.txt.gz" % embedding_dim
        train_transformer = GloveTransformer(
            glove_version, data_stream=train_stream, vectorizer=vect
            )
        train_transformer_k = GloveTransformer(
            glove_version, data_stream=train_stream_k, vectorizer=vect
            )

        # Final Data Streams w/ contrastive examples
        final_train_stream = Merge(
              (train_transformer, ShuffleBatch(train_transformer_k))
            , sources + sources_k
            )
        final_train_stream.iteration_scheme = FoxyIterationScheme(len(trX), batch_size)

        return final_train_stream
Exemplo n.º 5
0
def test_merge():
    english = IterableDataset(['Hello world!'])
    french = IterableDataset(['Bonjour le monde!'])
    streams = (english.get_example_stream(), french.get_example_stream())
    merged_stream = Merge(streams, ('english', 'french'))
    assert merged_stream.sources == ('english', 'french')
    assert (next(merged_stream.get_epoch_iterator()) == ('Hello world!',
                                                         'Bonjour le monde!'))
Exemplo n.º 6
0
class TestMerge(object):
    def setUp(self):
        self.streams = (DataStream(IterableDataset(['Hello world!'])),
                        DataStream(IterableDataset(['Bonjour le monde!'])))
        self.transformer = Merge(self.streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        assert_equal(next(self.transformer.get_epoch_iterator()),
                     ('Hello world!', 'Bonjour le monde!'))

    def test_as_dict(self):
        assert_equal(next(self.transformer.get_epoch_iterator(as_dict=True)),
                     ({
                         'english': 'Hello world!',
                         'french': 'Bonjour le monde!'
                     }))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english', ))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream, ),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [
            FlagDataStream(IterableDataset([1, 2, 3])),
            FlagDataStream(IterableDataset([4, 5, 6])),
            FlagDataStream(IterableDataset([7, 8, 9]))
        ]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [
            FlagDataStream(IterableDataset([1, 2, 3])),
            FlagDataStream(IterableDataset([4, 5, 6])),
            FlagDataStream(IterableDataset([7, 8, 9]))
        ]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [
            FlagDataStream(IterableDataset([1, 2, 3])),
            FlagDataStream(IterableDataset([4, 5, 6])),
            FlagDataStream(IterableDataset([7, 8, 9]))
        ]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
Exemplo n.º 7
0
def test_merge():
    english = IterableDataset(['Hello world!'])
    french = IterableDataset(['Bonjour le monde!'])
    streams = (english.get_example_stream(),
               french.get_example_stream())
    merged_stream = Merge(streams, ('english', 'french'))
    assert merged_stream.sources == ('english', 'french')
    assert (next(merged_stream.get_epoch_iterator()) ==
            ('Hello world!', 'Bonjour le monde!'))
    def getFinalStream(X,
                       Y,
                       sources,
                       sources_k,
                       batch_size=128,
                       embedding_dim=300,
                       shuffle=False):
        """
        Returns
        -------
        merged stream with sources = sources + sources_k
        """
        trX, trY = (X, Y)
        trX_k, trY_k = (X, Y)

        # Transforms
        trXt = lambda x: floatX(x)
        Yt = lambda y: intX(
            SeqPadded(vect.transform(sampleCaptions(y)), 'back'))

        # Foxhound Iterators
        # RCL: Write own iterator to sample positive examples/captions, since there are 5 for each image.
        train_iterator = iterators.Linear(trXt=trXt,
                                          trYt=Yt,
                                          size=batch_size,
                                          shuffle=shuffle)
        train_iterator_k = iterators.Linear(trXt=trXt,
                                            trYt=Yt,
                                            size=batch_size,
                                            shuffle=True)

        # FoxyDataStreams
        train_stream = FoxyDataStream(
            (trX, trY), sources, train_iterator,
            FoxyIterationScheme(len(trX), batch_size))

        train_stream_k = FoxyDataStream(
            (trX_k, trY_k), sources_k, train_iterator_k,
            FoxyIterationScheme(len(trX), batch_size))
        glove_version = "glove.6B.%sd.txt.gz" % embedding_dim
        train_transformer = GloveTransformer(glove_version,
                                             data_stream=train_stream,
                                             vectorizer=vect)
        train_transformer_k = GloveTransformer(glove_version,
                                               data_stream=train_stream_k,
                                               vectorizer=vect)

        # Final Data Streams w/ contrastive examples
        final_train_stream = Merge(
            (train_transformer, ShuffleBatch(train_transformer_k)),
            sources + sources_k)
        final_train_stream.iteration_scheme = FoxyIterationScheme(
            len(trX), batch_size)

        return final_train_stream
Exemplo n.º 9
0
 def setUp(self):
     self.streams = (
         DataStream(IterableDataset(['Hello world!'])),
         DataStream(IterableDataset(['Bonjour le monde!'])))
     self.batch_streams = (
         Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
               iteration_scheme=ConstantScheme(2)),
         Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
               iteration_scheme=ConstantScheme(2)))
     self.transformer = Merge(
         self.streams, ('english', 'french'))
     self.batch_transformer = Merge(
         self.batch_streams, ('english', 'french'))
Exemplo n.º 10
0
 def setUp(self):
     self.streams = (
         DataStream(IterableDataset(['Hello world!'])),
         DataStream(IterableDataset(['Bonjour le monde!'])))
     self.batch_streams = (
         Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
               iteration_scheme=ConstantScheme(2)),
         Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
               iteration_scheme=ConstantScheme(2)))
     self.transformer = Merge(
         self.streams, ('english', 'french'))
     self.batch_transformer = Merge(
         self.batch_streams, ('english', 'french'))
Exemplo n.º 11
0
class TestMerge(object):
    def setUp(self):
        self.streams = (
            DataStream(IterableDataset(['Hello world!'])),
            DataStream(IterableDataset(['Bonjour le monde!'])))
        self.transformer = Merge(self.streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        assert_equal(next(self.transformer.get_epoch_iterator()),
                     ('Hello world!', 'Bonjour le monde!'))

    def test_as_dict(self):
        assert_equal(
            next(self.transformer.get_epoch_iterator(as_dict=True)),
            ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english',))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream,),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
Exemplo n.º 12
0
def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data,
                  src_vocab_size=30000, trg_vocab_size=30000,
                  unk_id=0, eos_id=1, bos_id=2, train_noise=0,
                  seq_len=50, batch_size=80, sort_k_batches=12, **kwargs):
    src_stream = get_stream(src_vocab, src_data, src_vocab_size, unk_id, eos_id, bos_id, train_noise)
    trg_stream = get_stream(trg_vocab, trg_data, trg_vocab_size, unk_id, eos_id, bos_id, 0)

    # Merge them to get a source, target pair
    stream = Merge([src_stream, trg_stream], ('source', 'target'))

    # Filter sequences that are too long
    stream = Filter(stream, predicate=_not_too_long(seq_len))

    # Build a batched version of stream to read k batches ahead
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_k_batches))

    # Sort all samples in the read-ahead batch
    stream = Mapping(stream, SortMapping(_length))

    # Convert it into a stream again
    stream = Unpack(stream)

    # Construct batches from the stream with specified batch size
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    return PaddingWithEOS(stream, [eos_id, eos_id])
Exemplo n.º 13
0
    def train(self, req_vars):
        prefix_stream = DataStream(self.train_dataset,
                                   iteration_scheme=ShuffledExampleScheme(
                                       self.train_dataset.num_examples))

        if not data.tvt:
            prefix_stream = transformers.TaxiExcludeTrips(
                prefix_stream, self.valid_trips_ids)
        prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)

        prefix_stream = transformers.taxi_add_datetime(prefix_stream)

        prefix_stream = transformers.balanced_batch(
            prefix_stream,
            key='latitude',
            batch_size=self.config.batch_size,
            batch_sort_size=self.config.batch_sort_size)

        prefix_stream = Padding(prefix_stream,
                                mask_sources=['latitude', 'longitude'])

        candidate_stream = self.candidate_stream(
            self.config.train_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)

        stream = transformers.Select(stream, tuple(req_vars))
        # stream = MultiProcessing(stream)
        return stream
Exemplo n.º 14
0
def get_comb_stream(fea2obj,
                    which_set,
                    batch_size=None,
                    shuffle=True,
                    num_examples=None):
    streams = []
    for fea in fea2obj:
        obj = fea2obj[fea]
        dataset = H5PYDataset(obj.fuelfile,
                              which_sets=(which_set, ),
                              load_in_memory=True)
        if batch_size == None: batch_size = dataset.num_examples
        if num_examples == None: num_examples = dataset.num_examples
        if shuffle:
            iterschema = ShuffledScheme(examples=num_examples,
                                        batch_size=batch_size,
                                        rng=numpy.random.RandomState(seed))
        else:
            iterschema = SequentialScheme(examples=num_examples,
                                          batch_size=batch_size)
        stream = DataStream(dataset=dataset, iteration_scheme=iterschema)
        if fea in seq_features:
            stream = CutInput(stream, obj.max_len)
            if obj.rec == True:
                logger.info('transforming data for recursive input')
                stream = LettersTransposer(
                    stream, which_sources=fea
                )  # Required because Recurrent last_hid receive as input [sequence, batch,# features]
        streams.append(stream)
    stream = Merge(streams, tuple(fea2obj.keys()))
    return stream, num_examples
Exemplo n.º 15
0
    def valid(self, req_vars):
        prefix_stream = DataStream(self.valid_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.valid_dataset.num_examples))

        #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)

        prefix_stream = transformers.taxi_add_datetime(prefix_stream)

        prefix_stream = transformers.balanced_batch(
            prefix_stream,
            key='latitude',
            batch_size=self.config.batch_size,
            batch_sort_size=self.config.batch_sort_size)

        prefix_stream = Padding(prefix_stream,
                                mask_sources=['latitude', 'longitude'])

        candidate_stream = self.candidate_stream(
            self.config.valid_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)

        stream = transformers.Select(stream, tuple(req_vars))
        # stream = MultiProcessing(stream)

        return stream
Exemplo n.º 16
0
    def test(self, req_vars):
        prefix_stream = DataStream(self.test_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.test_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)

        if not data.tvt:
            prefix_stream = transformers.taxi_remove_test_only_clients(
                prefix_stream)

        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.test_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Exemplo n.º 17
0
def get_train_stream(configuration, sfiles, tfiles, svocab_dict, tvocab_dict):

	s_dataset = TextFile(sfiles, svocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')
	t_dataset = TextFile(tfiles, tvocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')

	# Merge 
	stream = Merge([s_dataset.get_example_stream(),
                    t_dataset.get_example_stream()],
                   ('source', 'target'))
	# Filter -- TODO 
	stream = Filter(stream, predicate=_too_long(seq_len=configuration['seq_len']))

	# Map - no need 

	# Batch - Sort 
	stream = Batch(stream, 
		iteration_scheme=ConstantScheme(
			configuration['batch_size']*configuration['sort_k_batches']))
	stream = Mapping(stream, SortMapping(_length))
	stream = Unpack(stream)
	stream = Batch(
        stream, iteration_scheme=ConstantScheme(configuration['batch_size']))

	# Pad 
	# Note that </s>=0. Fuel only allows padding 0 by default 
	masked_stream = Padding(stream)

	return masked_stream
Exemplo n.º 18
0
def get_dev_stream_with_context_features(val_context_features=None, val_set=None, src_vocab=None,
                                         src_vocab_size=30000, unk_id=1, **kwargs):
    """Setup development set stream if necessary."""

    def _get_np_array(filename):
        return numpy.load(filename)['arr_0']


    dev_stream = None
    if val_set is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(
            src_vocab if isinstance(src_vocab, dict) else
            cPickle.load(open(src_vocab)),
            bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)

        dev_dataset = TextFile([val_set], src_vocab, None)

        # now add the source with the image features
        # create the image datastream (iterate over a file line-by-line)
        con_features = _get_np_array(val_context_features)
        con_feature_dataset = IterableDataset(con_features)
        valid_image_stream = DataStream(con_feature_dataset)

        # dev_stream = DataStream(dev_dataset)
        dev_stream = Merge([dev_dataset.get_example_stream(),
                            valid_image_stream], ('source', 'initial_context'))
    #         dev_stream = dev_stream.get_example_stream()

    return dev_stream
Exemplo n.º 19
0
    def train(self, req_vars):
        prefix_stream = DataStream(self.train_dataset,
                                   iteration_scheme=ShuffledExampleScheme(
                                       self.train_dataset.num_examples))

        if not data.tvt:
            prefix_stream = transformers.TaxiExcludeTrips(
                prefix_stream, self.valid_trips_ids)
        prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.train_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Exemplo n.º 20
0
def _get_align_stream(src_data, trg_data, src_vocab_size, trg_vocab_size,
                      seq_len, **kwargs):
    """Creates the stream which is used for the main loop.
    
    Args:
        src_data (string): Path to the source sentences
        trg_data (string): Path to the target sentences
        src_vocab_size (int): Size of the source vocabulary in the NMT
                              model
        trg_vocab_size (int): Size of the target vocabulary in the NMT
                              model
        seq_len (int): Maximum length of any source or target sentence
    
    Returns:
        ExplicitNext. Alignment data stream which can be iterated
        explicitly
    """
    # Build dummy vocabulary to make TextFile happy
    src_vocab = _add_special_ids({str(i): i for i in xrange(src_vocab_size)})
    trg_vocab = _add_special_ids({str(i): i for i in xrange(trg_vocab_size)})
    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)
    # Merge them to get a source, target pair
    s = Merge(
        [src_dataset.get_example_stream(),
         trg_dataset.get_example_stream()], ('source', 'target'))
    s = Filter(s, predicate=stream._too_long(seq_len=seq_len))
    s = Batch(s, iteration_scheme=ConstantScheme(1))
    masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID])
    return ExplicitNext(masked_stream)
Exemplo n.º 21
0
def get_dev_stream(val_set=None,
                   valid_sent_dict=None,
                   src_vocab=None,
                   trg_vocab=None,
                   src_vocab_size=30000,
                   trg_vocab_size=30000,
                   unk_id=1,
                   **kwargs):
    """Setup development set stream if necessary."""

    dev_stream = None
    if val_set is not None and src_vocab is not None:
        # Load dictionaries and ensure special tokens exist
        src_vocab = ensure_special_tokens(src_vocab if isinstance(
            src_vocab, dict) else cPickle.load(open(src_vocab)),
                                          bos_idx=0,
                                          eos_idx=src_vocab_size - 1,
                                          unk_idx=unk_id)

        trg_vocab = ensure_special_tokens(trg_vocab if isinstance(
            trg_vocab, dict) else cPickle.load(open(trg_vocab)),
                                          bos_idx=0,
                                          eos_idx=trg_vocab_size - 1,
                                          unk_idx=unk_id)

        dev_dataset = TextFile([val_set], src_vocab, None)
        dev_dictset = TextFile([valid_sent_dict], trg_vocab, None)
        #dev_stream = DataStream(dev_dataset)
        # Merge them to get a source, target pair
        dev_stream = Merge([
            dev_dataset.get_example_stream(),
            dev_dictset.get_example_stream()
        ], ('source', 'valid_sent_trg_dict'))
    return dev_stream
Exemplo n.º 22
0
def get_dev_stream_with_topicalq(test_set=None,
                                 src_vocab=None,
                                 src_vocab_size=30000,
                                 topical_test_set=None,
                                 topical_vocab=None,
                                 topical_vocab_size=2000,
                                 unk_id=1,
                                 **kwargs):
    """Setup development set stream if necessary."""
    dev_stream = None
    if test_set is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(src_vocab if isinstance(
            src_vocab, dict) else cPickle.load(open(src_vocab, 'rb')),
                                           bos_idx=0,
                                           eos_idx=src_vocab_size - 1,
                                           unk_idx=unk_id)
        print test_set, type(src_vocab)
        topical_vocab = cPickle.load(open(topical_vocab, 'rb'))
        #not ensure special token.
        topical_dataset = TextFile([topical_test_set], topical_vocab, None,
                                   None, '10')
        dev_dataset = TextFile([test_set], src_vocab, None)
        #dev_stream = DataStream(dev_dataset)
        # Merge them to get a source, target pair
        dev_stream = Merge([
            dev_dataset.get_example_stream(),
            topical_dataset.get_example_stream()
        ], ('source', 'source_topical'))
    return dev_stream
Exemplo n.º 23
0
    def test(self, req_vars):
        prefix_stream = DataStream(self.test_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.test_dataset.num_examples))

        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        if not data.tvt:
            prefix_stream = transformers.taxi_remove_test_only_clients(
                prefix_stream)

        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))
        prefix_stream = Padding(prefix_stream,
                                mask_sources=['latitude', 'longitude'])

        candidate_stream = self.candidate_stream(
            self.config.test_candidate_size, False)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)

        stream = transformers.Select(stream, tuple(req_vars))
        # stream = MultiProcessing(stream)

        return stream
Exemplo n.º 24
0
def get_test_stream(sfiles, svocab_dict): 
	dataset = TextFile(sfiles, svocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')
	stream = Merge([dataset.get_example_stream(),], ('source', ))
	stream = Batch(
        stream, iteration_scheme=ConstantScheme(10))
	stream = Padding(stream)
	return stream
Exemplo n.º 25
0
def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data,
                  src_vocab_size=30000, trg_vocab_size=30000, unk_id=1,
                  seq_len=50, batch_size=80, sort_k_batches=12, **kwargs):
    """Prepares the training data stream."""

    # Load dictionaries and ensure special tokens exist
    src_vocab = _ensure_special_tokens(
        src_vocab if isinstance(src_vocab, dict)
        else cPickle.load(open(src_vocab, 'rb')),
        bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)
    trg_vocab = _ensure_special_tokens(
        trg_vocab if isinstance(trg_vocab, dict) else
        cPickle.load(open(trg_vocab, 'rb')),
        bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id)

    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)

    # Merge them to get a source, target pair
    stream = Merge([src_dataset.get_example_stream(),
                    trg_dataset.get_example_stream()],
                   ('source', 'target'))

    # Filter sequences that are too long
    stream = Filter(stream,
                    predicate=_too_long(seq_len=seq_len))

    # Replace out of vocabulary tokens with unk token
    stream = Mapping(stream,
                     _oov_to_unk(src_vocab_size=src_vocab_size,
                                 trg_vocab_size=trg_vocab_size,
                                 unk_id=unk_id))

    # Build a batched version of stream to read k batches ahead
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(
                       batch_size*sort_k_batches))

    # Sort all samples in the read-ahead batch
    stream = Mapping(stream, SortMapping(_length))

    # Convert it into a stream again
    stream = Unpack(stream)

    # Construct batches from the stream with specified batch size
    stream = Batch(
        stream, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    masked_stream = PaddingWithEOS(
        stream, [src_vocab_size - 1, trg_vocab_size - 1])

    return masked_stream
Exemplo n.º 26
0
def get_sgnmt_tr_stream(src_data,
                        trg_data,
                        src_vocab_size=30000,
                        trg_vocab_size=30000,
                        unk_id=1,
                        seq_len=50,
                        batch_size=80,
                        sort_k_batches=12,
                        **kwargs):
    """Prepares the unshuffled training data stream. This corresponds 
    to ``get_sgnmt_tr_stream`` in ``machine_translation/stream`` in the
    blocks examples."""

    # Build dummy vocabulary to make TextFile happy
    src_vocab = add_special_ids({str(i): i for i in xrange(src_vocab_size)})
    trg_vocab = add_special_ids({str(i): i for i in xrange(trg_vocab_size)})

    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)

    # Merge them to get a source, target pair
    s = Merge(
        [src_dataset.get_example_stream(),
         trg_dataset.get_example_stream()], ('source', 'target'))

    # Filter sequences that are too long
    s = Filter(s, predicate=stream._too_long(seq_len=seq_len))

    # Replace out of vocabulary tokens with unk token
    s = Mapping(
        s,
        stream._oov_to_unk(src_vocab_size=src_vocab_size,
                           trg_vocab_size=trg_vocab_size,
                           unk_id=utils.UNK_ID))

    # Build a batched version of stream to read k batches ahead
    s = Batch(s, iteration_scheme=ConstantScheme(batch_size * sort_k_batches))

    # Sort all samples in the read-ahead batch
    s = Mapping(s, SortMapping(stream._length))

    # Convert it into a stream again
    s = Unpack(s)

    # Construct batches from the stream with specified batch size
    s = Batch(s, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID])

    return masked_stream
Exemplo n.º 27
0
def get_test_stream_withContext_grdTruth(test_ctx_datas=None,
                                         test_set_source=None,
                                         test_set_target=None,
                                         src_vocab=None,
                                         src_vocab_size=30000,
                                         trg_vocab=None,
                                         trg_vocab_size=30000,
                                         batch_size=128,
                                         unk_id=1,
                                         ctx_num=3,
                                         **kwargs):
    """Setup development set stream if necessary."""
    masked_stream = None
    if test_set_source is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(src_vocab if isinstance(
            src_vocab, dict) else cPickle.load(open(src_vocab, 'rb')),
                                           bos_idx=0,
                                           eos_idx=src_vocab_size - 1,
                                           unk_idx=unk_id)
        trg_vocab = _ensure_special_tokens(trg_vocab if isinstance(
            trg_vocab, dict) else cPickle.load(open(trg_vocab, 'rb')),
                                           bos_idx=0,
                                           eos_idx=trg_vocab_size - 1,
                                           unk_idx=unk_id)
        print test_set_source, type(src_vocab)
        # Get text files from both source and target
        ctx_datasets = []
        for i in range(ctx_num):
            ctx_datasets.append(TextFile([test_ctx_datas[i]], src_vocab, None))
        dev_dataset = TextFile([test_set_source], src_vocab, None)
        dev_target = TextFile([test_set_target], trg_vocab, None)
        dev_stream = Merge([i.get_example_stream() for i in ctx_datasets] + [
            dev_dataset.get_example_stream(),
            dev_target.get_example_stream()
        ],
                           tuple('context_' + str(i) for i in range(ctx_num)) +
                           ('source', 'target'))
        stream = Mapping(
            dev_stream,
            _oov_to_unk(ctx_num=ctx_num,
                        src_vocab_size=src_vocab_size,
                        trg_vocab_size=trg_vocab_size,
                        unk_id=unk_id))

        # Build a batched version of stream to read k batches ahead
        stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
        masked_stream = PaddingWithEOSContext(
            stream, [src_vocab_size - 1
                     for i in range(ctx_num + 1)] + [trg_vocab_size - 1])

    return masked_stream
Exemplo n.º 28
0
def pair_data_stream(dataset, batch_size):
    data_streams = [
        Rename(_data_stream(dataset=dataset, batch_size=batch_size),
               names={
                   source: '{}_{}'.format(source, i)
                   for source in dataset.sources
               }) for i in [1, 2]
    ]
    data_stream = Merge(data_streams=data_streams,
                        sources=data_streams[0].sources +
                        data_streams[1].sources)
    _ = data_streams[0].get_epoch_iterator()  # make sure not same order

    return data_stream
Exemplo n.º 29
0
def get_dev_stream_withContext_withPosTag(test_ctx_datas=None,
                                          test_posTag_datas=None,
                                          test_set_source=None,
                                          src_vocab=None,
                                          src_vocab_size=30000,
                                          unk_id=1,
                                          ctx_num=3,
                                          **kwargs):
    """Setup development set stream if necessary."""
    dev_stream = None
    if test_set_source is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(src_vocab if isinstance(
            src_vocab, dict) else cPickle.load(open(src_vocab, 'rb')),
                                           bos_idx=0,
                                           eos_idx=src_vocab_size - 1,
                                           unk_idx=unk_id)
        print test_set_source, type(src_vocab)
        # Get text files from both source and target
        ctx_datasets = []
        posTag_datasets = []
        for i in range(ctx_num):
            ctx_datasets.append(TextFile([test_ctx_datas[i]], src_vocab, None))
            posTag_datasets.append(
                TextFile([test_posTag_datas[i]], src_vocab, None))
        posTag_datasets.append(
            TextFile([test_posTag_datas[ctx_num]], src_vocab, None))
        src_dataset = TextFile([test_set_source], src_vocab, None)

        # Merge them to get a source, target pair
        dev_stream = Merge(
            [i.get_example_stream() for i in ctx_datasets] +
            [i.get_example_stream()
             for i in posTag_datasets] + [src_dataset.get_example_stream()],
            tuple('context_' + str(i) for i in range(ctx_num)) +
            tuple('context_posTag_' + str(i)
                  for i in range(ctx_num)) + ('source_posTag', 'source'))

        stream = Mapping(
            dev_stream,
            _oov_to_unk_posTag_dev(ctx_num=ctx_num,
                                   src_vocab_size=src_vocab_size,
                                   unk_id=unk_id))

        # Build a batched version of stream to read k batches ahead
        stream = Batch(stream, iteration_scheme=ConstantScheme(1))
        masked_stream = PaddingWithEOSContext(
            stream, [src_vocab_size - 1 for i in range(2 * ctx_num + 2)])

    return masked_stream
Exemplo n.º 30
0
def get_dev_stream_with_grdTruth(val_set_source=None,
                                 val_set_target=None,
                                 src_vocab=None,
                                 src_vocab_size=30000,
                                 trg_vocab=None,
                                 trg_vocab_size=30000,
                                 batch_size=128,
                                 unk_id=1,
                                 seq_len=50,
                                 **kwargs):
    """Setup development set stream if necessary."""
    dev_stream = None
    if val_set_source is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(src_vocab if isinstance(
            src_vocab, dict) else cPickle.load(open(src_vocab, 'rb')),
                                           bos_idx=0,
                                           eos_idx=src_vocab_size - 1,
                                           unk_idx=unk_id)
        trg_vocab = _ensure_special_tokens(trg_vocab if isinstance(
            trg_vocab, dict) else cPickle.load(open(trg_vocab, 'rb')),
                                           bos_idx=0,
                                           eos_idx=trg_vocab_size - 1,
                                           unk_idx=unk_id)

        print val_set_source, type(src_vocab)
        dev_dataset = TextFile([val_set_source], src_vocab, None)
        trg_dataset = TextFile([val_set_target], trg_vocab, None)
        # Merge them to get a source, target pair
        dev_stream = Merge([
            dev_dataset.get_example_stream(),
            trg_dataset.get_example_stream()
        ], ('dev_source', 'dev_target'))
        # Filter sequences that are too long
        stream = Filter(dev_stream, predicate=_too_long(seq_len=seq_len))

        # Replace out of vocabulary tokens with unk token
        stream = Mapping(
            stream,
            _oov_to_unk(src_vocab_size=src_vocab_size,
                        trg_vocab_size=trg_vocab_size,
                        unk_id=unk_id))

        # Build a batched version of stream to read k batches ahead
        stream = Batch(stream, iteration_scheme=ConstantScheme(1))
        # Pad sequences that are short
        masked_stream = PaddingWithEOS(
            stream, [src_vocab_size - 1, trg_vocab_size - 1])
    return masked_stream
Exemplo n.º 31
0
def get_log_prob_stream(cg, config):
    eid, did = p_(cg)
    dataset = config['log_prob_sets'][cg]

    # Prepare source vocabs and files, make sure special tokens are there
    src_vocab = cPickle.load(open(config['src_vocabs'][eid]))
    src_vocab['<S>'] = 0
    src_vocab['</S>'] = config['src_eos_idxs'][eid]
    src_vocab['<UNK>'] = config['unk_id']

    # Prepare target vocabs and files, make sure special tokens are there
    trg_vocab = cPickle.load(open(config['trg_vocabs'][did]))
    trg_vocab['<S>'] = 0
    trg_vocab['</S>'] = config['trg_eos_idxs'][did]
    trg_vocab['<UNK>'] = config['unk_id']

    # Build the preprocessing pipeline for individual streams
    logger.info('Building logprob stream for cg:[{}]'.format(cg))
    src_dataset = TextFile([dataset[0]], src_vocab, None)
    trg_dataset = TextFile([dataset[1]], trg_vocab, None)
    stream = Merge(
        [src_dataset.get_example_stream(),
         trg_dataset.get_example_stream()], ('source', 'target'))

    stream = Mapping(
        stream,
        _oov_to_unk(src_vocab_size=config['src_vocab_sizes'][eid],
                    trg_vocab_size=config['trg_vocab_sizes'][did],
                    unk_id=config['unk_id']))
    bs = 100
    if 'log_prob_bs' in config:
        if isinstance(config['log_prob_bs'], dict):
            bs = config['log_prob_bs'][cg]
        else:
            bs = config['log_prob_bs']
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(bs,
                                                   num_examples=get_num_lines(
                                                       dataset[0])))

    masked_stream = Padding(stream)
    masked_stream = Mapping(
        masked_stream,
        _remapWordIdx([(0, 0, config['src_eos_idxs'][eid]),
                       (2, 0, config['trg_eos_idxs'][did])]))

    return masked_stream
Exemplo n.º 32
0
def get_dev_stream_with_prefix_file(val_set=None, val_set_grndtruth=None, val_set_prefixes=None, val_set_suffixes=None,
                                    src_vocab=None, src_vocab_size=30000, trg_vocab=None, trg_vocab_size=30000, unk_id=1,
                                    return_vocab=False, **kwargs):
    """Setup development stream with user-provided source, target, prefixes, and suffixes"""

    dev_stream = None
    if val_set is not None and val_set_grndtruth is not None and val_set_prefixes is not None and val_set_suffixes is not None:
        src_vocab = _ensure_special_tokens(
            src_vocab if isinstance(src_vocab, dict) else
            cPickle.load(open(src_vocab)),
            bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)

        trg_vocab = _ensure_special_tokens(
            trg_vocab if isinstance(trg_vocab, dict) else
            cPickle.load(open(trg_vocab)),
            bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id)

        # Note: user should have already provided the EOS token in the data representation for the suffix
        # Note: The reason that we need EOS tokens in the reference file is that IMT systems need to evaluate metrics
        # Note: which count prediction of the </S> token, and evaluation scripts are called on the files
        dev_source_dataset = TextFile([val_set], src_vocab,
                                      bos_token='<S>',
                                      eos_token='</S>',
                                      unk_token='<UNK>')
        dev_target_dataset = TextFile([val_set_grndtruth], trg_vocab,
                                      bos_token='<S>',
                                      eos_token='</S>',
                                      unk_token='<UNK>')
        dev_prefix_dataset = TextFile([val_set_prefixes], trg_vocab,
                                      bos_token='<S>',
                                      eos_token=None,
                                      unk_token='<UNK>')
        dev_suffix_dataset = TextFile([val_set_suffixes], trg_vocab,
                                      bos_token=None,
                                      eos_token=None,
                                      unk_token='<UNK>')

        dev_stream = Merge([dev_source_dataset.get_example_stream(),
                            dev_target_dataset.get_example_stream(),
                            dev_prefix_dataset.get_example_stream(),
                            dev_suffix_dataset.get_example_stream()],
                           ('source', 'target','target_prefix','target_suffix'))

    if return_vocab:
        return dev_stream, src_vocab, trg_vocab
    else:
        return dev_stream
Exemplo n.º 33
0
def load_parallel_data(src_file,
                       tgt_file,
                       batch_size,
                       sort_k_batches,
                       dictionary,
                       training=False):
    def preproc(s):
        s = s.replace('``', '"')
        s = s.replace('\'\'', '"')
        return s

    enc_dset = TextFile(files=[src_file],
                        dictionary=dictionary,
                        bos_token=None,
                        eos_token=None,
                        unk_token=CHAR_UNK_TOK,
                        level='character',
                        preprocess=preproc)
    dec_dset = TextFile(files=[tgt_file],
                        dictionary=dictionary,
                        bos_token=CHAR_SOS_TOK,
                        eos_token=CHAR_EOS_TOK,
                        unk_token=CHAR_UNK_TOK,
                        level='character',
                        preprocess=preproc)
    # NOTE merge encoder and decoder setup together
    stream = Merge(
        [enc_dset.get_example_stream(),
         dec_dset.get_example_stream()], ('source', 'target'))
    if training:
        # filter sequences that are too long
        stream = Filter(stream, predicate=TooLong(seq_len=CHAR_MAX_SEQ_LEN))
        # batch and read k batches ahead
        stream = Batch(stream,
                       iteration_scheme=ConstantScheme(batch_size *
                                                       sort_k_batches))
        # sort all samples in read-ahead batch
        stream = Mapping(stream, SortMapping(lambda x: len(x[1])))
        # turn back into stream
        stream = Unpack(stream)
    # batch again
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
    masked_stream = Padding(stream)
    return masked_stream
Exemplo n.º 34
0
def get_dev_tr_stream_with_topic_target(val_set_source=None,val_set_target=None, src_vocab=None,trg_vocab=None, src_vocab_size=30000,trg_vocab_size=30000,
                                        trg_topic_vocab_size=2000,source_topic_vocab_size=2000,
                                        topical_dev_set=None,topic_vocab_input=None,topic_vocab_output=None,topical_vocab_size=2000,
                   unk_id=1, **kwargs):
    """Prepares the training data stream."""

    dev_stream = None
    if val_set_source is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(
            src_vocab if isinstance(src_vocab, dict)
            else cPickle.load(open(src_vocab, 'rb')),
            bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)
        trg_vocab = _ensure_special_tokens(
            trg_vocab if isinstance(trg_vocab, dict) else
            cPickle.load(open(trg_vocab, 'rb')),
            bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id)
        topic_vocab_input=cPickle.load(open(topic_vocab_input,'rb'));
        topic_vocab_output=cPickle.load(open(topic_vocab_output, 'rb'));#already has <UNK> and </S> in it
        topic_binary_vocab={};
        for k,v in topic_vocab_output.items():
            if k=='<UNK>':
                topic_binary_vocab[k]=0;
            else:
                topic_binary_vocab[k]=1;
        # Get text files from both source and target
        src_dataset = TextFile([val_set_source], src_vocab, None)
        trg_dataset = TextFile([val_set_target], trg_vocab, None)
        src_topic_input=TextFile([topical_dev_set],topic_vocab_input,None,None,'rt')
        trg_topic_dataset = TextFile([val_set_target],topic_vocab_output,None);
        trg_topic_binary_dataset= TextFile([val_set_target],topic_binary_vocab,None);

        # Merge them to get a source, target pair
        dev_stream = Merge([src_dataset.get_example_stream(),
                        trg_dataset.get_example_stream(),
                        src_topic_input.get_example_stream(),
                        trg_topic_dataset.get_example_stream(),
                        trg_topic_binary_dataset.get_example_stream()],
                       ('source', 'target','source_topical','target_topic','target_binary_topic'))
        stream = Batch(
        dev_stream, iteration_scheme=ConstantScheme(1))
        masked_stream = PaddingWithEOS(
        stream, [src_vocab_size - 1,trg_vocab_size - 1, source_topic_vocab_size-1,trg_topic_vocab_size - 1,trg_topic_vocab_size-1])

    return masked_stream
Exemplo n.º 35
0
def get_dev_stream(sfiles, tfiles, svocab_dict, tvocab_dict):

	s_dataset = TextFile(sfiles, svocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')
	t_dataset = TextFile(tfiles, tvocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')

	# Merge 
	stream = Merge([s_dataset.get_example_stream(),
                    t_dataset.get_example_stream()],
                   ('source', 'target'))
	# Batch - Sort 
	stream = Batch(stream, 
		iteration_scheme=ConstantScheme(1006))
	# Pad 
	# Note that </s>=0. Fuel only allows padding 0 by default 
	masked_stream = Padding(stream)

	return masked_stream
Exemplo n.º 36
0
class TestMerge(object):
    def setUp(self):
        self.streams = (
            DataStream(IterableDataset(['Hello world!'])),
            DataStream(IterableDataset(['Bonjour le monde!'])))
        self.batch_streams = (
            Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
                  iteration_scheme=ConstantScheme(2)),
            Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
                  iteration_scheme=ConstantScheme(2)))
        self.transformer = Merge(
            self.streams, ('english', 'french'))
        self.batch_transformer = Merge(
            self.batch_streams, ('english', 'french'))

    def test_sources(self):
        assert_equal(self.transformer.sources, ('english', 'french'))

    def test_merge(self):
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)

    def test_batch_merge(self):
        it = self.batch_transformer.get_epoch_iterator()
        assert_equal(next(it),
                     (('Hello world!', 'Hi!'),
                      ('Bonjour le monde!', 'Salut!')))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.batch_transformer.get_epoch_iterator()
        assert_equal(next(it),
                     (('Hello world!', 'Hi!'),
                      ('Bonjour le monde!', 'Salut!')))
        assert_raises(StopIteration, next, it)

    def test_merge_batch_streams(self):
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)
        # There used to be problems with reseting Merge, for which
        # reason we regression-test it as follows:
        it = self.transformer.get_epoch_iterator()
        assert_equal(next(it), ('Hello world!', 'Bonjour le monde!'))
        assert_raises(StopIteration, next, it)

    def test_as_dict(self):
        assert_equal(
            next(self.transformer.get_epoch_iterator(as_dict=True)),
            ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))

    def test_error_on_wrong_number_of_sources(self):
        assert_raises(ValueError, Merge, self.streams, ('english',))

    def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream,),
                      ('english', 'french', 'spanish'))

    def test_close_calls_close_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.close()
        assert all(stream.close_called for stream in streams)

    def test_next_epoch_calls_next_epoch_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.next_epoch()
        assert all(stream.next_epoch_called for stream in streams)

    def test_reset_calls_reset_on_all_streams(self):
        streams = [FlagDataStream(IterableDataset([1, 2, 3])),
                   FlagDataStream(IterableDataset([4, 5, 6])),
                   FlagDataStream(IterableDataset([7, 8, 9]))]
        transformer = Merge(streams, ('1', '2', '3'))
        transformer.reset()
        assert all(stream.reset_called for stream in streams)
 def test_merge(self):
     transformer = Merge(self.streams, ("english", "french"))
     assert_equal(next(transformer.get_epoch_iterator()), ("Hello world!", "Bonjour le monde!"))
Exemplo n.º 38
0
 def setUp(self):
     self.streams = (
         DataStream(IterableDataset(['Hello world!'])),
         DataStream(IterableDataset(['Bonjour le monde!'])))
     self.transformer = Merge(self.streams, ('english', 'french'))
 def test_as_dict(self):
     transformer = Merge(self.streams, ("english", "french"))
     assert_equal(
         next(transformer.get_epoch_iterator(as_dict=True)),
         ({"english": "Hello world!", "french": "Bonjour le monde!"}),
     )
Exemplo n.º 40
0
 def test_merge(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(next(transformer.get_epoch_iterator()),
                  ('Hello world!', 'Bonjour le monde!'))
Exemplo n.º 41
0
 def test_as_dict(self):
     transformer = Merge(self.streams, ('english', 'french'))
     assert_equal(
         next(transformer.get_epoch_iterator(as_dict=True)),
         ({'english': 'Hello world!', 'french': 'Bonjour le monde!'}))