def sr_parse(self, doc, bcvocab=None): """ Shift-reduce RST parsing based on models prediction :type doc: Doc :param doc: the document instance :type bcvocab: dict :param bcvocab: brown clusters """ # use transition-based parsing to build tree structure conf = ParsingState([], []) conf.init(doc) action_hist = [] while not conf.end_parsing(): stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, action_hist, doc, bcvocab) action_feats = fg.gen_features() action_probs = self.action_clf.predict_probs(action_feats) for action, cur_prob in action_probs: if conf.is_action_allowed(action): conf.operate(action) action_hist.append(action) break tree = conf.get_parse_tree() # RstTree.down_prop(tree) # assign the node to rst_tree rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.down_prop(tree) rst_tree.back_prop(tree, doc) # tag relations for the tree post_nodelist = RstTree.postorder_DFT(rst_tree.tree, []) for node in post_nodelist: if (node.lnode is not None) and (node.rnode is not None): fg = RelationFeatureGenerator(node, rst_tree, node.level, bcvocab) relation_feats = fg.gen_features() relation = self.relation_clf.predict(relation_feats, node.level) node.assign_relation(relation) return rst_tree
def eval_parser(self, path='./examples', report=False, bcvocab=None, draw=True): """ Test the parsing performance""" # Evaluation met = Metrics(levels=['span', 'nuclearity', 'relation']) # ---------------------------------------- # Read all files from the given path doclist = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.merge') ] pred_forms = [] gold_forms = [] depth_per_relation = {} for fmerge in doclist: # ---------------------------------------- # Read *.merge file doc = Doc() doc.read_from_fmerge(fmerge) # ---------------------------------------- # Parsing pred_rst = self.parser.sr_parse(doc, bcvocab) if draw: pred_rst.draw_rst(fmerge.replace(".merge", ".ps")) # Get brackets from parsing results pred_brackets = pred_rst.bracketing() fbrackets = fmerge.replace('.merge', '.brackets') # Write brackets into file Evaluator.writebrackets(fbrackets, pred_brackets) # ---------------------------------------- # Evaluate with gold RST tree if report: fdis = fmerge.replace('.merge', '.dis') gold_rst = RstTree(fdis, fmerge) gold_rst.build() met.eval(gold_rst, pred_rst) for node in pred_rst.postorder_DFT(pred_rst.tree, []): pred_forms.append(node.form) for node in gold_rst.postorder_DFT(gold_rst.tree, []): gold_forms.append(node.form) nodes = gold_rst.postorder_DFT(gold_rst.tree, []) inner_nodes = [ node for node in nodes if node.lnode is not None and node.rnode is not None ] for idx, node in enumerate(inner_nodes): relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation rela_class = RstTree.extract_relation(relation) if rela_class in depth_per_relation: depth_per_relation[rela_class].append(node.depth) else: depth_per_relation[rela_class] = [node.depth] lnode_text = ' '.join([ gold_rst.doc.token_dict[tid].word for tid in node.lnode.text ]) lnode_lemmas = ' '.join([ gold_rst.doc.token_dict[tid].lemma for tid in node.lnode.text ]) rnode_text = ' '.join([ gold_rst.doc.token_dict[tid].word for tid in node.rnode.text ]) rnode_lemmas = ' '.join([ gold_rst.doc.token_dict[tid].lemma for tid in node.rnode.text ]) # if rela_class == 'Topic-Change': # print(fmerge) # print(relation) # print(lnode_text) # print(rnode_text) # print() if report: met.report()