示例#1
0
    def eval_parser(self,
                    dev_data=None,
                    path='./examples',
                    use_parseval=False):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(use_parseval, levels=['span', 'nuclearity'])
        # ----------------------------------------
        # Read all files from the given path
        if dev_data is None:
            eval_list = [
                os.path.join(path, fname) for fname in os.listdir(path)
                if fname.endswith('.merge')
            ]
        else:
            eval_list = dev_data
        pred_forms = []
        gold_forms = []
        total_cost = 0
        for eval_instance in eval_list:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            if dev_data is not None:
                gold_rst = eval_instance
                doc = eval_instance.doc
            else:
                doc.read_from_fmerge(eval_instance)
                eval_instance = eval_instance.replace('.dis', '.merge')
                fdis = eval_instance.replace('.merge', '.dis')
                gold_rst = RstTree(fdis, eval_instance)
                gold_rst.build()

            _, gold_action_seq = gold_rst.generate_action_samples()
            gold_action_seq = list(
                map(lambda x: self.data_helper.action_map[x], gold_action_seq))
            pred_rst, cost = self.parser.sr_parse(
                [doc], [torch.cuda.LongTensor(gold_action_seq)],
                None,
                is_train=False)
            total_cost += cost
            pred_rst = pred_rst[0]
            # ----------------------------------------
            # Evaluate with gold RST tree
            met.eval(gold_rst, pred_rst)

        print("Total cost: ", total_cost)
        met.report()
示例#2
0
def test_tree():
    fdis = "../data/data_dir/train_dir/wsj_0603.out.dis"
    fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge"
    tree = RstTree(fdis, fmerge)
    tree.build()
    bftbin = RstTree.BFTbin(tree.tree)
    assert len(bftbin) == 13
    edu_spans = [(1, 7), (1, 5), (6, 7), (1, 4), (5, 5), (6, 6), (7, 7), (1, 2), (3, 4), (1, 1), (2, 2), (3, 3), (4, 4)]
    r, n, s = 'Root', 'Nucleus', 'Satellite'
    props = [r, n, n, n, s, n, s, n, n, n, s, n, s] 
    rels = [None, 'TextualOrganization', 'TextualOrganization', 
            'span', 'elaboration-additional', 'span', 'attribution', 'Same-Unit', 
            'Same-Unit', 'span', 'elaboration-object-attribute-e', 'span', 
            'elaboration-object-attribute-e',]
    for i, span in enumerate(bftbin):
        assert span.edu_span == edu_spans[i]
        assert span.prop == props[i]
        assert span.relation == rels[i]
    
    # TODO: Check where relation conversion happens? In RST-Coref paper, in generate_action_samples
    fdis, action_seq = tree.generate_action_samples()
    assert fdis == fdis
    sh, rd = 'Shift', 'Reduce' 
    nn, ns, sn = 'NN', 'NS', 'SN'
    gold_action_seq = [(sh, None), (sh, None), (rd, ns), (sh, None), (sh, None),
                  (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None),
                  (sh, None), (rd, ns), (rd, nn)]
    assert len(action_seq) == len(gold_action_seq)
    for i in range(len(action_seq)):
        assert action_seq[i] == gold_action_seq[i]
        
    # Check that tree binarization is node correctly
    fdis = "../data/data_dir/train_dir/file2.dis"
    fmerge = "../data/data_dir/train_dir/file2.merge"
    tree = RstTree(fdis, fmerge)
    tree.build()
    spans = [span.edu_span for span in RstTree.BFTbin(tree.tree)]
    nonbinary_spans = [(22, 28), (22, 22), (23, 28), (23, 23), (24, 28), 
                       (24, 24), (25, 28), (25, 25), (26, 28), (26, 26), 
                       (27, 28), (27, 27), (28, 28), (11, 15), (11, 11), 
                       (12, 15), (12, 12), (13, 15), (13, 14), (13, 13),
                       (14, 14), (15, 15), (45, 57), (45, 47), (48, 57),
                       (48, 53), (54, 57), (54, 56), (57, 57)]
    for span in nonbinary_spans:
        assert span in spans