示例#1
0
    def sr_parse(self, doc, bcvocab=None):
        """ Shift-reduce RST parsing based on models prediction

        :type doc: Doc
        :param doc: the document instance

        :type bcvocab: dict
        :param bcvocab: brown clusters
        """
        # use transition-based parsing to build tree structure
        conf = ParsingState([], [])
        conf.init(doc)
        action_hist = []
        while not conf.end_parsing():
            stack, queue = conf.get_status()
            fg = ActionFeatureGenerator(stack, queue, action_hist, doc, bcvocab)
            action_feats = fg.gen_features()
            action_probs = self.action_clf.predict_probs(action_feats)
            for action, cur_prob in action_probs:
                if conf.is_action_allowed(action):
                    conf.operate(action)
                    action_hist.append(action)
                    break
        tree = conf.get_parse_tree()
        # RstTree.down_prop(tree)
        # assign the node to rst_tree
        rst_tree = RstTree()
        rst_tree.assign_tree(tree)
        rst_tree.assign_doc(doc)
        rst_tree.down_prop(tree)
        rst_tree.back_prop(tree, doc)
        # tag relations for the tree
        post_nodelist = RstTree.postorder_DFT(rst_tree.tree, [])
        for node in post_nodelist:
            if (node.lnode is not None) and (node.rnode is not None):
                fg = RelationFeatureGenerator(node, rst_tree, node.level, bcvocab)
                relation_feats = fg.gen_features()
                relation = self.relation_clf.predict(relation_feats, node.level)
                node.assign_relation(relation)
        return rst_tree
示例#2
0
    def eval_parser(self,
                    path='./examples',
                    report=False,
                    bcvocab=None,
                    draw=True):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(levels=['span', 'nuclearity', 'relation'])
        # ----------------------------------------
        # Read all files from the given path
        doclist = [
            os.path.join(path, fname) for fname in os.listdir(path)
            if fname.endswith('.merge')
        ]
        pred_forms = []
        gold_forms = []
        depth_per_relation = {}
        for fmerge in doclist:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            doc.read_from_fmerge(fmerge)
            # ----------------------------------------
            # Parsing
            pred_rst = self.parser.sr_parse(doc, bcvocab)
            if draw:
                pred_rst.draw_rst(fmerge.replace(".merge", ".ps"))
            # Get brackets from parsing results
            pred_brackets = pred_rst.bracketing()
            fbrackets = fmerge.replace('.merge', '.brackets')
            # Write brackets into file
            Evaluator.writebrackets(fbrackets, pred_brackets)
            # ----------------------------------------
            # Evaluate with gold RST tree
            if report:
                fdis = fmerge.replace('.merge', '.dis')
                gold_rst = RstTree(fdis, fmerge)
                gold_rst.build()
                met.eval(gold_rst, pred_rst)
                for node in pred_rst.postorder_DFT(pred_rst.tree, []):
                    pred_forms.append(node.form)
                for node in gold_rst.postorder_DFT(gold_rst.tree, []):
                    gold_forms.append(node.form)

                nodes = gold_rst.postorder_DFT(gold_rst.tree, [])
                inner_nodes = [
                    node for node in nodes
                    if node.lnode is not None and node.rnode is not None
                ]
                for idx, node in enumerate(inner_nodes):
                    relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation
                    rela_class = RstTree.extract_relation(relation)
                    if rela_class in depth_per_relation:
                        depth_per_relation[rela_class].append(node.depth)
                    else:
                        depth_per_relation[rela_class] = [node.depth]
                    lnode_text = ' '.join([
                        gold_rst.doc.token_dict[tid].word
                        for tid in node.lnode.text
                    ])
                    lnode_lemmas = ' '.join([
                        gold_rst.doc.token_dict[tid].lemma
                        for tid in node.lnode.text
                    ])
                    rnode_text = ' '.join([
                        gold_rst.doc.token_dict[tid].word
                        for tid in node.rnode.text
                    ])
                    rnode_lemmas = ' '.join([
                        gold_rst.doc.token_dict[tid].lemma
                        for tid in node.rnode.text
                    ])
                    # if rela_class == 'Topic-Change':
                    #     print(fmerge)
                    #     print(relation)
                    #     print(lnode_text)
                    #     print(rnode_text)
                    #     print()

        if report:
            met.report()