def build_test_file(self, out_path=OUT_PATH, remove_singleton=True, print_all_mentions=False, debug=None):
        """ Build a test file to supply to the coreference scoring perl script
        """
        print("🌋 Building test file")
        self._prepare_clusters()
        self.dataloader.dataset.no_targets = True
        if not print_all_mentions:
            print("🌋 Build coreference clusters")
            cur_m = 0
            for sample_batched, mentions_idx, n_pairs_l in zip(self.dataloader, self.mentions_idx, self.n_pairs):
                scores, max_i = self.get_max_score(sample_batched)
                for m_idx, ind, n_pairs in zip(mentions_idx, max_i, n_pairs_l):
                    if ind < n_pairs : # the single score is not the highest, we have a match !
                        prev_idx = m_idx - n_pairs + ind
                        if debug is not None and (debug == -1 or debug == prev_idx or debug == m_idx):
                            m1_doc, m1_idx = self.flat_m_idx[m_idx]
                            m1 = self.docs[m1_doc]['mentions'][m1_idx]
                            m2_doc, m2_idx = self.flat_m_idx[prev_idx]
                            m2 = self.docs[m2_doc]['mentions'][m2_idx]
                            print("We have a match between:", m1, "(" + str(m1_idx) + ")", "and:", m2, "(" + str(m2_idx) + ")")
                        self._merge_coreference_clusters(prev_idx, m_idx)
            if remove_singleton:
                self.remove_singletons_clusters()
        self.dataloader.dataset.no_targets = False

        print("🌋 Construct test file")
        out_str = ""
        for doc, d_tokens, d_lookup, d_m_loc, d_m_to_c in zip(self.docs, self.tokens, self.lookup, self.m_loc, self.mention_to_cluster):
            out_str += u"#begin document (" + doc['name'] + u"); part " + doc['part'] + u"\n"
            for utt_idx, (c_tokens, c_lookup) in enumerate(zip(d_tokens, d_lookup)):
                for i, (token, lookup) in enumerate(zip(c_tokens, c_lookup)):
                    out_coref = u""
                    for m_str, mention, mention_cluster in zip(doc['mentions'], d_m_loc, d_m_to_c):
                        m_start, m_end, m_utt, m_idx, m_doc = mention
                        if mention_cluster is None:
                            pass
                        elif m_utt == utt_idx:
                            if m_start in lookup:
                                out_coref += u"|" if out_coref else u""
                                out_coref += u"(" + unicode_(mention_cluster)
                                if (m_end - 1) in lookup:
                                    out_coref += u")"
                                else:
                                    out_coref += u""
                            elif (m_end - 1) in lookup:
                                out_coref += u"|" if out_coref else u""
                                out_coref += unicode_(mention_cluster) + u")"
                    out_line = doc['name'] + u" " + doc['part'] + u" " + unicode_(i) \
                               + u" " + token + u" "
                    out_line += u"-" if len(out_coref) == 0 else out_coref
                    out_str += out_line + u"\n"
                out_str += u"\n"
            out_str += u"#end document\n"

        # Write test file
        print("Writing in", out_path)
        with io.open(out_path, 'w', encoding='utf-8') as out_file:
            out_file.write(out_str)
예제 #2
0
파일: document.py 프로젝트: qiqipipioioi/KP
 def __init__(self, speaker_id, speaker_names=None):
     self.mentions = []
     self.speaker_id = speaker_id
     if speaker_names is None:
         self.speaker_names = [unicode_(speaker_id)]
     elif isinstance(speaker_names, string_types):
         self.speaker_names = [speaker_names]
     elif len(speaker_names) > 1:
         self.speaker_names = speaker_names
     else:
         self.speaker_names = unicode_(speaker_names)
     self.speaker_tokens = [tok.lower() for s in self.speaker_names for tok in re.split(WHITESPACE_PATTERN, s)]
 def display_clusters(self):
     '''
     Print clusters informations
     '''
     print(self.clusters)
     for key, mentions in self.clusters.items():
         print("cluster", key, "(", ", ".join(unicode_(self.data[m]) for m in mentions), ")")
예제 #4
0
 def get_mention_embeddings(self, mention, doc_embedding):
     ''' Get span (averaged) and word (single) embeddings of a mention '''
     st = mention.sent
     mention_lefts = mention.doc[max(mention.start -
                                     5, st.start):mention.start]
     mention_rights = mention.doc[mention.end:min(mention.end + 5, st.end)]
     head = mention.root.head
     spans = [
         self.get_average_embedding(mention),
         self.get_average_embedding(mention_lefts),
         self.get_average_embedding(mention_rights),
         self.get_average_embedding(st),
         (unicode_(doc_embedding[0:8]) + "...", doc_embedding)
     ]
     words = [
         self.get_word_embedding(mention.root),
         self.get_word_embedding(mention[0]),
         self.get_word_embedding(mention[-1]),
         self.get_word_in_sentence(mention.start - 1, st),
         self.get_word_in_sentence(mention.end, st),
         self.get_word_in_sentence(mention.start - 2, st),
         self.get_word_in_sentence(mention.end + 1, st),
         self.get_word_embedding(head)
     ]
     spans_embeddings_ = {
         "00_Mention": spans[0][0],
         "01_MentionLeft": spans[1][0],
         "02_MentionRight": spans[2][0],
         "03_Sentence": spans[3][0],
         "04_Doc": spans[4][0]
     }
     words_embeddings_ = {
         "00_MentionHead": words[0][0],
         "01_MentionFirstWord": words[1][0],
         "02_MentionLastWord": words[2][0],
         "03_PreviousWord": words[3][0],
         "04_NextWord": words[4][0],
         "05_SecondPreviousWord": words[5][0],
         "06_SecondNextWord": words[6][0],
         "07_MentionRootHead": words[7][0]
     }
     return (spans_embeddings_, words_embeddings_,
             np.concatenate([em[1] for em in spans], axis=0),
             np.concatenate([em[1] for em in words], axis=0))
예제 #5
0
 def __str__(self):
     return '<utterances, speakers> \n {}\n<mentions> \n {}' \
             .format('\n '.join(unicode_(i) + " " + unicode_(s) for i, s in zip(self.utterances, self.utterances_speaker)),
                     '\n '.join(unicode_(i) + " " + unicode_(i.speaker) for i in self.mentions))
예제 #6
0
    def read_corpus(self, data_path, debug=False):
        # this function holds the key to constructing the memory module for the conll corpus
        # find the discourse end marker, that holds the key to forming stories for memory inference
        print("🌋 Reading files")
        for dirpath, _, filenames in os.walk(data_path):
            print("In", dirpath, os.path.abspath(dirpath))
            file_list = [os.path.join(dirpath, f) for f in filenames if f.endswith(".v4_auto_conll") \
                        or f.endswith(".v4_gold_conll")]
            cleaned_file_list = []
            for f in file_list:
                fn = f.split('.')
                if fn[1] == "v4_auto_conll":
                    gold = fn[0] + "." + "v4_gold_conll"
                    if gold not in file_list:
                        cleaned_file_list.append(f)
                else:
                    cleaned_file_list.append(f)
            #doc_list = parallel_process(cleaned_file_list, load_file)

            # what is the doc list ?
            #for docs in doc_list:#executor.map(self.load_file, cleaned_file_list):
            for file in cleaned_file_list :
                docs = load_file(file)
                for utts_text, utt_tokens, utts_corefs, utts_speakers, name, part in docs:
                    print("Imported", name)
                    if debug:
                        print("utts_text", utts_text)
                        print("utt_tokens", utt_tokens)
                        print("utts_corefs", utts_corefs)
                        print("utts_speakers", utts_speakers)
                        print("name, part", name, part)
                    self.utts_text += utts_text
                    self.utts_tokens += utt_tokens
                    self.utts_corefs += utts_corefs
                    self.utts_speakers += utts_speakers
                    self.utts_doc_idx += [len(self.docs_names)] * len(utts_text)
                    self.docs_names.append((name, part))
        print("utts_text size", len(self.utts_text))
        print("utts_tokens size", len(self.utts_tokens))
        print("utts_corefs size", len(self.utts_corefs))
        print("utts_speakers size", len(self.utts_speakers))
        print("utts_doc_idx size", len(self.utts_doc_idx))
        print("🌋 Building docs")
        for name, part in self.docs_names:
            self.docs.append(ConllDoc(name=name, part=part, nlp=None,
                                      blacklist=False, consider_speakers=True,
                                      embedding_extractor=self.embed_extractor,
                                      conll=CONLL_GENRES[name[:2]]))
        print("🌋 Loading spacy model")
        try:
            spacy.info('en_core_web_sm')
            model = 'en_core_web_sm'
        except IOError:
            print("No spacy 2 model detected, using spacy1 'en' model")
            spacy.info('en')
            model = 'en'
        nlp = spacy.load(model)
        print("🌋 Parsing utterances and filling docs")
        doc_iter = (s for s in self.utts_text)
        for utt_tuple in tqdm(zip(nlp.pipe(doc_iter),
                                           self.utts_tokens, self.utts_corefs,
                                           self.utts_speakers, self.utts_doc_idx)):
            spacy_tokens, conll_tokens, corefs, speaker, doc_id = utt_tuple
            if debug:  print(unicode_(self.docs_names[doc_id]), "-", spacy_tokens)
            doc = spacy_tokens
            if debug: 
                out_str = "utterance " + unicode_(doc) + " corefs " + unicode_(corefs) + \
                          " speaker " + unicode_(speaker) + "doc_id" + unicode_(doc_id)
                
                print(out_str.encode('utf-8'))
            self.docs[doc_id].add_conll_utterance(doc, conll_tokens, corefs, speaker,
                                                  use_gold_mentions=self.use_gold_mentions)
            del spacy_tokens, conll_tokens, corefs,speaker, doc_id
        del nlp, doc_iter