Example #1
0
    def predict(self, content):
        if not hasattr(self, 'pipeline'):
            raise ValueError("You have to load model first.")

        # 1. 利用POS得到分词和pos tagging结果
        pos_out = self.pos_tagger.predict(content)
        # pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()]

        # 2. 组建dataset
        dataset = DataSet()
        dataset.add_field('wp', pos_out)
        dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']],
                      new_field_name='words')
        dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']],
                      new_field_name='pos')
        dataset.rename_field("words", "raw_words")

        # 3. 使用pipeline
        self.pipeline(dataset)
        dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']],
                      new_field_name='arc_pred')
        dataset.apply(lambda x: [
            arc + '/' + label
            for arc, label in zip(x['arc_pred'], x['label_pred_seq'])
        ][1:],
                      new_field_name='output')
        # output like: [['2/top', '0/root', '4/nn', '2/dep']]
        return dataset.field_arrays['output'].content
Example #2
0
    def test_rename_field(self):
        ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
        ds.rename_field("x", "xx")
        self.assertTrue("xx" in ds)
        self.assertFalse("x" in ds)

        with self.assertRaises(KeyError):
            ds.rename_field("yyy", "oo")