示例#1
0
def cal_class_distribution(data_dir, level):
    """
    calculate the class distribution
    :param data_dir:
    :param level: 0 for inner-sentence, 1 for inter-sentence but inner paragraph, 2 for inter-paragraph, 3 for different depth
    :return: None
    """
    rst_trees = DataHelper.read_rst_trees(data_dir)
    all_nodes = [node for rst_tree in rst_trees for node in rst_tree.postorder_DFT(rst_tree.tree, [])]
    if level in [0, 1, 2]:
        valid_relations = [RstTree.extract_relation(node.child_relation) for node in all_nodes if
                           node.level == level and node.child_relation is not None]
        distribution = Counter(valid_relations)
        for cla in class2rel:
            if cla not in distribution:
                distribution[cla] = 0
        return distribution
    if level == 3:
        depth_relation_distributions = {}
        for node in all_nodes:
            if node.lnode is None and node.rnode is None:
                continue
            if node.depth in depth_relation_distributions:
                depth_relation_distributions[node.depth][RstTree.extract_relation(node.child_relation)] += 1
            else:
                depth_relation_distributions[node.depth] = Counter()
                depth_relation_distributions[node.depth][RstTree.extract_relation(node.child_relation)] = 1
        for depth, distribution in depth_relation_distributions.items():
            for cla in class2rel:
                if cla not in distribution:
                    distribution[cla] = 0
        return depth_relation_distributions
示例#2
0
 def read_rst_trees(data_dir):
     # Read RST tree file
     files = [os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.dis')]
     for i, fdis in enumerate(files):
         fmerge = fdis.replace('.dis', '.merge')
         if not os.path.isfile(fmerge):
             print("Corresponding .fmerge file does not exist. Skipping the file.")
             continue
         rst_tree = RstTree(fdis, fmerge)
         rst_tree.build()
         yield rst_tree
示例#3
0
 def sr_parse(self, doc, gold_actions, gold_rels):
             
     # Generate coref clusters for the document
     if self.clf.config[MODEL_TYPE] in [1, 2]:
         with torch.no_grad():
             clusters, _ = self.coref_trainer.predict_clusters(doc) 
     else:
         clusters = None
         
     # Stack/Queue state
     conf = ParsingState([], [], self.clf.config)
     conf.init(doc)        
     all_action_probs, all_rel_probs = [], []
     # Until the tree is built
     while not conf.end_parsing():
         
         # Get features for the current stack/queue state, and span boundaries
         stack, queue = conf.get_status()
         fg = ActionFeatureGenerator(stack, queue, [], doc, self.data_helper, self.config)
         action_feat, span_boundary = fg.gen_features()
         span_embeds = self.clf.get_edus_bert_coref([doc], [clusters], [span_boundary])
         action_probs, rel_probs = self.clf.decode_action_coref(span_embeds, [action_feat])
         all_action_probs.append(action_probs.squeeze())
         sorted_action_idx = torch.argsort(action_probs, descending=True)
         sorted_rel_idx = torch.argsort(rel_probs, descending=True)
         
         # Select Shift/Reduce action (shift/reduce-nn/...)
         action_idx = 0
         pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])]      
         while not conf.is_action_allowed((pred_action, pred_nuc, None), doc):
             action_idx += 1
             pred_action, pred_nuc = xidx_action_map[int(sorted_action_idx[0, action_idx])]
             
         # Select Relation annotation
         pred_rel = None
         if pred_action != "Shift":
             all_rel_probs.append(rel_probs.squeeze())
             pred_rel_idx = int(sorted_rel_idx[0, 0])
             pred_rel = xidx_relation_map[pred_rel_idx]
         assert not (pred_action == "Reduce" and pred_rel is None)
         
         predictions = (pred_action, pred_nuc, pred_rel)
         conf.operate(predictions)
         
     # Shift/Reduce loss
     cost = self.loss(torch.stack(all_action_probs), gold_actions)
     
     # Relation annotation loss
     if all_rel_probs != []:
         cost_relation = self.loss(torch.stack(all_rel_probs), gold_rels)
         cost += cost_relation
     
     tree = conf.get_parse_tree()
     rst_tree = RstTree()
     rst_tree.assign_tree(tree)
     rst_tree.assign_doc(doc)
     rst_tree.back_prop(tree, doc)
     
     return rst_tree, cost.item()
示例#4
0
 def __getitem__(self, idx):
     """
         Returns:
             If train: features for that tree (corefs, organizational features, spans, etc.)
             If test: gold tree
     """        
     if self.is_train:
         return self.X[idx], torch.LongTensor(self.y[idx])
     
     doc = self.data_helper.docs[self.X[idx][0]]
     
     gold_rst = RstTree(doc.filename, doc.filename.replace('.dis', '.merge'))
     gold_rst.build()
     
     return gold_rst
示例#5
0
    def eval_parser(self,
                    dev_data=None,
                    path='./examples',
                    use_parseval=False):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(use_parseval, levels=['span', 'nuclearity'])
        # ----------------------------------------
        # Read all files from the given path
        if dev_data is None:
            eval_list = [
                os.path.join(path, fname) for fname in os.listdir(path)
                if fname.endswith('.merge')
            ]
        else:
            eval_list = dev_data
        pred_forms = []
        gold_forms = []
        total_cost = 0
        for eval_instance in eval_list:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            if dev_data is not None:
                gold_rst = eval_instance
                doc = eval_instance.doc
            else:
                doc.read_from_fmerge(eval_instance)
                eval_instance = eval_instance.replace('.dis', '.merge')
                fdis = eval_instance.replace('.merge', '.dis')
                gold_rst = RstTree(fdis, eval_instance)
                gold_rst.build()

            _, gold_action_seq = gold_rst.generate_action_samples()
            gold_action_seq = list(
                map(lambda x: self.data_helper.action_map[x], gold_action_seq))
            pred_rst, cost = self.parser.sr_parse(
                [doc], [torch.cuda.LongTensor(gold_action_seq)],
                None,
                is_train=False)
            total_cost += cost
            pred_rst = pred_rst[0]
            # ----------------------------------------
            # Evaluate with gold RST tree
            met.eval(gold_rst, pred_rst)

        print("Total cost: ", total_cost)
        met.report()
示例#6
0
 def read_rst_trees(data_dir, output_dir, parse_type, isFlat):
     # Read RST tree file
     files = [
         os.path.join(data_dir, fname) for fname in os.listdir(data_dir)
         if fname.endswith('.dis')
     ]
     with open(os.path.join(output_dir, parse_type, "processed_data.p"),
               'rb') as file:
         fmerges = pickle.load(file)
     rst_trees = []
     for fdis, fmerge in zip(files, fmerges):
         print("building tree " + fdis)
         rst_tree = RstTree(fdis, fmerge, isFlat)
         rst_tree.build()
         rst_trees.append(rst_tree)
     return rst_trees
示例#7
0
 def read_rst_trees(data_dir):
     # Read RST tree file
     files = [
         os.path.join(data_dir, fname) for fname in os.listdir(data_dir)
         if fname.endswith('.dis')
     ]
     rst_trees = []
     for fdis in files:
         fmerge = fdis.replace('.dis', '.merge')
         if not os.path.isfile(fmerge):
             raise FileNotFoundError(
                 'Corresponding .fmerge file does not exist. You should do preprocessing first.'
             )
         rst_tree = RstTree(fdis, fmerge)
         rst_tree.build()
         rst_trees.append(rst_tree)
     return rst_trees
示例#8
0
 def read_rst_trees(data_dir):
     # Read RST tree file
     files = [
         os.path.join(data_dir, fname) for fname in os.listdir(data_dir)
         if fname.endswith('.dis')
     ]
     for i, fdis in enumerate(files):
         fmerge = fdis.replace('.dis', '.merge')
         if not os.path.isfile(fmerge):
             print(
                 fmerge +
                 " - Corresponding .fmerge file does not exist. Skipping the file."
             )
             continue
         rst_tree = RstTree(fdis, fmerge)
         rst_tree.build()
         if (rst_tree.get_parse().startswith(" ( EDU")):
             print("Skipping document - tree is corrupt.")
             continue
         yield rst_tree
示例#9
0
 def __getitem__(self, idx):
     """
         Returns:
             document_edus - packed list of tensors, where each element is tensor of
                             word embeddings for an EDU of that document
             sorted_idx    - output of torch.argsort, used to return EDUs to original
                             linear order
             y             - sequence of gold actions for a particular tree
             doc           - Document instance of an RstTree instance
     """
     doc = Doc()
     eval_instance = self.X[idx].replace('.dis', '.merge')
     doc.read_from_fmerge(eval_instance)
     
     if self.is_train:
         return doc, torch.cuda.LongTensor(self.y[idx])
     
     fdis = eval_instance.replace('.merge', '.dis')
     gold_rst = RstTree(fdis, eval_instance)
     gold_rst.build()
     return gold_rst     
示例#10
0
def test_eval():
    fdis = "../data/data_dir/train_dir/wsj_0603.out.dis"
    fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge"
    gold_tree = RstTree(fdis, fmerge)
    gold_tree.build()
    eval_trees(gold_tree, gold_tree, 1, 1, 1, 1)  
    
    sr_parser = ParsingState([], [])
    sr_parser.init(gold_tree.doc)
    sh, rd = 'Shift', 'Reduce' 
    nn, ns, sn = 'NN', 'NS', 'SN'
    
    # Last two nuclearity actions are incorrect
    silver_nuc = [(sh, None), (sh, None), (rd, ns), (sh, None), (sh, None),
                  (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None),
                  (sh, None), (rd, sn), (rd, ns)]
    
    for idx, action in enumerate(silver_nuc):
        sr_parser.operate(action)
    tree = sr_parser.get_parse_tree()
    silver_tree = RstTree()
    silver_tree.assign_tree(tree)
    silver_tree.assign_doc(gold_tree.doc)
    silver_tree.back_prop(tree, gold_tree.doc)
    
    eval_trees(silver_tree, gold_tree, 1, 9/12, 1, 4/6)  

    # Last two nuclearity actions are incorrect
    sr_parser = ParsingState([], [])
    sr_parser.init(gold_tree.doc)
    silver_nuc = [(sh, None), (sh, None), (sh, None), (rd, ns), (sh, None),
                  (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None), (sh, None),
                  (rd, ns), (rd, ns)]
    
    for idx, action in enumerate(silver_nuc):
        sr_parser.operate(action)
    tree = sr_parser.get_parse_tree()
    silver_tree = RstTree()
    silver_tree.assign_tree(tree)
    silver_tree.assign_doc(gold_tree.doc)
    silver_tree.back_prop(tree, gold_tree.doc)
    eval_trees(silver_tree, gold_tree, 10/12, 7/12, 4/6, 3/6)  
示例#11
0
def test_tree():
    fdis = "../data/data_dir/train_dir/wsj_0603.out.dis"
    fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge"
    tree = RstTree(fdis, fmerge)
    tree.build()
    bftbin = RstTree.BFTbin(tree.tree)
    assert len(bftbin) == 13
    edu_spans = [(1, 7), (1, 5), (6, 7), (1, 4), (5, 5), (6, 6), (7, 7), (1, 2), (3, 4), (1, 1), (2, 2), (3, 3), (4, 4)]
    r, n, s = 'Root', 'Nucleus', 'Satellite'
    props = [r, n, n, n, s, n, s, n, n, n, s, n, s] 
    rels = [None, 'TextualOrganization', 'TextualOrganization', 
            'span', 'elaboration-additional', 'span', 'attribution', 'Same-Unit', 
            'Same-Unit', 'span', 'elaboration-object-attribute-e', 'span', 
            'elaboration-object-attribute-e',]
    for i, span in enumerate(bftbin):
        assert span.edu_span == edu_spans[i]
        assert span.prop == props[i]
        assert span.relation == rels[i]
    
    # TODO: Check where relation conversion happens? In RST-Coref paper, in generate_action_samples
    fdis, action_seq = tree.generate_action_samples()
    assert fdis == fdis
    sh, rd = 'Shift', 'Reduce' 
    nn, ns, sn = 'NN', 'NS', 'SN'
    gold_action_seq = [(sh, None), (sh, None), (rd, ns), (sh, None), (sh, None),
                  (rd, ns), (rd, nn), (sh, None), (rd, ns), (sh, None),
                  (sh, None), (rd, ns), (rd, nn)]
    assert len(action_seq) == len(gold_action_seq)
    for i in range(len(action_seq)):
        assert action_seq[i] == gold_action_seq[i]
        
    # Check that tree binarization is node correctly
    fdis = "../data/data_dir/train_dir/file2.dis"
    fmerge = "../data/data_dir/train_dir/file2.merge"
    tree = RstTree(fdis, fmerge)
    tree.build()
    spans = [span.edu_span for span in RstTree.BFTbin(tree.tree)]
    nonbinary_spans = [(22, 28), (22, 22), (23, 28), (23, 23), (24, 28), 
                       (24, 24), (25, 28), (25, 25), (26, 28), (26, 26), 
                       (27, 28), (27, 27), (28, 28), (11, 15), (11, 11), 
                       (12, 15), (12, 12), (13, 15), (13, 14), (13, 13),
                       (14, 14), (15, 15), (45, 57), (45, 47), (48, 57),
                       (48, 53), (54, 57), (54, 56), (57, 57)]
    for span in nonbinary_spans:
        assert span in spans
示例#12
0
    def eval_parser(self,
                    dev_data=None,
                    path='./examples',
                    save_preds=True,
                    use_parseval=False):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(levels=['span', 'nuclearity', 'relation'],
                      use_parseval=use_parseval)
        # ----------------------------------------
        # Read all files from the given path
        if dev_data is None:
            dev_data = [
                os.path.join(path, fname) for fname in os.listdir(path)
                if fname.endswith('.dis')
            ]
        total_cost = 0
        for eval_instance in dev_data:
            # ----------------------------------------
            fmerge = eval_instance.replace('.dis', '.merge')
            doc = Doc()
            doc.read_from_fmerge(fmerge)
            gold_rst = RstTree(eval_instance, fmerge)
            gold_rst.build()

            # tok_edus = [nltk.word_tokenize(edu) for edu in doc.doc_edus]
            tok_edus = [edu.split(" ") for edu in doc.doc_edus]
            tokens = flatten(tok_edus)

            coref_document = Document(raw_text=None,
                                      tokens=tokens,
                                      sents=tok_edus,
                                      corefs=[],
                                      speakers=["0"] * len(tokens),
                                      genre="nw",
                                      filename=None)

            coref_document.token_dict = doc.token_dict
            coref_document.edu_dict = doc.edu_dict
            doc = coref_document

            gold_action_seq, gold_rel_seq = gold_rst.decode_rst_tree()

            gold_action_seq = [action_map[x] for x in gold_action_seq]
            gold_relation_seq = [
                relation_map[x.lower()] for x in gold_rel_seq if x is not None
            ]
            pred_rst, cost = self.parser.sr_parse(
                doc, torch.cuda.LongTensor(gold_action_seq),
                torch.cuda.LongTensor(gold_relation_seq))
            total_cost += cost

            if save_preds:
                if not os.path.isdir('../data/predicted_trees'):
                    os.mkdir('../data/predicted_trees')

                filename = eval_instance.split(os.sep)[-1]
                filepath = f'../data/predicted_trees/{self.config[MODEL_NAME]}_{filename}'

                pred_brackets = pred_rst.bracketing()
                # Write brackets into file
                Evaluator.writebrackets(filepath, pred_brackets)
            # ----------------------------------------
            # Evaluate with gold RST tree
            met.eval(gold_rst, pred_rst)

        print("Total cost: ", total_cost)
        if use_parseval:
            print("Reporting original Parseval metric.")
        else:
            print("Reporting RST Parseval metric.")
        met.report()
示例#13
0
    def sr_parse(self, doc, bcvocab=None):
        """ Shift-reduce RST parsing based on models prediction

        :type doc: Doc
        :param doc: the document instance

        :type bcvocab: dict
        :param bcvocab: brown clusters
        """
        # use transition-based parsing to build tree structure
        conf = ParsingState([], [])
        conf.init(doc)
        action_hist = []
        while not conf.end_parsing():
            stack, queue = conf.get_status()
            fg = ActionFeatureGenerator(stack, queue, action_hist, doc, bcvocab)
            action_feats = fg.gen_features()
            action_probs = self.action_clf.predict_probs(action_feats)
            for action, cur_prob in action_probs:
                if conf.is_action_allowed(action):
                    conf.operate(action)
                    action_hist.append(action)
                    break
        tree = conf.get_parse_tree()
        # RstTree.down_prop(tree)
        # assign the node to rst_tree
        rst_tree = RstTree()
        rst_tree.assign_tree(tree)
        rst_tree.assign_doc(doc)
        rst_tree.down_prop(tree)
        rst_tree.back_prop(tree, doc)
        # tag relations for the tree
        post_nodelist = RstTree.postorder_DFT(rst_tree.tree, [])
        for node in post_nodelist:
            if (node.lnode is not None) and (node.rnode is not None):
                fg = RelationFeatureGenerator(node, rst_tree, node.level, bcvocab)
                relation_feats = fg.gen_features()
                relation = self.relation_clf.predict(relation_feats, node.level)
                node.assign_relation(relation)
        return rst_tree
示例#14
0
    def eval_parser(self,
                    path='./examples',
                    report=False,
                    bcvocab=None,
                    draw=True):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(levels=['span', 'nuclearity', 'relation'])
        # ----------------------------------------
        # Read all files from the given path
        doclist = [
            os.path.join(path, fname) for fname in os.listdir(path)
            if fname.endswith('.merge')
        ]
        pred_forms = []
        gold_forms = []
        depth_per_relation = {}
        for fmerge in doclist:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            doc.read_from_fmerge(fmerge)
            # ----------------------------------------
            # Parsing
            pred_rst = self.parser.sr_parse(doc, bcvocab)
            if draw:
                pred_rst.draw_rst(fmerge.replace(".merge", ".ps"))
            # Get brackets from parsing results
            pred_brackets = pred_rst.bracketing()
            fbrackets = fmerge.replace('.merge', '.brackets')
            # Write brackets into file
            Evaluator.writebrackets(fbrackets, pred_brackets)
            # ----------------------------------------
            # Evaluate with gold RST tree
            if report:
                fdis = fmerge.replace('.merge', '.dis')
                gold_rst = RstTree(fdis, fmerge)
                gold_rst.build()
                met.eval(gold_rst, pred_rst)
                for node in pred_rst.postorder_DFT(pred_rst.tree, []):
                    pred_forms.append(node.form)
                for node in gold_rst.postorder_DFT(gold_rst.tree, []):
                    gold_forms.append(node.form)

                nodes = gold_rst.postorder_DFT(gold_rst.tree, [])
                inner_nodes = [
                    node for node in nodes
                    if node.lnode is not None and node.rnode is not None
                ]
                for idx, node in enumerate(inner_nodes):
                    relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation
                    rela_class = RstTree.extract_relation(relation)
                    if rela_class in depth_per_relation:
                        depth_per_relation[rela_class].append(node.depth)
                    else:
                        depth_per_relation[rela_class] = [node.depth]
                    lnode_text = ' '.join([
                        gold_rst.doc.token_dict[tid].word
                        for tid in node.lnode.text
                    ])
                    lnode_lemmas = ' '.join([
                        gold_rst.doc.token_dict[tid].lemma
                        for tid in node.lnode.text
                    ])
                    rnode_text = ' '.join([
                        gold_rst.doc.token_dict[tid].word
                        for tid in node.rnode.text
                    ])
                    rnode_lemmas = ' '.join([
                        gold_rst.doc.token_dict[tid].lemma
                        for tid in node.rnode.text
                    ])
                    # if rela_class == 'Topic-Change':
                    #     print(fmerge)
                    #     print(relation)
                    #     print(lnode_text)
                    #     print(rnode_text)
                    #     print()

        if report:
            met.report()
示例#15
0
    def sr_parse(self, docs, gold_actions, optim, is_train=True):

        confs = []
        action_hists = []
        for doc in docs:
            conf = ParsingState([], [], self.config)
            conf.init(doc)
            confs.append(conf)
            action_hists.append([])
        action_hist_nonnum = []
        iter = 0
        cost_acc = 0
        while not confs[0].end_parsing():
            if is_train:
                optim.zero_grad()
            action_feats, neural_feats, spans, curr_gold_actions = [], [], [], []
            for conf, doc in zip(confs, docs):
                stack, queue = conf.get_status()
                fg = ActionFeatureGenerator(stack, queue, action_hist_nonnum,
                                            doc, self.clf.data_helper,
                                            self.config)
                action_feat, neural_feat = fg.gen_features()
                action_feats.append(action_feat)
                neural_feats.append(neural_feat)

            spans = self.clf.get_edus_bert_span(docs, neural_feats)
            action_probs = self.clf.decode_action_online(spans, action_feats)
            sorted_action_idx = torch.argsort(action_probs, descending=True)

            for i, conf in enumerate(confs):
                action_hists[i].append(action_probs[i])
                action_idx = 0
                curr_gold_actions.append(gold_actions[i][iter].unsqueeze(0))
                if is_train:
                    action = self.clf.xidx_action_map[int(
                        gold_actions[i][iter])]
                else:
                    action = self.clf.xidx_action_map[int(
                        sorted_action_idx[i, action_idx])]
                    while not conf.is_action_allowed(action, docs[i]):
                        action = self.clf.xidx_action_map[int(
                            sorted_action_idx[i, action_idx])]
                        action_idx += 1
                conf.operate(action)
            cost = self.loss(action_probs, torch.cat(curr_gold_actions))
            cost_acc += cost.item()
            if is_train:
                cost.backward()
                nn.utils.clip_grad_norm_(self.clf.parameters(), 0.2)
                optim.step()
            iter += 1

        rst_trees = []
        for conf, doc in zip(confs, docs):
            tree = conf.get_parse_tree()
            rst_tree = RstTree()
            rst_tree.assign_tree(tree)
            rst_tree.assign_doc(doc)
            rst_tree.back_prop(tree, doc)
            rst_trees.append(rst_tree)
        return rst_trees, cost_acc