def predict(self, content): if not hasattr(self, 'pipeline'): raise ValueError("You have to load model first.") # 1. 利用POS得到分词和pos tagging结果 pos_out = self.pos_tagger.predict(content) # pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] # 2. 组建dataset dataset = DataSet() dataset.add_field('wp', pos_out) dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') dataset.rename_field("words", "raw_words") # 3. 使用pipeline self.pipeline(dataset) dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') dataset.apply(lambda x: [ arc + '/' + label for arc, label in zip(x['arc_pred'], x['label_pred_seq']) ][1:], new_field_name='output') # output like: [['2/top', '0/root', '4/nn', '2/dep']] return dataset.field_arrays['output'].content
def test_rename_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds.rename_field("x", "xx") self.assertTrue("xx" in ds) self.assertFalse("x" in ds) with self.assertRaises(KeyError): ds.rename_field("yyy", "oo")