示例#1
0
 def sr_parse(self, doc, gold_actions, gold_rels):
             
     # Generate coref clusters for the document
     if self.clf.config[MODEL_TYPE] in [1, 2]:
         with torch.no_grad():
             clusters, _ = self.coref_trainer.predict_clusters(doc) 
     else:
         clusters = None
         
     # Stack/Queue state
     conf = ParsingState([], [], self.clf.config)
     conf.init(doc)        
     all_action_probs, all_rel_probs = [], []
     # Until the tree is built
     while not conf.end_parsing():
         
         # Get features for the current stack/queue state, and span boundaries
         stack, queue = conf.get_status()
         fg = ActionFeatureGenerator(stack, queue, [], doc, self.data_helper, self.config)
         action_feat, span_boundary = fg.gen_features()
         span_embeds = self.clf.get_edus_bert_coref([doc], [clusters], [span_boundary])
         action_probs, rel_probs = self.clf.decode_action_coref(span_embeds, [action_feat])
         all_action_probs.append(action_probs.squeeze())
         sorted_action_idx = torch.argsort(action_probs, descending=True)
         sorted_rel_idx = torch.argsort(rel_probs, descending=True)
         
         # Select Shift/Reduce action (shift/reduce-nn/...)
         action_idx = 0
         pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])]      
         while not conf.is_action_allowed((pred_action, pred_nuc, None), doc):
             action_idx += 1
             pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])]
             
         # Select Relation annotation
         pred_rel = None
         if pred_action != "Shift":
             all_rel_probs.append(rel_probs.squeeze())
             pred_rel_idx = int(sorted_rel_idx[0, 0])
             pred_rel = xidx_relation_map[pred_rel_idx]
         assert not (pred_action == "Reduce" and pred_rel is None)
         
         predictions = (pred_action, pred_nuc, pred_rel)
         conf.operate(predictions)
         
     # Shift/Reduce loss
     cost = self.loss(torch.stack(all_action_probs), gold_actions)
     
     # Relation annotation loss
     if all_rel_probs != []:
         cost_relation = self.loss(torch.stack(all_rel_probs), gold_rels)
         cost += cost_relation
     
     tree = conf.get_parse_tree()
     rst_tree = RstTree()
     rst_tree.assign_tree(tree)
     rst_tree.assign_doc(doc)
     rst_tree.back_prop(tree, doc)
     
     return rst_tree, cost.item()
示例#2
0
def test_eval():
    fdis = "../data/data_dir/train_dir/wsj_0603.out.dis"
    fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge"
    gold_tree = RstTree(fdis, fmerge)
    gold_tree.build()
    eval_trees(gold_tree, gold_tree, 1, 1, 1)
    sr_parser = ParsingState([], [])
    sr_parser.init(gold_tree.doc)
    sh, rd = 'Shift', 'Reduce'
    nn, ns, sn = 'NN', 'NS', 'SN'

    # Last two nuclearity actions are incorrect, last rel is incorrect
    silver_nuc = [(sh, None, None), (sh, None, None), (rd, ns, 'Elaboration'),
                  (sh, None, None), (sh, None, None), (rd, ns, 'Elaboration'),
                  (rd, nn, 'Same-Unit'), (sh, None, None),
                  (rd, ns, 'Elaboration'), (sh, None, None), (sh, None, None),
                  (rd, sn, 'Attribution'), (rd, ns, 'Attribution')]

    for idx, action in enumerate(silver_nuc):
        sr_parser.operate(action)
    tree = sr_parser.get_parse_tree()
    silver_tree = RstTree()
    silver_tree.assign_tree(tree)
    silver_tree.assign_doc(gold_tree.doc)
    silver_tree.back_prop(tree, gold_tree.doc)

    eval_trees(silver_tree, gold_tree, 1, 9 / 12, 8 / 12)
    eval_trees(silver_tree, gold_tree, 1, 4 / 6, 5 / 6, use_parseval=True)

    sr_parser = ParsingState([], [])
    sr_parser.init(gold_tree.doc)
    silver_nuc = [(sh, None, None), (sh, None, None), (sh, None, None),
                  (rd, ns, 'Elaboration'), (sh, None, None),
                  (rd, ns, 'Elaboration'), (rd, nn, 'Same-Unit'),
                  (sh, None, None), (rd, ns, 'Elaboration'), (sh, None, None),
                  (sh, None, None), (rd, ns, 'Attribution'),
                  (rd, ns, 'Attribution')]

    for idx, action in enumerate(silver_nuc):
        sr_parser.operate(action)
    tree = sr_parser.get_parse_tree()
    silver_tree = RstTree()
    silver_tree.assign_tree(tree)
    silver_tree.assign_doc(gold_tree.doc)
    silver_tree.back_prop(tree, gold_tree.doc)
    eval_trees(silver_tree, gold_tree, 10 / 12, 7 / 12, 5 / 12)
    eval_trees(silver_tree, gold_tree, 4 / 6, 3 / 6, 3 / 6, use_parseval=True)
示例#3
0
    def sr_parse(self, doc, bcvocab=None):
        """ Shift-reduce RST parsing based on models prediction

        :type doc: Doc
        :param doc: the document instance

        :type bcvocab: dict
        :param bcvocab: brown clusters
        """
        # use transition-based parsing to build tree structure
        conf = ParsingState([], [])
        conf.init(doc)
        action_hist = []
        while not conf.end_parsing():
            stack, queue = conf.get_status()
            fg = ActionFeatureGenerator(stack, queue, action_hist, doc, bcvocab)
            action_feats = fg.gen_features()
            action_probs = self.action_clf.predict_probs(action_feats)
            for action, cur_prob in action_probs:
                if conf.is_action_allowed(action):
                    conf.operate(action)
                    action_hist.append(action)
                    break
        tree = conf.get_parse_tree()
        # RstTree.down_prop(tree)
        # assign the node to rst_tree
        rst_tree = RstTree()
        rst_tree.assign_tree(tree)
        rst_tree.assign_doc(doc)
        rst_tree.down_prop(tree)
        rst_tree.back_prop(tree, doc)
        # tag relations for the tree
        post_nodelist = RstTree.postorder_DFT(rst_tree.tree, [])
        for node in post_nodelist:
            if (node.lnode is not None) and (node.rnode is not None):
                fg = RelationFeatureGenerator(node, rst_tree, node.level, bcvocab)
                relation_feats = fg.gen_features()
                relation = self.relation_clf.predict(relation_feats, node.level)
                node.assign_relation(relation)
        return rst_tree
示例#4
0
    def sr_parse(self, docs, gold_actions, optim, is_train=True):

        confs = []
        action_hists = []
        for doc in docs:
            conf = ParsingState([], [], self.config)
            conf.init(doc)
            confs.append(conf)
            action_hists.append([])
        action_hist_nonnum = []
        iter = 0
        cost_acc = 0
        while not confs[0].end_parsing():
            if is_train:
                optim.zero_grad()
            action_feats, neural_feats, spans, curr_gold_actions = [], [], [], []
            for conf, doc in zip(confs, docs):
                stack, queue = conf.get_status()
                fg = ActionFeatureGenerator(stack, queue, action_hist_nonnum,
                                            doc, self.clf.data_helper,
                                            self.config)
                action_feat, neural_feat = fg.gen_features()
                action_feats.append(action_feat)
                neural_feats.append(neural_feat)

            spans = self.clf.get_edus_bert_span(docs, neural_feats)
            action_probs = self.clf.decode_action_online(spans, action_feats)
            sorted_action_idx = torch.argsort(action_probs, descending=True)

            for i, conf in enumerate(confs):
                action_hists[i].append(action_probs[i])
                action_idx = 0
                curr_gold_actions.append(gold_actions[i][iter].unsqueeze(0))
                if is_train:
                    action = self.clf.xidx_action_map[int(
                        gold_actions[i][iter])]
                else:
                    action = self.clf.xidx_action_map[int(
                        sorted_action_idx[i, action_idx])]
                    while not conf.is_action_allowed(action, docs[i]):
                        action = self.clf.xidx_action_map[int(
                            sorted_action_idx[i, action_idx])]
                        action_idx += 1
                conf.operate(action)
            cost = self.loss(action_probs, torch.cat(curr_gold_actions))
            cost_acc += cost.item()
            if is_train:
                cost.backward()
                nn.utils.clip_grad_norm_(self.clf.parameters(), 0.2)
                optim.step()
            iter += 1

        rst_trees = []
        for conf, doc in zip(confs, docs):
            tree = conf.get_parse_tree()
            rst_tree = RstTree()
            rst_tree.assign_tree(tree)
            rst_tree.assign_doc(doc)
            rst_tree.back_prop(tree, doc)
            rst_trees.append(rst_tree)
        return rst_trees, cost_acc