def test_doc(): doc = Doc() fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge" doc.read_from_fmerge(fmerge) # Correct number of tokens/edus assert len(doc.edu_dict) == 7 edu_lens = [10, 3, 9, 7, 9, 13, 7] for i in range(1, 8): new_edu = [] for tok_idx in doc.edu_dict[i]: new_word = doc.token_dict[tok_idx].word new_edu.append(new_word) assert len(doc.edu_dict[i]) == edu_lens[i-1] num_tokens = sum(edu_lens) # Sidx/pidx/ TODO eduidx assert len(doc.token_dict) == num_tokens assert doc.token_dict[0].sidx == 0 and doc.token_dict[num_tokens - 1].sidx == 2 assert doc.token_dict[37].sidx == 1 and doc.token_dict[38].sidx == 2 assert doc.token_dict[0].pidx == 1 and doc.token_dict[num_tokens - 1].pidx == 2 assert doc.token_dict[0].eduidx == 1 and doc.token_dict[num_tokens - 1].eduidx == 7 # Bigger doc doc = Doc() fmerge = "../data/data_dir/train_dir/wsj_0606.out.merge" doc.read_from_fmerge(fmerge) assert len(doc.edu_dict) == 57 assert doc.token_dict[127].pidx == 2 and doc.token_dict[128].pidx == 3 assert doc.token_dict[0].eduidx == 1 and doc.token_dict[581].eduidx == 57 assert doc.token_dict[362].eduidx == 38 and doc.token_dict[363].eduidx == 39
def pred_parser(self, output_dir='./examples', parse_type=None, bcvocab=None, draw=True): """ Test the parsing performance""" # Evaluation # met = Metrics(levels=['span', 'nuclearity', 'relation']) # ---------------------------------------- # Read all files from the given path with open(os.path.join(output_dir, parse_type, "processed_data.p"), 'rb') as file: doclist = pickle.load(file) relations = list(other.class2rel.keys()) results = [] for lines in doclist: # ---------------------------------------- # Read *.merge file doc = Doc() relation_d = {rel:0.0 for rel in relations} if len(lines) >= 2: doc.read_from_fmerge(lines) # ---------------------------------------- # Parsing pred_rst = self.parser.sr_parse(doc, self.isFlat, bcvocab) # if draw: # pred_rst.draw_rst(fmerge.replace(".merge", ".ps")) # Get brackets from parsing results pred_brackets = pred_rst.bracketing(self.isFlat) for brack in pred_brackets: relation_d[brack[2]] +=1 if sum(relation_d.values()): relation_d = {k: str(v) +"/"+ str(sum(relation_d.values())) for k, v in relation_d.items()} print(relation_d) results.append(relation_d) with open(os.path.join(output_dir, parse_type,"result.p"), 'wb') as file: pickle.dump(results, file)
def build(self): """ Build BINARY RST tree """ with open(self.fdis) as fin: text = fin.read() # Build RST as annotation self.tree = RstTree.build_tree(text) # Binarize it if self.isFlat: self.tree = RstTree.flat_tree(self.tree) else: self.tree = RstTree.binarize_tree(self.tree) # Read doc file if self.fmerge: doc = Doc() doc.read_from_fmerge(self.fmerge) self.doc = doc else: raise IOError("File doesn't exist: {}".format(self.fmerge)) if self.isFlat: RstTree.down_flat_prop(self.tree) RstTree.flat_back_prop(self.tree, self.doc) else: RstTree.down_prop(self.tree) RstTree.back_prop(self.tree, self.doc)
def eval_parser(self, dev_data=None, path='./examples', use_parseval=False): """ Test the parsing performance""" # Evaluation met = Metrics(use_parseval, levels=['span', 'nuclearity']) # ---------------------------------------- # Read all files from the given path if dev_data is None: eval_list = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.merge') ] else: eval_list = dev_data pred_forms = [] gold_forms = [] total_cost = 0 for eval_instance in eval_list: # ---------------------------------------- # Read *.merge file doc = Doc() if dev_data is not None: gold_rst = eval_instance doc = eval_instance.doc else: doc.read_from_fmerge(eval_instance) eval_instance = eval_instance.replace('.dis', '.merge') fdis = eval_instance.replace('.merge', '.dis') gold_rst = RstTree(fdis, eval_instance) gold_rst.build() _, gold_action_seq = gold_rst.generate_action_samples() gold_action_seq = list( map(lambda x: self.data_helper.action_map[x], gold_action_seq)) pred_rst, cost = self.parser.sr_parse( [doc], [torch.cuda.LongTensor(gold_action_seq)], None, is_train=False) total_cost += cost pred_rst = pred_rst[0] # ---------------------------------------- # Evaluate with gold RST tree met.eval(gold_rst, pred_rst) print("Total cost: ", total_cost) met.report()
def build(self): """ Build BINARY RST tree """ with open(self.fdis) as fin: text = fin.read() print(f"Processing {self.fdis}") # Build RST as annotation self.tree = RstTree.build_tree(text) # Binarize it self.tree = RstTree.binarize_tree(self.tree) # Read doc file if isfile(self.fmerge): doc = Doc() doc.read_from_fmerge(self.fmerge) self.doc = doc else: raise IOError("File doesn't exist: {}".format(self.fmerge)) RstTree.back_prop(self.tree, self.doc)
def __getitem__(self, idx): """ Returns: document_edus - packed list of tensors, where each element is tensor of word embeddings for an EDU of that document sorted_idx - output of torch.argsort, used to return EDUs to original linear order y - sequence of gold actions for a particular tree doc - Document instance of an RstTree instance """ doc = Doc() eval_instance = self.X[idx].replace('.dis', '.merge') doc.read_from_fmerge(eval_instance) if self.is_train: return doc, torch.cuda.LongTensor(self.y[idx]) fdis = eval_instance.replace('.merge', '.dis') gold_rst = RstTree(fdis, eval_instance) gold_rst.build() return gold_rst
def create_data_helper(self, data_dir, config, coref_trainer): print("Parsing trees") # read train data all_feats_list, self.feats_list = [], [] all_actions_numeric, self.actions_numeric = [], [] all_relations_numeric, self.relations_numeric = [], [] self.docs = [] self.val_trees = [] print("Generating features") for i, rst_tree in enumerate(self.read_rst_trees(data_dir=data_dir)): feats, actions, relations = rst_tree.generate_action_relation_samples( config) fdis = feats[0][0] # Old doc instance for storing sentence/paragraph/document features doc = Doc() eval_instance = fdis.replace('.dis', '.merge') doc.read_from_fmerge(eval_instance) # tok_edus = [nltk.word_tokenize(edu) for edu in doc.doc_edus] tok_edus = [edu.split(" ") for edu in doc.doc_edus] tokens = flatten(tok_edus) # Coreference resolver document instance for coreference functionality # (converting tokens to wordpieces and getting corresponding coref boundaries etc) coref_document = Document(raw_text=None, tokens=tokens, sents=tok_edus, corefs=[], speakers=["0"] * len(tokens), genre="nw", filename=fdis) # Duplicate for convenience coref_document.token_dict = doc.token_dict coref_document.edu_dict = doc.edu_dict coref_document.old_doc = doc for (feat, action, relation) in zip(feats, actions, relations): feat[0] = i all_feats_list.append(feat) all_actions_numeric.append(action) all_relations_numeric.append(relation) self.docs.append(coref_document) if i % 50 == 0: print("Processed ", i + 1, " trees") all_actions_numeric = [action_map[x] for x in all_actions_numeric] all_relations_numeric = [ relation_map[x if x is None else x.lower()] for x in all_relations_numeric ] # Stratify by number of EDUs in the document stratified = get_stratify_classes( [len(coref_document.edu_dict) for coref_document in self.docs]) train_indexes, val_indexes = train_test_split(np.arange(len( self.docs)), test_size=0.1, random_state=1, stratify=stratified) # Select only those stack-queue actions that belong to trees in the train set for i, feat in enumerate(all_feats_list): if feat[0] in train_indexes: self.feats_list.append(feat) self.actions_numeric.append(all_actions_numeric[i]) self.relations_numeric.append(all_relations_numeric[i]) self.val_trees = [self.docs[index].filename for index in val_indexes] self.all_clusters = []
def eval_parser(self, dev_data=None, path='./examples', save_preds=True, use_parseval=False): """ Test the parsing performance""" # Evaluation met = Metrics(levels=['span', 'nuclearity', 'relation'], use_parseval=use_parseval) # ---------------------------------------- # Read all files from the given path if dev_data is None: dev_data = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.dis') ] total_cost = 0 for eval_instance in dev_data: # ---------------------------------------- fmerge = eval_instance.replace('.dis', '.merge') doc = Doc() doc.read_from_fmerge(fmerge) gold_rst = RstTree(eval_instance, fmerge) gold_rst.build() # tok_edus = [nltk.word_tokenize(edu) for edu in doc.doc_edus] tok_edus = [edu.split(" ") for edu in doc.doc_edus] tokens = flatten(tok_edus) coref_document = Document(raw_text=None, tokens=tokens, sents=tok_edus, corefs=[], speakers=["0"] * len(tokens), genre="nw", filename=None) coref_document.token_dict = doc.token_dict coref_document.edu_dict = doc.edu_dict doc = coref_document gold_action_seq, gold_rel_seq = gold_rst.decode_rst_tree() gold_action_seq = [action_map[x] for x in gold_action_seq] gold_relation_seq = [ relation_map[x.lower()] for x in gold_rel_seq if x is not None ] pred_rst, cost = self.parser.sr_parse( doc, torch.cuda.LongTensor(gold_action_seq), torch.cuda.LongTensor(gold_relation_seq)) total_cost += cost if save_preds: if not os.path.isdir('../data/predicted_trees'): os.mkdir('../data/predicted_trees') filename = eval_instance.split(os.sep)[-1] filepath = f'../data/predicted_trees/{self.config[MODEL_NAME]}_{filename}' pred_brackets = pred_rst.bracketing() # Write brackets into file Evaluator.writebrackets(filepath, pred_brackets) # ---------------------------------------- # Evaluate with gold RST tree met.eval(gold_rst, pred_rst) print("Total cost: ", total_cost) if use_parseval: print("Reporting original Parseval metric.") else: print("Reporting RST Parseval metric.") met.report()
def eval_parser(self, path='./examples', report=False, bcvocab=None, draw=True): """ Test the parsing performance""" # Evaluation met = Metrics(levels=['span', 'nuclearity', 'relation']) # ---------------------------------------- # Read all files from the given path doclist = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.merge') ] pred_forms = [] gold_forms = [] depth_per_relation = {} for fmerge in doclist: # ---------------------------------------- # Read *.merge file doc = Doc() doc.read_from_fmerge(fmerge) # ---------------------------------------- # Parsing pred_rst = self.parser.sr_parse(doc, bcvocab) if draw: pred_rst.draw_rst(fmerge.replace(".merge", ".ps")) # Get brackets from parsing results pred_brackets = pred_rst.bracketing() fbrackets = fmerge.replace('.merge', '.brackets') # Write brackets into file Evaluator.writebrackets(fbrackets, pred_brackets) # ---------------------------------------- # Evaluate with gold RST tree if report: fdis = fmerge.replace('.merge', '.dis') gold_rst = RstTree(fdis, fmerge) gold_rst.build() met.eval(gold_rst, pred_rst) for node in pred_rst.postorder_DFT(pred_rst.tree, []): pred_forms.append(node.form) for node in gold_rst.postorder_DFT(gold_rst.tree, []): gold_forms.append(node.form) nodes = gold_rst.postorder_DFT(gold_rst.tree, []) inner_nodes = [ node for node in nodes if node.lnode is not None and node.rnode is not None ] for idx, node in enumerate(inner_nodes): relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation rela_class = RstTree.extract_relation(relation) if rela_class in depth_per_relation: depth_per_relation[rela_class].append(node.depth) else: depth_per_relation[rela_class] = [node.depth] lnode_text = ' '.join([ gold_rst.doc.token_dict[tid].word for tid in node.lnode.text ]) lnode_lemmas = ' '.join([ gold_rst.doc.token_dict[tid].lemma for tid in node.lnode.text ]) rnode_text = ' '.join([ gold_rst.doc.token_dict[tid].word for tid in node.rnode.text ]) rnode_lemmas = ' '.join([ gold_rst.doc.token_dict[tid].lemma for tid in node.rnode.text ]) # if rela_class == 'Topic-Change': # print(fmerge) # print(relation) # print(lnode_text) # print(rnode_text) # print() if report: met.report()
def eval_parser(self, data_dir, output_dir='./examples', report=False, bcvocab=None, draw=True, isFlat=False): """ Test the parsing performance""" # Evaluation met = Metrics(levels=['span', 'nuclearity', 'relation']) # ---------------------------------------- # Read all files from the given path with open(os.path.join(output_dir, "Treebank/TEST", "processed_data.p"), 'rb') as file: doclist = pickle.load(file) fnames = [fn for fn in os.listdir(data_dir) if fn.endswith(".out")] pred_forms = [] gold_forms = [] depth_per_relation = {} for lines, fname in zip(doclist, fnames): # ---------------------------------------- # Read *.merge file doc = Doc() doc.read_from_fmerge(lines) fout = os.path.join(data_dir, fname) print(fout) # ---------------------------------------- # Parsing print("************************ predict rst ************************") pred_rst = self.parser.sr_parse(doc, self.isFlat,bcvocab) if draw: pred_rst.draw_rst(fout+'.ps') # Get brackets from parsing results pred_brackets = pred_rst.bracketing(self.isFlat) fbrackets = fout+'.brackets' # Write brackets into file Evaluator.writebrackets(fbrackets, pred_brackets) # ---------------------------------------- # Evaluate with gold RST tree if report: print("************************ gold rst ************************") fdis = fout+'.dis' gold_rst = RstTree(fdis, lines, isFlat) gold_rst.build() met.eval(gold_rst, pred_rst, self.isFlat) if isFlat: for node in pred_rst.postorder_flat_DFT(pred_rst.tree, []): pred_forms.append(node.form) for node in gold_rst.postorder_flat_DFT(gold_rst.tree, []): gold_forms.append(node.form) nodes = gold_rst.postorder_flat_DFT(gold_rst.tree, []) else: for node in pred_rst.postorder_DFT(pred_rst.tree, []): pred_forms.append(node.form) for node in gold_rst.postorder_DFT(gold_rst.tree, []): gold_forms.append(node.form) nodes = gold_rst.postorder_DFT(gold_rst.tree, []) inner_nodes = [node for node in nodes if node.lnode is not None and node.rnode is not None] for idx, node in enumerate(inner_nodes): relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation rela_class = RstTree.extract_relation(relation) if rela_class in depth_per_relation: depth_per_relation[rela_class].append(node.depth) else: depth_per_relation[rela_class] = [node.depth] if report: met.report()