def setup_datastream(path, batch_size, sort_batch_count, valid=False): A = numpy.load( os.path.join(path, ('valid_x_raw.npy' if valid else 'train_x_raw.npy'))) B = numpy.load( os.path.join(path, ('valid_phn.npy' if valid else 'train_phn.npy'))) C = numpy.load( os.path.join( path, ('valid_seq_to_phn.npy' if valid else 'train_seq_to_phn.npy'))) D = [B[x[0]:x[1], 2] for x in C] ds = IndexableDataset({'input': A, 'output': D}) stream = DataStream(ds, iteration_scheme=ShuffledExampleScheme(len(A))) stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_batch_count)) comparison = _balanced_batch_helper(stream.sources.index('input')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size, num_examples=len(A))) stream = Padding(stream, mask_sources=['input', 'output']) return ds, stream
def setup_squad_ranker_datastream(path, vocab_file, config, example_count=1836975): ds = SQuADRankerDataset(path, vocab_file) it = ShuffledExampleScheme(examples=example_count) stream = DataStream(ds, iteration_scheme=it) # Sort sets of multiple batches to make batches of similar sizes stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) comparison = _balanced_batch_helper(stream.sources.index('question')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) stream = Padding(stream, mask_sources=[ 'question', 'answer', 'better', 'worse', 'b_left', 'b_right', 'w_left', 'w_right' ], mask_dtype='int32') return ds, stream
def setup_datastream(path, vocab_file, config): ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question) it = QAIterator(path, shuffle=config.shuffle_questions) stream = DataStream(ds, iteration_scheme=it) if config.concat_ctx_and_question: stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>']) # Sort sets of multiple batches to make batches of similar sizes stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) comparison = _balanced_batch_helper( stream.sources.index( 'question' if config.concat_ctx_and_question else 'context')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) print('sources') print(stream.sources) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32') print('sources2') print(stream.sources) return ds, stream
def train(self, req_vars): stream = TaxiDataset('train', data.traintest_ds) if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) if not data.tvt: valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) if hasattr(self.config, 'shuffle_batch_size'): stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size)) stream = Mapping(stream, SortMapping(key=UniformGenerator())) stream = Unpack(stream) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) stream = MultiProcessing(stream) return stream
def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data, src_vocab_size=30000, trg_vocab_size=30000, unk_id=0, eos_id=1, bos_id=2, train_noise=0, seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): src_stream = get_stream(src_vocab, src_data, src_vocab_size, unk_id, eos_id, bos_id, train_noise) trg_stream = get_stream(trg_vocab, trg_data, trg_vocab_size, unk_id, eos_id, bos_id, 0) # Merge them to get a source, target pair stream = Merge([src_stream, trg_stream], ('source', 'target')) # Filter sequences that are too long stream = Filter(stream, predicate=_not_too_long(seq_len)) # Build a batched version of stream to read k batches ahead stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_k_batches)) # Sort all samples in the read-ahead batch stream = Mapping(stream, SortMapping(_length)) # Convert it into a stream again stream = Unpack(stream) # Construct batches from the stream with specified batch size stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) # Pad sequences that are short return PaddingWithEOS(stream, [eos_id, eos_id])
def setup_squad_datastream(path, vocab_file, config): ds = SQuADDataset(path, vocab_file) it = SQuADIterator(path) stream = DataStream(ds, iteration_scheme=it) if config.concat_ctx_and_question: stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<DUMMY>']) # Sort sets of multiple batches to make batches of similar sizes stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) comparison = _balanced_batch_helper(stream.sources.index('context')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) stream = Padding(stream, mask_sources=[ 'context', 'question', 'answer', 'ans_indices', 'ans_boundaries' ], mask_dtype='int32') return ds, stream
def test_cache(): dataset = IterableDataset(range(100)) stream = DataStream(dataset) batched_stream = Batch(stream, ConstantScheme(11)) cached_stream = Cache(batched_stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() # Make sure that cache is filled as expected for (features, ), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert len(cached_stream.cache[0]) == cache_size # Make sure that the epoch finishes correctly for (features, ) in cached_stream.get_epoch_iterator(): pass assert len(features) == 100 % 7 assert not cached_stream.cache[0] # Ensure that the epoch transition is correct cached_stream = Cache(batched_stream, ConstantScheme(7, times=3)) for _, epoch in zip(range(2), cached_stream.iterate_epochs()): cache_sizes = [4, 8, 1] for i, (features, ) in enumerate(epoch): assert len(cached_stream.cache[0]) == cache_sizes[i] assert len(features) == 7 assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features) assert i == 2
def setUp(self): data = range(10) self.stream = Batch(DataStream(IterableDataset(data)), iteration_scheme=ConstantScheme(2)) data_np = numpy.arange(10) self.stream_np = Batch(DataStream(IterableDataset(data_np)), iteration_scheme=ConstantScheme(2))
def test_concatenated_scheme(): sch = ConcatenatedScheme(schemes=[ ConstantScheme(batch_size=10, times=5), ConstantScheme(batch_size=20, times=3), ConstantScheme(batch_size=30, times=1) ]) assert (list(sch.get_request_iterator()) == ([10] * 5) + ([20] * 3) + [30])
def get_train_stream(configuration, sfiles, tfiles, svocab_dict, tvocab_dict): s_dataset = TextFile(sfiles, svocab_dict, bos_token=None, eos_token=None,\ unk_token='<unk>', level='word', preprocess=None, encoding='utf8') t_dataset = TextFile(tfiles, tvocab_dict, bos_token=None, eos_token=None,\ unk_token='<unk>', level='word', preprocess=None, encoding='utf8') # Merge stream = Merge([s_dataset.get_example_stream(), t_dataset.get_example_stream()], ('source', 'target')) # Filter -- TODO stream = Filter(stream, predicate=_too_long(seq_len=configuration['seq_len'])) # Map - no need # Batch - Sort stream = Batch(stream, iteration_scheme=ConstantScheme( configuration['batch_size']*configuration['sort_k_batches'])) stream = Mapping(stream, SortMapping(_length)) stream = Unpack(stream) stream = Batch( stream, iteration_scheme=ConstantScheme(configuration['batch_size'])) # Pad # Note that </s>=0. Fuel only allows padding 0 by default masked_stream = Padding(stream) return masked_stream
def balanced_batch(stream, key, batch_size, batch_sort_size): stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size)) comparison = _balanced_batch_helper(stream.sources.index(key)) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) return Batch(stream, iteration_scheme=ConstantScheme(batch_size))
def test_concatenated_scheme_infers_request_type(): assert not ConcatenatedScheme(schemes=[ ConstantScheme(batch_size=10, times=5), ConstantScheme(batch_size=10, times=5) ]).requests_examples assert ConcatenatedScheme(schemes=[ SequentialExampleScheme(examples=10), SequentialExampleScheme(examples=10) ]).requests_examples
def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data, src_vocab_size=30000, trg_vocab_size=30000, unk_id=1, seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): """Prepares the training data stream.""" # Load dictionaries and ensure special tokens exist src_vocab = _ensure_special_tokens( src_vocab if isinstance(src_vocab, dict) else cPickle.load(open(src_vocab, 'rb')), bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) trg_vocab = _ensure_special_tokens( trg_vocab if isinstance(trg_vocab, dict) else cPickle.load(open(trg_vocab, 'rb')), bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id) # Get text files from both source and target src_dataset = TextFile([src_data], src_vocab, None) trg_dataset = TextFile([trg_data], trg_vocab, None) # Merge them to get a source, target pair stream = Merge([src_dataset.get_example_stream(), trg_dataset.get_example_stream()], ('source', 'target')) # Filter sequences that are too long stream = Filter(stream, predicate=_too_long(seq_len=seq_len)) # Replace out of vocabulary tokens with unk token stream = Mapping(stream, _oov_to_unk(src_vocab_size=src_vocab_size, trg_vocab_size=trg_vocab_size, unk_id=unk_id)) # Build a batched version of stream to read k batches ahead stream = Batch(stream, iteration_scheme=ConstantScheme( batch_size*sort_k_batches)) # Sort all samples in the read-ahead batch stream = Mapping(stream, SortMapping(_length)) # Convert it into a stream again stream = Unpack(stream) # Construct batches from the stream with specified batch size stream = Batch( stream, iteration_scheme=ConstantScheme(batch_size)) # Pad sequences that are short masked_stream = PaddingWithEOS( stream, [src_vocab_size - 1, trg_vocab_size - 1]) return masked_stream
def get_sgnmt_tr_stream(src_data, trg_data, src_vocab_size=30000, trg_vocab_size=30000, unk_id=1, seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): """Prepares the unshuffled training data stream. This corresponds to ``get_sgnmt_tr_stream`` in ``machine_translation/stream`` in the blocks examples.""" # Build dummy vocabulary to make TextFile happy src_vocab = add_special_ids({str(i): i for i in xrange(src_vocab_size)}) trg_vocab = add_special_ids({str(i): i for i in xrange(trg_vocab_size)}) # Get text files from both source and target src_dataset = TextFile([src_data], src_vocab, None) trg_dataset = TextFile([trg_data], trg_vocab, None) # Merge them to get a source, target pair s = Merge( [src_dataset.get_example_stream(), trg_dataset.get_example_stream()], ('source', 'target')) # Filter sequences that are too long s = Filter(s, predicate=stream._too_long(seq_len=seq_len)) # Replace out of vocabulary tokens with unk token s = Mapping( s, stream._oov_to_unk(src_vocab_size=src_vocab_size, trg_vocab_size=trg_vocab_size, unk_id=utils.UNK_ID)) # Build a batched version of stream to read k batches ahead s = Batch(s, iteration_scheme=ConstantScheme(batch_size * sort_k_batches)) # Sort all samples in the read-ahead batch s = Mapping(s, SortMapping(stream._length)) # Convert it into a stream again s = Unpack(s) # Construct batches from the stream with specified batch size s = Batch(s, iteration_scheme=ConstantScheme(batch_size)) # Pad sequences that are short masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID]) return masked_stream
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0] stream = Batch(DataStream(IterableDataset(range(3000))), ConstantScheme(3200)) cached_stream = Cache(stream, ConstantScheme(64)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 3000 % 64) assert not cached_stream.cache[0]
def setUp(self): self.streams = ( DataStream(IterableDataset(['Hello world!'])), DataStream(IterableDataset(['Bonjour le monde!']))) self.batch_streams = ( Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])), iteration_scheme=ConstantScheme(2)), Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])), iteration_scheme=ConstantScheme(2))) self.transformer = Merge( self.streams, ('english', 'french')) self.batch_transformer = Merge( self.batch_streams, ('english', 'french'))
def get_stream(self, part, batches=True, shuffle=True, add_sources=(), num_examples=None, rng=None, seed=None): dataset = self.get_dataset(part, add_sources=add_sources) if num_examples is None: num_examples = dataset.num_examples if shuffle: iteration_scheme = ShuffledExampleScheme(num_examples, rng=rng) else: iteration_scheme = SequentialExampleScheme(num_examples) stream = DataStream( dataset, iteration_scheme=iteration_scheme) stream = FilterSources(stream, (self.recordings_source, self.labels_source)+tuple(add_sources)) if self.add_eos: stream = Mapping(stream, _AddLabel(self.eos_label)) if self.add_bos: stream = Mapping(stream, _AddLabel(self.bos_label, append=False, times=self.add_bos)) if self.preprocess_text: stream = Mapping(stream, lvsr.datasets.wsj.preprocess_text) stream = Filter(stream, self.length_filter) if self.sort_k_batches and batches: stream = Batch(stream, iteration_scheme=ConstantScheme( self.batch_size * self.sort_k_batches)) stream = Mapping(stream, SortMapping(_length)) stream = Unpack(stream) if self.preprocess_features == 'log_spectrogram': stream = Mapping( stream, functools.partial(apply_preprocessing, log_spectrogram)) if self.normalization: stream = self.normalization.wrap_stream(stream) stream = ForceFloatX(stream) if not batches: return stream stream = Batch( stream, iteration_scheme=ConstantScheme(self.batch_size if part == 'train' else self.validation_batch_size)) stream = Padding(stream) stream = Mapping(stream, switch_first_two_axes) stream = ForceCContiguous(stream) return stream
def get_sgnmt_shuffled_tr_stream(src_data, trg_data, src_vocab_size=30000, trg_vocab_size=30000, unk_id=1, seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): """Prepares the shuffled training data stream. This is similar to ``get_sgnmt_tr_stream`` but uses ``ParallelTextFile`` in combination with ``ShuffledExampleScheme`` to support reshuffling.""" # Build dummy vocabulary to make TextFile happy src_vocab = add_special_ids({str(i): i for i in xrange(src_vocab_size)}) trg_vocab = add_special_ids({str(i): i for i in xrange(trg_vocab_size)}) parallel_dataset = ParallelTextFile(src_data, trg_data, src_vocab, trg_vocab, None) #iter_scheme = SequentialExampleScheme(parallel_dataset.num_examples) iter_scheme = ShuffledExampleScheme(parallel_dataset.num_examples) s = DataStream(parallel_dataset, iteration_scheme=iter_scheme) # Filter sequences that are too long s = Filter(s, predicate=stream._too_long(seq_len=seq_len)) # Replace out of vocabulary tokens with unk token s = Mapping( s, stream._oov_to_unk(src_vocab_size=src_vocab_size, trg_vocab_size=trg_vocab_size, unk_id=utils.UNK_ID)) # Build a batched version of stream to read k batches ahead s = Batch(s, iteration_scheme=ConstantScheme(batch_size * sort_k_batches)) # Sort all samples in the read-ahead batch s = Mapping(s, SortMapping(stream._length)) # Convert it into a stream again s = Unpack(s) # Construct batches from the stream with specified batch size s = Batch(s, iteration_scheme=ConstantScheme(batch_size)) # Pad sequences that are short masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID]) return masked_stream
def create_data_generator(path, vocab_file, config): ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question) it = QAIterator(path, shuffle=config.shuffle_questions) stream = DataStream(ds, iteration_scheme=it) if config.concat_ctx_and_question: stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>']) # Sort sets of multiple batches to make batches of similar sizes stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) comparison = _balanced_batch_helper( stream.sources.index( 'question' if config.concat_ctx_and_question else 'context')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32') def gen(): if not config.concat_ctx_and_question: for (seq_cont, seq_cont_mask, seq_quest, seq_quest_mask, tg, candidates, candidates_mask) in stream.get_epoch_iterator(): seq_cont_mask = seq_cont_mask.astype('float32') seq_quest_mask = seq_quest_mask.astype('float32') candidates_mask = candidates_mask.astype('float32') yield (seq_cont, seq_cont_mask, seq_quest, seq_quest_mask, tg, candidates, candidates_mask) else: for (seq, seq_mask, tg, candidates, candidates_mask) \ in stream.get_epoch_iterator(): seq_mask = seq_mask.astype('float32') candidates_mask = candidates_mask.astype('float32') yield (seq, seq_mask, tg, candidates, candidates_mask) return gen
def setup_sorter_datastream(path, config): ds = SorterDataset(path) it = ShuffledExampleScheme(examples=config.example_count) stream = DataStream(ds, iteration_scheme=it) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) comparison = _balanced_batch_helper(stream.sources.index('unsorted')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) stream = Padding(stream, mask_sources=['answer', 'unsorted'], mask_dtype='int32') return ds, stream
def test_adds_batch_to_axis_labels(self): stream = DataStream( IterableDataset( {'features': [1, 2, 3, 4, 5]}, axis_labels={'features': ('index',)})) transformer = Batch(stream, ConstantScheme(2), strictness=0) assert_equal(transformer.axis_labels, {'features': ('batch', 'index')})
def test_value_error_on_request(self): transformer = Padding(Batch( DataStream( IterableDataset( dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))), ConstantScheme(2))) assert_raises(ValueError, transformer.get_data, [0, 1])
def test_two_sources(self): transformer = Padding(Batch( DataStream( IterableDataset( dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))), ConstantScheme(2))) assert len(next(transformer.get_epoch_iterator())) == 4
def get_stream(self, part, batch_size, seed=None, raw_text=False): d = self.get_dataset(part) print("Dataset with {} examples".format(d.num_examples)) it = ShuffledExampleScheme(d.num_examples, rng=numpy.random.RandomState(seed)) stream = DataStream(d, iteration_scheme=it) stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) if self._retrieval: stream = FixedMapping( stream, functools.partial(retrieve_and_pad_snli, self._retrieval), add_sources=("defs", "def_mask", "sentence1_def_map", "sentence2_def_map") ) # This is because there is bug in Fuel :( Cannot concatenate tuple and list if not raw_text: stream = SourcewiseMapping(stream, functools.partial(digitize, self.vocab), which_sources=('sentence1', 'sentence2')) stream = Padding( stream, mask_sources=('sentence1', 'sentence2')) # Increases amount of outputs by x2 return stream
def train(self, req_vars): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme( self.train_dataset.num_examples)) if not data.tvt: prefix_stream = transformers.TaxiExcludeTrips( prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits( prefix_stream, max_splits=self.config.max_splits) prefix_stream = transformers.taxi_add_datetime(prefix_stream) prefix_stream = transformers.taxi_add_first_last_len( prefix_stream, self.config.n_begin_end_pts) prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme( self.config.batch_size)) candidate_stream = self.candidate_stream( self.config.train_candidate_size) sources = prefix_stream.sources + tuple( 'candidate_%s' % k for k in candidate_stream.sources) stream = Merge((prefix_stream, candidate_stream), sources) stream = transformers.Select(stream, tuple(req_vars)) stream = MultiProcessing(stream) return stream
def test(self, req_vars): prefix_stream = DataStream(self.test_dataset, iteration_scheme=SequentialExampleScheme( self.test_dataset.num_examples)) prefix_stream = transformers.taxi_add_datetime(prefix_stream) prefix_stream = transformers.taxi_add_first_last_len( prefix_stream, self.config.n_begin_end_pts) if not data.tvt: prefix_stream = transformers.taxi_remove_test_only_clients( prefix_stream) prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme( self.config.batch_size)) candidate_stream = self.candidate_stream( self.config.test_candidate_size) sources = prefix_stream.sources + tuple( 'candidate_%s' % k for k in candidate_stream.sources) stream = Merge((prefix_stream, candidate_stream), sources) stream = transformers.Select(stream, tuple(req_vars)) stream = MultiProcessing(stream) return stream
def get_data_stream(iterable): dataset = IterableDataset({'numbers': iterable}) data_stream = Mapping(dataset.get_example_stream(), _data_sqrt, add_sources=('roots', )) data_stream = Mapping(data_stream, _array_tuple) return Batch(data_stream, ConstantScheme(20))
def candidate_stream(self, n_candidates, sortmap=True): candidate_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme( self.train_dataset.num_examples)) if not data.tvt: candidate_stream = transformers.TaxiExcludeTrips( candidate_stream, self.valid_trips_ids) candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) candidate_stream = transformers.taxi_add_datetime(candidate_stream) if not data.tvt: candidate_stream = transformers.add_destination(candidate_stream) if sortmap: candidate_stream = transformers.balanced_batch( candidate_stream, key='latitude', batch_size=n_candidates, batch_sort_size=self.config.batch_sort_size) else: candidate_stream = Batch( candidate_stream, iteration_scheme=ConstantScheme(n_candidates)) candidate_stream = Padding(candidate_stream, mask_sources=['latitude', 'longitude']) return candidate_stream
def test(self, req_vars): prefix_stream = DataStream(self.test_dataset, iteration_scheme=SequentialExampleScheme( self.test_dataset.num_examples)) prefix_stream = transformers.taxi_add_datetime(prefix_stream) if not data.tvt: prefix_stream = transformers.taxi_remove_test_only_clients( prefix_stream) prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme( self.config.batch_size)) prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) candidate_stream = self.candidate_stream( self.config.test_candidate_size, False) sources = prefix_stream.sources + tuple( 'candidate_%s' % k for k in candidate_stream.sources) stream = Merge((prefix_stream, candidate_stream), sources) stream = transformers.Select(stream, tuple(req_vars)) # stream = MultiProcessing(stream) return stream
def _get_align_stream(src_data, trg_data, src_vocab_size, trg_vocab_size, seq_len, **kwargs): """Creates the stream which is used for the main loop. Args: src_data (string): Path to the source sentences trg_data (string): Path to the target sentences src_vocab_size (int): Size of the source vocabulary in the NMT model trg_vocab_size (int): Size of the target vocabulary in the NMT model seq_len (int): Maximum length of any source or target sentence Returns: ExplicitNext. Alignment data stream which can be iterated explicitly """ # Build dummy vocabulary to make TextFile happy src_vocab = _add_special_ids({str(i): i for i in xrange(src_vocab_size)}) trg_vocab = _add_special_ids({str(i): i for i in xrange(trg_vocab_size)}) # Get text files from both source and target src_dataset = TextFile([src_data], src_vocab, None) trg_dataset = TextFile([trg_data], trg_vocab, None) # Merge them to get a source, target pair s = Merge( [src_dataset.get_example_stream(), trg_dataset.get_example_stream()], ('source', 'target')) s = Filter(s, predicate=stream._too_long(seq_len=seq_len)) s = Batch(s, iteration_scheme=ConstantScheme(1)) masked_stream = stream.PaddingWithEOS(s, [utils.EOS_ID, utils.EOS_ID]) return ExplicitNext(masked_stream)