def eval_parser(self, dev_data=None, path='./examples', use_parseval=False): """ Test the parsing performance""" # Evaluation met = Metrics(use_parseval, levels=['span', 'nuclearity']) # ---------------------------------------- # Read all files from the given path if dev_data is None: eval_list = [ os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('.merge') ] else: eval_list = dev_data pred_forms = [] gold_forms = [] total_cost = 0 for eval_instance in eval_list: # ---------------------------------------- # Read *.merge file doc = Doc() if dev_data is not None: gold_rst = eval_instance doc = eval_instance.doc else: doc.read_from_fmerge(eval_instance) eval_instance = eval_instance.replace('.dis', '.merge') fdis = eval_instance.replace('.merge', '.dis') gold_rst = RstTree(fdis, eval_instance) gold_rst.build() _, gold_action_seq = gold_rst.generate_action_samples() gold_action_seq = list( map(lambda x: self.data_helper.action_map[x], gold_action_seq)) pred_rst, cost = self.parser.sr_parse( [doc], [torch.cuda.LongTensor(gold_action_seq)], None, is_train=False) total_cost += cost pred_rst = pred_rst[0] # ---------------------------------------- # Evaluate with gold RST tree met.eval(gold_rst, pred_rst) print("Total cost: ", total_cost) met.report()
def test_tree(): fdis = "../data/data_dir/train_dir/wsj_0603.out.dis" fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge" tree = RstTree(fdis, fmerge) tree.build() bftbin = RstTree.BFTbin(tree.tree) assert len(bftbin) == 13 edu_spans = [(1, 7), (1, 5), (6, 7), (1, 4), (5, 5), (6, 6), (7, 7), (1, 2), (3, 4), (1, 1), (2, 2), (3, 3), (4, 4)] r, n, s = 'Root', 'Nucleus', 'Satellite' props = [r, n, n, n, s, n, s, n, n, n, s, n, s] rels = [None, 'TextualOrganization', 'TextualOrganization', 'span', 'elaboration-additional', 'span', 'attribution', 'Same-Unit', 'Same-Unit', 'span', 'elaboration-object-attribute-e', 'span', 'elaboration-object-attribute-e',] for i, span in enumerate(bftbin): assert span.edu_span == edu_spans[i] assert span.prop == props[i] assert span.relation == rels[i] # TODO: Check where relation conversion happens? In RST-Coref paper, in generate_action_samples fdis, action_seq = tree.generate_action_samples() assert fdis == fdis sh, rd = 'Shift', 'Reduce' nn, ns, sn = 'NN', 'NS', 'SN' gold_action_seq = [(sh, None), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None), (sh, None), (rd, ns), (rd, nn)] assert len(action_seq) == len(gold_action_seq) for i in range(len(action_seq)): assert action_seq[i] == gold_action_seq[i] # Check that tree binarization is node correctly fdis = "../data/data_dir/train_dir/file2.dis" fmerge = "../data/data_dir/train_dir/file2.merge" tree = RstTree(fdis, fmerge) tree.build() spans = [span.edu_span for span in RstTree.BFTbin(tree.tree)] nonbinary_spans = [(22, 28), (22, 22), (23, 28), (23, 23), (24, 28), (24, 24), (25, 28), (25, 25), (26, 28), (26, 26), (27, 28), (27, 27), (28, 28), (11, 15), (11, 11), (12, 15), (12, 12), (13, 15), (13, 14), (13, 13), (14, 14), (15, 15), (45, 57), (45, 47), (48, 57), (48, 53), (54, 57), (54, 56), (57, 57)] for span in nonbinary_spans: assert span in spans