def prediction_score_analysis_adaptive_threshold(raw_data, predictions,
                                                 prediction_scores,
                                                 at_thresholds):
    def positive_neg_score(scores, mask, names, gold_names, threshold_score,
                           pred_names):
        assert len(scores) == len(mask)
        mask_sum_num = int(sum(mask))
        prune_names = names[:mask_sum_num]
        gold_name_set = set(gold_names)
        if (gold_name_set.issubset(set(prune_names))):
            flag = True
        else:
            flag = False
        positive_scores = []
        negative_scores = []
        for idx in range(mask_sum_num):
            name_i = prune_names[idx]
            if name_i in gold_name_set:
                positive_scores.append(scores[idx])
            else:
                negative_scores.append(scores[idx])

        if len(positive_scores) > 0:
            min_positive = min(positive_scores)
        else:
            min_positive = 0.0
        if len(negative_scores) == 0:
            max_negative = 1.0
        else:
            max_negative = max(negative_scores)
        num_candidates = mask_sum_num
        num_golds = len(gold_name_set)

        min_p_names = []
        max_n_names = []
        threshold_names = []
        for i in range(mask_sum_num):
            if scores[i] >= min_positive:
                min_p_names.append(names[i])
            if scores[i] > max_negative:
                max_n_names.append(names[i])
            if threshold_score > 0.45:
                threshold_score = 0.45
            else:
                threshold_score = 0.35
            if sp_scores[i] > threshold_score:
                threshold_names.append(names[i])
        return flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names, threshold_names

    threshold_metric_dict = {}
    threshold_metric_dict['pred'] = []
    threshold_metric_dict['min_p'] = []
    threshold_metric_dict['max_n'] = []
    threshold_metric_dict['threshold_n'] = []
    prune_gold_num = 0
    analysis_result_list = []
    for row in raw_data:
        qid = row['_id']
        threshold_score = at_thresholds[qid]
        question_type = row['type']
        answer_type = row['answer']
        if answer_type.strip().lower() not in ['yes', 'no']:
            answer_type = 'span'
        sp_predictions = predictions['sp'][qid]
        sp_predictions = [(x[0], x[1]) for x in sp_predictions]
        sp_para_predictions = list(set([x[0] for x in sp_predictions]))

        sp_golds = row['supporting_facts']
        sp_golds = [(x[0], x[1]) for x in sp_golds]
        sp_para_golds = list(set([_[0] for _ in sp_golds]))

        if qid == '5a8a4a4055429930ff3c0d77':
            print(sp_predictions)
            print(sp_golds)

        res_scores = prediction_scores[qid]
        sp_scores = res_scores['sp_score']
        sp_mask = res_scores['sp_mask']
        sp_names = res_scores['sp_names']
        sp_names = [(x[0], x[1]) for x in sp_names]
        flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names, threshold_names = \
            positive_neg_score(scores=sp_scores, mask=sp_mask, names=sp_names, gold_names=sp_golds, threshold_score=threshold_score, pred_names=sp_predictions)

        ans_prediction = predictions['answer'][qid]
        raw_answer = row['answer']
        raw_answer = normalize_answer(raw_answer)
        ans_prediction = normalize_answer(ans_prediction)
        ans_metrics = update_answer(prediction=ans_prediction, gold=raw_answer)

        predict_metrics = update_sp(prediction=sp_predictions, gold=sp_golds)
        threshold_metric_dict['pred'].append((ans_metrics, predict_metrics))
        min_p_metrics = update_sp(prediction=min_p_names, gold=sp_golds)
        threshold_metric_dict['min_p'].append((ans_metrics, min_p_metrics))
        max_n_metrics = update_sp(prediction=max_n_names, gold=sp_golds)
        threshold_metric_dict['max_n'].append((ans_metrics, max_n_metrics))
        threshold_metrics = update_sp(prediction=threshold_names,
                                      gold=sp_golds)
        threshold_metric_dict['threshold_n'].append(
            (ans_metrics, threshold_metrics))

        if not flag:
            prune_gold_num += 1

        sp_sent_type = set_comparison(prediction_list=sp_predictions,
                                      true_list=sp_golds)

        # for key, value in sp_scores.items():
        #     print(key, value)
        # print('{}\t{}\t{}\t{:.5f}\t{:.5f}'.format(question_type, sp_sent_type, flag, min_positive, max_negative))
        analysis_result_list.append(
            (qid, question_type, sp_sent_type, flag, min_positive,
             max_negative, threshold_score, num_candidates, num_golds,
             answer_type))

    for key, value in threshold_metric_dict.items():
        print('threshold type = {}'.format(key))
        answer_em, answer_prec, answer_recall, answer_f1 = 0.0, 0.0, 0.0, 0.0
        sp_em, sp_prec, sp_recall, sp_f1 = 0.0, 0.0, 0.0, 0.0
        type_count = len(value)
        all_joint_em, all_joint_f1 = 0.0, 0.0
        for ans_tup, sp_tup in value:
            answer_em += ans_tup[0]
            answer_prec += ans_tup[1]
            answer_recall += ans_tup[2]
            answer_f1 += ans_tup[3]

            sp_em += sp_tup[0]
            sp_prec += sp_tup[1]
            sp_recall += sp_tup[2]
            sp_f1 += sp_tup[3]

            joint_prec = ans_tup[1] * sp_tup[1]
            joint_recall = ans_tup[2] * sp_tup[2]
            if joint_prec + joint_recall > 0:
                joint_f1 = 2 * joint_prec * joint_recall / (joint_prec +
                                                            joint_recall)
            else:
                joint_f1 = 0.
            joint_em = ans_tup[0] * sp_tup[0]

            all_joint_f1 += joint_f1
            all_joint_em += joint_em

        print('ans {}\t{}\t{}\t{}'.format(answer_em / type_count,
                                          answer_recall / type_count,
                                          answer_prec / type_count,
                                          answer_f1 / type_count))
        print('sup {}\t{}\t{}\t{}'.format(sp_em / type_count,
                                          sp_recall / type_count,
                                          sp_prec / type_count,
                                          sp_f1 / type_count))
        print('joint em ', all_joint_em / type_count)
        print('joint f1 ', all_joint_f1 / type_count)

    df = pd.DataFrame(analysis_result_list,
                      columns=[
                          'id', 'q_type', 'sp_sent_type', 'flag', 'min_p',
                          'max_n', 'threshold', 'cand_num', 'gold_num',
                          'ans_type'
                      ])

    print('prune = {}, complete = {}'.format(prune_gold_num,
                                             len(raw_data) - prune_gold_num))
    return df
def error_analysis_question_type(raw_data,
                                 predictions,
                                 tokenizer,
                                 use_ent_ans=False):
    type_metric_dict = {}
    for row in raw_data:
        question_type = row['type']
        if question_type not in type_metric_dict:
            type_metric_dict[question_type] = []

        qid = row['_id']
        sp_predictions = predictions['sp'][qid]
        sp_predictions = [(x[0], x[1]) for x in sp_predictions]
        sp_golds = row['supporting_facts']
        sp_golds = [(x[0], x[1]) for x in sp_golds]
        sp_metrics = update_sp(prediction=sp_predictions, gold=sp_golds)

        if qid == '5add114a5542994734353826':
            for x in row['context']:
                print('title ', x[0])
                for y_idx, y in enumerate(x[1]):
                    print('sentence', y_idx, y)

        ans_prediction = predictions['answer'][qid]
        raw_answer = row['answer']
        raw_answer = normalize_answer(raw_answer)
        ans_prediction = normalize_answer(ans_prediction)
        ans_metrics = update_answer(prediction=ans_prediction, gold=raw_answer)

        type_metric_dict[question_type].append((ans_metrics, sp_metrics))

    for key, value in type_metric_dict.items():
        print('question type = {}'.format(key))
        answer_em, answer_prec, answer_recall, answer_f1 = 0.0, 0.0, 0.0, 0.0
        sp_em, sp_prec, sp_recall, sp_f1 = 0.0, 0.0, 0.0, 0.0
        type_count = len(value)
        all_joint_em, all_joint_f1 = 0.0, 0.0
        for ans_tup, sp_tup in value:
            answer_em += ans_tup[0]
            answer_prec += ans_tup[1]
            answer_recall += ans_tup[2]
            answer_f1 += ans_tup[3]

            sp_em += sp_tup[0]
            sp_prec += sp_tup[1]
            sp_recall += sp_tup[2]
            sp_f1 += sp_tup[3]

            joint_prec = ans_tup[1] * sp_tup[1]
            joint_recall = ans_tup[2] * sp_tup[2]
            if joint_prec + joint_recall > 0:
                joint_f1 = 2 * joint_prec * joint_recall / (joint_prec +
                                                            joint_recall)
            else:
                joint_f1 = 0.
            joint_em = ans_tup[0] * sp_tup[0]

            all_joint_f1 += joint_f1
            all_joint_em += joint_em

        print('ans {}\t{}\t{}\t{}'.format(answer_em / type_count,
                                          answer_recall / type_count,
                                          answer_prec / type_count,
                                          answer_f1 / type_count))
        print('sup {}\t{}\t{}\t{}'.format(sp_em / type_count,
                                          sp_recall / type_count,
                                          sp_prec / type_count,
                                          sp_f1 / type_count))
        print('joint em ', all_joint_em / type_count)
        print('joint f1 ', all_joint_f1 / type_count)
def prediction_score_gap_train_analysis(raw_data,
                                        predictions,
                                        prediction_scores,
                                        train_type=None):
    def score_gap_split(scores, mask, names):
        assert len(scores) == len(mask)
        mask_sum_num = int(sum(mask))
        prune_names = names[:mask_sum_num]

        prune_scores = np.array(scores[:mask_sum_num])
        sorted_idxes = np.argsort(prune_scores)[::-1]
        largest_gap = -1
        max_gap_idx = -1
        for i in range(1, mask_sum_num - 1):
            gap = prune_scores[sorted_idxes[i]] - prune_scores[sorted_idxes[i +
                                                                            1]]
            if gap > largest_gap:
                largest_gap = gap
                max_gap_idx = i
        pred_idxes = sorted_idxes[:(max_gap_idx + 1)]
        gap_names = [prune_names[_] for _ in pred_idxes]
        return gap_names

    def positive_neg_score(scores, mask, names, gold_names, pred_names):
        assert len(scores) == len(mask)
        mask_sum_num = int(sum(mask))
        prune_names = names[:mask_sum_num]
        gold_name_set = set(gold_names)
        if (gold_name_set.issubset(set(prune_names))):
            flag = True
        else:
            flag = False
        positive_scores = []
        negative_scores = []
        for idx in range(mask_sum_num):
            name_i = prune_names[idx]
            if name_i in gold_name_set:
                positive_scores.append(scores[idx])
            else:
                negative_scores.append(scores[idx])

        if len(positive_scores) > 0:
            min_positive = min(positive_scores)
        else:
            min_positive = 0.0
        if len(negative_scores) == 0:
            max_negative = 1.0
        else:
            max_negative = max(negative_scores)
        num_candidates = mask_sum_num
        num_golds = len(gold_name_set)

        min_p_names = []
        max_n_names = []
        for i in range(mask_sum_num):
            if scores[i] >= min_positive:
                min_p_names.append(names[i])
            if scores[i] > max_negative:
                max_n_names.append(names[i])

        return flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names

    threshold_metric_dict = {}
    threshold_metric_dict['pred'] = []
    threshold_metric_dict['min_p'] = []
    threshold_metric_dict['max_n'] = []
    threshold_metric_dict['gap'] = []
    prune_gold_num = 0
    analysis_result_list = []
    # print(predictions['sp'])
    for row in raw_data:
        qid = row['_id']
        question_type = row['type']
        answer_type = row['answer']
        if train_type is not None:
            qid = qid + '_' + train_type
        # print(qid)
        if answer_type.strip().lower() not in ['yes', 'no']:
            answer_type = 'span'
        if qid not in predictions['sp']:
            continue
        sp_predictions = predictions['sp'][qid]
        sp_predictions = [(x[0], x[1]) for x in sp_predictions]
        sp_para_predictions = list(set([x[0] for x in sp_predictions]))

        sp_golds = row['supporting_facts']
        sp_golds = [(x[0], x[1]) for x in sp_golds]
        sp_para_golds = list(set([_[0] for _ in sp_golds]))

        res_scores = prediction_scores[qid]
        sp_scores = res_scores['sp_score']
        sp_mask = res_scores['sp_mask']
        sp_names = res_scores['sp_names']
        sp_names = [(x[0], x[1]) for x in sp_names]
        flag, min_positive, max_negative, num_candidates, num_golds, min_p_names, max_n_names = \
            positive_neg_score(scores=sp_scores, mask=sp_mask, names=sp_names, gold_names=sp_golds, pred_names=sp_predictions)

        ##++++
        gap_names = score_gap_split(scores=sp_scores,
                                    mask=sp_mask,
                                    names=sp_names)
        ##++++

        ans_prediction = predictions['answer'][qid]
        raw_answer = row['answer']
        raw_answer = normalize_answer(raw_answer)
        ans_prediction = normalize_answer(ans_prediction)
        ans_metrics = update_answer(prediction=ans_prediction, gold=raw_answer)

        predict_metrics = update_sp(prediction=sp_predictions, gold=sp_golds)
        threshold_metric_dict['pred'].append((ans_metrics, predict_metrics))
        min_p_metrics = update_sp(prediction=min_p_names, gold=sp_golds)
        threshold_metric_dict['min_p'].append((ans_metrics, min_p_metrics))
        max_n_metrics = update_sp(prediction=max_n_names, gold=sp_golds)
        threshold_metric_dict['max_n'].append((ans_metrics, max_n_metrics))

        gap_metrics = update_sp(prediction=gap_names, gold=sp_golds)
        threshold_metric_dict['gap'].append((ans_metrics, gap_metrics))

        if not flag:
            prune_gold_num += 1

        sp_sent_type = set_comparison(prediction_list=sp_predictions,
                                      true_list=sp_golds)

        # for key, value in sp_scores.items():
        #     print(key, value)
        # print('{}\t{}\t{}\t{:.5f}\t{:.5f}'.format(question_type, sp_sent_type, flag, min_positive, max_negative))
        analysis_result_list.append(
            (qid, question_type, sp_sent_type, flag, min_positive,
             max_negative, num_candidates, num_golds, answer_type))

    for key, value in threshold_metric_dict.items():
        print('threshold type = {}'.format(key))
        answer_em, answer_prec, answer_recall, answer_f1 = 0.0, 0.0, 0.0, 0.0
        sp_em, sp_prec, sp_recall, sp_f1 = 0.0, 0.0, 0.0, 0.0
        type_count = len(value)
        all_joint_em, all_joint_f1 = 0.0, 0.0
        for ans_tup, sp_tup in value:
            answer_em += ans_tup[0]
            answer_prec += ans_tup[1]
            answer_recall += ans_tup[2]
            answer_f1 += ans_tup[3]

            sp_em += sp_tup[0]
            sp_prec += sp_tup[1]
            sp_recall += sp_tup[2]
            sp_f1 += sp_tup[3]

            joint_prec = ans_tup[1] * sp_tup[1]
            joint_recall = ans_tup[2] * sp_tup[2]
            if joint_prec + joint_recall > 0:
                joint_f1 = 2 * joint_prec * joint_recall / (joint_prec +
                                                            joint_recall)
            else:
                joint_f1 = 0.
            joint_em = ans_tup[0] * sp_tup[0]

            all_joint_f1 += joint_f1
            all_joint_em += joint_em

        print('ans\t{}\t{}\t{}\t{}'.format(answer_em / type_count,
                                           answer_recall / type_count,
                                           answer_prec / type_count,
                                           answer_f1 / type_count))
        print('sup\t{}\t{}\t{}\t{}'.format(sp_em / type_count,
                                           sp_recall / type_count,
                                           sp_prec / type_count,
                                           sp_f1 / type_count))
        print('joint_em\t', all_joint_em / type_count)
        print('joint_f1\t', all_joint_f1 / type_count)

    df = pd.DataFrame(analysis_result_list,
                      columns=[
                          'id', 'q_type', 'sp_sent_type', 'flag', 'min_p',
                          'max_n', 'cand_num', 'gold_num', 'ans_type'
                      ])

    print('prune = {}, complete = {}'.format(prune_gold_num,
                                             len(raw_data) - prune_gold_num))
    return df
def error_analysis(raw_data, predictions, tokenizer, use_ent_ans=False):
    yes_no_span_predictions = []
    yes_no_span_true = []
    prediction_ans_type_counter = Counter()
    prediction_sent_type_counter = Counter()
    prediction_para_type_counter = Counter()

    pred_ans_type_list = []
    pred_sent_type_list = []
    pred_doc_type_list = []
    pred_sent_count_list = []

    pred_para_count_list = []
    ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    for row in raw_data:
        qid = row['_id']
        sp_predictions = predictions['sp'][qid]
        sp_predictions = [(x[0], x[1]) for x in sp_predictions]
        ans_prediction = predictions['answer'][qid]

        raw_answer = row['answer']
        raw_answer = normalize_answer(raw_answer)
        ans_prediction = normalize_answer(ans_prediction)
        sp_golds = row['supporting_facts']
        sp_golds = [(x[0], x[1]) for x in sp_golds]
        sp_para_golds = list(set([_[0] for _ in sp_golds]))
        ##+++++++++++
        # sp_predictions = [x for x in sp_predictions if x[0] in sp_para_golds]
        # sp_predictions
        print("{}\t{}\t{}".format(qid, len(set(sp_golds)),
                                  len(set(sp_predictions))))
        sp_para_predictions = list(set([x[0] for x in sp_predictions]))
        pred_para_count_list.append(len(sp_para_predictions))
        # +++++++++++
        if len(set(sp_golds)) > len(set(sp_predictions)):
            pred_sent_count_list.append('less')
        elif len(set(sp_golds)) < len(set(sp_predictions)):
            pred_sent_count_list.append('more')
        else:
            pred_sent_count_list.append('equal')
        ##+++++++++++
        sp_sent_type = set_comparison(prediction_list=sp_predictions,
                                      true_list=sp_golds)
        ###+++++++++
        prediction_sent_type_counter[sp_sent_type] += 1
        pred_sent_type_list.append(sp_sent_type)
        ###+++++++++
        sp_para_preds = list(set([_[0] for _ in sp_predictions]))
        para_type = set_comparison(prediction_list=sp_para_preds,
                                   true_list=sp_para_golds)
        prediction_para_type_counter[para_type] += 1
        pred_doc_type_list.append(para_type)
        ###+++++++++
        if raw_answer not in ['yes', 'no']:
            yes_no_span_true.append('span')
        else:
            yes_no_span_true.append(raw_answer)

        if ans_prediction not in ['yes', 'no']:
            yes_no_span_predictions.append('span')
        else:
            yes_no_span_predictions.append(ans_prediction)

        ans_type = 'em'
        if raw_answer not in ['yes', 'no']:
            if raw_answer == ans_prediction:
                ans_type = 'em'
            elif raw_answer in ans_prediction:
                # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction))
                # print('-'*75)
                ans_type = 'super_of_gold'
            elif ans_prediction in raw_answer:
                # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction))
                # print('-'*75)
                ans_type = 'sub_of_gold'
            else:
                ans_pred_tokens = ans_prediction.split(' ')
                ans_raw_tokens = raw_answer.split(' ')
                is_empty_set = len(
                    set(ans_pred_tokens).intersection(
                        set(ans_raw_tokens))) == 0
                if is_empty_set:
                    ans_type = 'no_over_lap'
                else:
                    ans_type = 'others'
        else:
            if raw_answer == ans_prediction:
                ans_type = 'em'
            else:
                ans_type = 'others'

        prediction_ans_type_counter[ans_type] += 1
        pred_ans_type_list.append(ans_type)

        # print('{} | {} | {}'.format(ans_type, raw_answer, ans_prediction))

    print(len(pred_sent_type_list), len(pred_ans_type_list),
          len(pred_doc_type_list))

    supp_sent_compare_type = ['equal', 'less', 'more']
    result_types = [
        'em', 'sub_of_gold', 'super_of_gold', 'no_over_lap', 'others'
    ]
    supp_sent_comp_dict = dict([(y, x)
                                for x, y in enumerate(supp_sent_compare_type)])
    supp_sent_type_dict = dict([(y, x) for x, y in enumerate(result_types)])
    assert len(pred_sent_type_list) == len(pred_sent_count_list)
    print(len(pred_sent_type_list), len(pred_sent_count_list))
    conf_supp_sent_matrix = np.zeros(
        (len(supp_sent_compare_type), len(result_types)), dtype=np.long)
    for idx in range(len(pred_sent_type_list)):
        comp_type_i = pred_sent_count_list[idx]
        supp_sent_type_i = pred_sent_type_list[idx]
        comp_idx_i = supp_sent_comp_dict[comp_type_i]
        supp_sent_idx_i = supp_sent_type_dict[supp_sent_type_i]
        conf_supp_sent_matrix[comp_idx_i][supp_sent_idx_i] += 1
    print('Sent Type vs Sent Count conf matrix:\n{}'.format(
        conf_supp_sent_matrix))
    print('Sum of matrix = {}'.format(conf_supp_sent_matrix.sum()))

    conf_matrix = confusion_matrix(yes_no_span_true,
                                   yes_no_span_predictions,
                                   labels=["yes", "no", "span"])
    conf_ans_sent_matrix = confusion_matrix(pred_sent_type_list,
                                            pred_ans_type_list,
                                            labels=result_types)
    print('*' * 75)
    print('Ans type conf matrix:\n{}'.format(conf_matrix))
    print('*' * 75)
    print('Sent vs ans conf matrix:\n{}'.format(conf_ans_sent_matrix))
    print('*' * 75)
    print("Ans prediction type: {}".format(prediction_ans_type_counter))
    print("Sent prediction type: {}".format(prediction_sent_type_counter))
    print("Para prediction type: {}".format(prediction_para_type_counter))
    print('*' * 75)

    conf_matrix_para_vs_sent = confusion_matrix(pred_doc_type_list,
                                                pred_sent_type_list,
                                                labels=result_types)
    print('Para Type vs Sent Type conf matrix:\n{}'.format(
        conf_matrix_para_vs_sent))
    print('*' * 75)
    conf_matrix_para_vs_ans = confusion_matrix(pred_doc_type_list,
                                               pred_ans_type_list,
                                               labels=result_types)
    print('Para Type vs ans Type conf matrix:\n{}'.format(
        conf_matrix_para_vs_ans))
    para_counter = Counter(pred_para_count_list)
    print('Para counter : {}'.format(para_counter))
示例#5
0
def read_hotpot_examples(para_file, full_file, ner_file, doc_link_file):

    with open(para_file, 'r', encoding='utf-8') as reader:
        para_data = json.load(reader)

    with open(full_file, 'r', encoding='utf-8') as reader:
        full_data = json.load(reader)

    with open(ner_file, 'r', encoding='utf-8') as reader:
        ner_data = json.load(reader)

    with open(doc_link_file, 'r', encoding='utf-8') as reader:
        doc_link_data = json.load(reader)

    def split_sent(sent, offset=0):
        nlp_doc = nlp(sent)
        words, word_start_idx, char_to_word_offset = [], [], []
        for token in nlp_doc:
            # token match a-b, then split further
            words.append(token.text)
            word_start_idx.append(token.idx)

        word_offset = 0
        for c in range(len(sent)):
            if word_offset >= len(word_start_idx) - 1 or c < word_start_idx[
                    word_offset + 1]:
                char_to_word_offset.append(word_offset + offset)
            else:
                char_to_word_offset.append(word_offset + offset + 1)
                word_offset += 1
        return words, char_to_word_offset, word_start_idx

    max_sent_cnt, max_entity_cnt = 0, 0

    examples = []
    for case in tqdm(full_data):
        key = case['_id']
        qas_type = case['type']
        sup_facts = set([(sp[0], sp[1]) for sp in case['supporting_facts']])
        context = dict(case['context'])

        doc_tokens = []
        sent_names = []
        sup_facts_sent_id = []
        sup_para_id = set()
        sent_start_end_position = []
        para_start_end_position = []
        ques_entity_start_end_position = []
        ques_entities_text = []
        ctx_entity_start_end_position = []
        ctx_entities_text = []
        ctx_text = ""
        ans_start_position, ans_end_position = [], []
        ques_answer_ids, ctx_answer_ids = [], []

        title_to_id, title_id = {}, 0
        sent_to_id, sent_id = {}, 0
        s_e_edges = []
        s_s_edges = []
        p_s_edges = []

        ctx_answer_candidates = []
        ctx_char_to_word_offset = []  # Accumulated along all sentences
        ctx_word_to_char_idx = []

        # process question entity span
        question_text = case['question']
        question_tokens, ques_char_to_word_offset, ques_word_to_char_idx = split_sent(
            question_text)
        answer_norm = normalize_answer(case['answer'])

        q_e_edges = []
        for q_ent, q_start, q_end, q_type in ner_data[key]['question']:
            q_ent_text = question_text[q_start:q_end]
            if q_type != 'CONTEXT' and q_ent_text not in ques_entities_text:
                if len(ques_answer_ids) == 0 and normalize_answer(
                        q_ent_text) == answer_norm:
                    ques_answer_ids.append(len(ques_entities_text))

                ques_entities_text.append(q_ent_text)
                q_e_edges.append((0, len(ques_entity_start_end_position)
                                  ))  # Q -> P; the id of Q is 0
                ques_entity_start_end_position.append(
                    (ques_char_to_word_offset[q_start],
                     ques_char_to_word_offset[q_end - 1]))

        sel_paras = para_data[key]
        ner_context = dict(ner_data[key]['context'])

        for title in itertools.chain.from_iterable(sel_paras):
            stripped_title = re.sub(r' \(.*?\)$', '', title)
            stripped_title_norm = normalize_answer(stripped_title)

            sents = context[title]
            sents_ner = ner_context[title]
            assert len(sents) == len(sents_ner)

            title_to_id[title] = title_id

            para_start_position = len(doc_tokens)
            prev_sent_id = None

            ctx_answer_set = set()
            for local_sent_id, (sent,
                                sent_ner) in enumerate(zip(sents, sents_ner)):
                # Determine the global sent id for supporting facts
                local_sent_name = (title, local_sent_id)
                sent_to_id[local_sent_name] = sent_id
                sent_names.append(local_sent_name)

                # P -> S
                p_s_edges.append((title_id, sent_id))
                if prev_sent_id is not None:
                    # S -> S
                    s_s_edges.append((prev_sent_id, sent_id))

                sent += " "
                ctx_text += sent
                sent_start_word_id = len(doc_tokens)
                sent_start_char_id = len(ctx_char_to_word_offset)

                prev_is_whitespace = True
                cur_sent_words, cur_sent_char_to_word_offset, cur_sent_words_start_idx = split_sent(
                    sent, offset=len(doc_tokens))
                doc_tokens.extend(cur_sent_words)
                ctx_char_to_word_offset.extend(cur_sent_char_to_word_offset)
                for cur_sent_word in cur_sent_words_start_idx:
                    ctx_word_to_char_idx.append(sent_start_char_id +
                                                cur_sent_word)
                assert len(doc_tokens) == len(ctx_word_to_char_idx)

                sent_start_end_position.append(
                    (sent_start_word_id, len(doc_tokens) - 1))

                for sent_ner_id, (_, ent_start_char, ent_end_char,
                                  _) in enumerate(sent_ner):
                    if (ent_start_char, ent_end_char) in ctx_answer_set:
                        continue
                    s_ent_text = sent[ent_start_char:ent_end_char]
                    s_ent_text_norm = normalize_answer(s_ent_text)

                    if s_ent_text_norm == stripped_title_norm:
                        ctx_answer_candidates.append(len(ctx_entities_text))

                        if local_sent_name in sup_facts:
                            if len(ctx_answer_ids
                                   ) == 0 and s_ent_text_norm == answer_norm:
                                ctx_answer_ids.append(len(ctx_entities_text))

                    ctx_entities_text.append(s_ent_text)
                    s_e_edges.append(
                        (sent_id, len(ctx_entity_start_end_position)))
                    ctx_entity_start_end_position.append(
                        (ctx_char_to_word_offset[sent_start_char_id +
                                                 ent_start_char],
                         ctx_char_to_word_offset[sent_start_char_id +
                                                 ent_end_char - 1]))

                    ctx_answer_set.add((ent_start_char, ent_end_char))

                # Find answer position
                if local_sent_name in sup_facts:
                    sup_para_id.add(title_id)
                    sup_facts_sent_id.append(sent_id)

                    answer_offsets = []
                    # find word offset
                    for cur_word_start_idx in cur_sent_words_start_idx:
                        if sent[cur_word_start_idx:cur_word_start_idx +
                                len(case['answer'])] == case['answer']:
                            answer_offsets.append(cur_word_start_idx)
                    if len(answer_offsets) == 0:
                        answer_offset = sent.find(case['answer'])
                        if answer_offset != -1:
                            answer_offsets.append(answer_offset)
                    if case['answer'] not in ['yes', 'no'
                                              ] and len(answer_offsets) > 0:
                        for answer_offset in answer_offsets:
                            start_char_position = sent_start_char_id + answer_offset
                            end_char_position = start_char_position + len(
                                case['answer']) - 1
                            ans_start_position.append(
                                ctx_char_to_word_offset[start_char_position])
                            ans_end_position.append(
                                ctx_char_to_word_offset[end_char_position])
                prev_sent_id = sent_id
                sent_id += 1
            para_end_position = len(doc_tokens) - 1
            para_start_end_position.append(
                (para_start_position, para_end_position, title))

            title_id += 1

        p_p_edges = []
        s_p_edges = []
        for _l in sel_paras[0]:
            for _r in sel_paras[1]:
                # edges: P -> P
                p_p_edges.append((title_to_id[_l], title_to_id[_r]))

                # edges: S -> P
                for local_sent_id, link_titles in enumerate(
                        doc_link_data[_l]['hyperlink_titles']):
                    inter_titles = set(link_titles) & set(title_to_id.keys())
                    if len(inter_titles) > 0 and _r in inter_titles:
                        s_p_edges.append(
                            (sent_to_id[(_l, local_sent_id)], title_to_id[_r]))
        q_p_edges = [(0, title_to_id[para]) for para in sel_paras[0]]

        edges = {
            'ques_para': q_p_edges,
            'para_para': p_p_edges,
            'sent_sent': s_s_edges,
            'para_sent': p_s_edges,
            'sent_para': s_p_edges,
            'ques_ent': q_e_edges,
            'sent_ent': s_e_edges
        }

        max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position))
        max_entity_cnt = max(max_entity_cnt,
                             len(ctx_entity_start_end_position))

        if len(ans_start_position) > 1:
            # take the exact match for answer to avoid case of partial match
            start_position, end_position = [], []
            for _start_pos, _end_pos in zip(ans_start_position,
                                            ans_end_position):
                if normalize_answer(" ".join(
                        doc_tokens[_start_pos:_end_pos +
                                   1])) == normalize_answer(case['answer']):
                    start_position.append(_start_pos)
                    end_position.append(_end_pos)
示例#6
0
def read_hotpot_examples(para_file,
                         full_file,
                         ner_file,
                         doc_link_file,
                         data_source_type=None):
    with open(para_file, 'r', encoding='utf-8') as reader:
        para_data = json.load(reader)

    with open(full_file, 'r', encoding='utf-8') as reader:
        full_data = json.load(reader)

    with open(ner_file, 'r', encoding='utf-8') as reader:
        ner_data = json.load(reader)

    with open(doc_link_file, 'r', encoding='utf-8') as reader:
        doc_link_data = json.load(reader)

    def split_sent(sent, offset=0):
        nlp_doc = nlp(sent)
        words, word_start_idx, char_to_word_offset = [], [], []
        for token in nlp_doc:
            # token match a-b, then split further
            words.append(token.text)
            word_start_idx.append(token.idx)

        word_offset = 0
        for c in range(len(sent)):
            if word_offset >= len(word_start_idx) - 1 or c < word_start_idx[
                    word_offset + 1]:
                char_to_word_offset.append(word_offset + offset)
            else:
                char_to_word_offset.append(word_offset + offset + 1)
                word_offset += 1
        return words, char_to_word_offset, word_start_idx

    max_sent_cnt, max_entity_cnt = 0, 0

    examples = []
    for case in tqdm(full_data):
        key = case['_id']
        qas_type = case['type']
        sup_facts = set([(sp[0], sp[1]) for sp in case['supporting_facts']])
        context = dict(case['context'])

        doc_tokens = []  ## spacy tokenized results
        sent_names = []  ## list of (title and local index)
        sup_facts_sent_id = [
        ]  ## send_id (absolute sent index, index in the concat text)
        sup_para_id = set()  ## support paragraph ids --> for para ranking
        sent_start_end_position = [
        ]  ## list of tuple (start and end) positions
        para_start_end_position = []  ## list of tuple (start, end, title)
        ques_entity_start_end_position = [
        ]  ## entity position pair (start, end) in the question
        ques_entities_text = []  ## question entity
        ctx_entity_start_end_position = [
        ]  ## entity position pair (start, end) in the context
        ctx_entities_text = []  ## context entities
        ctx_text = ""  ## ctx text information
        ans_start_position, ans_end_position = [], [
        ]  ## ans_start position, ans_end position
        ques_answer_ids, ctx_answer_ids = [], []  ##

        title_to_id, title_id = {}, 0
        sent_to_id, sent_id = {}, 0
        s_e_edges = []  ### 1) sentence2entity edges
        s_s_edges = []  ### 2) sentence2sentence edges (in single paragraph)
        p_s_edges = []  ### 3) para2sentence

        ctx_answer_candidates = []
        ctx_char_to_word_offset = []  # Accumulated along all sentences
        ctx_word_to_char_idx = []

        # process question entity span
        question_text = case['question']
        question_tokens, ques_char_to_word_offset, ques_word_to_char_idx = split_sent(
            question_text)
        answer_norm = normalize_answer(case['answer'])

        q_e_edges = []  ### 4) question2entity edges
        for q_ent, q_start, q_end, q_type in ner_data[key]['question']:
            q_ent_text = question_text[q_start:q_end]
            if q_type != 'CONTEXT' and q_ent_text not in ques_entities_text:
                if len(ques_answer_ids) == 0 and normalize_answer(
                        q_ent_text) == answer_norm:
                    ques_answer_ids.append(len(ques_entities_text))

                ques_entities_text.append(q_ent_text)
                q_e_edges.append((0, len(ques_entity_start_end_position)
                                  ))  # Q -> P; the id of Q is 0
                ques_entity_start_end_position.append(
                    (ques_char_to_word_offset[q_start],
                     ques_char_to_word_offset[q_end - 1]))

        sel_paras = para_data[key]
        ner_context = dict(ner_data[key]['context'])

        ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        para_names = []  ## for paragraph evaluation and checking
        ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        for title in itertools.chain.from_iterable(sel_paras):
            stripped_title = re.sub(r' \(.*?\)$', '', title)
            stripped_title_norm = normalize_answer(stripped_title)
            ####+++++++++++++++++++++++++
            para_names.append(title)
            ####+++++++++++++++++++++++++

            sents = context[title]
            sents_ner = ner_context[title]
            assert len(sents) == len(sents_ner)

            title_to_id[title] = title_id

            para_start_position = len(doc_tokens)
            prev_sent_id = None

            ctx_answer_set = set()
            for local_sent_id, (sent,
                                sent_ner) in enumerate(zip(sents, sents_ner)):
                # Determine the global sent id for supporting facts
                local_sent_name = (title, local_sent_id)
                sent_to_id[local_sent_name] = sent_id
                sent_names.append(local_sent_name)

                # P -> S
                p_s_edges.append((title_id, sent_id))  ###
                if prev_sent_id is not None:
                    # S -> S
                    s_s_edges.append((prev_sent_id, sent_id))

                sent += " "
                ctx_text += sent
                sent_start_word_id = len(doc_tokens)
                sent_start_char_id = len(ctx_char_to_word_offset)

                prev_is_whitespace = True
                cur_sent_words, cur_sent_char_to_word_offset, cur_sent_words_start_idx = split_sent(
                    sent, offset=len(doc_tokens))
                doc_tokens.extend(cur_sent_words)
                ctx_char_to_word_offset.extend(cur_sent_char_to_word_offset)
                for cur_sent_word in cur_sent_words_start_idx:
                    ctx_word_to_char_idx.append(sent_start_char_id +
                                                cur_sent_word)
                assert len(doc_tokens) == len(ctx_word_to_char_idx)

                sent_start_end_position.append(
                    (sent_start_word_id, len(doc_tokens) - 1))

                for sent_ner_id, (_, ent_start_char, ent_end_char,
                                  _) in enumerate(sent_ner):
                    if (ent_start_char, ent_end_char) in ctx_answer_set:
                        continue
                    s_ent_text = sent[ent_start_char:ent_end_char]
                    s_ent_text_norm = normalize_answer(s_ent_text)

                    if s_ent_text_norm == stripped_title_norm:
                        ctx_answer_candidates.append(len(ctx_entities_text))

                        if local_sent_name in sup_facts:
                            if len(ctx_answer_ids
                                   ) == 0 and s_ent_text_norm == answer_norm:
                                ctx_answer_ids.append(len(ctx_entities_text))

                    ctx_entities_text.append(s_ent_text)
                    s_e_edges.append(
                        (sent_id, len(ctx_entity_start_end_position)))
                    ctx_entity_start_end_position.append(
                        (ctx_char_to_word_offset[sent_start_char_id +
                                                 ent_start_char],
                         ctx_char_to_word_offset[sent_start_char_id +
                                                 ent_end_char - 1]))

                    ctx_answer_set.add((ent_start_char, ent_end_char))

                # Find answer position
                if local_sent_name in sup_facts:
                    sup_para_id.add(title_id)
                    sup_facts_sent_id.append(sent_id)

                    answer_offsets = []
                    # find word offset
                    for cur_word_start_idx in cur_sent_words_start_idx:
                        if sent[cur_word_start_idx:cur_word_start_idx +
                                len(case['answer'])] == case['answer']:
                            answer_offsets.append(cur_word_start_idx)
                    if len(answer_offsets) == 0:
                        answer_offset = sent.find(case['answer'])
                        if answer_offset != -1:
                            answer_offsets.append(answer_offset)
                    if case['answer'] not in ['yes', 'no'
                                              ] and len(answer_offsets) > 0:
                        for answer_offset in answer_offsets:
                            start_char_position = sent_start_char_id + answer_offset
                            end_char_position = start_char_position + len(
                                case['answer']) - 1
                            ans_start_position.append(
                                ctx_char_to_word_offset[start_char_position])
                            ans_end_position.append(
                                ctx_char_to_word_offset[end_char_position])
                prev_sent_id = sent_id
                sent_id += 1
            para_end_position = len(doc_tokens) - 1
            para_start_end_position.append(
                (para_start_position, para_end_position, title))

            title_id += 1

        p_p_edges = []  ## 5) paragraph2paragraph edges
        s_p_edges = []  ## 6) sentence2paragraph edges
        for _l in sel_paras[0]:
            for _r in sel_paras[1]:
                # edges: P -> P
                p_p_edges.append((title_to_id[_l], title_to_id[_r]))

                # edges: S -> P
                for local_sent_id, link_titles in enumerate(
                        doc_link_data[_l]['hyperlink_titles']):
                    inter_titles = set(link_titles) & set(title_to_id.keys())
                    if len(inter_titles) > 0 and _r in inter_titles:
                        s_p_edges.append(
                            (sent_to_id[(_l, local_sent_id)], title_to_id[_r]))
        # print('selected paragraphs {}'.format(sel_paras))
        q_p_edges = [(0, title_to_id[para])
                     for para in sel_paras[0]]  ### 7) question2paragraph edges

        edges = {
            'ques_para': q_p_edges,
            'para_para': p_p_edges,
            'sent_sent': s_s_edges,
            'para_sent': p_s_edges,
            'sent_para': s_p_edges,
            'ques_ent': q_e_edges,
            'sent_ent': s_e_edges
        }

        ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        def sae_graph_edges(edges, ques_entities_text, ctx_entities_text):
            def tuple_to_dict(tuple_list):
                res = {}
                for tup in tuple_list:
                    if tup[0] not in res:
                        res[tup[0]] = [tup[1]]
                    else:
                        res[tup[0]].append(tup[1])
                return res

            #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            para_sent_edges = edges['para_sent']
            sents_in_para_dict = tuple_to_dict(tuple_list=para_sent_edges)
            sent_to_sent_in_doc_edges = []
            for key, sent_list in sents_in_para_dict.items():
                sent_list = sorted(sent_list)  ### increasing order
                if len(sent_list) > 1:
                    for i in range(len(sent_list) - 1):
                        for j in range(i + 1, len(sent_list)):
                            sent_to_sent_in_doc_edges.append(
                                (sent_list[i], sent_list[j]))
            # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            query_ent_edges = edges['ques_ent']
            assert len(query_ent_edges) == len(
                ques_entities_text)  ### equal to number of entities in query
            # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            sent_ent_edges = edges['sent_ent']
            assert len(sent_ent_edges) == len(ctx_entities_text)
            norm_ctx_entities_text = [
                normalize_text(_) for _ in ctx_entities_text
            ]
            norm_ctx_ent_pair = [
                (w[0], w[1][0])
                for w in zip(norm_ctx_entities_text, sent_ent_edges)
            ]  ## tuple (normed entity, sent id)
            sents_for_norm_ent_dict = tuple_to_dict(
                tuple_list=norm_ctx_ent_pair
            )  ## key: normed entity, value: sent ids
            ents_in_sent_dict = tuple_to_dict(
                tuple_list=sent_ent_edges)  ## key: sentence, value: entities
            for key in sents_for_norm_ent_dict.keys():
                sents_for_norm_ent_dict[key] = sorted(
                    list(set(sents_for_norm_ent_dict[key])))  ### distinct
            # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            norm_ques_entities_text = [
                normalize_text(_) for _ in ques_entities_text
            ]
            norm_ques_entities_text = list(
                set(norm_ques_entities_text))  ## distinct normalized entities

            def shared_query_entity_sent_edges(norm_ques_entities_text,
                                               ents_in_sent_dict,
                                               sents_for_norm_ent_dict,
                                               para_sent_edges):
                sent_to_sent_shared_edges = []
                norm_ques_entities_text_filter = [
                    _ for _ in norm_ques_entities_text
                    if _ in sents_for_norm_ent_dict
                ]
                for i in range(len(norm_ques_entities_text_filter) - 1):
                    sent_list_i = sents_for_norm_ent_dict[
                        norm_ques_entities_text_filter[i]]
                    for j in range(i + 1, len(norm_ques_entities_text_filter)):
                        sent_list_j = sents_for_norm_ent_dict[
                            norm_ques_entities_text_filter[j]]
                        for l, r in zip(sent_list_i, sent_list_j):
                            sent_pair = (l, r) if l < r else (r, l)
                            if para_sent_edges[sent_pair[0]][
                                    0] != para_sent_edges[sent_pair[1]][0]:
                                ents_l = set(ents_in_sent_dict[sent_pair[0]])
                                ents_r = set(ents_in_sent_dict[sent_pair[1]])
                                if (sent_pair not in sent_to_sent_shared_edges
                                    ) and (len(ents_l.intersection(ents_r))
                                           == 0):
                                    sent_to_sent_shared_edges.append(sent_pair)
                return sent_to_sent_shared_edges

            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            sent_to_sent_query_cross_edges = shared_query_entity_sent_edges(
                norm_ques_entities_text, ents_in_sent_dict,
                sents_for_norm_ent_dict, para_sent_edges)

            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            def doc_cross_entity_sent_edges(sents_for_norm_ent_dict,
                                            para_sent_edges):
                sent_to_sent_cross_edges = []
                sents_for_norm_ent_filter = [
                    (key, value)
                    for key, value in sents_for_norm_ent_dict.items()
                    if len(value) > 1
                ]
                for key, sent_list in sents_for_norm_ent_filter:
                    sent_list = sorted(sent_list)
                    for i in range(len(sent_list) - 1):
                        for j in range(i + 1, len(sent_list)):
                            if para_sent_edges[sent_list[i]][
                                    0] != para_sent_edges[sent_list[j]][0]:
                                sent_pair = (sent_list[i], sent_list[j])
                                if sent_pair not in sent_to_sent_cross_edges:
                                    sent_to_sent_cross_edges.append(sent_pair)
                return sent_to_sent_cross_edges

            sent_to_sent_para_cross_edges = doc_cross_entity_sent_edges(
                sents_for_norm_ent_dict, para_sent_edges)
            return sent_to_sent_in_doc_edges, sent_to_sent_query_cross_edges, sent_to_sent_para_cross_edges

        ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        s_s_edges, s_s_q_edges, s_s_p_edges = sae_graph_edges(
            edges=edges,
            ctx_entities_text=ctx_entities_text,
            ques_entities_text=ques_entities_text)
        ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        edges['sent_sent'] = s_s_p_edges
        edges[
            'sent_sent_cross'] = s_s_q_edges + s_s_p_edges  ### updating edges
        ###########+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position))
        max_entity_cnt = max(max_entity_cnt,
                             len(ctx_entity_start_end_position))

        if len(ans_start_position) > 1:
            # take the exact match for answer to avoid case of partial match
            start_position, end_position = [], []
            for _start_pos, _end_pos in zip(ans_start_position,
                                            ans_end_position):
                if normalize_answer(" ".join(
                        doc_tokens[_start_pos:_end_pos +
                                   1])) == normalize_answer(case['answer']):
                    start_position.append(_start_pos)
                    end_position.append(_end_pos)
def error_analysis(raw_data, examples, features, predictions, tokenizer, use_ent_ans=False):
    yes_no_span_predictions = []
    yes_no_span_true = []
    prediction_ans_type_counter = Counter()
    prediction_sent_type_counter = Counter()
    prediction_para_type_counter = Counter()
    pred_ans_type_list = []
    pred_sent_type_list = []
    pred_doc_type_list = []
    ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    for row in raw_data:
        qid = row['_id']
        sp_predictions = predictions['sp'][qid]
        sp_predictions = [(x[0], x[1]) for x in sp_predictions]
        ans_prediction = predictions['answer'][qid]

        raw_answer = row['answer']
        raw_answer = normalize_answer(raw_answer)
        ans_prediction = normalize_answer(ans_prediction)
        sp_golds = row['supporting_facts']
        sp_golds = [(x[0], x[1]) for x in sp_golds]
        sp_para_golds = list(set([_[0] for _ in sp_golds]))
        ##+++++++++++
        sp_predictions = [x for x in sp_predictions if x[0] in sp_para_golds]
        # print(len(set(sp_predictions)))
        ##+++++++++++
        sp_sent_type = set_comparison(prediction_list=sp_predictions, true_list=sp_golds)
        ###+++++++++
        prediction_sent_type_counter[sp_sent_type] +=1
        pred_sent_type_list.append(sp_sent_type)
        ###+++++++++
        sp_para_preds = list(set([_[0] for _ in sp_predictions]))
        para_type = set_comparison(prediction_list=sp_para_preds, true_list=sp_para_golds)
        prediction_para_type_counter[para_type] += 1
        pred_doc_type_list.append(para_type)
        ###+++++++++
        if raw_answer not in ['yes', 'no']:
            yes_no_span_true.append('span')
        else:
            yes_no_span_true.append(raw_answer)

        if ans_prediction not in ['yes', 'no']:
            yes_no_span_predictions.append('span')
        else:
            yes_no_span_predictions.append(ans_prediction)

        ans_type = 'em'
        if raw_answer not in ['yes', 'no']:
            if raw_answer == ans_prediction:
                ans_type = 'em'
            elif raw_answer in ans_prediction:
                # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction))
                # print('-'*75)
                ans_type = 'super_of_gold'
            elif ans_prediction in raw_answer:
                # print('{}: {} |{}'.format(qid, raw_answer, ans_prediction))
                # print('-'*75)
                ans_type = 'sub_of_gold'
            else:
                ans_pred_tokens = ans_prediction.split(' ')
                ans_raw_tokens = raw_answer.split(' ')
                is_empty_set = len(set(ans_pred_tokens).intersection(set(ans_raw_tokens))) == 0
                if is_empty_set:
                    ans_type = 'no_over_lap'
                else:
                    ans_type = 'others'
        else:
            if raw_answer == ans_prediction:
                ans_type = 'em'
            else:
                ans_type = 'others'

        prediction_ans_type_counter[ans_type] += 1
        pred_ans_type_list.append(ans_type)


    print(len(pred_sent_type_list), len(pred_ans_type_list), len(pred_doc_type_list))

    result_types = ['em', 'sub_of_gold', 'super_of_gold', 'no_over_lap', 'others']
    conf_matrix = confusion_matrix(yes_no_span_true, yes_no_span_predictions, labels=["yes", "no", "span"])
    conf_ans_sent_matrix = confusion_matrix(pred_sent_type_list, pred_ans_type_list, labels=result_types)
    print('*' * 75)
    print('Ans type conf matrix:\n{}'.format(conf_matrix))
    print('*' * 75)
    print('Type conf matrix:\n{}'.format(conf_ans_sent_matrix))
    print('*' * 75)
    print("Ans prediction type: {}".format(prediction_ans_type_counter))
    print("Sent prediction type: {}".format(prediction_sent_type_counter))
    print("Para prediction type: {}".format(prediction_para_type_counter))
    print('*' * 75)

    conf_matrix_para_vs_sent = confusion_matrix(pred_doc_type_list, pred_sent_type_list, labels=result_types)
    print('Para Type vs Sent Type conf matrix:\n{}'.format(conf_matrix_para_vs_sent))
    print('*' * 75)
    conf_matrix_para_vs_ans = confusion_matrix(pred_doc_type_list, pred_ans_type_list, labels=result_types)
    print('Para Type vs Sent Type conf matrix:\n{}'.format(conf_matrix_para_vs_ans))
def predict(examples, features, pred_file, tokenizer, use_ent_ans=False):
    answer_dict = dict()
    sp_dict = dict()
    ids = list(examples.keys())

    max_sent_num = 0
    max_entity_num = 0
    q_type_counter = Counter()

    answer_no_match_cnt = 0
    for i, qid in enumerate(ids):
        feature = features[qid]
        example = examples[qid]
        q_type = feature.ans_type

        max_sent_num = max(max_sent_num, len(feature.sent_spans))
        max_entity_num = max(max_entity_num, len(feature.entity_spans))
        q_type_counter[q_type] += 1

        def get_ans_from_pos(y1, y2):
            tok_to_orig_map = feature.token_to_orig_map

            final_text = " "
            if y1 < len(tok_to_orig_map) and y2 < len(tok_to_orig_map):
                orig_tok_start = tok_to_orig_map[y1]
                orig_tok_end = tok_to_orig_map[y2]

                ques_tok_len = len(example.question_tokens)
                if orig_tok_start < ques_tok_len and orig_tok_end < ques_tok_len:
                    ques_start_idx = example.question_word_to_char_idx[orig_tok_start]
                    ques_end_idx = example.question_word_to_char_idx[orig_tok_end] + len(example.question_tokens[orig_tok_end])
                    final_text = example.question_text[ques_start_idx:ques_end_idx]
                else:
                    orig_tok_start -= len(example.question_tokens)
                    orig_tok_end -= len(example.question_tokens)
                    ctx_start_idx = example.ctx_word_to_char_idx[orig_tok_start]
                    ctx_end_idx = example.ctx_word_to_char_idx[orig_tok_end] + len(example.doc_tokens[orig_tok_end])
                    final_text = example.ctx_text[example.ctx_word_to_char_idx[orig_tok_start]:example.ctx_word_to_char_idx[orig_tok_end]+len(example.doc_tokens[orig_tok_end])]

            return final_text
            #return tokenizer.convert_tokens_to_string(tok_tokens)

        answer_text = ''
        if q_type == 0 or q_type == 3:
            if len(feature.start_position) == 0 or len(feature.end_position) == 0:
                answer_text = ""
            else:
                #st, ed = example.start_position[0], example.end_position[0]
                #answer_text = example.ctx_text[example.ctx_word_to_char_idx[st]:example.ctx_word_to_char_idx[ed]+len(example.doc_tokens[example.end_position[0]])]
                answer_text = get_ans_from_pos(feature.start_position[0], feature.end_position[0])
                if normalize_answer(answer_text) != normalize_answer(example.orig_answer_text):
                    print("{} | {} | {} | {} | {}".format(qid, answer_text, example.orig_answer_text, feature.start_position[0], feature.end_position[0]))
                    answer_no_match_cnt += 1
            if q_type == 3 and use_ent_ans:
                ans_id = feature.answer_in_entity_ids[0]
                st, ed = feature.entity_spans[ans_id]
                answer_text = get_ans_from_pos(st, ed)
        elif q_type == 1:
            answer_text = 'yes'
        elif q_type == 2:
            answer_text = 'no'

        answer_dict[qid] = answer_text
        cur_sp = []
        for sent_id in feature.sup_fact_ids:
            cur_sp.append(example.sent_names[sent_id])
        sp_dict[qid] = cur_sp

    final_pred = {'answer': answer_dict, 'sp': sp_dict}
    json.dump(final_pred, open(pred_file, 'w'))

    print("Maximum sentence num: {}".format(max_sent_num))
    print("Maximum entity num: {}".format(max_entity_num))
    print("Question type: {}".format(q_type_counter))
    print("Answer doesnot match: {}".format(answer_no_match_cnt))