def prediction_score_analysis_adaptive_threshold(raw_data, predictions, prediction_scores, at_thresholds): def positive_neg_score(scores, mask, names, gold_names, threshold_score, pred_names): assert len(scores) == len(mask) mask_sum_num = int(sum(mask)) prune_names = names[:mask_sum_num] gold_name_set = set(gold_names) if (gold_name_set.issubset(set(prune_names))): flag = True else: flag = False positive_scores = [] negative_scores = [] for idx in range(mask_sum_num): name_i = prune_names[idx] if name_i in gold_name_set: positive_scores.append(scores[idx]) else: negative_scores.append(scores[idx]) if len(positive_scores) > 0: min_positive = min(positive_scores) else: min_positive = 0.0 if len(negative_scores) == 0: max_negative = 1.0 else: max_negative = max(negative_scores) num_candidates = mask_sum_num num_golds = len(gold_name_set) min_p_names = [] max_n_names = [] threshold_names = [] for i in range(mask_sum_num): if scores[i] >= min_positive: min_p_names.append(names[i]) if scores[i] > max_negative: max_n_names.append(names[i]) if threshold_score > 0.45: threshold_score = 0.45 else: threshold_score = 0.35 if sp_scores[i] > threshold_score: threshold_names.append(names[i]) return flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names, threshold_names threshold_metric_dict = {} threshold_metric_dict['pred'] = [] threshold_metric_dict['min_p'] = [] threshold_metric_dict['max_n'] = [] threshold_metric_dict['threshold_n'] = [] prune_gold_num = 0 analysis_result_list = [] for row in raw_data: qid = row['_id'] threshold_score = at_thresholds[qid] question_type = row['type'] answer_type = row['answer'] if answer_type.strip().lower() not in ['yes', 'no']: answer_type = 'span' sp_predictions = predictions['sp'][qid] sp_predictions = [(x[0], x[1]) for x in sp_predictions] sp_para_predictions = list(set([x[0] for x in sp_predictions])) sp_golds = row['supporting_facts'] sp_golds = [(x[0], x[1]) for x in sp_golds] sp_para_golds = list(set([_[0] for _ in sp_golds])) if qid == '5a8a4a4055429930ff3c0d77': print(sp_predictions) print(sp_golds) res_scores = prediction_scores[qid] sp_scores = res_scores['sp_score'] sp_mask = res_scores['sp_mask'] sp_names = res_scores['sp_names'] sp_names = [(x[0], x[1]) for x in sp_names] flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names, threshold_names = \ positive_neg_score(scores=sp_scores, mask=sp_mask, names=sp_names, gold_names=sp_golds, threshold_score=threshold_score, pred_names=sp_predictions) ans_prediction = predictions['answer'][qid] raw_answer = row['answer'] raw_answer = normalize_answer(raw_answer) ans_prediction = normalize_answer(ans_prediction) ans_metrics = update_answer(prediction=ans_prediction, gold=raw_answer) predict_metrics = update_sp(prediction=sp_predictions, gold=sp_golds) threshold_metric_dict['pred'].append((ans_metrics, predict_metrics)) min_p_metrics = update_sp(prediction=min_p_names, gold=sp_golds) threshold_metric_dict['min_p'].append((ans_metrics, min_p_metrics)) max_n_metrics = update_sp(prediction=max_n_names, gold=sp_golds) threshold_metric_dict['max_n'].append((ans_metrics, max_n_metrics)) threshold_metrics = update_sp(prediction=threshold_names, gold=sp_golds) threshold_metric_dict['threshold_n'].append( (ans_metrics, threshold_metrics)) if not flag: prune_gold_num += 1 sp_sent_type = set_comparison(prediction_list=sp_predictions, true_list=sp_golds) # for key, value in sp_scores.items(): # print(key, value) # print('{}\t{}\t{}\t{:.5f}\t{:.5f}'.format(question_type, sp_sent_type, flag, min_positive, max_negative)) analysis_result_list.append( (qid, question_type, sp_sent_type, flag, min_positive, max_negative, threshold_score, num_candidates, num_golds, answer_type)) for key, value in threshold_metric_dict.items(): print('threshold type = {}'.format(key)) answer_em, answer_prec, answer_recall, answer_f1 = 0.0, 0.0, 0.0, 0.0 sp_em, sp_prec, sp_recall, sp_f1 = 0.0, 0.0, 0.0, 0.0 type_count = len(value) all_joint_em, all_joint_f1 = 0.0, 0.0 for ans_tup, sp_tup in value: answer_em += ans_tup[0] answer_prec += ans_tup[1] answer_recall += ans_tup[2] answer_f1 += ans_tup[3] sp_em += sp_tup[0] sp_prec += sp_tup[1] sp_recall += sp_tup[2] sp_f1 += sp_tup[3] joint_prec = ans_tup[1] * sp_tup[1] joint_recall = ans_tup[2] * sp_tup[2] if joint_prec + joint_recall > 0: joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) else: joint_f1 = 0. joint_em = ans_tup[0] * sp_tup[0] all_joint_f1 += joint_f1 all_joint_em += joint_em print('ans {}\t{}\t{}\t{}'.format(answer_em / type_count, answer_recall / type_count, answer_prec / type_count, answer_f1 / type_count)) print('sup {}\t{}\t{}\t{}'.format(sp_em / type_count, sp_recall / type_count, sp_prec / type_count, sp_f1 / type_count)) print('joint em ', all_joint_em / type_count) print('joint f1 ', all_joint_f1 / type_count) df = pd.DataFrame(analysis_result_list, columns=[ 'id', 'q_type', 'sp_sent_type', 'flag', 'min_p', 'max_n', 'threshold', 'cand_num', 'gold_num', 'ans_type' ]) print('prune = {}, complete = {}'.format(prune_gold_num, len(raw_data) - prune_gold_num)) return df
def error_analysis_question_type(raw_data, predictions, tokenizer, use_ent_ans=False): type_metric_dict = {} for row in raw_data: question_type = row['type'] if question_type not in type_metric_dict: type_metric_dict[question_type] = [] qid = row['_id'] sp_predictions = predictions['sp'][qid] sp_predictions = [(x[0], x[1]) for x in sp_predictions] sp_golds = row['supporting_facts'] sp_golds = [(x[0], x[1]) for x in sp_golds] sp_metrics = update_sp(prediction=sp_predictions, gold=sp_golds) if qid == '5add114a5542994734353826': for x in row['context']: print('title ', x[0]) for y_idx, y in enumerate(x[1]): print('sentence', y_idx, y) ans_prediction = predictions['answer'][qid] raw_answer = row['answer'] raw_answer = normalize_answer(raw_answer) ans_prediction = normalize_answer(ans_prediction) ans_metrics = update_answer(prediction=ans_prediction, gold=raw_answer) type_metric_dict[question_type].append((ans_metrics, sp_metrics)) for key, value in type_metric_dict.items(): print('question type = {}'.format(key)) answer_em, answer_prec, answer_recall, answer_f1 = 0.0, 0.0, 0.0, 0.0 sp_em, sp_prec, sp_recall, sp_f1 = 0.0, 0.0, 0.0, 0.0 type_count = len(value) all_joint_em, all_joint_f1 = 0.0, 0.0 for ans_tup, sp_tup in value: answer_em += ans_tup[0] answer_prec += ans_tup[1] answer_recall += ans_tup[2] answer_f1 += ans_tup[3] sp_em += sp_tup[0] sp_prec += sp_tup[1] sp_recall += sp_tup[2] sp_f1 += sp_tup[3] joint_prec = ans_tup[1] * sp_tup[1] joint_recall = ans_tup[2] * sp_tup[2] if joint_prec + joint_recall > 0: joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) else: joint_f1 = 0. joint_em = ans_tup[0] * sp_tup[0] all_joint_f1 += joint_f1 all_joint_em += joint_em print('ans {}\t{}\t{}\t{}'.format(answer_em / type_count, answer_recall / type_count, answer_prec / type_count, answer_f1 / type_count)) print('sup {}\t{}\t{}\t{}'.format(sp_em / type_count, sp_recall / type_count, sp_prec / type_count, sp_f1 / type_count)) print('joint em ', all_joint_em / type_count) print('joint f1 ', all_joint_f1 / type_count)
def prediction_score_gap_train_analysis(raw_data, predictions, prediction_scores, train_type=None): def score_gap_split(scores, mask, names): assert len(scores) == len(mask) mask_sum_num = int(sum(mask)) prune_names = names[:mask_sum_num] prune_scores = np.array(scores[:mask_sum_num]) sorted_idxes = np.argsort(prune_scores)[::-1] largest_gap = -1 max_gap_idx = -1 for i in range(1, mask_sum_num - 1): gap = prune_scores[sorted_idxes[i]] - prune_scores[sorted_idxes[i + 1]] if gap > largest_gap: largest_gap = gap max_gap_idx = i pred_idxes = sorted_idxes[:(max_gap_idx + 1)] gap_names = [prune_names[_] for _ in pred_idxes] return gap_names def positive_neg_score(scores, mask, names, gold_names, pred_names): assert len(scores) == len(mask) mask_sum_num = int(sum(mask)) prune_names = names[:mask_sum_num] gold_name_set = set(gold_names) if (gold_name_set.issubset(set(prune_names))): flag = True else: flag = False positive_scores = [] negative_scores = [] for idx in range(mask_sum_num): name_i = prune_names[idx] if name_i in gold_name_set: positive_scores.append(scores[idx]) else: negative_scores.append(scores[idx]) if len(positive_scores) > 0: min_positive = min(positive_scores) else: min_positive = 0.0 if len(negative_scores) == 0: max_negative = 1.0 else: max_negative = max(negative_scores) num_candidates = mask_sum_num num_golds = len(gold_name_set) min_p_names = [] max_n_names = [] for i in range(mask_sum_num): if scores[i] >= min_positive: min_p_names.append(names[i]) if scores[i] > max_negative: max_n_names.append(names[i]) return flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names threshold_metric_dict = {} threshold_metric_dict['pred'] = [] threshold_metric_dict['min_p'] = [] threshold_metric_dict['max_n'] = [] threshold_metric_dict['gap'] = [] prune_gold_num = 0 analysis_result_list = [] # print(predictions['sp']) for row in raw_data: qid = row['_id'] question_type = row['type'] answer_type = row['answer'] if train_type is not None: qid = qid + '_' + train_type # print(qid) if answer_type.strip().lower() not in ['yes', 'no']: answer_type = 'span' if qid not in predictions['sp']: continue sp_predictions = predictions['sp'][qid] sp_predictions = [(x[0], x[1]) for x in sp_predictions] sp_para_predictions = list(set([x[0] for x in sp_predictions])) sp_golds = row['supporting_facts'] sp_golds = [(x[0], x[1]) for x in sp_golds] sp_para_golds = list(set([_[0] for _ in sp_golds])) res_scores = prediction_scores[qid] sp_scores = res_scores['sp_score'] sp_mask = res_scores['sp_mask'] sp_names = res_scores['sp_names'] sp_names = [(x[0], x[1]) for x in sp_names] flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names = \ positive_neg_score(scores=sp_scores, mask=sp_mask, names=sp_names, gold_names=sp_golds, pred_names=sp_predictions) ##++++ gap_names = score_gap_split(scores=sp_scores, mask=sp_mask, names=sp_names) ##++++ ans_prediction = predictions['answer'][qid] raw_answer = row['answer'] raw_answer = normalize_answer(raw_answer) ans_prediction = normalize_answer(ans_prediction) ans_metrics = update_answer(prediction=ans_prediction, gold=raw_answer) predict_metrics = update_sp(prediction=sp_predictions, gold=sp_golds) threshold_metric_dict['pred'].append((ans_metrics, predict_metrics)) min_p_metrics = update_sp(prediction=min_p_names, gold=sp_golds) threshold_metric_dict['min_p'].append((ans_metrics, min_p_metrics)) max_n_metrics = update_sp(prediction=max_n_names, gold=sp_golds) threshold_metric_dict['max_n'].append((ans_metrics, max_n_metrics)) gap_metrics = update_sp(prediction=gap_names, gold=sp_golds) threshold_metric_dict['gap'].append((ans_metrics, gap_metrics)) if not flag: prune_gold_num += 1 sp_sent_type = set_comparison(prediction_list=sp_predictions, true_list=sp_golds) # for key, value in sp_scores.items(): # print(key, value) # print('{}\t{}\t{}\t{:.5f}\t{:.5f}'.format(question_type, sp_sent_type, flag, min_positive, max_negative)) analysis_result_list.append( (qid, question_type, sp_sent_type, flag, min_positive, max_negative, num_candidates, num_golds, answer_type)) for key, value in threshold_metric_dict.items(): print('threshold type = {}'.format(key)) answer_em, answer_prec, answer_recall, answer_f1 = 0.0, 0.0, 0.0, 0.0 sp_em, sp_prec, sp_recall, sp_f1 = 0.0, 0.0, 0.0, 0.0 type_count = len(value) all_joint_em, all_joint_f1 = 0.0, 0.0 for ans_tup, sp_tup in value: answer_em += ans_tup[0] answer_prec += ans_tup[1] answer_recall += ans_tup[2] answer_f1 += ans_tup[3] sp_em += sp_tup[0] sp_prec += sp_tup[1] sp_recall += sp_tup[2] sp_f1 += sp_tup[3] joint_prec = ans_tup[1] * sp_tup[1] joint_recall = ans_tup[2] * sp_tup[2] if joint_prec + joint_recall > 0: joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) else: joint_f1 = 0. joint_em = ans_tup[0] * sp_tup[0] all_joint_f1 += joint_f1 all_joint_em += joint_em print('ans\t{}\t{}\t{}\t{}'.format(answer_em / type_count, answer_recall / type_count, answer_prec / type_count, answer_f1 / type_count)) print('sup\t{}\t{}\t{}\t{}'.format(sp_em / type_count, sp_recall / type_count, sp_prec / type_count, sp_f1 / type_count)) print('joint_em\t', all_joint_em / type_count) print('joint_f1\t', all_joint_f1 / type_count) df = pd.DataFrame(analysis_result_list, columns=[ 'id', 'q_type', 'sp_sent_type', 'flag', 'min_p', 'max_n', 'cand_num', 'gold_num', 'ans_type' ]) print('prune = {}, complete = {}'.format(prune_gold_num, len(raw_data) - prune_gold_num)) return df
def error_analysis(raw_data, predictions, tokenizer, use_ent_ans=False): yes_no_span_predictions = [] yes_no_span_true = [] prediction_ans_type_counter = Counter() prediction_sent_type_counter = Counter() prediction_para_type_counter = Counter() pred_ans_type_list = [] pred_sent_type_list = [] pred_doc_type_list = [] pred_sent_count_list = [] pred_para_count_list = [] ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for row in raw_data: qid = row['_id'] sp_predictions = predictions['sp'][qid] sp_predictions = [(x[0], x[1]) for x in sp_predictions] ans_prediction = predictions['answer'][qid] raw_answer = row['answer'] raw_answer = normalize_answer(raw_answer) ans_prediction = normalize_answer(ans_prediction) sp_golds = row['supporting_facts'] sp_golds = [(x[0], x[1]) for x in sp_golds] sp_para_golds = list(set([_[0] for _ in sp_golds])) ##+++++++++++ # sp_predictions = [x for x in sp_predictions if x[0] in sp_para_golds] # sp_predictions print("{}\t{}\t{}".format(qid, len(set(sp_golds)), len(set(sp_predictions)))) sp_para_predictions = list(set([x[0] for x in sp_predictions])) pred_para_count_list.append(len(sp_para_predictions)) # +++++++++++ if len(set(sp_golds)) > len(set(sp_predictions)): pred_sent_count_list.append('less') elif len(set(sp_golds)) < len(set(sp_predictions)): pred_sent_count_list.append('more') else: pred_sent_count_list.append('equal') ##+++++++++++ sp_sent_type = set_comparison(prediction_list=sp_predictions, true_list=sp_golds) ###+++++++++ prediction_sent_type_counter[sp_sent_type] += 1 pred_sent_type_list.append(sp_sent_type) ###+++++++++ sp_para_preds = list(set([_[0] for _ in sp_predictions])) para_type = set_comparison(prediction_list=sp_para_preds, true_list=sp_para_golds) prediction_para_type_counter[para_type] += 1 pred_doc_type_list.append(para_type) ###+++++++++ if raw_answer not in ['yes', 'no']: yes_no_span_true.append('span') else: yes_no_span_true.append(raw_answer) if ans_prediction not in ['yes', 'no']: yes_no_span_predictions.append('span') else: yes_no_span_predictions.append(ans_prediction) ans_type = 'em' if raw_answer not in ['yes', 'no']: if raw_answer == ans_prediction: ans_type = 'em' elif raw_answer in ans_prediction: # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction)) # print('-'*75) ans_type = 'super_of_gold' elif ans_prediction in raw_answer: # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction)) # print('-'*75) ans_type = 'sub_of_gold' else: ans_pred_tokens = ans_prediction.split(' ') ans_raw_tokens = raw_answer.split(' ') is_empty_set = len( set(ans_pred_tokens).intersection( set(ans_raw_tokens))) == 0 if is_empty_set: ans_type = 'no_over_lap' else: ans_type = 'others' else: if raw_answer == ans_prediction: ans_type = 'em' else: ans_type = 'others' prediction_ans_type_counter[ans_type] += 1 pred_ans_type_list.append(ans_type) # print('{} | {} | {}'.format(ans_type, raw_answer, ans_prediction)) print(len(pred_sent_type_list), len(pred_ans_type_list), len(pred_doc_type_list)) supp_sent_compare_type = ['equal', 'less', 'more'] result_types = [ 'em', 'sub_of_gold', 'super_of_gold', 'no_over_lap', 'others' ] supp_sent_comp_dict = dict([(y, x) for x, y in enumerate(supp_sent_compare_type)]) supp_sent_type_dict = dict([(y, x) for x, y in enumerate(result_types)]) assert len(pred_sent_type_list) == len(pred_sent_count_list) print(len(pred_sent_type_list), len(pred_sent_count_list)) conf_supp_sent_matrix = np.zeros( (len(supp_sent_compare_type), len(result_types)), dtype=np.long) for idx in range(len(pred_sent_type_list)): comp_type_i = pred_sent_count_list[idx] supp_sent_type_i = pred_sent_type_list[idx] comp_idx_i = supp_sent_comp_dict[comp_type_i] supp_sent_idx_i = supp_sent_type_dict[supp_sent_type_i] conf_supp_sent_matrix[comp_idx_i][supp_sent_idx_i] += 1 print('Sent Type vs Sent Count conf matrix:\n{}'.format( conf_supp_sent_matrix)) print('Sum of matrix = {}'.format(conf_supp_sent_matrix.sum())) conf_matrix = confusion_matrix(yes_no_span_true, yes_no_span_predictions, labels=["yes", "no", "span"]) conf_ans_sent_matrix = confusion_matrix(pred_sent_type_list, pred_ans_type_list, labels=result_types) print('*' * 75) print('Ans type conf matrix:\n{}'.format(conf_matrix)) print('*' * 75) print('Sent vs ans conf matrix:\n{}'.format(conf_ans_sent_matrix)) print('*' * 75) print("Ans prediction type: {}".format(prediction_ans_type_counter)) print("Sent prediction type: {}".format(prediction_sent_type_counter)) print("Para prediction type: {}".format(prediction_para_type_counter)) print('*' * 75) conf_matrix_para_vs_sent = confusion_matrix(pred_doc_type_list, pred_sent_type_list, labels=result_types) print('Para Type vs Sent Type conf matrix:\n{}'.format( conf_matrix_para_vs_sent)) print('*' * 75) conf_matrix_para_vs_ans = confusion_matrix(pred_doc_type_list, pred_ans_type_list, labels=result_types) print('Para Type vs ans Type conf matrix:\n{}'.format( conf_matrix_para_vs_ans)) para_counter = Counter(pred_para_count_list) print('Para counter : {}'.format(para_counter))
def read_hotpot_examples(para_file, full_file, ner_file, doc_link_file): with open(para_file, 'r', encoding='utf-8') as reader: para_data = json.load(reader) with open(full_file, 'r', encoding='utf-8') as reader: full_data = json.load(reader) with open(ner_file, 'r', encoding='utf-8') as reader: ner_data = json.load(reader) with open(doc_link_file, 'r', encoding='utf-8') as reader: doc_link_data = json.load(reader) def split_sent(sent, offset=0): nlp_doc = nlp(sent) words, word_start_idx, char_to_word_offset = [], [], [] for token in nlp_doc: # token match a-b, then split further words.append(token.text) word_start_idx.append(token.idx) word_offset = 0 for c in range(len(sent)): if word_offset >= len(word_start_idx) - 1 or c < word_start_idx[ word_offset + 1]: char_to_word_offset.append(word_offset + offset) else: char_to_word_offset.append(word_offset + offset + 1) word_offset += 1 return words, char_to_word_offset, word_start_idx max_sent_cnt, max_entity_cnt = 0, 0 examples = [] for case in tqdm(full_data): key = case['_id'] qas_type = case['type'] sup_facts = set([(sp[0], sp[1]) for sp in case['supporting_facts']]) context = dict(case['context']) doc_tokens = [] sent_names = [] sup_facts_sent_id = [] sup_para_id = set() sent_start_end_position = [] para_start_end_position = [] ques_entity_start_end_position = [] ques_entities_text = [] ctx_entity_start_end_position = [] ctx_entities_text = [] ctx_text = "" ans_start_position, ans_end_position = [], [] ques_answer_ids, ctx_answer_ids = [], [] title_to_id, title_id = {}, 0 sent_to_id, sent_id = {}, 0 s_e_edges = [] s_s_edges = [] p_s_edges = [] ctx_answer_candidates = [] ctx_char_to_word_offset = [] # Accumulated along all sentences ctx_word_to_char_idx = [] # process question entity span question_text = case['question'] question_tokens, ques_char_to_word_offset, ques_word_to_char_idx = split_sent( question_text) answer_norm = normalize_answer(case['answer']) q_e_edges = [] for q_ent, q_start, q_end, q_type in ner_data[key]['question']: q_ent_text = question_text[q_start:q_end] if q_type != 'CONTEXT' and q_ent_text not in ques_entities_text: if len(ques_answer_ids) == 0 and normalize_answer( q_ent_text) == answer_norm: ques_answer_ids.append(len(ques_entities_text)) ques_entities_text.append(q_ent_text) q_e_edges.append((0, len(ques_entity_start_end_position) )) # Q -> P; the id of Q is 0 ques_entity_start_end_position.append( (ques_char_to_word_offset[q_start], ques_char_to_word_offset[q_end - 1])) sel_paras = para_data[key] ner_context = dict(ner_data[key]['context']) for title in itertools.chain.from_iterable(sel_paras): stripped_title = re.sub(r' \(.*?\)$', '', title) stripped_title_norm = normalize_answer(stripped_title) sents = context[title] sents_ner = ner_context[title] assert len(sents) == len(sents_ner) title_to_id[title] = title_id para_start_position = len(doc_tokens) prev_sent_id = None ctx_answer_set = set() for local_sent_id, (sent, sent_ner) in enumerate(zip(sents, sents_ner)): # Determine the global sent id for supporting facts local_sent_name = (title, local_sent_id) sent_to_id[local_sent_name] = sent_id sent_names.append(local_sent_name) # P -> S p_s_edges.append((title_id, sent_id)) if prev_sent_id is not None: # S -> S s_s_edges.append((prev_sent_id, sent_id)) sent += " " ctx_text += sent sent_start_word_id = len(doc_tokens) sent_start_char_id = len(ctx_char_to_word_offset) prev_is_whitespace = True cur_sent_words, cur_sent_char_to_word_offset, cur_sent_words_start_idx = split_sent( sent, offset=len(doc_tokens)) doc_tokens.extend(cur_sent_words) ctx_char_to_word_offset.extend(cur_sent_char_to_word_offset) for cur_sent_word in cur_sent_words_start_idx: ctx_word_to_char_idx.append(sent_start_char_id + cur_sent_word) assert len(doc_tokens) == len(ctx_word_to_char_idx) sent_start_end_position.append( (sent_start_word_id, len(doc_tokens) - 1)) for sent_ner_id, (_, ent_start_char, ent_end_char, _) in enumerate(sent_ner): if (ent_start_char, ent_end_char) in ctx_answer_set: continue s_ent_text = sent[ent_start_char:ent_end_char] s_ent_text_norm = normalize_answer(s_ent_text) if s_ent_text_norm == stripped_title_norm: ctx_answer_candidates.append(len(ctx_entities_text)) if local_sent_name in sup_facts: if len(ctx_answer_ids ) == 0 and s_ent_text_norm == answer_norm: ctx_answer_ids.append(len(ctx_entities_text)) ctx_entities_text.append(s_ent_text) s_e_edges.append( (sent_id, len(ctx_entity_start_end_position))) ctx_entity_start_end_position.append( (ctx_char_to_word_offset[sent_start_char_id + ent_start_char], ctx_char_to_word_offset[sent_start_char_id + ent_end_char - 1])) ctx_answer_set.add((ent_start_char, ent_end_char)) # Find answer position if local_sent_name in sup_facts: sup_para_id.add(title_id) sup_facts_sent_id.append(sent_id) answer_offsets = [] # find word offset for cur_word_start_idx in cur_sent_words_start_idx: if sent[cur_word_start_idx:cur_word_start_idx + len(case['answer'])] == case['answer']: answer_offsets.append(cur_word_start_idx) if len(answer_offsets) == 0: answer_offset = sent.find(case['answer']) if answer_offset != -1: answer_offsets.append(answer_offset) if case['answer'] not in ['yes', 'no' ] and len(answer_offsets) > 0: for answer_offset in answer_offsets: start_char_position = sent_start_char_id + answer_offset end_char_position = start_char_position + len( case['answer']) - 1 ans_start_position.append( ctx_char_to_word_offset[start_char_position]) ans_end_position.append( ctx_char_to_word_offset[end_char_position]) prev_sent_id = sent_id sent_id += 1 para_end_position = len(doc_tokens) - 1 para_start_end_position.append( (para_start_position, para_end_position, title)) title_id += 1 p_p_edges = [] s_p_edges = [] for _l in sel_paras[0]: for _r in sel_paras[1]: # edges: P -> P p_p_edges.append((title_to_id[_l], title_to_id[_r])) # edges: S -> P for local_sent_id, link_titles in enumerate( doc_link_data[_l]['hyperlink_titles']): inter_titles = set(link_titles) & set(title_to_id.keys()) if len(inter_titles) > 0 and _r in inter_titles: s_p_edges.append( (sent_to_id[(_l, local_sent_id)], title_to_id[_r])) q_p_edges = [(0, title_to_id[para]) for para in sel_paras[0]] edges = { 'ques_para': q_p_edges, 'para_para': p_p_edges, 'sent_sent': s_s_edges, 'para_sent': p_s_edges, 'sent_para': s_p_edges, 'ques_ent': q_e_edges, 'sent_ent': s_e_edges } max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position)) max_entity_cnt = max(max_entity_cnt, len(ctx_entity_start_end_position)) if len(ans_start_position) > 1: # take the exact match for answer to avoid case of partial match start_position, end_position = [], [] for _start_pos, _end_pos in zip(ans_start_position, ans_end_position): if normalize_answer(" ".join( doc_tokens[_start_pos:_end_pos + 1])) == normalize_answer(case['answer']): start_position.append(_start_pos) end_position.append(_end_pos)
def read_hotpot_examples(para_file, full_file, ner_file, doc_link_file, data_source_type=None): with open(para_file, 'r', encoding='utf-8') as reader: para_data = json.load(reader) with open(full_file, 'r', encoding='utf-8') as reader: full_data = json.load(reader) with open(ner_file, 'r', encoding='utf-8') as reader: ner_data = json.load(reader) with open(doc_link_file, 'r', encoding='utf-8') as reader: doc_link_data = json.load(reader) def split_sent(sent, offset=0): nlp_doc = nlp(sent) words, word_start_idx, char_to_word_offset = [], [], [] for token in nlp_doc: # token match a-b, then split further words.append(token.text) word_start_idx.append(token.idx) word_offset = 0 for c in range(len(sent)): if word_offset >= len(word_start_idx) - 1 or c < word_start_idx[ word_offset + 1]: char_to_word_offset.append(word_offset + offset) else: char_to_word_offset.append(word_offset + offset + 1) word_offset += 1 return words, char_to_word_offset, word_start_idx max_sent_cnt, max_entity_cnt = 0, 0 examples = [] for case in tqdm(full_data): key = case['_id'] qas_type = case['type'] sup_facts = set([(sp[0], sp[1]) for sp in case['supporting_facts']]) context = dict(case['context']) doc_tokens = [] ## spacy tokenized results sent_names = [] ## list of (title and local index) sup_facts_sent_id = [ ] ## send_id (absolute sent index, index in the concat text) sup_para_id = set() ## support paragraph ids --> for para ranking sent_start_end_position = [ ] ## list of tuple (start and end) positions para_start_end_position = [] ## list of tuple (start, end, title) ques_entity_start_end_position = [ ] ## entity position pair (start, end) in the question ques_entities_text = [] ## question entity ctx_entity_start_end_position = [ ] ## entity position pair (start, end) in the context ctx_entities_text = [] ## context entities ctx_text = "" ## ctx text information ans_start_position, ans_end_position = [], [ ] ## ans_start position, ans_end position ques_answer_ids, ctx_answer_ids = [], [] ## title_to_id, title_id = {}, 0 sent_to_id, sent_id = {}, 0 s_e_edges = [] ### 1) sentence2entity edges s_s_edges = [] ### 2) sentence2sentence edges (in single paragraph) p_s_edges = [] ### 3) para2sentence ctx_answer_candidates = [] ctx_char_to_word_offset = [] # Accumulated along all sentences ctx_word_to_char_idx = [] # process question entity span question_text = case['question'] question_tokens, ques_char_to_word_offset, ques_word_to_char_idx = split_sent( question_text) answer_norm = normalize_answer(case['answer']) q_e_edges = [] ### 4) question2entity edges for q_ent, q_start, q_end, q_type in ner_data[key]['question']: q_ent_text = question_text[q_start:q_end] if q_type != 'CONTEXT' and q_ent_text not in ques_entities_text: if len(ques_answer_ids) == 0 and normalize_answer( q_ent_text) == answer_norm: ques_answer_ids.append(len(ques_entities_text)) ques_entities_text.append(q_ent_text) q_e_edges.append((0, len(ques_entity_start_end_position) )) # Q -> P; the id of Q is 0 ques_entity_start_end_position.append( (ques_char_to_word_offset[q_start], ques_char_to_word_offset[q_end - 1])) sel_paras = para_data[key] ner_context = dict(ner_data[key]['context']) ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ para_names = [] ## for paragraph evaluation and checking ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for title in itertools.chain.from_iterable(sel_paras): stripped_title = re.sub(r' \(.*?\)$', '', title) stripped_title_norm = normalize_answer(stripped_title) ####+++++++++++++++++++++++++ para_names.append(title) ####+++++++++++++++++++++++++ sents = context[title] sents_ner = ner_context[title] assert len(sents) == len(sents_ner) title_to_id[title] = title_id para_start_position = len(doc_tokens) prev_sent_id = None ctx_answer_set = set() for local_sent_id, (sent, sent_ner) in enumerate(zip(sents, sents_ner)): # Determine the global sent id for supporting facts local_sent_name = (title, local_sent_id) sent_to_id[local_sent_name] = sent_id sent_names.append(local_sent_name) # P -> S p_s_edges.append((title_id, sent_id)) ### if prev_sent_id is not None: # S -> S s_s_edges.append((prev_sent_id, sent_id)) sent += " " ctx_text += sent sent_start_word_id = len(doc_tokens) sent_start_char_id = len(ctx_char_to_word_offset) prev_is_whitespace = True cur_sent_words, cur_sent_char_to_word_offset, cur_sent_words_start_idx = split_sent( sent, offset=len(doc_tokens)) doc_tokens.extend(cur_sent_words) ctx_char_to_word_offset.extend(cur_sent_char_to_word_offset) for cur_sent_word in cur_sent_words_start_idx: ctx_word_to_char_idx.append(sent_start_char_id + cur_sent_word) assert len(doc_tokens) == len(ctx_word_to_char_idx) sent_start_end_position.append( (sent_start_word_id, len(doc_tokens) - 1)) for sent_ner_id, (_, ent_start_char, ent_end_char, _) in enumerate(sent_ner): if (ent_start_char, ent_end_char) in ctx_answer_set: continue s_ent_text = sent[ent_start_char:ent_end_char] s_ent_text_norm = normalize_answer(s_ent_text) if s_ent_text_norm == stripped_title_norm: ctx_answer_candidates.append(len(ctx_entities_text)) if local_sent_name in sup_facts: if len(ctx_answer_ids ) == 0 and s_ent_text_norm == answer_norm: ctx_answer_ids.append(len(ctx_entities_text)) ctx_entities_text.append(s_ent_text) s_e_edges.append( (sent_id, len(ctx_entity_start_end_position))) ctx_entity_start_end_position.append( (ctx_char_to_word_offset[sent_start_char_id + ent_start_char], ctx_char_to_word_offset[sent_start_char_id + ent_end_char - 1])) ctx_answer_set.add((ent_start_char, ent_end_char)) # Find answer position if local_sent_name in sup_facts: sup_para_id.add(title_id) sup_facts_sent_id.append(sent_id) answer_offsets = [] # find word offset for cur_word_start_idx in cur_sent_words_start_idx: if sent[cur_word_start_idx:cur_word_start_idx + len(case['answer'])] == case['answer']: answer_offsets.append(cur_word_start_idx) if len(answer_offsets) == 0: answer_offset = sent.find(case['answer']) if answer_offset != -1: answer_offsets.append(answer_offset) if case['answer'] not in ['yes', 'no' ] and len(answer_offsets) > 0: for answer_offset in answer_offsets: start_char_position = sent_start_char_id + answer_offset end_char_position = start_char_position + len( case['answer']) - 1 ans_start_position.append( ctx_char_to_word_offset[start_char_position]) ans_end_position.append( ctx_char_to_word_offset[end_char_position]) prev_sent_id = sent_id sent_id += 1 para_end_position = len(doc_tokens) - 1 para_start_end_position.append( (para_start_position, para_end_position, title)) title_id += 1 p_p_edges = [] ## 5) paragraph2paragraph edges s_p_edges = [] ## 6) sentence2paragraph edges for _l in sel_paras[0]: for _r in sel_paras[1]: # edges: P -> P p_p_edges.append((title_to_id[_l], title_to_id[_r])) # edges: S -> P for local_sent_id, link_titles in enumerate( doc_link_data[_l]['hyperlink_titles']): inter_titles = set(link_titles) & set(title_to_id.keys()) if len(inter_titles) > 0 and _r in inter_titles: s_p_edges.append( (sent_to_id[(_l, local_sent_id)], title_to_id[_r])) # print('selected paragraphs {}'.format(sel_paras)) q_p_edges = [(0, title_to_id[para]) for para in sel_paras[0]] ### 7) question2paragraph edges edges = { 'ques_para': q_p_edges, 'para_para': p_p_edges, 'sent_sent': s_s_edges, 'para_sent': p_s_edges, 'sent_para': s_p_edges, 'ques_ent': q_e_edges, 'sent_ent': s_e_edges } ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def sae_graph_edges(edges, ques_entities_text, ctx_entities_text): def tuple_to_dict(tuple_list): res = {} for tup in tuple_list: if tup[0] not in res: res[tup[0]] = [tup[1]] else: res[tup[0]].append(tup[1]) return res #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ para_sent_edges = edges['para_sent'] sents_in_para_dict = tuple_to_dict(tuple_list=para_sent_edges) sent_to_sent_in_doc_edges = [] for key, sent_list in sents_in_para_dict.items(): sent_list = sorted(sent_list) ### increasing order if len(sent_list) > 1: for i in range(len(sent_list) - 1): for j in range(i + 1, len(sent_list)): sent_to_sent_in_doc_edges.append( (sent_list[i], sent_list[j])) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ query_ent_edges = edges['ques_ent'] assert len(query_ent_edges) == len( ques_entities_text) ### equal to number of entities in query # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_ent_edges = edges['sent_ent'] assert len(sent_ent_edges) == len(ctx_entities_text) norm_ctx_entities_text = [ normalize_text(_) for _ in ctx_entities_text ] norm_ctx_ent_pair = [ (w[0], w[1][0]) for w in zip(norm_ctx_entities_text, sent_ent_edges) ] ## tuple (normed entity, sent id) sents_for_norm_ent_dict = tuple_to_dict( tuple_list=norm_ctx_ent_pair ) ## key: normed entity, value: sent ids ents_in_sent_dict = tuple_to_dict( tuple_list=sent_ent_edges) ## key: sentence, value: entities for key in sents_for_norm_ent_dict.keys(): sents_for_norm_ent_dict[key] = sorted( list(set(sents_for_norm_ent_dict[key]))) ### distinct # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ norm_ques_entities_text = [ normalize_text(_) for _ in ques_entities_text ] norm_ques_entities_text = list( set(norm_ques_entities_text)) ## distinct normalized entities def shared_query_entity_sent_edges(norm_ques_entities_text, ents_in_sent_dict, sents_for_norm_ent_dict, para_sent_edges): sent_to_sent_shared_edges = [] norm_ques_entities_text_filter = [ _ for _ in norm_ques_entities_text if _ in sents_for_norm_ent_dict ] for i in range(len(norm_ques_entities_text_filter) - 1): sent_list_i = sents_for_norm_ent_dict[ norm_ques_entities_text_filter[i]] for j in range(i + 1, len(norm_ques_entities_text_filter)): sent_list_j = sents_for_norm_ent_dict[ norm_ques_entities_text_filter[j]] for l, r in zip(sent_list_i, sent_list_j): sent_pair = (l, r) if l < r else (r, l) if para_sent_edges[sent_pair[0]][ 0] != para_sent_edges[sent_pair[1]][0]: ents_l = set(ents_in_sent_dict[sent_pair[0]]) ents_r = set(ents_in_sent_dict[sent_pair[1]]) if (sent_pair not in sent_to_sent_shared_edges ) and (len(ents_l.intersection(ents_r)) == 0): sent_to_sent_shared_edges.append(sent_pair) return sent_to_sent_shared_edges # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_to_sent_query_cross_edges = shared_query_entity_sent_edges( norm_ques_entities_text, ents_in_sent_dict, sents_for_norm_ent_dict, para_sent_edges) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def doc_cross_entity_sent_edges(sents_for_norm_ent_dict, para_sent_edges): sent_to_sent_cross_edges = [] sents_for_norm_ent_filter = [ (key, value) for key, value in sents_for_norm_ent_dict.items() if len(value) > 1 ] for key, sent_list in sents_for_norm_ent_filter: sent_list = sorted(sent_list) for i in range(len(sent_list) - 1): for j in range(i + 1, len(sent_list)): if para_sent_edges[sent_list[i]][ 0] != para_sent_edges[sent_list[j]][0]: sent_pair = (sent_list[i], sent_list[j]) if sent_pair not in sent_to_sent_cross_edges: sent_to_sent_cross_edges.append(sent_pair) return sent_to_sent_cross_edges sent_to_sent_para_cross_edges = doc_cross_entity_sent_edges( sents_for_norm_ent_dict, para_sent_edges) return sent_to_sent_in_doc_edges, sent_to_sent_query_cross_edges, sent_to_sent_para_cross_edges ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ s_s_edges, s_s_q_edges, s_s_p_edges = sae_graph_edges( edges=edges, ctx_entities_text=ctx_entities_text, ques_entities_text=ques_entities_text) ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ edges['sent_sent'] = s_s_p_edges edges[ 'sent_sent_cross'] = s_s_q_edges + s_s_p_edges ### updating edges ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position)) max_entity_cnt = max(max_entity_cnt, len(ctx_entity_start_end_position)) if len(ans_start_position) > 1: # take the exact match for answer to avoid case of partial match start_position, end_position = [], [] for _start_pos, _end_pos in zip(ans_start_position, ans_end_position): if normalize_answer(" ".join( doc_tokens[_start_pos:_end_pos + 1])) == normalize_answer(case['answer']): start_position.append(_start_pos) end_position.append(_end_pos)
def error_analysis(raw_data, examples, features, predictions, tokenizer, use_ent_ans=False): yes_no_span_predictions = [] yes_no_span_true = [] prediction_ans_type_counter = Counter() prediction_sent_type_counter = Counter() prediction_para_type_counter = Counter() pred_ans_type_list = [] pred_sent_type_list = [] pred_doc_type_list = [] ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for row in raw_data: qid = row['_id'] sp_predictions = predictions['sp'][qid] sp_predictions = [(x[0], x[1]) for x in sp_predictions] ans_prediction = predictions['answer'][qid] raw_answer = row['answer'] raw_answer = normalize_answer(raw_answer) ans_prediction = normalize_answer(ans_prediction) sp_golds = row['supporting_facts'] sp_golds = [(x[0], x[1]) for x in sp_golds] sp_para_golds = list(set([_[0] for _ in sp_golds])) ##+++++++++++ sp_predictions = [x for x in sp_predictions if x[0] in sp_para_golds] # print(len(set(sp_predictions))) ##+++++++++++ sp_sent_type = set_comparison(prediction_list=sp_predictions, true_list=sp_golds) ###+++++++++ prediction_sent_type_counter[sp_sent_type] +=1 pred_sent_type_list.append(sp_sent_type) ###+++++++++ sp_para_preds = list(set([_[0] for _ in sp_predictions])) para_type = set_comparison(prediction_list=sp_para_preds, true_list=sp_para_golds) prediction_para_type_counter[para_type] += 1 pred_doc_type_list.append(para_type) ###+++++++++ if raw_answer not in ['yes', 'no']: yes_no_span_true.append('span') else: yes_no_span_true.append(raw_answer) if ans_prediction not in ['yes', 'no']: yes_no_span_predictions.append('span') else: yes_no_span_predictions.append(ans_prediction) ans_type = 'em' if raw_answer not in ['yes', 'no']: if raw_answer == ans_prediction: ans_type = 'em' elif raw_answer in ans_prediction: # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction)) # print('-'*75) ans_type = 'super_of_gold' elif ans_prediction in raw_answer: # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction)) # print('-'*75) ans_type = 'sub_of_gold' else: ans_pred_tokens = ans_prediction.split(' ') ans_raw_tokens = raw_answer.split(' ') is_empty_set = len(set(ans_pred_tokens).intersection(set(ans_raw_tokens))) == 0 if is_empty_set: ans_type = 'no_over_lap' else: ans_type = 'others' else: if raw_answer == ans_prediction: ans_type = 'em' else: ans_type = 'others' prediction_ans_type_counter[ans_type] += 1 pred_ans_type_list.append(ans_type) print(len(pred_sent_type_list), len(pred_ans_type_list), len(pred_doc_type_list)) result_types = ['em', 'sub_of_gold', 'super_of_gold', 'no_over_lap', 'others'] conf_matrix = confusion_matrix(yes_no_span_true, yes_no_span_predictions, labels=["yes", "no", "span"]) conf_ans_sent_matrix = confusion_matrix(pred_sent_type_list, pred_ans_type_list, labels=result_types) print('*' * 75) print('Ans type conf matrix:\n{}'.format(conf_matrix)) print('*' * 75) print('Type conf matrix:\n{}'.format(conf_ans_sent_matrix)) print('*' * 75) print("Ans prediction type: {}".format(prediction_ans_type_counter)) print("Sent prediction type: {}".format(prediction_sent_type_counter)) print("Para prediction type: {}".format(prediction_para_type_counter)) print('*' * 75) conf_matrix_para_vs_sent = confusion_matrix(pred_doc_type_list, pred_sent_type_list, labels=result_types) print('Para Type vs Sent Type conf matrix:\n{}'.format(conf_matrix_para_vs_sent)) print('*' * 75) conf_matrix_para_vs_ans = confusion_matrix(pred_doc_type_list, pred_ans_type_list, labels=result_types) print('Para Type vs Sent Type conf matrix:\n{}'.format(conf_matrix_para_vs_ans))
def predict(examples, features, pred_file, tokenizer, use_ent_ans=False): answer_dict = dict() sp_dict = dict() ids = list(examples.keys()) max_sent_num = 0 max_entity_num = 0 q_type_counter = Counter() answer_no_match_cnt = 0 for i, qid in enumerate(ids): feature = features[qid] example = examples[qid] q_type = feature.ans_type max_sent_num = max(max_sent_num, len(feature.sent_spans)) max_entity_num = max(max_entity_num, len(feature.entity_spans)) q_type_counter[q_type] += 1 def get_ans_from_pos(y1, y2): tok_to_orig_map = feature.token_to_orig_map final_text = " " if y1 < len(tok_to_orig_map) and y2 < len(tok_to_orig_map): orig_tok_start = tok_to_orig_map[y1] orig_tok_end = tok_to_orig_map[y2] ques_tok_len = len(example.question_tokens) if orig_tok_start < ques_tok_len and orig_tok_end < ques_tok_len: ques_start_idx = example.question_word_to_char_idx[orig_tok_start] ques_end_idx = example.question_word_to_char_idx[orig_tok_end] + len(example.question_tokens[orig_tok_end]) final_text = example.question_text[ques_start_idx:ques_end_idx] else: orig_tok_start -= len(example.question_tokens) orig_tok_end -= len(example.question_tokens) ctx_start_idx = example.ctx_word_to_char_idx[orig_tok_start] ctx_end_idx = example.ctx_word_to_char_idx[orig_tok_end] + len(example.doc_tokens[orig_tok_end]) final_text = example.ctx_text[example.ctx_word_to_char_idx[orig_tok_start]:example.ctx_word_to_char_idx[orig_tok_end]+len(example.doc_tokens[orig_tok_end])] return final_text #return tokenizer.convert_tokens_to_string(tok_tokens) answer_text = '' if q_type == 0 or q_type == 3: if len(feature.start_position) == 0 or len(feature.end_position) == 0: answer_text = "" else: #st, ed = example.start_position[0], example.end_position[0] #answer_text = example.ctx_text[example.ctx_word_to_char_idx[st]:example.ctx_word_to_char_idx[ed]+len(example.doc_tokens[example.end_position[0]])] answer_text = get_ans_from_pos(feature.start_position[0], feature.end_position[0]) if normalize_answer(answer_text) != normalize_answer(example.orig_answer_text): print("{} | {} | {} | {} | {}".format(qid, answer_text, example.orig_answer_text, feature.start_position[0], feature.end_position[0])) answer_no_match_cnt += 1 if q_type == 3 and use_ent_ans: ans_id = feature.answer_in_entity_ids[0] st, ed = feature.entity_spans[ans_id] answer_text = get_ans_from_pos(st, ed) elif q_type == 1: answer_text = 'yes' elif q_type == 2: answer_text = 'no' answer_dict[qid] = answer_text cur_sp = [] for sent_id in feature.sup_fact_ids: cur_sp.append(example.sent_names[sent_id]) sp_dict[qid] = cur_sp final_pred = {'answer': answer_dict, 'sp': sp_dict} json.dump(final_pred, open(pred_file, 'w')) print("Maximum sentence num: {}".format(max_sent_num)) print("Maximum entity num: {}".format(max_entity_num)) print("Question type: {}".format(q_type_counter)) print("Answer doesnot match: {}".format(answer_no_match_cnt))