def sr_parse(self, doc, gold_actions, gold_rels): # Generate coref clusters for the document if self.clf.config[MODEL_TYPE] in [1, 2]: with torch.no_grad(): clusters, _ = self.coref_trainer.predict_clusters(doc) else: clusters = None # Stack/Queue state conf = ParsingState([], [], self.clf.config) conf.init(doc) all_action_probs, all_rel_probs = [], [] # Until the tree is built while not conf.end_parsing(): # Get features for the current stack/queue state, and span boundaries stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, [], doc, self.data_helper, self.config) action_feat, span_boundary = fg.gen_features() span_embeds = self.clf.get_edus_bert_coref([doc], [clusters], [span_boundary]) action_probs, rel_probs = self.clf.decode_action_coref(span_embeds, [action_feat]) all_action_probs.append(action_probs.squeeze()) sorted_action_idx = torch.argsort(action_probs, descending=True) sorted_rel_idx = torch.argsort(rel_probs, descending=True) # Select Shift/Reduce action (shift/reduce-nn/...) action_idx = 0 pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])] while not conf.is_action_allowed((pred_action, pred_nuc, None), doc): action_idx += 1 pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])] # Select Relation annotation pred_rel = None if pred_action != "Shift": all_rel_probs.append(rel_probs.squeeze()) pred_rel_idx = int(sorted_rel_idx[0, 0]) pred_rel = xidx_relation_map[pred_rel_idx] assert not (pred_action == "Reduce" and pred_rel is None) predictions = (pred_action, pred_nuc, pred_rel) conf.operate(predictions) # Shift/Reduce loss cost = self.loss(torch.stack(all_action_probs), gold_actions) # Relation annotation loss if all_rel_probs != []: cost_relation = self.loss(torch.stack(all_rel_probs), gold_rels) cost += cost_relation tree = conf.get_parse_tree() rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.back_prop(tree, doc) return rst_tree, cost.item()
def test_eval(): fdis = "../data/data_dir/train_dir/wsj_0603.out.dis" fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge" gold_tree = RstTree(fdis, fmerge) gold_tree.build() eval_trees(gold_tree, gold_tree, 1, 1, 1) sr_parser = ParsingState([], []) sr_parser.init(gold_tree.doc) sh, rd = 'Shift', 'Reduce' nn, ns, sn = 'NN', 'NS', 'SN' # Last two nuclearity actions are incorrect, last rel is incorrect silver_nuc = [(sh, None, None), (sh, None, None), (rd, ns, 'Elaboration'), (sh, None, None), (sh, None, None), (rd, ns, 'Elaboration'), (rd, nn, 'Same-Unit'), (sh, None, None), (rd, ns, 'Elaboration'), (sh, None, None), (sh, None, None), (rd, sn, 'Attribution'), (rd, ns, 'Attribution')] for idx, action in enumerate(silver_nuc): sr_parser.operate(action) tree = sr_parser.get_parse_tree() silver_tree = RstTree() silver_tree.assign_tree(tree) silver_tree.assign_doc(gold_tree.doc) silver_tree.back_prop(tree, gold_tree.doc) eval_trees(silver_tree, gold_tree, 1, 9 / 12, 8 / 12) eval_trees(silver_tree, gold_tree, 1, 4 / 6, 5 / 6, use_parseval=True) sr_parser = ParsingState([], []) sr_parser.init(gold_tree.doc) silver_nuc = [(sh, None, None), (sh, None, None), (sh, None, None), (rd, ns, 'Elaboration'), (sh, None, None), (rd, ns, 'Elaboration'), (rd, nn, 'Same-Unit'), (sh, None, None), (rd, ns, 'Elaboration'), (sh, None, None), (sh, None, None), (rd, ns, 'Attribution'), (rd, ns, 'Attribution')] for idx, action in enumerate(silver_nuc): sr_parser.operate(action) tree = sr_parser.get_parse_tree() silver_tree = RstTree() silver_tree.assign_tree(tree) silver_tree.assign_doc(gold_tree.doc) silver_tree.back_prop(tree, gold_tree.doc) eval_trees(silver_tree, gold_tree, 10 / 12, 7 / 12, 5 / 12) eval_trees(silver_tree, gold_tree, 4 / 6, 3 / 6, 3 / 6, use_parseval=True)
def sr_parse(self, doc, bcvocab=None): """ Shift-reduce RST parsing based on models prediction :type doc: Doc :param doc: the document instance :type bcvocab: dict :param bcvocab: brown clusters """ # use transition-based parsing to build tree structure conf = ParsingState([], []) conf.init(doc) action_hist = [] while not conf.end_parsing(): stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, action_hist, doc, bcvocab) action_feats = fg.gen_features() action_probs = self.action_clf.predict_probs(action_feats) for action, cur_prob in action_probs: if conf.is_action_allowed(action): conf.operate(action) action_hist.append(action) break tree = conf.get_parse_tree() # RstTree.down_prop(tree) # assign the node to rst_tree rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.down_prop(tree) rst_tree.back_prop(tree, doc) # tag relations for the tree post_nodelist = RstTree.postorder_DFT(rst_tree.tree, []) for node in post_nodelist: if (node.lnode is not None) and (node.rnode is not None): fg = RelationFeatureGenerator(node, rst_tree, node.level, bcvocab) relation_feats = fg.gen_features() relation = self.relation_clf.predict(relation_feats, node.level) node.assign_relation(relation) return rst_tree
def sr_parse(self, docs, gold_actions, optim, is_train=True): confs = [] action_hists = [] for doc in docs: conf = ParsingState([], [], self.config) conf.init(doc) confs.append(conf) action_hists.append([]) action_hist_nonnum = [] iter = 0 cost_acc = 0 while not confs[0].end_parsing(): if is_train: optim.zero_grad() action_feats, neural_feats, spans, curr_gold_actions = [], [], [], [] for conf, doc in zip(confs, docs): stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, action_hist_nonnum, doc, self.clf.data_helper, self.config) action_feat, neural_feat = fg.gen_features() action_feats.append(action_feat) neural_feats.append(neural_feat) spans = self.clf.get_edus_bert_span(docs, neural_feats) action_probs = self.clf.decode_action_online(spans, action_feats) sorted_action_idx = torch.argsort(action_probs, descending=True) for i, conf in enumerate(confs): action_hists[i].append(action_probs[i]) action_idx = 0 curr_gold_actions.append(gold_actions[i][iter].unsqueeze(0)) if is_train: action = self.clf.xidx_action_map[int( gold_actions[i][iter])] else: action = self.clf.xidx_action_map[int( sorted_action_idx[i, action_idx])] while not conf.is_action_allowed(action, docs[i]): action = self.clf.xidx_action_map[int( sorted_action_idx[i, action_idx])] action_idx += 1 conf.operate(action) cost = self.loss(action_probs, torch.cat(curr_gold_actions)) cost_acc += cost.item() if is_train: cost.backward() nn.utils.clip_grad_norm_(self.clf.parameters(), 0.2) optim.step() iter += 1 rst_trees = [] for conf, doc in zip(confs, docs): tree = conf.get_parse_tree() rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.back_prop(tree, doc) rst_trees.append(rst_tree) return rst_trees, cost_acc