def get_one_stream(self, part, lang=None, batches=True, shuffle=True, add_sources=(), num_examples=None, rng=None, seed=None, num_result=None, soften_distributions=None, only_stream=False): assert lang in self.langs dataset = self.get_dataset(part, lang, add_sources=add_sources) if num_examples is None: num_examples = dataset.num_examples if shuffle: iteration_scheme = ShuffledExampleScheme(num_examples, rng=rng) else: iteration_scheme = SequentialExampleScheme(num_examples) if num_result is None: num_result = num_examples if lang != self.langs[0] and not only_stream: iteration_scheme = RandomExampleScheme(num_examples, num_result=num_result, rng=rng) stream = DataStream( dataset, iteration_scheme=iteration_scheme) if soften_distributions: stream = Mapping(stream, SoftenResult(self.default_sources, soften_distributions)) for bconv in self._binary_convertable_data: if bconv in self.default_sources: stream = Mapping(stream, ConvertToMask(self.default_sources, bconv, self.num_features(bconv))) if self.add_eos: stream = Mapping(stream, _AddLabel( self.eos_label, index=stream.sources.index(self.sources_map['labels']))) if self.add_bos: if self.bos_label is None: raise Exception('No bos label given') stream = Mapping(stream, _AddLabel( self.bos_label, append=False, times=self.add_bos, index=stream.sources.index(self.sources_map['labels']))) if self.max_length: stream = Filter(stream, self.length_filter) if self.sort_k_batches and batches: stream = Batch(stream, iteration_scheme=ConstantScheme( self.batch_size * self.sort_k_batches)) # # Hardcode 0 for source on which to sort. This will be good, as # most source lengths are correlated and, furthermore, the # labels will typically be the last source, thus in a single-input # case this sorts on input lengths # stream = Mapping(stream, SortMapping(_Length( index=0))) stream = Unpack(stream) if self.normalization: stream = self.normalization.wrap_stream(stream) stream = ForceFloatX(stream) stream = Rename(stream, names=dict_subset({v: k for (k, v) in self.sources_map.items()}, stream.sources, must_have=False)) if not batches: return stream, num_examples stream = Batch( stream, iteration_scheme=ConstantScheme(self.batch_size if part == 'train' else self.validation_batch_size)) stream._produces_examples = False return stream, num_examples
def get_stream(self, part, batches=True, shuffle=True, add_sources=(), num_examples=None, rng=None, seed=None): dataset = self.get_dataset(part, add_sources=add_sources) if num_examples is None: num_examples = dataset.num_examples if shuffle: iteration_scheme = ShuffledExampleScheme(num_examples, rng=rng) else: iteration_scheme = SequentialExampleScheme(num_examples) stream = DataStream( dataset, iteration_scheme=iteration_scheme) if self.add_eos: stream = Mapping(stream, _AddLabel( self.eos_label, index=stream.sources.index(self.sources_map['labels']))) if self.add_bos: if self.bos_label is None: raise Exception('No bos label given') stream = Mapping(stream, _AddLabel( self.bos_label, append=False, times=self.add_bos, index=stream.sources.index(self.sources_map['labels']))) if self.max_length: stream = Filter(stream, self.length_filter) if self.sort_k_batches and batches: stream = Batch(stream, iteration_scheme=ConstantScheme( self.batch_size * self.sort_k_batches)) # # Hardcode 0 for source on which to sort. This will be good, as # most source lengths are correlated and, furthermore, the # labels will typically be the last source, thus in a single-input # case this sorts on input lengths # stream = Mapping(stream, SortMapping(_Length( index=0))) stream = Unpack(stream) if self.normalization: stream = self.normalization.wrap_stream(stream) stream = ForceFloatX(stream) stream = Rename(stream, names=dict_subset({v: k for (k, v) in self.sources_map.items()}, stream.sources, must_have=False)) if not batches: return stream stream = Batch( stream, iteration_scheme=ConstantScheme(self.batch_size if part == 'train' else self.validation_batch_size)) stream = Padding(stream) stream = Mapping(stream, switch_first_two_axes) stream = ForceCContiguous(stream) stream._produces_examples = False return stream