示例#1
0
def build_dataset(fieldset, prefix='', filter_pred=None, **kwargs):
    fields, files = fieldset.fields_and_files(prefix, **kwargs)
    examples = Corpus.from_files(fields=fields, files=files)
    dataset = QEDataset(examples=examples,
                        fields=fields,
                        filter_pred=filter_pred)
    return dataset
示例#2
0
    def predict(self, examples, batch_size=1):
        """Create Predictions for a list of examples.

           Args:
             examples: A dict  mapping field names to the
               list of raw examples (strings).
             batch_size: Batch Size to use. Default 1.

           Returns:
             A dict mapping prediction levels
             (word, sentence ..) to the model predictions
             for each example.

           Raises:
             Exception: If an example has an empty string
               as `source` or `target` field.

           Example:
             >>> import kiwi
             >>> predictor = kiwi.load_model('tests/toy-data/models/nuqe.torch')
             >>> src = ['a b c', 'd e f g']
             >>> tgt = ['q w e r', 't y']
             >>> align = ['0-0 1-1 1-2', '1-1 3-0']
             >>> examples = {kiwi.constants.SOURCE: src,
                             kiwi.constants.TARGET: tgt,
                             kiwi.constants.ALIGNMENTS: align}
             >>> predictor.predict(examples)
             {'tags': [[0.4760947525501251,
                0.47569847106933594,
                0.4948718547821045,
                0.5305878520011902],
               [0.5105430483818054, 0.5252899527549744]]}
        """
        if not examples:
            return defaultdict(list)
        if self.fields is None:
            raise Exception('Missing fields object.')

        if not examples.get(const.SOURCE):
            raise KeyError('Missing required field "{}"'.format(const.SOURCE))
        if not examples.get(const.TARGET):
            raise KeyError('Missing required field "{}"'.format(const.TARGET))

        if not all([
                s.strip()
                for s in examples[const.SOURCE] + examples[const.TARGET]
        ]):
            raise Exception('Empty String in {} or {} field found!'.format(
                const.SOURCE, const.TARGET))
        fields = [(name, self.fields[name]) for name in examples]

        field_examples = [
            Example.fromlist(values, fields)
            for values in zip(*examples.values())
        ]

        dataset = QEDataset(field_examples, fields=fields)

        return self.run(dataset, batch_size)