def __init__(self, is_train): self.details = {} self.details[Handler.CONTENT_TYPE_TITLE] = { 'times': 0, 'found_times': 0, } if is_train: self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'] = [] self.rogue_eval = RougeL()
def find_best_match_answer(answer, support_para): """ 找到 sub_text 在 content 覆盖度最大的开始和结束下标(细粒度) """ answer = answer.lower() support_para = support_para.lower() if answer in support_para: best_start = support_para.index(answer) best_end = best_start + len(answer) - 1 return best_start, best_end, 1 if (answer.endswith('。') or answer.endswith(';') or answer.endswith(',') or answer.endswith('!')) \ and answer[:-1] in support_para: answer = answer[:-1] best_start = support_para.index(answer) best_end = best_start + len(answer) - 1 return best_start, best_end, 1 # 存在一些标注错误的样本,去掉空字符后才能定位 if answer.replace(' ', '') in support_para: answer = answer.replace(' ', '') best_start = support_para.index(answer) best_end = best_start + len(answer) - 1 return best_start, best_end, 1 # 不能直接定位,利用覆盖率搜索 support_para_chars = set(answer) best_score = 0 best_start = -1 best_end = len(support_para) - 1 for start_idx in range(0, len(support_para)): if support_para[start_idx] not in support_para_chars: continue for end_idx in range(len(support_para) - 1, start_idx - 1, -1): if support_para[end_idx] not in support_para_chars: continue sub_para_content = support_para[start_idx:end_idx + 1] score = RougeL().add_inst(cand=sub_para_content, ref=answer).get_score() if score > best_score: best_score = score best_start = start_idx best_end = end_idx if best_score == 0: return -1, -1, 0 else: return best_start, best_end, best_score
def __init__(self, feature_words_path, stop_words_path, is_train): self.name = 'Basic Templete Handler' with open(feature_words_path, encoding='utf-8') as f: self.feature_words = [line.strip() for line in f.readlines()] with open(stop_words_path, encoding='utf-8') as f: self.stop_words = [line.strip() for line in f.readlines()] self.details = { 'times': 0 } self.details[Handler.CONTENT_TYPE_TITLE] = { 'found_times': 0, } self.details[Handler.CONTENT_TYPE_ARTICLE] = { 'found_times': 0, } if is_train: self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'] = [] self.details[Handler.CONTENT_TYPE_ARTICLE]['rouge_scores'] = [] self.rogue_eval = RougeL()
def _find_golden_span(row, article_tokens_col, question_tokens_col, answer_tokens_col): article_tokens = row[article_tokens_col] # question_tokens = row[question_tokens_col] answer_tokens = row[answer_tokens_col] row['answer_token_start'] = -1 row['answer_token_end'] = -1 row['delta_token_starts'] = [] row['delta_token_ends'] = [] row['delta_rouges'] = [] rl = RougeL() ground_ans = ''.join(answer_tokens).strip() len_p = len(article_tokens) len_a = len(answer_tokens) s2 = set(ground_ans) star_spans = [] end_spans = [] for i in range(len_p - len_a + 1): for t_len in range(len_a - 2, len_a + 3): if t_len == 0 or i + t_len > len_p: continue cand_ans = ''.join(article_tokens[i:i + t_len]).strip() s1 = set(cand_ans) mlen = max(len(s1), len(s2)) iou = len(s1.intersection(s2)) / mlen if mlen != 0 else 0.0 if iou > 0.3: rl.add_inst(cand_ans, ground_ans) star_spans.append(i) end_spans.append(i + t_len - 1) if len(star_spans) == 0: return row else: best_idx = np.argmax(rl.inst_scores) row['answer_token_start'] = star_spans[best_idx] row['answer_token_end'] = end_spans[best_idx] row['delta_token_starts'] = star_spans row['delta_token_ends'] = end_spans row['delta_rouges'] = rl.inst_scores return row
def find_best_match_support_para(support_text, doc_content): """ 利用 support text 长度的窗口滑过 doc_content,计算 rougel 最大的大致位置(粗粒度) """ if support_text in doc_content: best_start = doc_content.index(support_text) best_end = best_start + len(support_text) - 1 return best_start, best_end, 1 if support_text.endswith('。') and support_text[:-1] in doc_content: sub_text = support_text[:-1] best_start = doc_content.index(sub_text) best_end = best_start + len(sub_text) - 1 return best_start, best_end, 1 # 存在一些标注错误的样本,去掉空字符后才能定位 if support_text.replace(' ', '') in doc_content: sub_text = support_text.replace(' ', '') best_start = doc_content.index(sub_text) best_end = best_start + len(sub_text) - 1 return best_start, best_end, 1 support_para_chars = set(support_text) window_len = len(support_text) # doc 和 support 不是严格的可定位 best_score = 0 best_start = -1 best_end = -1 start = 0 while start < len(doc_content) - window_len - 1: while start < len( doc_content) and doc_content[start] not in support_para_chars: start += 1 end = start + window_len sub_content = doc_content[start:end + 1] score = RougeL().add_inst(cand=sub_content, ref=support_text).get_score() if score > best_score: best_score = score best_start = start best_end = end start += 1 if best_score == 0: return -1, -1, 0 else: return best_start, best_end, best_score
def calc_ceil_rougel(answer_text, sample): # 计算抽取的 fake answer 以及对应的 ceil rougel fake_answers = [ sample['documents'][answer_label[0]]['content'] [answer_label[1]:answer_label[2] + 1] for answer_label in sample['answer_labels'] ] sample['fake_answers'] = fake_answers if len(fake_answers) == 0: sample['ceil_rougel'] = 0 else: ceil_rougel = RougeL().add_inst(cand=''.join(fake_answers).lower(), ref=answer_text.lower()).get_score() sample['ceil_rougel'] = ceil_rougel
def _find_golden_span_v2(row, article_tokens_col, question_tokens_col, answer_tokens_col): article_tokens = row[article_tokens_col] question_tokens = row[question_tokens_col] answer_tokens = row[answer_tokens_col] row['answer_token_start'] = -1 row['answer_token_end'] = -1 rl_ans = RougeL() rl_q = RougeL() ground_ans = ''.join(answer_tokens).strip() questrin_str = ''.join(question_tokens).strip() len_p = len(article_tokens) len_a = len(answer_tokens) s2 = set(ground_ans) spans = [] for i in range(len_p - len_a + 1): for t_len in range(len_a - 2, len_a + 3): if t_len == 0 or i + t_len > len_p: continue cand_ans = ''.join(article_tokens[i:i + t_len]).strip() s1 = set(cand_ans) mlen = max(len(s1), len(s2)) iou = len(s1.intersection(s2)) / mlen if mlen != 0 else 0.0 if iou > 0.3: s = max(i - 5, 0) cand_ctx = ''.join(article_tokens[s:i + t_len + 5]).strip() rl_ans.add_inst(cand_ans, ground_ans) rl_q.add_inst(cand_ctx, questrin_str) spans.append([i, i + t_len - 1]) if len(spans) == 0: return row sim_ans = np.array(rl_ans.inst_scores) sim_q = np.array(rl_q.r_scores) total_score = 0.7 * sim_ans + 0.3 * sim_q best_idx = total_score.argmax() row['answer_token_start'] = spans[best_idx][0] row['answer_token_end'] = spans[best_idx][1] return row
def _sample_article(row, article_tokens_col, article_flags_col, question_tokens_col, max_token_len=400): """ :param row: :param article_tokens_col: :param article_flags_col: :param question_tokens_col: :param max_token_len: :return: """ article_tokens = row[article_tokens_col] article_flags = row[article_flags_col] question_tokens = row[question_tokens_col] if len(article_tokens) <= max_token_len: return row sentences, sentences_f = [], [] cur_s, cur_s_f = [], [] question = ''.join(question_tokens) cand, cand_f = [], [] rl = RougeL() for idx, (token, flag) in enumerate(zip(article_tokens, article_flags)): cur_s.append(token) cur_s_f.append(flag) if token in '\001。' or idx == len(article_tokens) - 1: if len(cur_s) >= 2: sentences.append(cur_s) sentences_f.append(cur_s_f) rl.add_inst(''.join(cur_s), question) cur_s, cur_s_f = [], [] continue scores = rl.r_scores s_rank = np.zeros(len(sentences)) arg_sorted = list(reversed(np.argsort(scores))) for i in range(10): if i >= len(sentences): break pos = arg_sorted[i] if pos in [0, 1, len(sentences) - 1]: continue score = scores[pos] nb_score = score fnb_score = 0.5 * score ffnb_score = 0.25 * score block_scores = np.array([fnb_score, nb_score, score, nb_score, fnb_score, ffnb_score]) block = s_rank[pos - 2: pos + 4] block_scores = block_scores[:len(block)] block_scores = np.max(np.stack([block_scores, block]), axis=0) s_rank[pos - 2: pos + 4] = block_scores cand.extend(sentences[0]) cand_f.extend(sentences_f[0]) cand.extend(sentences[1]) cand_f.extend(sentences_f[1]) cand.extend(sentences[-1]) cand_f.extend(sentences_f[-1]) rank = list(reversed(np.argsort(s_rank))) for pos in rank: if pos in [0, 1, len(sentences) - 1]: continue if s_rank[pos] > 0: cand.extend(sentences[pos]) cand_f.extend(sentences_f[pos]) if len(cand) > max_token_len: break else: break row[article_tokens_col] = cand[:max_token_len] row[article_flags_col] = cand_f[:max_token_len] return row
class TopicHandler(): def __init__(self, is_train): self.details = {} self.details[Handler.CONTENT_TYPE_TITLE] = { 'times': 0, 'found_times': 0, } if is_train: self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'] = [] self.rogue_eval = RougeL() def ans_question(self, content, question, question_ans=None): """ 定向回答主旨类的问题 :param content: :param question: :param content_type: 区分article和title (注意:回答主旨时目前默认从title寻找答案) :return: """ topic_key_words = ['主旨', '大意', '内容', '文章说了什么', '介绍了什么'] for key in topic_key_words: if question.find(key) >= 0: if (question.find('文') >= 0 or question.find('报道') > 0) and len(question) <= 12: pred_ans = content if pred_ans.strip().endswith(')'): try: pred_ans = pred_ans[:pred_ans.rindex( '(')] + pred_ans[pred_ans.rindex(')') + 1:] except: pass elif pred_ans.strip().endswith(')'): try: pred_ans = pred_ans[:pred_ans.rindex( '(')] + pred_ans[pred_ans.rindex(')') + 1:] except: pass if question_ans is not None: self.record_found(pred_ans, question_ans) return pred_ans self.record_miss() return None def record_found(self, pred_ans, gt_ans): self.details[Handler.CONTENT_TYPE_TITLE]['times'] += 1 self.details[Handler.CONTENT_TYPE_TITLE]['found_times'] += 1 score = self.rogue_eval.calc_score(pred_ans, gt_ans) self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'].append(score) def record_miss(self): self.details[Handler.CONTENT_TYPE_TITLE]['times'] += 1 def describe(self): """ 解释到目前为止这个handler的执行情况 """ print('===================Topic Handler===================') # print('Content Type:', Handler.CONTENT_TYPE_TITLE) print('Content Type: 【%s】 ' % Handler.CONTENT_TYPE_TITLE) print('Total times:', self.details[Handler.CONTENT_TYPE_TITLE]['times'], end='\t\t') print('Found times:', self.details[Handler.CONTENT_TYPE_TITLE]['found_times'], end='\t\t') score_list = self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'] if len(score_list) == 0: rouge_score = -1 else: rouge_score = sum(score_list) / len(score_list) print('Rouge-L avg score:,', rouge_score) # print('-------------------Topic Handler-------------------') print()
def gen_bridging_entity_mrc_dataset(sample): """ 生成全文本下的针对 bridging_entity 的 MRC 数据集 """ # 根据 support paragraph 找到答案所在的 sub para support_para_in_docids = find_answer_in_docid( sample['supporting_paragraph']) supported_paras = { } # {'support所在doc_id': [{'找到的最匹配的 support para', '最匹配的开始下标', '最匹配的结束下标'}]} for sup_para_in_docid in support_para_in_docids: para_strs = sample['supporting_paragraph'].split( '@content{}@'.format(sup_para_in_docid)) for para_str in para_strs: if para_str != '' and '@content' not in para_str: para_str = para_str.replace( 'content{}@'.format(sup_para_in_docid), '') sup_start, sup_end, rougel = find_best_match_support_para( para_str, sample['documents'][sup_para_in_docid - 1]['content']) found_sup_para = sample['documents'][ sup_para_in_docid - 1]['content'][sup_start:sup_end + 1] # 同一个 doc 可能出现多个support para if sup_para_in_docid in supported_paras: supported_paras[sup_para_in_docid].append( (found_sup_para, sup_start, sup_end)) else: supported_paras[sup_para_in_docid] = [(found_sup_para, sup_start, sup_end)] bridging_entity = sample['bridging_entity'] # 不存在桥接实体的 if bridging_entity is None: sample['bridging_entity_labels'] = [] sample['fake_bridging_entity'] = None sample['ceil_rougel'] = -1 return max_rougel = 0 best_start_in_sup_para = -1 best_end_in_sup_para = -1 best_sup_doc_i = None best_sup_para_i = None bridging_entity_labels = [] for sup_para_in_docid in support_para_in_docids: doc_support_paras = supported_paras[sup_para_in_docid] for sup_para_i, doc_support_para in enumerate(doc_support_paras): start_in_sup_para, end_in_sup_para, rougel = find_best_match_answer( bridging_entity, doc_support_para[0]) if rougel > max_rougel: max_rougel = rougel best_start_in_sup_para = start_in_sup_para best_end_in_sup_para = end_in_sup_para best_sup_doc_i = sup_para_in_docid best_sup_para_i = sup_para_i if best_start_in_sup_para != -1 and best_end_in_sup_para != -1: start_label = best_start_in_sup_para + supported_paras[best_sup_doc_i][ best_sup_para_i][1] end_label = start_label + (best_end_in_sup_para - best_start_in_sup_para) bridging_entity_labels = (best_sup_doc_i - 1, start_label, end_label) sample['bridging_entity_labels'] = bridging_entity_labels if not bridging_entity_labels: sample['fake_bridging_entity'] = '' else: sample['fake_bridging_entity'] = sample['documents'][bridging_entity_labels[0]]['content'] \ [bridging_entity_labels[1]: bridging_entity_labels[2] + 1] if sample['fake_bridging_entity'] == '': sample['ceil_rougel'] = 0 else: ceil_rougel = RougeL().add_inst( cand=sample['fake_bridging_entity'].lower(), ref=bridging_entity.lower()).get_score() sample['ceil_rougel'] = ceil_rougel
def sample_train_content(sample, max_train_content_len, min_left_context_len=100, min_right_context_len=50): """ 对于全长度的训练集,进行 content 的采样,同时利用滑动窗口,保证 content 长度较小的同时保证足够到的覆盖率 Args: max_train_content_len: 截断的 train content 的最大长度 min_left_context_len: 答案左侧 context 的最小长度 min_right_context_len:答案右侧 context 的最小长度 """ al = sample['bridging_entity_labels'] if al: answer_in_docs = {al[0]: (al[1], al[2])} else: answer_in_docs = {} sample['ques_char_pos'] = dense_feature_list(sample['ques_char_pos']) sample['ques_char_kw'] = dense_feature_list(sample['ques_char_kw']) sample['ques_char_in_que'] = dense_feature_list(sample['ques_char_in_que']) sample['ques_char_entity'] = dense_feature_list( sample['ques_char_entity'].split(',')) for doc_id, doc in enumerate(sample['documents']): # 不包含答案的直接截断 if doc_id not in answer_in_docs: # 特征更新 split_features(doc, 0, max_train_content_len) else: # 包含答案的需要根据答案的位置和 max_train_content_len 的关系进行定位 start = answer_in_docs[doc_id][0] end = answer_in_docs[doc_id][1] # 左边 context 的长度稍短,答案从前面截断在前面的 max_train_content_len 内 if end <= max_train_content_len - min_right_context_len: # 特征更新 split_features(doc, 0, max_train_content_len) # 右边 context 的长度稍短,答案从后面截断在后面的 max_train_content_len 内 elif len(doc['content'] ) - start + min_left_context_len <= max_train_content_len: new_ans_start_idx = start - (len(doc['content']) - max_train_content_len) new_ans_end_idx = new_ans_start_idx + (end - start) # 特征更新 split_features(doc, len(doc['content']) - max_train_content_len, len(doc['content'])) # 更新答案下标 sample['bridging_entity_labels'] = (doc_id, new_ans_start_idx, new_ans_end_idx) # 左边右边的长度都比较长,则以答案为基本中心进行截断 else: cut_doc_where_answer_in(sample, doc_id, answer_in_docs, max_train_content_len, min_left_context_len, min_right_context_len) if sample['bridging_entity'] is not None: bridging_entity_labels = sample['bridging_entity_labels'] if len(bridging_entity_labels) > 0: sample['bridging_entity_labels'] = bridging_entity_labels sample['fake_bridging_entity'] = sample['documents'][bridging_entity_labels[0]]['content'] \ [bridging_entity_labels[1]: bridging_entity_labels[2] + 1] else: sample['bridging_entity_labels'] = [] sample['fake_bridging_entity'] = '' if sample['fake_bridging_entity'] == '': sample['ceil_rougel'] = 0 else: ceil_rougel = RougeL().add_inst( cand=sample['fake_bridging_entity'].lower(), ref=sample['bridging_entity'].lower()).get_score() sample['ceil_rougel'] = ceil_rougel else: sample['bridging_entity_labels'] = [] sample['fake_bridging_entity'] = None sample['ceil_rougel'] = -1
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): """ Evaluates the model performance on eval_batches and results are saved if specified Args: eval_batches: iterable batch data result_dir: directory to save predicted answers, answers will not be saved if None result_prefix: prefix of the file for saving predicted answers, answers will not be saved if None save_full_info: if True, the pred_answers will be added to raw sample and saved """ pred_answers, ref_answers = [], [] total_mrl, total_pointer_loss, total_num = 0, 0, 0 rl, bleu = RougeL(), Bleu() ariticle_map = {} for b_itx, batch in enumerate(eval_batches): feed_dict = { self.p_t: batch['article_token_ids'], self.q_t: batch['question_token_ids'], self.p_f: batch['article_flag_ids'], self.q_f: batch['question_flag_ids'], self.p_e: batch['article_elmo_ids'], self.q_e: batch['question_elmo_ids'], self.p_pad_len: batch['article_pad_len'], self.q_pad_len: batch['question_pad_len'], self.p_t_length: batch['article_tokens_len'], self.q_t_length: batch['question_tokens_len'], self.start_label: batch['start_id'], self.end_label: batch['end_id'], self.wiqB: batch['wiqB'], self.qtype_vec: batch['qtype_vecs'], # delta stuff self.delta_starts: batch['delta_token_starts'], self.delta_ends: batch['delta_token_ends'], self.delta_span_idxs: batch['delta_span_idxs'], self.delta_rouges: batch['delta_span_idxs'], self.dropout_keep_prob: 1.0 } if self.use_char_emb: feed_dict.update({ self.p_c: batch['article_char_ids'], self.q_c: batch['question_char_ids'], self.p_c_length: batch['article_c_len'], self.q_c_length: batch['question_c_len'], self.p_CL: batch['article_CL'], self.q_CL: batch['question_CL'] }) pred_starts, pred_ends, mrl, pointer_loss = self.sess.run([ self.pred_starts, self.pred_ends, self.mrl, self.pointer_loss ], feed_dict) batch_size = len(batch['raw_data']) total_mrl += mrl * batch_size total_pointer_loss += pointer_loss * batch_size total_num += batch_size for sample, best_start, best_end in zip(batch['raw_data'], pred_starts, pred_ends): best_answer = ''.join( sample['article_tokens'][best_start:best_end + 1]) if sample['article_id'] not in ariticle_map: ariticle_map[sample['article_id']] = len(ariticle_map) pred_answers.append({ 'article_id': sample['article_id'], 'questions': [] }) ref_answers.append({ 'article_id': sample['article_id'], 'questions': [] }) pred_answers[ariticle_map[ sample['article_id']]]['questions'].append({ 'question_id': sample['question_id'], 'answer': best_answer }) ref_answers[ariticle_map[ sample['article_id']]]['questions'].append({ 'question_id': sample['question_id'], 'answer': sample['answer'] }) rl.add_inst(best_answer, sample['answer']) bleu.add_inst(best_answer, sample['answer']) # compute the bleu and rouge scores rougel = rl.get_score() bleu4 = bleu.get_score() bleu_rouge = {'Rouge-L': rougel, 'Bleu-4': bleu4} if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: # for pred_answer in pred_answers: # fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n') json.dump(pred_answers, fout, ensure_ascii=False) self.logger.info('Saving {} results to {}'.format( result_prefix, result_file)) # this average loss is invalid on test set, since we don't have true start_id and end_id ave_mrl = 1.0 * total_mrl / total_num ave_pointer_loss = 1.0 * total_pointer_loss / total_num return ave_mrl, ave_pointer_loss, bleu_rouge
class BasicTempleteHandler(object): def __init__(self, feature_words_path, stop_words_path, is_train): self.name = 'Basic Templete Handler' with open(feature_words_path, encoding='utf-8') as f: self.feature_words = [line.strip() for line in f.readlines()] with open(stop_words_path, encoding='utf-8') as f: self.stop_words = [line.strip() for line in f.readlines()] self.details = { 'times': 0 } self.details[Handler.CONTENT_TYPE_TITLE] = { 'found_times': 0, } self.details[Handler.CONTENT_TYPE_ARTICLE] = { 'found_times': 0, } if is_train: self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'] = [] self.details[Handler.CONTENT_TYPE_ARTICLE]['rouge_scores'] = [] self.rogue_eval = RougeL() def ans_question(self, title, article, question, question_ans=None): self.details['times'] += 1 templetes = identify_templete(question, self.feature_words, self.stop_words) for templete in templetes: found_ans = match_content(templete, title, self.feature_words) if found_ans is not None: if question_ans is not None: self.record_found(Handler.CONTENT_TYPE_TITLE, found_ans, question_ans) return found_ans found_ans = match_content(templete, article, self.feature_words) if found_ans is not None: if question_ans is not None: self.record_found(Handler.CONTENT_TYPE_ARTICLE, found_ans, question_ans) return found_ans return None def record_found(self, content_type, pred_ans, gt_ans): self.details[content_type]['found_times'] += 1 score = self.rogue_eval.calc_score(pred_ans, gt_ans) self.details[content_type]['rouge_scores'].append(score) def describe(self): title_score_list = self.details[Handler.CONTENT_TYPE_TITLE]['rouge_scores'] article_score_list = self.details[Handler.CONTENT_TYPE_ARTICLE]['rouge_scores'] print('===================%s===================' % self.name) print('Total times:', self.details['times'], end='\t\t') print('Total found times:', self.details[Handler.CONTENT_TYPE_TITLE]['found_times'] + self.details[Handler.CONTENT_TYPE_ARTICLE]['found_times'], end='\t\t') print('Total Rouge-L avg score:', _get_rouge_avg_socre(title_score_list + article_score_list)) print('Content Type: 【%s】 ' % Handler.CONTENT_TYPE_TITLE) print('Found times:', self.details[Handler.CONTENT_TYPE_TITLE]['found_times'], end='\t\t') print('Rouge-L avg score:,', _get_rouge_avg_socre(title_score_list)) print('Content Type: 【%s】 ' % Handler.CONTENT_TYPE_ARTICLE) print('Found times:', self.details[Handler.CONTENT_TYPE_ARTICLE]['found_times'], end='\t\t') print('Rouge-L avg score:,', _get_rouge_avg_socre(article_score_list)) print()