コード例 #1
0
ファイル: __init__.py プロジェクト: ck624/dependency-parser
    def get_one_stream(self, part, lang=None, batches=True, shuffle=True, add_sources=(),
                   num_examples=None, rng=None, seed=None, num_result=None,
                   soften_distributions=None, only_stream=False):
        assert lang in self.langs
        dataset = self.get_dataset(part, lang, add_sources=add_sources)
        if num_examples is None:
            num_examples = dataset.num_examples

        if shuffle:
            iteration_scheme = ShuffledExampleScheme(num_examples, rng=rng)
        else:
            iteration_scheme = SequentialExampleScheme(num_examples)

        if num_result is None:
            num_result = num_examples

        if lang != self.langs[0] and not only_stream:
            iteration_scheme = RandomExampleScheme(num_examples, num_result=num_result, rng=rng)

        stream = DataStream(
            dataset, iteration_scheme=iteration_scheme)

        if soften_distributions:
            stream = Mapping(stream, SoftenResult(self.default_sources, soften_distributions))

        for bconv in self._binary_convertable_data:
            if bconv in self.default_sources:
                stream = Mapping(stream, ConvertToMask(self.default_sources,
                                                       bconv,
                                                       self.num_features(bconv)))

        if self.add_eos:
            stream = Mapping(stream, _AddLabel(
                self.eos_label,
                index=stream.sources.index(self.sources_map['labels'])))
        if self.add_bos:
            if self.bos_label is None:
                raise Exception('No bos label given')
            stream = Mapping(stream, _AddLabel(
                self.bos_label, append=False, times=self.add_bos,
                index=stream.sources.index(self.sources_map['labels'])))

        if self.max_length:
            stream = Filter(stream, self.length_filter)

        if self.sort_k_batches and batches:
            stream = Batch(stream,
                           iteration_scheme=ConstantScheme(
                               self.batch_size * self.sort_k_batches))
            #
            # Hardcode 0 for source on which to sort. This will be good, as
            # most source lengths are correlated and, furthermore, the
            # labels will typically be the last source, thus in a single-input
            # case this sorts on input lengths
            #
            stream = Mapping(stream, SortMapping(_Length(
                index=0)))
            stream = Unpack(stream)

        if self.normalization:
            stream = self.normalization.wrap_stream(stream)
        stream = ForceFloatX(stream)
        stream = Rename(stream,
                        names=dict_subset({v: k for (k, v)
                                           in self.sources_map.items()},
                                          stream.sources,
                                          must_have=False))
        if not batches:
            return stream, num_examples

        stream = Batch(
            stream,
            iteration_scheme=ConstantScheme(self.batch_size if part == 'train'
                                            else self.validation_batch_size))

        stream._produces_examples = False
        return stream, num_examples
コード例 #2
0
ファイル: __init__.py プロジェクト: ck624/dependency-parser
    def get_stream(self, part, batches=True, shuffle=True, add_sources=(),
                   num_examples=None, rng=None, seed=None):
        dataset = self.get_dataset(part, add_sources=add_sources)
        if num_examples is None:
            num_examples = dataset.num_examples

        if shuffle:
            iteration_scheme = ShuffledExampleScheme(num_examples, rng=rng)
        else:
            iteration_scheme = SequentialExampleScheme(num_examples)

        stream = DataStream(
            dataset, iteration_scheme=iteration_scheme)

        if self.add_eos:
            stream = Mapping(stream, _AddLabel(
                self.eos_label,
                index=stream.sources.index(self.sources_map['labels'])))
        if self.add_bos:
            if self.bos_label is None:
                raise Exception('No bos label given')
            stream = Mapping(stream, _AddLabel(
                self.bos_label, append=False, times=self.add_bos,
                index=stream.sources.index(self.sources_map['labels'])))

        if self.max_length:
            stream = Filter(stream, self.length_filter)

        if self.sort_k_batches and batches:
            stream = Batch(stream,
                           iteration_scheme=ConstantScheme(
                               self.batch_size * self.sort_k_batches))
            #
            # Hardcode 0 for source on which to sort. This will be good, as
            # most source lengths are correlated and, furthermore, the
            # labels will typically be the last source, thus in a single-input
            # case this sorts on input lengths
            #
            stream = Mapping(stream, SortMapping(_Length(
                index=0)))
            stream = Unpack(stream)

        if self.normalization:
            stream = self.normalization.wrap_stream(stream)
        stream = ForceFloatX(stream)
        stream = Rename(stream,
                        names=dict_subset({v: k for (k, v)
                                           in self.sources_map.items()},
                                          stream.sources,
                                          must_have=False))
        if not batches:
            return stream

        stream = Batch(
            stream,
            iteration_scheme=ConstantScheme(self.batch_size if part == 'train'
                                            else self.validation_batch_size))
        stream = Padding(stream)
        stream = Mapping(stream, switch_first_two_axes)
        stream = ForceCContiguous(stream)
        stream._produces_examples = False
        return stream