class ElmoEncoder: default_max_layer = 2 def __init__(self): self.model = ElmoEmbedder(PRETRAINED_ELMO_OPTION_URL, PRETRAINED_ELMO_WEIGHT_URL) self.pretrained_name = 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights' self.tokenizer = WordTokenizer() def encode(self, data, file_prefix=DEFAULT_FILE_PREFIX): emb_pkl_path = embed_elements_path(file_prefix, self.pretrained_name) if os.path.isfile(emb_pkl_path): with open(emb_pkl_path, 'rb') as f: entities, relations = pickle.load(f) return entities, relations ent_embs = [] rel_embs = [] for triple in data.triples: sent_tokens = self.tokenizer.batch_tokenize(triple) sbj_p, rel_p, obj_p = [len(tokens) for tokens in sent_tokens] tokens = [token.text for token in sum(sent_tokens, [])] hid_states = self.model.embed_sentence(tokens) hid_states = torch.from_numpy(hid_states) sbj_states = hid_states[:, :sbj_p, :] ent_embs.append(torch.mean(sbj_states, axis=1)) rel_p_padded = sbj_p + rel_p rel_states = hid_states[:, sbj_p:rel_p_padded, :] rel_embs.append(torch.mean(rel_states, axis=1)) obj_p_padded = rel_p_padded + obj_p obj_states = hid_states[:, rel_p_padded:obj_p_padded, :] ent_embs.append(torch.mean(obj_states, axis=1)) ent_embs = torch.stack(ent_embs, axis=0) rel_embs = torch.stack(rel_embs, axis=0) with open(emb_pkl_path, 'wb') as f: pickle.dump((ent_embs, rel_embs), f, protocol=pickle.HIGHEST_PROTOCOL) return ent_embs, rel_embs
class ProcessStory: def __init__(self, sentence_splitter): self._sentence_splitter = sentence_splitter self._word_tokenizer = WordTokenizer() def __call__(self, lines: List[List[str]], story_ids: List[int]) -> Tuple: story_metrics = [] story_sentences_to_save = [] story_sentences_split = self._sentence_splitter.batch_split_sentences(lines) for sentences, story_id in zip(story_sentences_split, story_ids): try: tokenized_sentences = self._word_tokenizer.batch_tokenize(sentences) total_story_tokens = 0 for i, sentence in enumerate(tokenized_sentences): start_span = total_story_tokens sentence_len = len(sentence) total_story_tokens += sentence_len end_span = total_story_tokens text = " ".join([s.text for s in sentence]) story_sentences_to_save.append( dict(sentence_num=i, story_id=story_id, text=text, sentence_len=sentence_len, start_span=start_span, end_span=end_span)) story_metrics.append(dict(sentence_num=len(sentences), tokens_num=total_story_tokens, id=story_id)) except Exception as e: logging.error(e) return story_ids, story_sentences_to_save, story_metrics
async def save_coreferences(coreference_model: Model, dataset_db: str, cuda_device: Union[List[int], int] = None, save_batch_size: int = 25, sentence_chunks: int = 200): with dataset.connect(dataset_db, engine_kwargs=engine_kwargs) as db: coref_table = db.create_table('coreference') coref_table.create_column('story_id', db.types.bigint) coref_table.create_column('coref_id', db.types.integer) coref_table.create_column('start_span', db.types.integer) coref_table.create_column('end_span', db.types.integer) coref_table.create_column('mention_text', db.types.string) coref_table.create_column('context_text', db.types.string) coref_table.create_index(['story_id']) coref_table.create_index(['start_span']) coref_table.create_index(['end_span']) gpu_max_workers = 1 if isinstance(cuda_device, (list, tuple)): gpu_max_workers = len(cuda_device) gpus = cuda_device else: gpus = [cuda_device] word_tokenizer = WordTokenizer() loop = asyncio.get_event_loop() with ThreadPoolExecutor(max_workers=gpu_max_workers) as executor: processors = [] for gpu in gpus: processors.append(CoreferenceProcessor(coreference_model, dataset_db, cuda_device=gpu)) processors_cycle = itertools.cycle(processors) tasks = [] # Order by shortest to longest so possible failures are at the end. for story in db['story'].find(order_by=['sentence_num', 'id']): sentence_list = [s["text"] for s in db["sentence"].find(story_id=story["id"], order_by='id')] sentence_tokens = word_tokenizer.batch_tokenize(sentence_list) for sentence_chunk in more_itertools.chunked(sentence_tokens, n=sentence_chunks): sentence_chunk_flat = list(more_itertools.flatten(sentence_chunk)) if len(sentence_chunk_flat) < 10: continue sentence_chunk_text = [t.text for t in sentence_chunk_flat] tasks.append(loop.run_in_executor(executor, next(processors_cycle), sentence_chunk_text, story["id"])) if len(tasks) == save_batch_size: results = await asyncio.gather(*tasks) for coref_to_save in results: try: db["coreference"].insert_many(copy.deepcopy(coref_to_save)) except Exception as e: logging.error(e) tasks = [] results = await asyncio.gather(*tasks) for coref_to_save in results: try: db["coreference"].insert_many(coref_to_save) except Exception as e: logging.error(e) logger.info(f"Coreferences Saved")
class DocumentOracleDerivation(object): def __init__(self, min_combination_num: int = 3, max_combination_num: int = 5, rm_stop_word: bool = True, synonyms: bool = True, stem: bool = False, tokenization: bool = True, beam_sz: int = 5, candidate_percent: float = 1.0): self.min_combination_num = min_combination_num self.max_combination_num = max_combination_num self.rm_stop_word = rm_stop_word self.stem = stem self.tokenization = tokenization self.beam_sz = beam_sz self.candidate_percent = candidate_percent if self.stem: self.stemmer = PorterStemmer().stem_word else: self.stemmer = lambda x: x self.synonyms = synonyms if self.tokenization: from allennlp.data.tokenizers.word_tokenizer import WordTokenizer self.tokenizer = WordTokenizer() if self.rm_stop_word: self.stop_words = list(set(stopwords.words('english'))) + [x for x in string.punctuation] + ['``', '\'\''] else: self.stop_words = [] def get_rouge_w_annotation_ready_to_use(self, gold_tokens: List[str], pred_tokens: List[str]): gold_lower = list(set([x.lower() for x in gold_tokens])) gold_wo_stop = [x for x in gold_lower if x not in self.stop_words] # change of index gold_wo_stop = replace_w_morphy(gold_wo_stop) gold_stem = [ps.stem(x) for x in gold_wo_stop] pred_lower = list([x.lower() for x in pred_tokens]) pred_lower = replace_w_morphy(pred_lower) pred_lower = remove_duplicate_tok(pred_lower) pred_stem = [ps.stem(x) for x in pred_lower] pred_stem = remove_duplicate_tok(pred_stem) size_of_gold = len(gold_stem) size_of_pred = len(pred_stem) gold_key, gold_value = [], [] for idx, word in enumerate(gold_wo_stop): # for one gold word, we have a minigroup _tmp = [] if word in pred_lower: _tmp.append(word) elif word in pred_stem: _tmp.append(word) elif gold_stem[idx] in pred_lower: _tmp.append(gold_stem[idx]) elif gold_stem[idx] in pred_stem: _tmp.append(gold_stem[idx]) # if word or stm word could match, we don't need to search syn if _tmp != []: _tmp = _tmp[0] gold_key.append(_tmp) gold_value.append(1) else: if word not in cache_for_th: try: cache_for_th[word] = flatten(th.Word(word).synonyms('all', relevance=[3])) except: cache_for_th[word] = [] if gold_stem[idx] not in cache_for_th: try: cache_for_th[gold_stem[idx]] = flatten( th.Word(gold_stem[idx]).synonyms('all', relevance=[3])) except: cache_for_th[gold_stem[idx]] = [] syn = cache_for_th[word] syn_stem = cache_for_th[gold_stem[idx]] syn = list(set(syn + syn_stem)) # print(syn) l_syn = len(syn) if l_syn != 0: gold_key += syn gold_value += [float(1 / l_syn)] * l_syn gold_tokens = [ps.stem(x) for x in gold_key] # pred_set = set(pred) # comp intersection vs = 0 key_index = [] for p_idx in range(len(pred_lower)): p_word = pred_lower[p_idx] p_stem_word = pred_stem[p_idx] if p_word in gold_key: idx = gold_key.index(p_word) v = gold_value[idx] vs += v key_index.append(p_idx) elif p_stem_word in gold_tokens: idx = gold_tokens.index(p_stem_word) v = gold_value[idx] vs += v key_index.append(p_idx) rouge_recall_1 = 0 if size_of_gold != 0: rouge_recall_1 = vs / float(size_of_gold) rouge_pre_1 = 0 if size_of_pred != 0: rouge_pre_1 = vs / float(size_of_pred) # print(rouge_recall_1, rouge_pre_1) # assert rouge_recall_1 <= 1 # assert rouge_pre_1 <= 1 if random.random() < 0.00001: print("Recall: {}\tPre: {}".format(rouge_recall_1, rouge_pre_1)) print(pred_tokens) customed_recall = rouge_recall_1 + rouge_pre_1 * 0.01 - 0.01 f1 = 0 if (rouge_recall_1 + rouge_pre_1 == 0) else 2 * (rouge_recall_1 * rouge_pre_1) / ( rouge_recall_1 + rouge_pre_1) return customed_recall, f1, key_index # f1 = 0 if (rouge_recall_1 + rouge_pre_1 == 0) else 2 * (rouge_recall_1 * rouge_pre_1) / ( # rouge_recall_1 + rouge_pre_1) # f1 = rouge_recall_1 * 5 + rouge_pre_1 def comp_num_seg_out_of_p_sent_beam(self, _filtered_doc_list, num_sent_in_combination, target_ref_sum_list, map_from_new_to_ori_idx) -> dict: beam: List[dict] = [] if len(_filtered_doc_list) < num_sent_in_combination: return {"nlabel": num_sent_in_combination, "data": {}, "best": None } combs = list(range(0, len(_filtered_doc_list))) # _num_edu seq_len cur_beam = { "in": [], "todo": combs, "val": 0 } beam.append(cur_beam) for t in range(num_sent_in_combination): dict_pattern = {} # compute top beam_sz for every beam global_board = [] for b in beam: already_in_beam = b['in'] todo = b['todo'] leaderboard = {} for to_add in todo: after_add = already_in_beam + [to_add] candidate_doc_list = list(itertools.chain.from_iterable([_filtered_doc_list[i] for i in after_add])) # average_f_score = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list) _, average_f_score, _ = self.get_rouge_w_annotation_ready_to_use(gold_tokens=target_ref_sum_list, pred_tokens=candidate_doc_list) leaderboard[to_add] = average_f_score sorted_beam = [(k, leaderboard[k]) for k in sorted(leaderboard, key=leaderboard.get, reverse=True)] for it in sorted_beam: new_in = already_in_beam + [it[0]] new_in.sort() str_new_in = [str(x) for x in new_in] if '_'.join(str_new_in) in dict_pattern: continue else: dict_pattern['_'.join(str_new_in)] = True new_list = todo.copy() new_list.remove(it[0]) _beam = { "in": new_in, "todo": new_list, "val": it[1] } global_board.append(_beam) # merge and get the top beam_sz among all sorted_global_board = sorted(global_board, key=lambda x: x["val"], reverse=True) _cnt = 0 check_dict = [] beam_waitlist = [] for it in sorted_global_board: str_in = sorted(it['in']) str_in = [str(x) for x in str_in] _tmp_key = '_'.join(str_in) if _tmp_key in check_dict: continue else: beam_waitlist.append(it) check_dict.append(_tmp_key) _cnt += 1 if _cnt >= self.beam_sz: break beam = beam_waitlist # if len(beam) < 2: # print(len(_filtered_doc_list)) # print(_num_edu) # Write oracle to a string like: 0.4 0.3 0.4 _comb_bag = {} for it in beam: n_comb = it['in'] n_comb.sort() n_comb_original = [map_from_new_to_ori_idx[a] for a in n_comb] n_comb_original.sort() # json label n_comb_original = [int(x) for x in n_comb_original] candidate_doc_list = list(itertools.chain.from_iterable([_filtered_doc_list[i] for i in n_comb])) # f1 = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list) _, f1, _ = self.get_rouge_w_annotation_ready_to_use(target_ref_sum_list, candidate_doc_list) # f_avg = (f1 + f2 + fl) / 3 _comb_bag[f1] = {"label": n_comb_original, "R1": f1, "nlabel": num_sent_in_combination} # print(len(_comb_bag)) if len(_comb_bag) == 0: return {"nlabel": num_sent_in_combination, "data": {}, "best": None } else: best_key = sorted(_comb_bag.keys(), reverse=True)[0] rt_dict = {"nlabel": num_sent_in_combination, "data": _comb_bag, "best": _comb_bag[best_key] } return rt_dict def derive_doc_oracle(self, doc_list: List[str], ref_sum: str, prefix_summary: str = "" ): processed_doc_list, processed_ref_sum_str, processed_prefix_sum_str = [], '', '' if self.tokenization: token_doc_list = self.tokenizer.batch_tokenize(doc_list) for doc in token_doc_list: processed_doc_list.append([word.text for word in doc]) processed_ref_sum_list = [w.text for w in self.tokenizer.tokenize(ref_sum)] processed_prefix_sum_list = [w.text for w in self.tokenizer.tokenize(prefix_summary)] else: processed_doc_list = [d.split(" ") for d in doc_list] processed_ref_sum_list = ref_sum.split(" ") processed_prefix_sum_list = prefix_summary.split(" ") processed_doc_list = [[x.lower() for x in sent] for sent in processed_doc_list] processed_ref_sum_list = [x.lower() for x in processed_ref_sum_list] processed_prefix_sum_list = [x.lower() for x in processed_prefix_sum_list] if self.rm_stop_word: processed_doc_list = [[x for x in sent if x not in self.stop_words] for sent in processed_doc_list] processed_ref_sum_list = [x for x in processed_ref_sum_list if x not in self.stop_words] processed_prefix_sum_list = [x for x in processed_prefix_sum_list if x not in self.stop_words] target_ref_sum_list = [x for x in processed_ref_sum_list if x not in processed_prefix_sum_list] # preprocessing finished filtered_doc_list, map_from_new_to_ori_idx = self.pre_prune(processed_doc_list, target_ref_sum_list) combination_data_dict = {} for num_sent_in_combination in range(self.min_combination_num, self.max_combination_num): combination_data = self.comp_num_seg_out_of_p_sent_beam(_filtered_doc_list=filtered_doc_list, num_sent_in_combination=num_sent_in_combination, target_ref_sum_list=target_ref_sum_list, map_from_new_to_ori_idx=map_from_new_to_ori_idx) combination_data_dict = {**combination_data_dict, **combination_data['data']} combination_data_dict[num_sent_in_combination] = combination_data return combination_data_dict def pre_prune(self, list_of_doc: List[List[str]], ref_sum: List[str] ): keep_candidate_num = math.ceil(len(list_of_doc) * self.candidate_percent) # f_score_list = [self.get_approximate_rouge(ref_sum, x) for x in list_of_doc] f_score_list = [self.get_rouge_w_annotation_ready_to_use(ref_sum, x)[1] for x in list_of_doc] top_p_sent_idx = numpy.argsort(f_score_list)[-keep_candidate_num:] map_from_new_to_ori_idx = [] # filter filtered_doc_list = [] for i in range(len(top_p_sent_idx)): filtered_doc_list.append(list_of_doc[top_p_sent_idx[i]]) map_from_new_to_ori_idx.append(top_p_sent_idx[i]) return filtered_doc_list, map_from_new_to_ori_idx
class DocumentOracleDerivation(object): def __init__(self, mixed_combination: bool, min_combination_num: int = 1, max_combination_num: int = 8, rm_stop_word: bool = True, stem: bool = False, morphy: bool = False, tokenization: bool = True, beam_sz: int = 5, prune_candidate_percent: float = 0.4): self.mixed_combination = mixed_combination self.min_combination_num = min_combination_num self.max_combination_num = max_combination_num self.rm_stop_word = rm_stop_word self.stem = stem self.tokenization = tokenization self.beam_sz = beam_sz self.prune_candidate_percent = prune_candidate_percent if self.stem: self.stemmer = PorterStemmer().stem_word else: self.stemmer = lambda x: x self.morphy = morphy if self.tokenization: from allennlp.data.tokenizers.word_tokenizer import WordTokenizer self.tokenizer = WordTokenizer() if self.rm_stop_word: self.stop_words = list(set(stopwords.words('english'))) + [ x for x in string.punctuation ] + ['``', '\'\''] else: self.stop_words = [] def derive_doc_oracle( self, doc_list: List[str], ref_sum: str, prefix_summary: str = "", ): # return a dict where key=rouge-f1 and value= [0,0,0,1,0,1,0,...] same size as doc_list # processed_doc_list, processed_ref_sum_str, processed_prefix_sum_str = [], '', '' len_of_doc = len(doc_list) processed_doc_list = [self._rouge_clean(x) for x in doc_list] processed_ref_sum_str = self._rouge_clean(ref_sum) processed_prefix_sum_str = self._rouge_clean(prefix_summary) if self.tokenization: new_processed_doc_list = [] token_doc_list = self.tokenizer.batch_tokenize(processed_doc_list) for doc in token_doc_list: new_processed_doc_list.append([word.text for word in doc]) processed_doc_list = new_processed_doc_list processed_ref_sum_list = [ w.text for w in self.tokenizer.tokenize(processed_ref_sum_str) ] processed_prefix_sum_list = [ w.text for w in self.tokenizer.tokenize(processed_prefix_sum_str) ] else: processed_doc_list = [d.split(" ") for d in processed_doc_list] processed_ref_sum_list = processed_ref_sum_str.split(" ") processed_prefix_sum_list = processed_prefix_sum_str.split(" ") # must do lower processed_doc_list = [[x.lower() for x in sent] for sent in processed_doc_list] processed_ref_sum_list = [x.lower() for x in processed_ref_sum_list] processed_prefix_sum_list = [ x.lower() for x in processed_prefix_sum_list ] # if self.rm_stop_word: # processed_doc_list = [[x for x in sent if x not in self.stop_words] for sent in processed_doc_list] # processed_ref_sum_list = [x for x in processed_ref_sum_list if x not in self.stop_words] # processed_prefix_sum_list = [x for x in processed_prefix_sum_list if x not in self.stop_words] target_ref_sum_list = [ x for x in processed_ref_sum_list if x not in processed_prefix_sum_list ] # TODO f_score_list, score_matrix = self.iter_rouge(processed_doc_list, target_ref_sum_list) # preprocessing finished filtered_doc_list, map_from_new_to_ori_idx = self.pre_prune( processed_doc_list, target_ref_sum_list) combination_data_dict = {} for num_sent_in_combination in range(self.min_combination_num, self.max_combination_num): combination_data = self.comp_num_seg_out_of_p_sent_beam( _filtered_doc_list=filtered_doc_list, num_sent_in_combination=num_sent_in_combination, target_ref_sum_list=target_ref_sum_list, map_from_new_to_ori_idx=map_from_new_to_ori_idx) if combination_data['best'] is None: break best_rouge_of_this_batch = combination_data['best']['R1'] if len(combination_data_dict) >= self.beam_sz: rouge_in_bag = [ float(k) for k, v in combination_data_dict.items() ] if best_rouge_of_this_batch < min(rouge_in_bag): break combination_data_dict = { **combination_data_dict, **combination_data['data'] } combination_data_dict = collections.OrderedDict( sorted(combination_data_dict.items(), reverse=True)) sliced = islice(combination_data_dict.items(), self.beam_sz) combination_data_dict = collections.OrderedDict(sliced) # combination_data_dict[num_sent_in_combination] = combination_data # prepare return data return_dict = {} for k, v in combination_data_dict.items(): # tmp_list = [0 for _ in range(len_of_doc)] # for i in v['label']: # tmp_list[i] = 1 return_dict[k] = v['label'] return return_dict def iter_rouge(self, list_of_doc, ref_sum): f_score_list = [ self.get_rouge_ready_to_use(ref_sum, x) for x in list_of_doc ] # score_matrix_delta = [[0 for _ in range(len(list_of_doc))] for _ in range(len(list_of_doc))] score_matrix = [[0 for _ in range(len(list_of_doc))] for _ in range(len(list_of_doc))] input = [] for idx, x in enumerate(list_of_doc): for jdx, y in enumerate(list_of_doc): input.append((idx, jdx, ref_sum, x + y)) s = self.get_rouge_ready_to_use(ref_sum, x + y) score_matrix[idx][jdx] = s # if f_score_list[idx] < 0.01: # # score_matrix_delta[idx][jdx] = 0 # else: # score_matrix_delta[idx][jdx] = min(s / (f_score_list[idx] + 0.001), 2) # import numpy as np # np.set_printoptions(precision=2) # import seaborn as sns # sns.set() # f_score_list = np.asarray([f_score_list, f_score_list]) # bx = sns.heatmap(f_score_list) # fig = bx.get_figure() # fig.savefig("individual_output.png") # print('-' * 30) # print(np.asarray(score_matrix)) # score_matrix_delta = np.asarray(score_matrix_delta) # ax = sns.heatmap(score_matrix_delta) # fig = ax.get_figure() # fig.savefig("output.png") # ncpu=multiprocessing.cpu_count() # pool = multiprocessing.Pool(processes=ncpu) # results = pool.starmap(self.get_rouge_ready_to_use, input) # for r in results: # score, idx,jdx = r # score_matrix[idx][jdx] = score return f_score_list, score_matrix def comp_num_seg_out_of_p_sent_beam(self, _filtered_doc_list, num_sent_in_combination, target_ref_sum_list, map_from_new_to_ori_idx) -> dict: beam: List[dict] = [] if len(_filtered_doc_list) < num_sent_in_combination: return { "nlabel": num_sent_in_combination, "data": {}, "best": None } combs = list(range(0, len(_filtered_doc_list))) # _num_edu seq_len cur_beam = {"in": [], "todo": combs, "val": 0} beam.append(cur_beam) for t in range(num_sent_in_combination): dict_pattern = {} # compute top beam_sz for every beam global_board = [] for b in beam: already_in_beam = b['in'] todo = b['todo'] leaderboard = {} for to_add in todo: after_add = already_in_beam + [to_add] candidate_doc_list = list( itertools.chain.from_iterable( [_filtered_doc_list[i] for i in after_add])) # average_f_score = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list) average_f_score = self.get_rouge_ready_to_use( gold_tokens=target_ref_sum_list, pred_tokens=candidate_doc_list) leaderboard[to_add] = average_f_score sorted_beam = [(k, leaderboard[k]) for k in sorted( leaderboard, key=leaderboard.get, reverse=True)] for it in sorted_beam: new_in = already_in_beam + [it[0]] new_in.sort() str_new_in = [str(x) for x in new_in] if '_'.join(str_new_in) in dict_pattern: continue else: dict_pattern['_'.join(str_new_in)] = True new_list = todo.copy() new_list.remove(it[0]) _beam = {"in": new_in, "todo": new_list, "val": it[1]} global_board.append(_beam) # merge and get the top beam_sz among all sorted_global_board = sorted(global_board, key=lambda x: x["val"], reverse=True) _cnt = 0 check_dict = [] beam_waitlist = [] for it in sorted_global_board: str_in = sorted(it['in']) str_in = [str(x) for x in str_in] _tmp_key = '_'.join(str_in) if _tmp_key in check_dict: continue else: beam_waitlist.append(it) check_dict.append(_tmp_key) _cnt += 1 if _cnt >= self.beam_sz: break beam = beam_waitlist # if len(beam) < 2: # print(len(_filtered_doc_list)) # print(_num_edu) # Write oracle to a string like: 0.4 0.3 0.4 _comb_bag = {} for it in beam: n_comb = it['in'] n_comb.sort() n_comb_original = [map_from_new_to_ori_idx[a] for a in n_comb] n_comb_original.sort() # json label n_comb_original = [int(x) for x in n_comb_original] candidate_doc_list = list( itertools.chain.from_iterable( [_filtered_doc_list[i] for i in n_comb])) # f1 = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list) f1 = self.get_rouge_ready_to_use(target_ref_sum_list, candidate_doc_list) # f_avg = (f1 + f2 + fl) / 3 _comb_bag[f1] = { "label": n_comb_original, "R1": f1, "nlabel": num_sent_in_combination } # print(len(_comb_bag)) if len(_comb_bag) == 0: return { "nlabel": num_sent_in_combination, "data": {}, "best": None } else: best_key = sorted(_comb_bag.keys(), reverse=True)[0] rt_dict = { "nlabel": num_sent_in_combination, "data": _comb_bag, "best": _comb_bag[best_key] } return rt_dict @staticmethod def _rouge_clean(s): return re.sub(r'[^a-zA-Z0-9 ]', '', s) def get_rouge_ready_to_use_w_index(self, gold_tokens: List[str], pred_tokens: List[str], idx, jdx): return self.get_rouge_ready_to_use(gold_tokens, pred_tokens), idx, jdx # No synomous standard version def get_rouge_ready_to_use(self, gold_tokens: List[str], pred_tokens: List[str]): len_gold = len(gold_tokens) len_pred = len(pred_tokens) gold_bigram = _get_ngrams(2, gold_tokens) pred_bigram = _get_ngrams(2, pred_tokens) if self.rm_stop_word: gold_unigram = set( [x for x in gold_tokens if x not in self.stop_words]) pred_unigram = set( [x for x in pred_tokens if x not in self.stop_words]) else: gold_unigram = set(gold_tokens) pred_unigram = set(pred_tokens) rouge_1 = cal_rouge(pred_unigram, gold_unigram, len_pred, len_gold)['f'] rouge_2 = cal_rouge(pred_bigram, gold_bigram, len_pred, len_gold)['f'] rouge_score = (rouge_1 + rouge_2) / 2 return rouge_score def pre_prune(self, list_of_doc: List[List[str]], ref_sum: List[str]): keep_candidate_num = math.ceil( len(list_of_doc) * self.prune_candidate_percent) # f_score_list = [self.get_approximate_rouge(ref_sum, x) for x in list_of_doc] f_score_list = [ self.get_rouge_ready_to_use(ref_sum, x) for x in list_of_doc ] top_p_sent_idx = numpy.argsort(f_score_list)[-keep_candidate_num:] map_from_new_to_ori_idx = [] # filter filtered_doc_list = [] for i in range(len(top_p_sent_idx)): filtered_doc_list.append(list_of_doc[top_p_sent_idx[i]]) map_from_new_to_ori_idx.append(top_p_sent_idx[i]) return filtered_doc_list, map_from_new_to_ori_idx