def cal_class_distribution(data_dir, level): """ calculate the class distribution :param data_dir: :param level: 0 for inner-sentence, 1 for inter-sentence but inner paragraph, 2 for inter-paragraph, 3 for different depth :return: None """ rst_trees = DataHelper.read_rst_trees(data_dir) all_nodes = [node for rst_tree in rst_trees for node in rst_tree.postorder_DFT(rst_tree.tree, [])] if level in [0, 1, 2]: valid_relations = [RstTree.extract_relation(node.child_relation) for node in all_nodes if node.level == level and node.child_relation is not None] distribution = Counter(valid_relations) for cla in class2rel: if cla not in distribution: distribution[cla] = 0 return distribution if level == 3: depth_relation_distributions = {} for node in all_nodes: if node.lnode is None and node.rnode is None: continue if node.depth in depth_relation_distributions: depth_relation_distributions[node.depth][RstTree.extract_relation(node.child_relation)] += 1 else: depth_relation_distributions[node.depth] = Counter() depth_relation_distributions[node.depth][RstTree.extract_relation(node.child_relation)] = 1 for depth, distribution in depth_relation_distributions.items(): for cla in class2rel: if cla not in distribution: distribution[cla] = 0 return depth_relation_distributions
def read_rst_trees(data_dir): # Read RST tree file files = [os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.dis')] for i, fdis in enumerate(files): fmerge = fdis.replace('.dis', '.merge') if not os.path.isfile(fmerge): print("Corresponding .fmerge file does not exist. Skipping the file.") continue rst_tree = RstTree(fdis, fmerge) rst_tree.build() yield rst_tree
def sr_parse(self, doc, gold_actions, gold_rels): # Generate coref clusters for the document if self.clf.config[MODEL_TYPE] in [1, 2]: with torch.no_grad(): clusters, _ = self.coref_trainer.predict_clusters(doc) else: clusters = None # Stack/Queue state conf = ParsingState([], [], self.clf.config) conf.init(doc) all_action_probs, all_rel_probs = [], [] # Until the tree is built while not conf.end_parsing(): # Get features for the current stack/queue state, and span boundaries stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, [], doc, self.data_helper, self.config) action_feat, span_boundary = fg.gen_features() span_embeds = self.clf.get_edus_bert_coref([doc], [clusters], [span_boundary]) action_probs, rel_probs = self.clf.decode_action_coref(span_embeds, [action_feat]) all_action_probs.append(action_probs.squeeze()) sorted_action_idx = torch.argsort(action_probs, descending=True) sorted_rel_idx = torch.argsort(rel_probs, descending=True) # Select Shift/Reduce action (shift/reduce-nn/...) action_idx = 0 pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])] while not conf.is_action_allowed((pred_action, pred_nuc, None), doc): action_idx += 1 pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])] # Select Relation annotation pred_rel = None if pred_action != "Shift": all_rel_probs.append(rel_probs.squeeze()) pred_rel_idx = int(sorted_rel_idx[0, 0]) pred_rel = xidx_relation_map[pred_rel_idx] assert not (pred_action == "Reduce" and pred_rel is None) predictions = (pred_action, pred_nuc, pred_rel) conf.operate(predictions) # Shift/Reduce loss cost = self.loss(torch.stack(all_action_probs), gold_actions) # Relation annotation loss if all_rel_probs != []: cost_relation = self.loss(torch.stack(all_rel_probs), gold_rels) cost += cost_relation tree = conf.get_parse_tree() rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.back_prop(tree, doc) return rst_tree, cost.item()
def __getitem__(self, idx): """ Returns: If train: features for that tree (corefs, organizational features, spans, etc.) If test: gold tree """ if self.is_train: return self.X[idx], torch.LongTensor(self.y[idx]) doc = self.data_helper.docs[self.X[idx][0]] gold_rst = RstTree(doc.filename, doc.filename.replace('.dis', '.merge')) gold_rst.build() return gold_rst
def eval_parser(self, dev_data=None, path='./examples', use_parseval=False): """ Test the parsing performance""" # Evaluation met = Metrics(use_parseval, levels=['span', 'nuclearity']) # ---------------------------------------- # Read all files from the given path if dev_data is None: eval_list = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.merge') ] else: eval_list = dev_data pred_forms = [] gold_forms = [] total_cost = 0 for eval_instance in eval_list: # ---------------------------------------- # Read *.merge file doc = Doc() if dev_data is not None: gold_rst = eval_instance doc = eval_instance.doc else: doc.read_from_fmerge(eval_instance) eval_instance = eval_instance.replace('.dis', '.merge') fdis = eval_instance.replace('.merge', '.dis') gold_rst = RstTree(fdis, eval_instance) gold_rst.build() _, gold_action_seq = gold_rst.generate_action_samples() gold_action_seq = list( map(lambda x: self.data_helper.action_map[x], gold_action_seq)) pred_rst, cost = self.parser.sr_parse( [doc], [torch.cuda.LongTensor(gold_action_seq)], None, is_train=False) total_cost += cost pred_rst = pred_rst[0] # ---------------------------------------- # Evaluate with gold RST tree met.eval(gold_rst, pred_rst) print("Total cost: ", total_cost) met.report()
def read_rst_trees(data_dir, output_dir, parse_type, isFlat): # Read RST tree file files = [ os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.dis') ] with open(os.path.join(output_dir, parse_type, "processed_data.p"), 'rb') as file: fmerges = pickle.load(file) rst_trees = [] for fdis, fmerge in zip(files, fmerges): print("building tree " + fdis) rst_tree = RstTree(fdis, fmerge, isFlat) rst_tree.build() rst_trees.append(rst_tree) return rst_trees
def read_rst_trees(data_dir): # Read RST tree file files = [ os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.dis') ] rst_trees = [] for fdis in files: fmerge = fdis.replace('.dis', '.merge') if not os.path.isfile(fmerge): raise FileNotFoundError( 'Corresponding .fmerge file does not exist. You should do preprocessing first.' ) rst_tree = RstTree(fdis, fmerge) rst_tree.build() rst_trees.append(rst_tree) return rst_trees
def read_rst_trees(data_dir): # Read RST tree file files = [ os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.dis') ] for i, fdis in enumerate(files): fmerge = fdis.replace('.dis', '.merge') if not os.path.isfile(fmerge): print( fmerge + " - Corresponding .fmerge file does not exist. Skipping the file." ) continue rst_tree = RstTree(fdis, fmerge) rst_tree.build() if (rst_tree.get_parse().startswith(" ( EDU")): print("Skipping document - tree is corrupt.") continue yield rst_tree
def __getitem__(self, idx): """ Returns: document_edus - packed list of tensors, where each element is tensor of word embeddings for an EDU of that document sorted_idx - output of torch.argsort, used to return EDUs to original linear order y - sequence of gold actions for a particular tree doc - Document instance of an RstTree instance """ doc = Doc() eval_instance = self.X[idx].replace('.dis', '.merge') doc.read_from_fmerge(eval_instance) if self.is_train: return doc, torch.cuda.LongTensor(self.y[idx]) fdis = eval_instance.replace('.merge', '.dis') gold_rst = RstTree(fdis, eval_instance) gold_rst.build() return gold_rst
def test_eval(): fdis = "../data/data_dir/train_dir/wsj_0603.out.dis" fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge" gold_tree = RstTree(fdis, fmerge) gold_tree.build() eval_trees(gold_tree, gold_tree, 1, 1, 1, 1) sr_parser = ParsingState([], []) sr_parser.init(gold_tree.doc) sh, rd = 'Shift', 'Reduce' nn, ns, sn = 'NN', 'NS', 'SN' # Last two nuclearity actions are incorrect silver_nuc = [(sh, None), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, sn), (rd, ns)] for idx, action in enumerate(silver_nuc): sr_parser.operate(action) tree = sr_parser.get_parse_tree() silver_tree = RstTree() silver_tree.assign_tree(tree) silver_tree.assign_doc(gold_tree.doc) silver_tree.back_prop(tree, gold_tree.doc) eval_trees(silver_tree, gold_tree, 1, 9/12, 1, 4/6) # Last two nuclearity actions are incorrect sr_parser = ParsingState([], []) sr_parser.init(gold_tree.doc) silver_nuc = [(sh, None), (sh, None), (sh, None), (rd, ns), (sh, None), (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, ns), (rd, ns)] for idx, action in enumerate(silver_nuc): sr_parser.operate(action) tree = sr_parser.get_parse_tree() silver_tree = RstTree() silver_tree.assign_tree(tree) silver_tree.assign_doc(gold_tree.doc) silver_tree.back_prop(tree, gold_tree.doc) eval_trees(silver_tree, gold_tree, 10/12, 7/12, 4/6, 3/6)
def test_tree(): fdis = "../data/data_dir/train_dir/wsj_0603.out.dis" fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge" tree = RstTree(fdis, fmerge) tree.build() bftbin = RstTree.BFTbin(tree.tree) assert len(bftbin) == 13 edu_spans = [(1, 7), (1, 5), (6, 7), (1, 4), (5, 5), (6, 6), (7, 7), (1, 2), (3, 4), (1, 1), (2, 2), (3, 3), (4, 4)] r, n, s = 'Root', 'Nucleus', 'Satellite' props = [r, n, n, n, s, n, s, n, n, n, s, n, s] rels = [None, 'TextualOrganization', 'TextualOrganization', 'span', 'elaboration-additional', 'span', 'attribution', 'Same-Unit', 'Same-Unit', 'span', 'elaboration-object-attribute-e', 'span', 'elaboration-object-attribute-e',] for i, span in enumerate(bftbin): assert span.edu_span == edu_spans[i] assert span.prop == props[i] assert span.relation == rels[i] # TODO: Check where relation conversion happens? In RST-Coref paper, in generate_action_samples fdis, action_seq = tree.generate_action_samples() assert fdis == fdis sh, rd = 'Shift', 'Reduce' nn, ns, sn = 'NN', 'NS', 'SN' gold_action_seq = [(sh, None), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, ns), (rd, nn)] assert len(action_seq) == len(gold_action_seq) for i in range(len(action_seq)): assert action_seq[i] == gold_action_seq[i] # Check that tree binarization is node correctly fdis = "../data/data_dir/train_dir/file2.dis" fmerge = "../data/data_dir/train_dir/file2.merge" tree = RstTree(fdis, fmerge) tree.build() spans = [span.edu_span for span in RstTree.BFTbin(tree.tree)] nonbinary_spans = [(22, 28), (22, 22), (23, 28), (23, 23), (24, 28), (24, 24), (25, 28), (25, 25), (26, 28), (26, 26), (27, 28), (27, 27), (28, 28), (11, 15), (11, 11), (12, 15), (12, 12), (13, 15), (13, 14), (13, 13), (14, 14), (15, 15), (45, 57), (45, 47), (48, 57), (48, 53), (54, 57), (54, 56), (57, 57)] for span in nonbinary_spans: assert span in spans
def eval_parser(self, dev_data=None, path='./examples', save_preds=True, use_parseval=False): """ Test the parsing performance""" # Evaluation met = Metrics(levels=['span', 'nuclearity', 'relation'], use_parseval=use_parseval) # ---------------------------------------- # Read all files from the given path if dev_data is None: dev_data = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.dis') ] total_cost = 0 for eval_instance in dev_data: # ---------------------------------------- fmerge = eval_instance.replace('.dis', '.merge') doc = Doc() doc.read_from_fmerge(fmerge) gold_rst = RstTree(eval_instance, fmerge) gold_rst.build() # tok_edus = [nltk.word_tokenize(edu) for edu in doc.doc_edus] tok_edus = [edu.split(" ") for edu in doc.doc_edus] tokens = flatten(tok_edus) coref_document = Document(raw_text=None, tokens=tokens, sents=tok_edus, corefs=[], speakers=["0"] * len(tokens), genre="nw", filename=None) coref_document.token_dict = doc.token_dict coref_document.edu_dict = doc.edu_dict doc = coref_document gold_action_seq, gold_rel_seq = gold_rst.decode_rst_tree() gold_action_seq = [action_map[x] for x in gold_action_seq] gold_relation_seq = [ relation_map[x.lower()] for x in gold_rel_seq if x is not None ] pred_rst, cost = self.parser.sr_parse( doc, torch.cuda.LongTensor(gold_action_seq), torch.cuda.LongTensor(gold_relation_seq)) total_cost += cost if save_preds: if not os.path.isdir('../data/predicted_trees'): os.mkdir('../data/predicted_trees') filename = eval_instance.split(os.sep)[-1] filepath = f'../data/predicted_trees/{self.config[MODEL_NAME]}_{filename}' pred_brackets = pred_rst.bracketing() # Write brackets into file Evaluator.writebrackets(filepath, pred_brackets) # ---------------------------------------- # Evaluate with gold RST tree met.eval(gold_rst, pred_rst) print("Total cost: ", total_cost) if use_parseval: print("Reporting original Parseval metric.") else: print("Reporting RST Parseval metric.") met.report()
def sr_parse(self, doc, bcvocab=None): """ Shift-reduce RST parsing based on models prediction :type doc: Doc :param doc: the document instance :type bcvocab: dict :param bcvocab: brown clusters """ # use transition-based parsing to build tree structure conf = ParsingState([], []) conf.init(doc) action_hist = [] while not conf.end_parsing(): stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, action_hist, doc, bcvocab) action_feats = fg.gen_features() action_probs = self.action_clf.predict_probs(action_feats) for action, cur_prob in action_probs: if conf.is_action_allowed(action): conf.operate(action) action_hist.append(action) break tree = conf.get_parse_tree() # RstTree.down_prop(tree) # assign the node to rst_tree rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.down_prop(tree) rst_tree.back_prop(tree, doc) # tag relations for the tree post_nodelist = RstTree.postorder_DFT(rst_tree.tree, []) for node in post_nodelist: if (node.lnode is not None) and (node.rnode is not None): fg = RelationFeatureGenerator(node, rst_tree, node.level, bcvocab) relation_feats = fg.gen_features() relation = self.relation_clf.predict(relation_feats, node.level) node.assign_relation(relation) return rst_tree
def eval_parser(self, path='./examples', report=False, bcvocab=None, draw=True): """ Test the parsing performance""" # Evaluation met = Metrics(levels=['span', 'nuclearity', 'relation']) # ---------------------------------------- # Read all files from the given path doclist = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.merge') ] pred_forms = [] gold_forms = [] depth_per_relation = {} for fmerge in doclist: # ---------------------------------------- # Read *.merge file doc = Doc() doc.read_from_fmerge(fmerge) # ---------------------------------------- # Parsing pred_rst = self.parser.sr_parse(doc, bcvocab) if draw: pred_rst.draw_rst(fmerge.replace(".merge", ".ps")) # Get brackets from parsing results pred_brackets = pred_rst.bracketing() fbrackets = fmerge.replace('.merge', '.brackets') # Write brackets into file Evaluator.writebrackets(fbrackets, pred_brackets) # ---------------------------------------- # Evaluate with gold RST tree if report: fdis = fmerge.replace('.merge', '.dis') gold_rst = RstTree(fdis, fmerge) gold_rst.build() met.eval(gold_rst, pred_rst) for node in pred_rst.postorder_DFT(pred_rst.tree, []): pred_forms.append(node.form) for node in gold_rst.postorder_DFT(gold_rst.tree, []): gold_forms.append(node.form) nodes = gold_rst.postorder_DFT(gold_rst.tree, []) inner_nodes = [ node for node in nodes if node.lnode is not None and node.rnode is not None ] for idx, node in enumerate(inner_nodes): relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation rela_class = RstTree.extract_relation(relation) if rela_class in depth_per_relation: depth_per_relation[rela_class].append(node.depth) else: depth_per_relation[rela_class] = [node.depth] lnode_text = ' '.join([ gold_rst.doc.token_dict[tid].word for tid in node.lnode.text ]) lnode_lemmas = ' '.join([ gold_rst.doc.token_dict[tid].lemma for tid in node.lnode.text ]) rnode_text = ' '.join([ gold_rst.doc.token_dict[tid].word for tid in node.rnode.text ]) rnode_lemmas = ' '.join([ gold_rst.doc.token_dict[tid].lemma for tid in node.rnode.text ]) # if rela_class == 'Topic-Change': # print(fmerge) # print(relation) # print(lnode_text) # print(rnode_text) # print() if report: met.report()
def sr_parse(self, docs, gold_actions, optim, is_train=True): confs = [] action_hists = [] for doc in docs: conf = ParsingState([], [], self.config) conf.init(doc) confs.append(conf) action_hists.append([]) action_hist_nonnum = [] iter = 0 cost_acc = 0 while not confs[0].end_parsing(): if is_train: optim.zero_grad() action_feats, neural_feats, spans, curr_gold_actions = [], [], [], [] for conf, doc in zip(confs, docs): stack, queue = conf.get_status() fg = ActionFeatureGenerator(stack, queue, action_hist_nonnum, doc, self.clf.data_helper, self.config) action_feat, neural_feat = fg.gen_features() action_feats.append(action_feat) neural_feats.append(neural_feat) spans = self.clf.get_edus_bert_span(docs, neural_feats) action_probs = self.clf.decode_action_online(spans, action_feats) sorted_action_idx = torch.argsort(action_probs, descending=True) for i, conf in enumerate(confs): action_hists[i].append(action_probs[i]) action_idx = 0 curr_gold_actions.append(gold_actions[i][iter].unsqueeze(0)) if is_train: action = self.clf.xidx_action_map[int( gold_actions[i][iter])] else: action = self.clf.xidx_action_map[int( sorted_action_idx[i, action_idx])] while not conf.is_action_allowed(action, docs[i]): action = self.clf.xidx_action_map[int( sorted_action_idx[i, action_idx])] action_idx += 1 conf.operate(action) cost = self.loss(action_probs, torch.cat(curr_gold_actions)) cost_acc += cost.item() if is_train: cost.backward() nn.utils.clip_grad_norm_(self.clf.parameters(), 0.2) optim.step() iter += 1 rst_trees = [] for conf, doc in zip(confs, docs): tree = conf.get_parse_tree() rst_tree = RstTree() rst_tree.assign_tree(tree) rst_tree.assign_doc(doc) rst_tree.back_prop(tree, doc) rst_trees.append(rst_tree) return rst_trees, cost_acc