Exemple #1
0
def setup_datastream(path, batch_size, sort_batch_count, valid=False):
    A = numpy.load(
        os.path.join(path,
                     ('valid_x_raw.npy' if valid else 'train_x_raw.npy')))
    B = numpy.load(
        os.path.join(path, ('valid_phn.npy' if valid else 'train_phn.npy')))
    C = numpy.load(
        os.path.join(
            path,
            ('valid_seq_to_phn.npy' if valid else 'train_seq_to_phn.npy')))

    D = [B[x[0]:x[1], 2] for x in C]

    ds = IndexableDataset({'input': A, 'output': D})
    stream = DataStream(ds, iteration_scheme=ShuffledExampleScheme(len(A)))

    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(batch_size *
                                                   sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('input'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(batch_size,
                                                   num_examples=len(A)))
    stream = Padding(stream, mask_sources=['input', 'output'])

    return ds, stream
def setup_squad_ranker_datastream(path,
                                  vocab_file,
                                  config,
                                  example_count=1836975):
    ds = SQuADRankerDataset(path, vocab_file)
    it = ShuffledExampleScheme(examples=example_count)
    stream = DataStream(ds, iteration_scheme=it)

    # Sort sets of multiple batches to make batches of similar sizes
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(config.batch_size *
                                                   config.sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('question'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream,
                     mask_sources=[
                         'question', 'answer', 'better', 'worse', 'b_left',
                         'b_right', 'w_left', 'w_right'
                     ],
                     mask_dtype='int32')

    return ds, stream
Exemple #3
0
def setup_datastream(path, vocab_file, config):
    ds = QADataset(path,
                   vocab_file,
                   config.n_entities,
                   need_sep_token=config.concat_ctx_and_question)
    it = QAIterator(path, shuffle=config.shuffle_questions)

    stream = DataStream(ds, iteration_scheme=it)

    if config.concat_ctx_and_question:
        stream = ConcatCtxAndQuestion(stream, config.concat_question_before,
                                      ds.reverse_vocab['<SEP>'])

    # Sort sets of multiple batches to make batches of similar sizes
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(config.batch_size *
                                                   config.sort_batch_count))
    comparison = _balanced_batch_helper(
        stream.sources.index(
            'question' if config.concat_ctx_and_question else 'context'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    print('sources')
    print(stream.sources)

    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream,
                     mask_sources=['context', 'question', 'candidates'],
                     mask_dtype='int32')

    print('sources2')
    print(stream.sources)

    return ds, stream
Exemple #4
0
    def train(self, req_vars):
        stream = TaxiDataset('train', data.traintest_ds)

        if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))

        if not data.tvt:
            valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
            valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
            stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)

        stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)

        if hasattr(self.config, 'shuffle_batch_size'):
            stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
            stream = Mapping(stream, SortMapping(key=UniformGenerator()))
            stream = Unpack(stream)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))
        
        stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        stream = MultiProcessing(stream)

        return stream
Exemple #5
0
def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data,
                  src_vocab_size=30000, trg_vocab_size=30000,
                  unk_id=0, eos_id=1, bos_id=2, train_noise=0,
                  seq_len=50, batch_size=80, sort_k_batches=12, **kwargs):
    src_stream = get_stream(src_vocab, src_data, src_vocab_size, unk_id, eos_id, bos_id, train_noise)
    trg_stream = get_stream(trg_vocab, trg_data, trg_vocab_size, unk_id, eos_id, bos_id, 0)

    # Merge them to get a source, target pair
    stream = Merge([src_stream, trg_stream], ('source', 'target'))

    # Filter sequences that are too long
    stream = Filter(stream, predicate=_not_too_long(seq_len))

    # Build a batched version of stream to read k batches ahead
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_k_batches))

    # Sort all samples in the read-ahead batch
    stream = Mapping(stream, SortMapping(_length))

    # Convert it into a stream again
    stream = Unpack(stream)

    # Construct batches from the stream with specified batch size
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    return PaddingWithEOS(stream, [eos_id, eos_id])
def setup_squad_datastream(path, vocab_file, config):
    ds = SQuADDataset(path, vocab_file)
    it = SQuADIterator(path)
    stream = DataStream(ds, iteration_scheme=it)

    if config.concat_ctx_and_question:
        stream = ConcatCtxAndQuestion(stream, config.concat_question_before,
                                      ds.reverse_vocab['<DUMMY>'])

    # Sort sets of multiple batches to make batches of similar sizes
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(config.batch_size *
                                                   config.sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('context'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream,
                     mask_sources=[
                         'context', 'question', 'answer', 'ans_indices',
                         'ans_boundaries'
                     ],
                     mask_dtype='int32')

    return ds, stream
def test_cache():
    dataset = IterableDataset(range(100))
    stream = DataStream(dataset)
    batched_stream = Batch(stream, ConstantScheme(11))
    cached_stream = Cache(batched_stream, ConstantScheme(7))
    epoch = cached_stream.get_epoch_iterator()

    # Make sure that cache is filled as expected
    for (features, ), cache_size in zip(epoch,
                                        [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]):
        assert len(cached_stream.cache[0]) == cache_size

    # Make sure that the epoch finishes correctly
    for (features, ) in cached_stream.get_epoch_iterator():
        pass
    assert len(features) == 100 % 7
    assert not cached_stream.cache[0]

    # Ensure that the epoch transition is correct
    cached_stream = Cache(batched_stream, ConstantScheme(7, times=3))
    for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
        cache_sizes = [4, 8, 1]
        for i, (features, ) in enumerate(epoch):
            assert len(cached_stream.cache[0]) == cache_sizes[i]
            assert len(features) == 7
            assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features)
        assert i == 2
Exemple #8
0
 def setUp(self):
     data = range(10)
     self.stream = Batch(DataStream(IterableDataset(data)),
                         iteration_scheme=ConstantScheme(2))
     data_np = numpy.arange(10)
     self.stream_np = Batch(DataStream(IterableDataset(data_np)),
                            iteration_scheme=ConstantScheme(2))
Exemple #9
0
def test_concatenated_scheme():
    sch = ConcatenatedScheme(schemes=[
        ConstantScheme(batch_size=10, times=5),
        ConstantScheme(batch_size=20, times=3),
        ConstantScheme(batch_size=30, times=1)
    ])
    assert (list(sch.get_request_iterator()) == ([10] * 5) + ([20] * 3) + [30])
Exemple #10
0
def get_train_stream(configuration, sfiles, tfiles, svocab_dict, tvocab_dict):

	s_dataset = TextFile(sfiles, svocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')
	t_dataset = TextFile(tfiles, tvocab_dict, bos_token=None, eos_token=None,\
		unk_token='<unk>', level='word', preprocess=None, encoding='utf8')

	# Merge 
	stream = Merge([s_dataset.get_example_stream(),
                    t_dataset.get_example_stream()],
                   ('source', 'target'))
	# Filter -- TODO 
	stream = Filter(stream, predicate=_too_long(seq_len=configuration['seq_len']))

	# Map - no need 

	# Batch - Sort 
	stream = Batch(stream, 
		iteration_scheme=ConstantScheme(
			configuration['batch_size']*configuration['sort_k_batches']))
	stream = Mapping(stream, SortMapping(_length))
	stream = Unpack(stream)
	stream = Batch(
        stream, iteration_scheme=ConstantScheme(configuration['batch_size']))

	# Pad 
	# Note that </s>=0. Fuel only allows padding 0 by default 
	masked_stream = Padding(stream)

	return masked_stream
def balanced_batch(stream, key, batch_size, batch_sort_size):
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(batch_size *
                                                   batch_sort_size))
    comparison = _balanced_batch_helper(stream.sources.index(key))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)
    return Batch(stream, iteration_scheme=ConstantScheme(batch_size))
Exemple #12
0
def test_concatenated_scheme_infers_request_type():
    assert not ConcatenatedScheme(schemes=[
        ConstantScheme(batch_size=10, times=5),
        ConstantScheme(batch_size=10, times=5)
    ]).requests_examples
    assert ConcatenatedScheme(schemes=[
        SequentialExampleScheme(examples=10),
        SequentialExampleScheme(examples=10)
    ]).requests_examples
Exemple #13
0
def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data,
                  src_vocab_size=30000, trg_vocab_size=30000, unk_id=1,
                  seq_len=50, batch_size=80, sort_k_batches=12, **kwargs):
    """Prepares the training data stream."""

    # Load dictionaries and ensure special tokens exist
    src_vocab = _ensure_special_tokens(
        src_vocab if isinstance(src_vocab, dict)
        else cPickle.load(open(src_vocab, 'rb')),
        bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)
    trg_vocab = _ensure_special_tokens(
        trg_vocab if isinstance(trg_vocab, dict) else
        cPickle.load(open(trg_vocab, 'rb')),
        bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id)

    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)

    # Merge them to get a source, target pair
    stream = Merge([src_dataset.get_example_stream(),
                    trg_dataset.get_example_stream()],
                   ('source', 'target'))

    # Filter sequences that are too long
    stream = Filter(stream,
                    predicate=_too_long(seq_len=seq_len))

    # Replace out of vocabulary tokens with unk token
    stream = Mapping(stream,
                     _oov_to_unk(src_vocab_size=src_vocab_size,
                                 trg_vocab_size=trg_vocab_size,
                                 unk_id=unk_id))

    # Build a batched version of stream to read k batches ahead
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(
                       batch_size*sort_k_batches))

    # Sort all samples in the read-ahead batch
    stream = Mapping(stream, SortMapping(_length))

    # Convert it into a stream again
    stream = Unpack(stream)

    # Construct batches from the stream with specified batch size
    stream = Batch(
        stream, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    masked_stream = PaddingWithEOS(
        stream, [src_vocab_size - 1, trg_vocab_size - 1])

    return masked_stream
Exemple #14
0
def get_sgnmt_tr_stream(src_data,
                        trg_data,
                        src_vocab_size=30000,
                        trg_vocab_size=30000,
                        unk_id=1,
                        seq_len=50,
                        batch_size=80,
                        sort_k_batches=12,
                        **kwargs):
    """Prepares the unshuffled training data stream. This corresponds 
    to ``get_sgnmt_tr_stream`` in ``machine_translation/stream`` in the
    blocks examples."""

    # Build dummy vocabulary to make TextFile happy
    src_vocab = add_special_ids({str(i): i for i in xrange(src_vocab_size)})
    trg_vocab = add_special_ids({str(i): i for i in xrange(trg_vocab_size)})

    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)

    # Merge them to get a source, target pair
    s = Merge(
        [src_dataset.get_example_stream(),
         trg_dataset.get_example_stream()], ('source', 'target'))

    # Filter sequences that are too long
    s = Filter(s, predicate=stream._too_long(seq_len=seq_len))

    # Replace out of vocabulary tokens with unk token
    s = Mapping(
        s,
        stream._oov_to_unk(src_vocab_size=src_vocab_size,
                           trg_vocab_size=trg_vocab_size,
                           unk_id=utils.UNK_ID))

    # Build a batched version of stream to read k batches ahead
    s = Batch(s, iteration_scheme=ConstantScheme(batch_size * sort_k_batches))

    # Sort all samples in the read-ahead batch
    s = Mapping(s, SortMapping(stream._length))

    # Convert it into a stream again
    s = Unpack(s)

    # Construct batches from the stream with specified batch size
    s = Batch(s, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID])

    return masked_stream
    def test_epoch_finishes_correctly(self):
        cached_stream = Cache(self.stream, ConstantScheme(7))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 100 % 7)
        assert not cached_stream.cache[0]

        stream = Batch(DataStream(IterableDataset(range(3000))),
                       ConstantScheme(3200))

        cached_stream = Cache(stream, ConstantScheme(64))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 3000 % 64)
        assert not cached_stream.cache[0]
 def setUp(self):
     self.streams = (
         DataStream(IterableDataset(['Hello world!'])),
         DataStream(IterableDataset(['Bonjour le monde!'])))
     self.batch_streams = (
         Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
               iteration_scheme=ConstantScheme(2)),
         Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
               iteration_scheme=ConstantScheme(2)))
     self.transformer = Merge(
         self.streams, ('english', 'french'))
     self.batch_transformer = Merge(
         self.batch_streams, ('english', 'french'))
Exemple #17
0
    def get_stream(self, part, batches=True, shuffle=True, add_sources=(),
                   num_examples=None, rng=None, seed=None):

        dataset = self.get_dataset(part, add_sources=add_sources)
        if num_examples is None:
            num_examples = dataset.num_examples

        if shuffle:
            iteration_scheme = ShuffledExampleScheme(num_examples, rng=rng)
        else:
            iteration_scheme = SequentialExampleScheme(num_examples)

        stream = DataStream(
            dataset, iteration_scheme=iteration_scheme)

        stream = FilterSources(stream, (self.recordings_source,
                                        self.labels_source)+tuple(add_sources))
        if self.add_eos:
            stream = Mapping(stream, _AddLabel(self.eos_label))
        if self.add_bos:
            stream = Mapping(stream, _AddLabel(self.bos_label, append=False,
                                               times=self.add_bos))
        if self.preprocess_text:
            stream = Mapping(stream, lvsr.datasets.wsj.preprocess_text)
        stream = Filter(stream, self.length_filter)
        if self.sort_k_batches and batches:
            stream = Batch(stream,
                           iteration_scheme=ConstantScheme(
                               self.batch_size * self.sort_k_batches))
            stream = Mapping(stream, SortMapping(_length))
            stream = Unpack(stream)

        if self.preprocess_features == 'log_spectrogram':
            stream = Mapping(
                stream, functools.partial(apply_preprocessing,
                                          log_spectrogram))
        if self.normalization:
            stream = self.normalization.wrap_stream(stream)
        stream = ForceFloatX(stream)
        if not batches:
            return stream

        stream = Batch(
            stream,
            iteration_scheme=ConstantScheme(self.batch_size if part == 'train'
                                            else self.validation_batch_size))
        stream = Padding(stream)
        stream = Mapping(stream, switch_first_two_axes)
        stream = ForceCContiguous(stream)
        return stream
Exemple #18
0
def get_sgnmt_shuffled_tr_stream(src_data,
                                 trg_data,
                                 src_vocab_size=30000,
                                 trg_vocab_size=30000,
                                 unk_id=1,
                                 seq_len=50,
                                 batch_size=80,
                                 sort_k_batches=12,
                                 **kwargs):
    """Prepares the shuffled training data stream. This is similar to 
    ``get_sgnmt_tr_stream`` but uses ``ParallelTextFile`` in combination
    with ``ShuffledExampleScheme`` to support reshuffling."""

    # Build dummy vocabulary to make TextFile happy
    src_vocab = add_special_ids({str(i): i for i in xrange(src_vocab_size)})
    trg_vocab = add_special_ids({str(i): i for i in xrange(trg_vocab_size)})

    parallel_dataset = ParallelTextFile(src_data, trg_data, src_vocab,
                                        trg_vocab, None)
    #iter_scheme = SequentialExampleScheme(parallel_dataset.num_examples)
    iter_scheme = ShuffledExampleScheme(parallel_dataset.num_examples)
    s = DataStream(parallel_dataset, iteration_scheme=iter_scheme)

    # Filter sequences that are too long
    s = Filter(s, predicate=stream._too_long(seq_len=seq_len))

    # Replace out of vocabulary tokens with unk token
    s = Mapping(
        s,
        stream._oov_to_unk(src_vocab_size=src_vocab_size,
                           trg_vocab_size=trg_vocab_size,
                           unk_id=utils.UNK_ID))

    # Build a batched version of stream to read k batches ahead
    s = Batch(s, iteration_scheme=ConstantScheme(batch_size * sort_k_batches))

    # Sort all samples in the read-ahead batch
    s = Mapping(s, SortMapping(stream._length))

    # Convert it into a stream again
    s = Unpack(s)

    # Construct batches from the stream with specified batch size
    s = Batch(s, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID])

    return masked_stream
Exemple #19
0
def create_data_generator(path, vocab_file, config):
    ds = QADataset(path,
                   vocab_file,
                   config.n_entities,
                   need_sep_token=config.concat_ctx_and_question)
    it = QAIterator(path, shuffle=config.shuffle_questions)

    stream = DataStream(ds, iteration_scheme=it)

    if config.concat_ctx_and_question:
        stream = ConcatCtxAndQuestion(stream, config.concat_question_before,
                                      ds.reverse_vocab['<SEP>'])

    # Sort sets of multiple batches to make batches of similar sizes
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(config.batch_size *
                                                   config.sort_batch_count))
    comparison = _balanced_batch_helper(
        stream.sources.index(
            'question' if config.concat_ctx_and_question else 'context'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream,
                     mask_sources=['context', 'question', 'candidates'],
                     mask_dtype='int32')

    def gen():

        if not config.concat_ctx_and_question:
            for (seq_cont, seq_cont_mask, seq_quest, seq_quest_mask, tg,
                 candidates, candidates_mask) in stream.get_epoch_iterator():
                seq_cont_mask = seq_cont_mask.astype('float32')
                seq_quest_mask = seq_quest_mask.astype('float32')
                candidates_mask = candidates_mask.astype('float32')

                yield (seq_cont, seq_cont_mask, seq_quest, seq_quest_mask, tg,
                       candidates, candidates_mask)
        else:

            for (seq, seq_mask, tg, candidates, candidates_mask) \
                    in stream.get_epoch_iterator():
                seq_mask = seq_mask.astype('float32')
                candidates_mask = candidates_mask.astype('float32')

                yield (seq, seq_mask, tg, candidates, candidates_mask)

    return gen
Exemple #20
0
def setup_sorter_datastream(path, config):
    ds = SorterDataset(path)
    it = ShuffledExampleScheme(examples=config.example_count)
    stream = DataStream(ds, iteration_scheme=it)
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(config.batch_size *
                                                   config.sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('unsorted'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)
    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream,
                     mask_sources=['answer', 'unsorted'],
                     mask_dtype='int32')
    return ds, stream
 def test_adds_batch_to_axis_labels(self):
     stream = DataStream(
         IterableDataset(
             {'features': [1, 2, 3, 4, 5]},
             axis_labels={'features': ('index',)}))
     transformer = Batch(stream, ConstantScheme(2), strictness=0)
     assert_equal(transformer.axis_labels, {'features': ('batch', 'index')})
Exemple #22
0
 def test_value_error_on_request(self):
     transformer = Padding(Batch(
         DataStream(
             IterableDataset(
                 dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
         ConstantScheme(2)))
     assert_raises(ValueError, transformer.get_data, [0, 1])
Exemple #23
0
 def test_two_sources(self):
     transformer = Padding(Batch(
         DataStream(
             IterableDataset(
                 dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
         ConstantScheme(2)))
     assert len(next(transformer.get_epoch_iterator())) == 4
Exemple #24
0
    def get_stream(self, part, batch_size, seed=None, raw_text=False):
        d = self.get_dataset(part)
        print("Dataset with {} examples".format(d.num_examples))
        it = ShuffledExampleScheme(d.num_examples,
                                   rng=numpy.random.RandomState(seed))
        stream = DataStream(d, iteration_scheme=it)
        stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))

        if self._retrieval:
            stream = FixedMapping(
                stream,
                functools.partial(retrieve_and_pad_snli, self._retrieval),
                add_sources=("defs", "def_mask", "sentence1_def_map",
                             "sentence2_def_map")
            )  # This is because there is bug in Fuel :( Cannot concatenate tuple and list

        if not raw_text:
            stream = SourcewiseMapping(stream,
                                       functools.partial(digitize, self.vocab),
                                       which_sources=('sentence1',
                                                      'sentence2'))

        stream = Padding(
            stream,
            mask_sources=('sentence1',
                          'sentence2'))  # Increases amount of outputs by x2

        return stream
Exemple #25
0
    def train(self, req_vars):
        prefix_stream = DataStream(self.train_dataset,
                                   iteration_scheme=ShuffledExampleScheme(
                                       self.train_dataset.num_examples))

        if not data.tvt:
            prefix_stream = transformers.TaxiExcludeTrips(
                prefix_stream, self.valid_trips_ids)
        prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.train_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Exemple #26
0
    def test(self, req_vars):
        prefix_stream = DataStream(self.test_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.test_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)

        if not data.tvt:
            prefix_stream = transformers.taxi_remove_test_only_clients(
                prefix_stream)

        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.test_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Exemple #27
0
def get_data_stream(iterable):
    dataset = IterableDataset({'numbers': iterable})
    data_stream = Mapping(dataset.get_example_stream(),
                          _data_sqrt,
                          add_sources=('roots', ))
    data_stream = Mapping(data_stream, _array_tuple)
    return Batch(data_stream, ConstantScheme(20))
Exemple #28
0
    def candidate_stream(self, n_candidates, sortmap=True):
        candidate_stream = DataStream(self.train_dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          self.train_dataset.num_examples))
        if not data.tvt:
            candidate_stream = transformers.TaxiExcludeTrips(
                candidate_stream, self.valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)

        if not data.tvt:
            candidate_stream = transformers.add_destination(candidate_stream)

        if sortmap:
            candidate_stream = transformers.balanced_batch(
                candidate_stream,
                key='latitude',
                batch_size=n_candidates,
                batch_sort_size=self.config.batch_sort_size)
        else:
            candidate_stream = Batch(
                candidate_stream,
                iteration_scheme=ConstantScheme(n_candidates))

        candidate_stream = Padding(candidate_stream,
                                   mask_sources=['latitude', 'longitude'])

        return candidate_stream
Exemple #29
0
    def test(self, req_vars):
        prefix_stream = DataStream(self.test_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.test_dataset.num_examples))

        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        if not data.tvt:
            prefix_stream = transformers.taxi_remove_test_only_clients(
                prefix_stream)

        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))
        prefix_stream = Padding(prefix_stream,
                                mask_sources=['latitude', 'longitude'])

        candidate_stream = self.candidate_stream(
            self.config.test_candidate_size, False)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)

        stream = transformers.Select(stream, tuple(req_vars))
        # stream = MultiProcessing(stream)

        return stream
Exemple #30
0
def _get_align_stream(src_data, trg_data, src_vocab_size, trg_vocab_size,
                      seq_len, **kwargs):
    """Creates the stream which is used for the main loop.
    
    Args:
        src_data (string): Path to the source sentences
        trg_data (string): Path to the target sentences
        src_vocab_size (int): Size of the source vocabulary in the NMT
                              model
        trg_vocab_size (int): Size of the target vocabulary in the NMT
                              model
        seq_len (int): Maximum length of any source or target sentence
    
    Returns:
        ExplicitNext. Alignment data stream which can be iterated
        explicitly
    """
    # Build dummy vocabulary to make TextFile happy
    src_vocab = _add_special_ids({str(i): i for i in xrange(src_vocab_size)})
    trg_vocab = _add_special_ids({str(i): i for i in xrange(trg_vocab_size)})
    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)
    # Merge them to get a source, target pair
    s = Merge(
        [src_dataset.get_example_stream(),
         trg_dataset.get_example_stream()], ('source', 'target'))
    s = Filter(s, predicate=stream._too_long(seq_len=seq_len))
    s = Batch(s, iteration_scheme=ConstantScheme(1))
    masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID])
    return ExplicitNext(masked_stream)