Beispiel #1
0
def test_doc():
    doc = Doc()
    fmerge = "../data/data_dir/train_dir/wsj_0603.out.merge"
    doc.read_from_fmerge(fmerge)
    # Correct number of tokens/edus
    assert len(doc.edu_dict) == 7
    edu_lens = [10, 3, 9, 7, 9, 13, 7]
    for i in range(1, 8):
        new_edu = []
        for tok_idx in doc.edu_dict[i]:
            new_word = doc.token_dict[tok_idx].word
            new_edu.append(new_word)
        assert len(doc.edu_dict[i]) == edu_lens[i-1]
    
    num_tokens = sum(edu_lens)
    
    # Sidx/pidx/ TODO eduidx
    assert len(doc.token_dict) == num_tokens
    assert doc.token_dict[0].sidx == 0 and doc.token_dict[num_tokens - 1].sidx == 2
    assert doc.token_dict[37].sidx == 1 and doc.token_dict[38].sidx == 2
    assert doc.token_dict[0].pidx == 1 and doc.token_dict[num_tokens - 1].pidx == 2
    assert doc.token_dict[0].eduidx == 1 and doc.token_dict[num_tokens - 1].eduidx == 7
    # Bigger doc
    doc = Doc()
    fmerge = "../data/data_dir/train_dir/wsj_0606.out.merge"
    doc.read_from_fmerge(fmerge)
    assert len(doc.edu_dict) == 57
    assert doc.token_dict[127].pidx == 2 and doc.token_dict[128].pidx == 3
    assert doc.token_dict[0].eduidx == 1 and doc.token_dict[581].eduidx == 57
    
    assert doc.token_dict[362].eduidx == 38 and doc.token_dict[363].eduidx == 39
Beispiel #2
0
    def pred_parser(self, output_dir='./examples', parse_type=None, bcvocab=None, draw=True):
        """ Test the parsing performance"""
        # Evaluation
        # met = Metrics(levels=['span', 'nuclearity', 'relation'])
        # ----------------------------------------
        # Read all files from the given path
        with open(os.path.join(output_dir, parse_type, "processed_data.p"), 'rb') as file:
            doclist = pickle.load(file)
        relations = list(other.class2rel.keys())
        results = []
        for lines in doclist:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            relation_d = {rel:0.0 for rel in relations}
            if len(lines) >= 2:
                doc.read_from_fmerge(lines)
                # ----------------------------------------
                # Parsing
                pred_rst = self.parser.sr_parse(doc, self.isFlat, bcvocab)

                # if draw:
                #     pred_rst.draw_rst(fmerge.replace(".merge", ".ps"))
                # Get brackets from parsing results

                pred_brackets = pred_rst.bracketing(self.isFlat)
                for brack in pred_brackets:
                    relation_d[brack[2]] +=1
            if sum(relation_d.values()):
                relation_d = {k: str(v) +"/"+ str(sum(relation_d.values())) for k, v in relation_d.items()}
            print(relation_d)
            results.append(relation_d)
        with open(os.path.join(output_dir, parse_type,"result.p"), 'wb') as file:
            pickle.dump(results, file)
Beispiel #3
0
 def build(self):
     """ Build BINARY RST tree
     """
     with open(self.fdis) as fin:
         text = fin.read()
     # Build RST as annotation
     self.tree = RstTree.build_tree(text)
     # Binarize it
     if self.isFlat:
         self.tree = RstTree.flat_tree(self.tree)
     else:
         self.tree = RstTree.binarize_tree(self.tree)
     # Read doc file
     if self.fmerge:
         doc = Doc()
         doc.read_from_fmerge(self.fmerge)
         self.doc = doc
     else:
         raise IOError("File doesn't exist: {}".format(self.fmerge))
     if self.isFlat:
         RstTree.down_flat_prop(self.tree)
         RstTree.flat_back_prop(self.tree, self.doc)
     else:
         RstTree.down_prop(self.tree)
         RstTree.back_prop(self.tree, self.doc)
Beispiel #4
0
    def eval_parser(self,
                    dev_data=None,
                    path='./examples',
                    use_parseval=False):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(use_parseval, levels=['span', 'nuclearity'])
        # ----------------------------------------
        # Read all files from the given path
        if dev_data is None:
            eval_list = [
                os.path.join(path, fname) for fname in os.listdir(path)
                if fname.endswith('.merge')
            ]
        else:
            eval_list = dev_data
        pred_forms = []
        gold_forms = []
        total_cost = 0
        for eval_instance in eval_list:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            if dev_data is not None:
                gold_rst = eval_instance
                doc = eval_instance.doc
            else:
                doc.read_from_fmerge(eval_instance)
                eval_instance = eval_instance.replace('.dis', '.merge')
                fdis = eval_instance.replace('.merge', '.dis')
                gold_rst = RstTree(fdis, eval_instance)
                gold_rst.build()

            _, gold_action_seq = gold_rst.generate_action_samples()
            gold_action_seq = list(
                map(lambda x: self.data_helper.action_map[x], gold_action_seq))
            pred_rst, cost = self.parser.sr_parse(
                [doc], [torch.cuda.LongTensor(gold_action_seq)],
                None,
                is_train=False)
            total_cost += cost
            pred_rst = pred_rst[0]
            # ----------------------------------------
            # Evaluate with gold RST tree
            met.eval(gold_rst, pred_rst)

        print("Total cost: ", total_cost)
        met.report()
Beispiel #5
0
 def build(self):
     """ Build BINARY RST tree
     """
     with open(self.fdis) as fin:
         text = fin.read()
         print(f"Processing {self.fdis}")
     # Build RST as annotation
     self.tree = RstTree.build_tree(text)
     # Binarize it
     self.tree = RstTree.binarize_tree(self.tree)
     # Read doc file
     if isfile(self.fmerge):
         doc = Doc()
         doc.read_from_fmerge(self.fmerge)
         self.doc = doc
     else:
         raise IOError("File doesn't exist: {}".format(self.fmerge))
     RstTree.back_prop(self.tree, self.doc)
 def __getitem__(self, idx):
     """
         Returns:
             document_edus - packed list of tensors, where each element is tensor of
                             word embeddings for an EDU of that document
             sorted_idx    - output of torch.argsort, used to return EDUs to original
                             linear order
             y             - sequence of gold actions for a particular tree
             doc           - Document instance of an RstTree instance
     """
     doc = Doc()
     eval_instance = self.X[idx].replace('.dis', '.merge')
     doc.read_from_fmerge(eval_instance)
     
     if self.is_train:
         return doc, torch.cuda.LongTensor(self.y[idx])
     
     fdis = eval_instance.replace('.merge', '.dis')
     gold_rst = RstTree(fdis, eval_instance)
     gold_rst.build()
     return gold_rst     
Beispiel #7
0
    def create_data_helper(self, data_dir, config, coref_trainer):
        print("Parsing trees")

        # read train data
        all_feats_list, self.feats_list = [], []
        all_actions_numeric, self.actions_numeric = [], []
        all_relations_numeric, self.relations_numeric = [], []
        self.docs = []
        self.val_trees = []

        print("Generating features")
        for i, rst_tree in enumerate(self.read_rst_trees(data_dir=data_dir)):
            feats, actions, relations = rst_tree.generate_action_relation_samples(
                config)
            fdis = feats[0][0]

            # Old doc instance for storing sentence/paragraph/document features
            doc = Doc()
            eval_instance = fdis.replace('.dis', '.merge')
            doc.read_from_fmerge(eval_instance)

            # tok_edus = [nltk.word_tokenize(edu) for edu in doc.doc_edus]
            tok_edus = [edu.split(" ") for edu in doc.doc_edus]
            tokens = flatten(tok_edus)

            # Coreference resolver document instance for coreference functionality
            # (converting tokens to wordpieces and getting corresponding coref boundaries etc)
            coref_document = Document(raw_text=None,
                                      tokens=tokens,
                                      sents=tok_edus,
                                      corefs=[],
                                      speakers=["0"] * len(tokens),
                                      genre="nw",
                                      filename=fdis)
            # Duplicate for convenience
            coref_document.token_dict = doc.token_dict
            coref_document.edu_dict = doc.edu_dict
            coref_document.old_doc = doc

            for (feat, action, relation) in zip(feats, actions, relations):
                feat[0] = i
                all_feats_list.append(feat)
                all_actions_numeric.append(action)
                all_relations_numeric.append(relation)

            self.docs.append(coref_document)
            if i % 50 == 0:
                print("Processed ", i + 1, " trees")

        all_actions_numeric = [action_map[x] for x in all_actions_numeric]
        all_relations_numeric = [
            relation_map[x if x is None else x.lower()]
            for x in all_relations_numeric
        ]

        # Stratify by number of EDUs in the document
        stratified = get_stratify_classes(
            [len(coref_document.edu_dict) for coref_document in self.docs])
        train_indexes, val_indexes = train_test_split(np.arange(len(
            self.docs)),
                                                      test_size=0.1,
                                                      random_state=1,
                                                      stratify=stratified)

        # Select only those stack-queue actions that belong to trees in the train set
        for i, feat in enumerate(all_feats_list):
            if feat[0] in train_indexes:
                self.feats_list.append(feat)
                self.actions_numeric.append(all_actions_numeric[i])
                self.relations_numeric.append(all_relations_numeric[i])

        self.val_trees = [self.docs[index].filename for index in val_indexes]
        self.all_clusters = []
Beispiel #8
0
    def eval_parser(self,
                    dev_data=None,
                    path='./examples',
                    save_preds=True,
                    use_parseval=False):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(levels=['span', 'nuclearity', 'relation'],
                      use_parseval=use_parseval)
        # ----------------------------------------
        # Read all files from the given path
        if dev_data is None:
            dev_data = [
                os.path.join(path, fname) for fname in os.listdir(path)
                if fname.endswith('.dis')
            ]
        total_cost = 0
        for eval_instance in dev_data:
            # ----------------------------------------
            fmerge = eval_instance.replace('.dis', '.merge')
            doc = Doc()
            doc.read_from_fmerge(fmerge)
            gold_rst = RstTree(eval_instance, fmerge)
            gold_rst.build()

            # tok_edus = [nltk.word_tokenize(edu) for edu in doc.doc_edus]
            tok_edus = [edu.split(" ") for edu in doc.doc_edus]
            tokens = flatten(tok_edus)

            coref_document = Document(raw_text=None,
                                      tokens=tokens,
                                      sents=tok_edus,
                                      corefs=[],
                                      speakers=["0"] * len(tokens),
                                      genre="nw",
                                      filename=None)

            coref_document.token_dict = doc.token_dict
            coref_document.edu_dict = doc.edu_dict
            doc = coref_document

            gold_action_seq, gold_rel_seq = gold_rst.decode_rst_tree()

            gold_action_seq = [action_map[x] for x in gold_action_seq]
            gold_relation_seq = [
                relation_map[x.lower()] for x in gold_rel_seq if x is not None
            ]
            pred_rst, cost = self.parser.sr_parse(
                doc, torch.cuda.LongTensor(gold_action_seq),
                torch.cuda.LongTensor(gold_relation_seq))
            total_cost += cost

            if save_preds:
                if not os.path.isdir('../data/predicted_trees'):
                    os.mkdir('../data/predicted_trees')

                filename = eval_instance.split(os.sep)[-1]
                filepath = f'../data/predicted_trees/{self.config[MODEL_NAME]}_{filename}'

                pred_brackets = pred_rst.bracketing()
                # Write brackets into file
                Evaluator.writebrackets(filepath, pred_brackets)
            # ----------------------------------------
            # Evaluate with gold RST tree
            met.eval(gold_rst, pred_rst)

        print("Total cost: ", total_cost)
        if use_parseval:
            print("Reporting original Parseval metric.")
        else:
            print("Reporting RST Parseval metric.")
        met.report()
Beispiel #9
0
    def eval_parser(self,
                    path='./examples',
                    report=False,
                    bcvocab=None,
                    draw=True):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(levels=['span', 'nuclearity', 'relation'])
        # ----------------------------------------
        # Read all files from the given path
        doclist = [
            os.path.join(path, fname) for fname in os.listdir(path)
            if fname.endswith('.merge')
        ]
        pred_forms = []
        gold_forms = []
        depth_per_relation = {}
        for fmerge in doclist:
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            doc.read_from_fmerge(fmerge)
            # ----------------------------------------
            # Parsing
            pred_rst = self.parser.sr_parse(doc, bcvocab)
            if draw:
                pred_rst.draw_rst(fmerge.replace(".merge", ".ps"))
            # Get brackets from parsing results
            pred_brackets = pred_rst.bracketing()
            fbrackets = fmerge.replace('.merge', '.brackets')
            # Write brackets into file
            Evaluator.writebrackets(fbrackets, pred_brackets)
            # ----------------------------------------
            # Evaluate with gold RST tree
            if report:
                fdis = fmerge.replace('.merge', '.dis')
                gold_rst = RstTree(fdis, fmerge)
                gold_rst.build()
                met.eval(gold_rst, pred_rst)
                for node in pred_rst.postorder_DFT(pred_rst.tree, []):
                    pred_forms.append(node.form)
                for node in gold_rst.postorder_DFT(gold_rst.tree, []):
                    gold_forms.append(node.form)

                nodes = gold_rst.postorder_DFT(gold_rst.tree, [])
                inner_nodes = [
                    node for node in nodes
                    if node.lnode is not None and node.rnode is not None
                ]
                for idx, node in enumerate(inner_nodes):
                    relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation
                    rela_class = RstTree.extract_relation(relation)
                    if rela_class in depth_per_relation:
                        depth_per_relation[rela_class].append(node.depth)
                    else:
                        depth_per_relation[rela_class] = [node.depth]
                    lnode_text = ' '.join([
                        gold_rst.doc.token_dict[tid].word
                        for tid in node.lnode.text
                    ])
                    lnode_lemmas = ' '.join([
                        gold_rst.doc.token_dict[tid].lemma
                        for tid in node.lnode.text
                    ])
                    rnode_text = ' '.join([
                        gold_rst.doc.token_dict[tid].word
                        for tid in node.rnode.text
                    ])
                    rnode_lemmas = ' '.join([
                        gold_rst.doc.token_dict[tid].lemma
                        for tid in node.rnode.text
                    ])
                    # if rela_class == 'Topic-Change':
                    #     print(fmerge)
                    #     print(relation)
                    #     print(lnode_text)
                    #     print(rnode_text)
                    #     print()

        if report:
            met.report()
Beispiel #10
0
    def eval_parser(self, data_dir, output_dir='./examples', report=False, bcvocab=None, draw=True, isFlat=False):
        """ Test the parsing performance"""
        # Evaluation
        met = Metrics(levels=['span', 'nuclearity', 'relation'])
        # ----------------------------------------
        # Read all files from the given path
        with open(os.path.join(output_dir, "Treebank/TEST", "processed_data.p"), 'rb') as file:
            doclist = pickle.load(file)
        fnames = [fn for fn in os.listdir(data_dir) if fn.endswith(".out")]

        pred_forms = []
        gold_forms = []
        depth_per_relation = {}
        for lines, fname in zip(doclist, fnames):
            # ----------------------------------------
            # Read *.merge file
            doc = Doc()
            doc.read_from_fmerge(lines)
            fout = os.path.join(data_dir, fname)
            print(fout)
            # ----------------------------------------
            # Parsing
            print("************************ predict rst ************************")
            pred_rst = self.parser.sr_parse(doc, self.isFlat,bcvocab)
            if draw:
                pred_rst.draw_rst(fout+'.ps')
            # Get brackets from parsing results
            pred_brackets = pred_rst.bracketing(self.isFlat)
            fbrackets = fout+'.brackets'
            # Write brackets into file
            Evaluator.writebrackets(fbrackets, pred_brackets)
            # ----------------------------------------
            # Evaluate with gold RST tree
            if report:
                print("************************ gold rst ************************")
                fdis = fout+'.dis'
                gold_rst = RstTree(fdis, lines, isFlat)
                gold_rst.build()
                met.eval(gold_rst, pred_rst, self.isFlat)
                if isFlat:
                    for node in pred_rst.postorder_flat_DFT(pred_rst.tree, []):
                        pred_forms.append(node.form)
                    for node in gold_rst.postorder_flat_DFT(gold_rst.tree, []):
                        gold_forms.append(node.form)
                    nodes = gold_rst.postorder_flat_DFT(gold_rst.tree, [])
                else:
                    for node in pred_rst.postorder_DFT(pred_rst.tree, []):
                        pred_forms.append(node.form)
                    for node in gold_rst.postorder_DFT(gold_rst.tree, []):
                        gold_forms.append(node.form)
                    nodes = gold_rst.postorder_DFT(gold_rst.tree, [])
                inner_nodes = [node for node in nodes if node.lnode is not None and node.rnode is not None]
                for idx, node in enumerate(inner_nodes):
                    relation = node.rnode.relation if node.form == 'NS' else node.lnode.relation
                    rela_class = RstTree.extract_relation(relation)
                    if rela_class in depth_per_relation:
                        depth_per_relation[rela_class].append(node.depth)
                    else:
                        depth_per_relation[rela_class] = [node.depth]


        if report:
            met.report()