def build_test_file(self, out_path=OUT_PATH, remove_singleton=True, print_all_mentions=False, debug=None): """ Build a test file to supply to the coreference scoring perl script """ print("🌋 Building test file") self._prepare_clusters() self.dataloader.dataset.no_targets = True if not print_all_mentions: print("🌋 Build coreference clusters") cur_m = 0 for sample_batched, mentions_idx, n_pairs_l in zip(self.dataloader, self.mentions_idx, self.n_pairs): scores, max_i = self.get_max_score(sample_batched) for m_idx, ind, n_pairs in zip(mentions_idx, max_i, n_pairs_l): if ind < n_pairs : # the single score is not the highest, we have a match ! prev_idx = m_idx - n_pairs + ind if debug is not None and (debug == -1 or debug == prev_idx or debug == m_idx): m1_doc, m1_idx = self.flat_m_idx[m_idx] m1 = self.docs[m1_doc]['mentions'][m1_idx] m2_doc, m2_idx = self.flat_m_idx[prev_idx] m2 = self.docs[m2_doc]['mentions'][m2_idx] print("We have a match between:", m1, "(" + str(m1_idx) + ")", "and:", m2, "(" + str(m2_idx) + ")") self._merge_coreference_clusters(prev_idx, m_idx) if remove_singleton: self.remove_singletons_clusters() self.dataloader.dataset.no_targets = False print("🌋 Construct test file") out_str = "" for doc, d_tokens, d_lookup, d_m_loc, d_m_to_c in zip(self.docs, self.tokens, self.lookup, self.m_loc, self.mention_to_cluster): out_str += u"#begin document (" + doc['name'] + u"); part " + doc['part'] + u"\n" for utt_idx, (c_tokens, c_lookup) in enumerate(zip(d_tokens, d_lookup)): for i, (token, lookup) in enumerate(zip(c_tokens, c_lookup)): out_coref = u"" for m_str, mention, mention_cluster in zip(doc['mentions'], d_m_loc, d_m_to_c): m_start, m_end, m_utt, m_idx, m_doc = mention if mention_cluster is None: pass elif m_utt == utt_idx: if m_start in lookup: out_coref += u"|" if out_coref else u"" out_coref += u"(" + unicode_(mention_cluster) if (m_end - 1) in lookup: out_coref += u")" else: out_coref += u"" elif (m_end - 1) in lookup: out_coref += u"|" if out_coref else u"" out_coref += unicode_(mention_cluster) + u")" out_line = doc['name'] + u" " + doc['part'] + u" " + unicode_(i) \ + u" " + token + u" " out_line += u"-" if len(out_coref) == 0 else out_coref out_str += out_line + u"\n" out_str += u"\n" out_str += u"#end document\n" # Write test file print("Writing in", out_path) with io.open(out_path, 'w', encoding='utf-8') as out_file: out_file.write(out_str)
def __init__(self, speaker_id, speaker_names=None): self.mentions = [] self.speaker_id = speaker_id if speaker_names is None: self.speaker_names = [unicode_(speaker_id)] elif isinstance(speaker_names, string_types): self.speaker_names = [speaker_names] elif len(speaker_names) > 1: self.speaker_names = speaker_names else: self.speaker_names = unicode_(speaker_names) self.speaker_tokens = [tok.lower() for s in self.speaker_names for tok in re.split(WHITESPACE_PATTERN, s)]
def display_clusters(self): ''' Print clusters informations ''' print(self.clusters) for key, mentions in self.clusters.items(): print("cluster", key, "(", ", ".join(unicode_(self.data[m]) for m in mentions), ")")
def get_mention_embeddings(self, mention, doc_embedding): ''' Get span (averaged) and word (single) embeddings of a mention ''' st = mention.sent mention_lefts = mention.doc[max(mention.start - 5, st.start):mention.start] mention_rights = mention.doc[mention.end:min(mention.end + 5, st.end)] head = mention.root.head spans = [ self.get_average_embedding(mention), self.get_average_embedding(mention_lefts), self.get_average_embedding(mention_rights), self.get_average_embedding(st), (unicode_(doc_embedding[0:8]) + "...", doc_embedding) ] words = [ self.get_word_embedding(mention.root), self.get_word_embedding(mention[0]), self.get_word_embedding(mention[-1]), self.get_word_in_sentence(mention.start - 1, st), self.get_word_in_sentence(mention.end, st), self.get_word_in_sentence(mention.start - 2, st), self.get_word_in_sentence(mention.end + 1, st), self.get_word_embedding(head) ] spans_embeddings_ = { "00_Mention": spans[0][0], "01_MentionLeft": spans[1][0], "02_MentionRight": spans[2][0], "03_Sentence": spans[3][0], "04_Doc": spans[4][0] } words_embeddings_ = { "00_MentionHead": words[0][0], "01_MentionFirstWord": words[1][0], "02_MentionLastWord": words[2][0], "03_PreviousWord": words[3][0], "04_NextWord": words[4][0], "05_SecondPreviousWord": words[5][0], "06_SecondNextWord": words[6][0], "07_MentionRootHead": words[7][0] } return (spans_embeddings_, words_embeddings_, np.concatenate([em[1] for em in spans], axis=0), np.concatenate([em[1] for em in words], axis=0))
def __str__(self): return '<utterances, speakers> \n {}\n<mentions> \n {}' \ .format('\n '.join(unicode_(i) + " " + unicode_(s) for i, s in zip(self.utterances, self.utterances_speaker)), '\n '.join(unicode_(i) + " " + unicode_(i.speaker) for i in self.mentions))
def read_corpus(self, data_path, debug=False): # this function holds the key to constructing the memory module for the conll corpus # find the discourse end marker, that holds the key to forming stories for memory inference print("🌋 Reading files") for dirpath, _, filenames in os.walk(data_path): print("In", dirpath, os.path.abspath(dirpath)) file_list = [os.path.join(dirpath, f) for f in filenames if f.endswith(".v4_auto_conll") \ or f.endswith(".v4_gold_conll")] cleaned_file_list = [] for f in file_list: fn = f.split('.') if fn[1] == "v4_auto_conll": gold = fn[0] + "." + "v4_gold_conll" if gold not in file_list: cleaned_file_list.append(f) else: cleaned_file_list.append(f) #doc_list = parallel_process(cleaned_file_list, load_file) # what is the doc list ? #for docs in doc_list:#executor.map(self.load_file, cleaned_file_list): for file in cleaned_file_list : docs = load_file(file) for utts_text, utt_tokens, utts_corefs, utts_speakers, name, part in docs: print("Imported", name) if debug: print("utts_text", utts_text) print("utt_tokens", utt_tokens) print("utts_corefs", utts_corefs) print("utts_speakers", utts_speakers) print("name, part", name, part) self.utts_text += utts_text self.utts_tokens += utt_tokens self.utts_corefs += utts_corefs self.utts_speakers += utts_speakers self.utts_doc_idx += [len(self.docs_names)] * len(utts_text) self.docs_names.append((name, part)) print("utts_text size", len(self.utts_text)) print("utts_tokens size", len(self.utts_tokens)) print("utts_corefs size", len(self.utts_corefs)) print("utts_speakers size", len(self.utts_speakers)) print("utts_doc_idx size", len(self.utts_doc_idx)) print("🌋 Building docs") for name, part in self.docs_names: self.docs.append(ConllDoc(name=name, part=part, nlp=None, blacklist=False, consider_speakers=True, embedding_extractor=self.embed_extractor, conll=CONLL_GENRES[name[:2]])) print("🌋 Loading spacy model") try: spacy.info('en_core_web_sm') model = 'en_core_web_sm' except IOError: print("No spacy 2 model detected, using spacy1 'en' model") spacy.info('en') model = 'en' nlp = spacy.load(model) print("🌋 Parsing utterances and filling docs") doc_iter = (s for s in self.utts_text) for utt_tuple in tqdm(zip(nlp.pipe(doc_iter), self.utts_tokens, self.utts_corefs, self.utts_speakers, self.utts_doc_idx)): spacy_tokens, conll_tokens, corefs, speaker, doc_id = utt_tuple if debug: print(unicode_(self.docs_names[doc_id]), "-", spacy_tokens) doc = spacy_tokens if debug: out_str = "utterance " + unicode_(doc) + " corefs " + unicode_(corefs) + \ " speaker " + unicode_(speaker) + "doc_id" + unicode_(doc_id) print(out_str.encode('utf-8')) self.docs[doc_id].add_conll_utterance(doc, conll_tokens, corefs, speaker, use_gold_mentions=self.use_gold_mentions) del spacy_tokens, conll_tokens, corefs,speaker, doc_id del nlp, doc_iter