예제 #1
0
 def __str__(self):
     formatted = "\n ".join(
         unicode_(i) + " " + unicode_(s)
         for i, s in zip(self.utterances, self.utterances_speaker))
     mentions = "\n ".join(
         unicode_(i) + " " + unicode_(i.speaker) for i in self.mentions)
     return f"<utterances, speakers> \n {formatted}\n<mentions> \n {mentions}"
예제 #2
0
 def __init__(self, speaker_id, speaker_names=None):
     self.mentions = []
     self.speaker_id = speaker_id
     if speaker_names is None:
         self.speaker_names = [unicode_(speaker_id)]
     elif isinstance(speaker_names, string_types):
         self.speaker_names = [speaker_names]
     elif len(speaker_names) > 1:
         self.speaker_names = speaker_names
     else:
         self.speaker_names = unicode_(speaker_names)
     self.speaker_tokens = [tok.lower() for s in self.speaker_names for tok in re.split(WHITESPACE_PATTERN, s)]
예제 #3
0
 def display_clusters(self):
     '''
     Print clusters informations
     '''
     print(self.clusters)
     for key, mentions in self.clusters.items():
         print("cluster", key, "(",
               ", ".join(unicode_(self.data[m]) for m in mentions), ")")
예제 #4
0
 def get_mention_embeddings(self, mention, doc_embedding):
     """ Get span (averaged) and word (single) embeddings of a mention """
     st = mention.sent
     mention_lefts = mention.doc[max(mention.start -
                                     5, st.start):mention.start]
     mention_rights = mention.doc[mention.end:min(mention.end + 5, st.end)]
     head = mention.root.head
     spans = [
         self.get_average_embedding(mention),
         self.get_average_embedding(mention_lefts),
         self.get_average_embedding(mention_rights),
         self.get_average_embedding(st),
         (unicode_(doc_embedding[0:8]) + "...", doc_embedding),
     ]
     words = [
         self.get_word_embedding(mention.root),
         self.get_word_embedding(mention[0]),
         self.get_word_embedding(mention[-1]),
         self.get_word_in_sentence(mention.start - 1, st),
         self.get_word_in_sentence(mention.end, st),
         self.get_word_in_sentence(mention.start - 2, st),
         self.get_word_in_sentence(mention.end + 1, st),
         self.get_word_embedding(head),
     ]
     spans_embeddings_ = {
         "00_Mention": spans[0][0],
         "01_MentionLeft": spans[1][0],
         "02_MentionRight": spans[2][0],
         "03_Sentence": spans[3][0],
         "04_Doc": spans[4][0],
     }
     words_embeddings_ = {
         "00_MentionHead": words[0][0],
         "01_MentionFirstWord": words[1][0],
         "02_MentionLastWord": words[2][0],
         "03_PreviousWord": words[3][0],
         "04_NextWord": words[4][0],
         "05_SecondPreviousWord": words[5][0],
         "06_SecondNextWord": words[6][0],
         "07_MentionRootHead": words[7][0],
     }
     return (
         spans_embeddings_,
         words_embeddings_,
         np.concatenate([em[1] for em in spans], axis=0),
         np.concatenate([em[1] for em in words], axis=0),
     )
예제 #5
0
    def build_test_file(self,
                        out_path=OUT_PATH,
                        remove_singleton=True,
                        print_all_mentions=False,
                        debug=None):
        """ Build a test file to supply to the coreference scoring perl script
        """
        print("🌋 Building test file")
        self._prepare_clusters()
        self.dataloader.dataset.no_targets = True
        if not print_all_mentions:
            print("🌋 Build coreference clusters")
            for sample_batched, mentions_idx, n_pairs_l in zip(
                    self.dataloader, self.mentions_idx, self.n_pairs):
                scores, max_i = self.get_max_score(sample_batched)
                for m_idx, ind, n_pairs in zip(mentions_idx, max_i, n_pairs_l):
                    if ind < n_pairs:  # the single score is not the highest, we have a match !
                        prev_idx = m_idx - n_pairs + ind
                        if debug is not None and (debug == -1
                                                  or debug == prev_idx
                                                  or debug == m_idx):
                            m1_doc, m1_idx = self.flat_m_idx[m_idx]
                            m1 = self.docs[m1_doc]['mentions'][m1_idx]
                            m2_doc, m2_idx = self.flat_m_idx[prev_idx]
                            m2 = self.docs[m2_doc]['mentions'][m2_idx]
                            print("We have a match between:", m1,
                                  "(" + str(m1_idx) + ")", "and:", m2,
                                  "(" + str(m2_idx) + ")")
                        self._merge_coreference_clusters(prev_idx, m_idx)
            if remove_singleton:
                self.remove_singletons_clusters()
        self.dataloader.dataset.no_targets = False

        print("🌋 Construct test file")
        out_str = ""
        for doc, d_tokens, d_lookup, d_m_loc, d_m_to_c in zip(
                self.docs, self.tokens, self.lookup, self.m_loc,
                self.mention_to_cluster):
            out_str += u"#begin document (" + doc['name'] + u"); part " + doc[
                'part'] + u"\n"
            for utt_idx, (c_tokens,
                          c_lookup) in enumerate(zip(d_tokens, d_lookup)):
                for i, (token, lookup) in enumerate(zip(c_tokens, c_lookup)):
                    out_coref = u""
                    for m_str, mention, mention_cluster in zip(
                            doc['mentions'], d_m_loc, d_m_to_c):
                        m_start, m_end, m_utt, m_idx, m_doc = mention
                        if mention_cluster is None:
                            pass
                        elif m_utt == utt_idx:
                            if m_start in lookup:
                                out_coref += u"|" if out_coref else u""
                                out_coref += u"(" + unicode_(mention_cluster)
                                if (m_end - 1) in lookup:
                                    out_coref += u")"
                                else:
                                    out_coref += u""
                            elif (m_end - 1) in lookup:
                                out_coref += u"|" if out_coref else u""
                                out_coref += unicode_(mention_cluster) + u")"
                    out_line = doc['name'] + u" " + doc['part'] + u" " + unicode_(i) \
                               + u" " + token + u" "
                    out_line += u"-" if len(out_coref) == 0 else out_coref
                    out_str += out_line + u"\n"
                out_str += u"\n"
            out_str += u"#end document\n"

        # Write test file
        print("Writing in", out_path)
        with io.open(out_path, 'w', encoding='utf-8') as out_file:
            out_file.write(out_str)
예제 #6
0
 def __str__(self):
     return '<utterances, speakers> \n {}\n<mentions> \n {}' \
             .format('\n '.join(unicode_(i) + " " + unicode_(s) for i, s in zip(self.utterances, self.utterances_speaker)),
                     '\n '.join(unicode_(i) + " " + unicode_(i.speaker) for i in self.mentions))
예제 #7
0
    def read_corpus(self, data_path, model=None, debug=False):
        print("🌋 Reading files")
        for dirpath, _, filenames in os.walk(data_path):
            print("In", dirpath, os.path.abspath(dirpath))
            file_list = [
                os.path.join(dirpath, f) for f in filenames
                if f.endswith(".v4_auto_conll") or f.endswith(".v4_gold_conll")
            ]
            cleaned_file_list = []
            for f in file_list:
                fn = f.split(".")
                if fn[1] == "v4_auto_conll":
                    gold = fn[0] + "." + "v4_gold_conll"
                    if gold not in file_list:
                        cleaned_file_list.append(f)
                else:
                    cleaned_file_list.append(f)
            doc_list = parallel_process(cleaned_file_list, load_file)
            for docs in doc_list:  # executor.map(self.load_file, cleaned_file_list):
                for (
                        utts_text,
                        utt_tokens,
                        utts_corefs,
                        utts_speakers,
                        name,
                        part,
                ) in docs:
                    if debug:
                        print("Imported", name)
                        print("utts_text", utts_text)
                        print("utt_tokens", utt_tokens)
                        print("utts_corefs", utts_corefs)
                        print("utts_speakers", utts_speakers)
                        print("name, part", name, part)
                    self.utts_text += utts_text
                    self.utts_tokens += utt_tokens
                    self.utts_corefs += utts_corefs
                    self.utts_speakers += utts_speakers
                    self.utts_doc_idx += [len(self.docs_names)
                                          ] * len(utts_text)
                    self.docs_names.append((name, part))
        print("utts_text size", len(self.utts_text))
        print("utts_tokens size", len(self.utts_tokens))
        print("utts_corefs size", len(self.utts_corefs))
        print("utts_speakers size", len(self.utts_speakers))
        print("utts_doc_idx size", len(self.utts_doc_idx))
        print("🌋 Building docs")
        for name, part in self.docs_names:
            self.docs.append(
                ConllDoc(
                    name=name,
                    part=part,
                    nlp=None,
                    blacklist=self.blacklist,
                    consider_speakers=True,
                    embedding_extractor=self.embed_extractor,
                    conll=CONLL_GENRES[name[:2]],
                ))
        print("🌋 Loading spacy model")

        if model is None:
            model_options = [
                "en_core_web_lg", "en_core_web_md", "en_core_web_sm", "en"
            ]
            for model_option in model_options:
                if not model:
                    try:
                        spacy.info(model_option)
                        model = model_option
                        print("Loading model", model_option)
                    except:
                        print("Could not detect model", model_option)
            if not model:
                print("Could not detect any suitable English model")
                return
        else:
            spacy.info(model)
            print("Loading model", model)
        nlp = spacy.load(model)
        print("🌋 Parsing utterances and filling docs with use_gold_mentions=" +
              (str(bool(self.gold_mentions))))
        doc_iter = (s for s in self.utts_text)
        for utt_tuple in tqdm(
                zip(
                    nlp.pipe(doc_iter),
                    self.utts_tokens,
                    self.utts_corefs,
                    self.utts_speakers,
                    self.utts_doc_idx,
                )):
            spacy_tokens, conll_tokens, corefs, speaker, doc_id = utt_tuple
            if debug:
                print(unicode_(self.docs_names[doc_id]), "-", spacy_tokens)
            doc = spacy_tokens
            if debug:
                out_str = ("utterance " + unicode_(doc) + " corefs " +
                           unicode_(corefs) + " speaker " + unicode_(speaker) +
                           "doc_id" + unicode_(doc_id))
                print(out_str.encode("utf-8"))
            self.docs[doc_id].add_conll_utterance(
                doc,
                conll_tokens,
                corefs,
                speaker,
                use_gold_mentions=self.gold_mentions)