Exemplo n.º 1
0
    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        #####
        best_threshold_idx = -1
        #####
        for thresh_i in range(N_thresh):
            prediction = {
                'answer': ans_dict,
                'sp': total_sp_dict[thresh_i],
                'type': answer_type_dict,
                'type_prob': answer_type_prob_dict
            }
            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                #####
                best_threshold_idx = thresh_i
                #####
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)

        return best_metrics, best_threshold, best_threshold_idx
Exemplo n.º 2
0
 def choose_best_threshold(ans_dict, pred_file):
     best_joint_f1 = 0
     best_metrics = None
     best_threshold = 0
     #################
     metric_dict = {}
     #################
     for thresh_i in range(N_thresh):
         prediction = {'answer': ans_dict,
                       'sp': total_sp_dict[thresh_i],
                       'type': answer_type_dict,
                       'type_prob': answer_type_prob_dict}
         tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp_{}.json'.format(self.trainer.root_gpu))
         with open(tmp_file, 'w') as f:
             json.dump(prediction, f)
         metrics = hotpot_eval(tmp_file, self.hparams.dev_gold_file)
         if metrics['joint_f1'] >= best_joint_f1:
             best_joint_f1 = metrics['joint_f1']
             best_threshold = thresholds[thresh_i]
             best_metrics = metrics
             shutil.move(tmp_file, pred_file)
         #######
         metric_dict[thresh_i] = (metrics['em'], metrics['f1'], metrics['sp_em'], metrics['sp_f1'], metrics['joint_em'], metrics['joint_f1'])
         #######
     return best_metrics, best_threshold, metric_dict
print('-' * 75)
prediction_score_dict = jd_postprocess_score_prediction(
    args=args,
    model=model,
    data_loader=test_data_loader,
    threshold_category=threshold_category)
with open(prediction_score_file, 'w') as fp:
    json.dump(prediction_score_dict, fp)
print('Saving {} records into {}'.format(len(prediction_score_dict),
                                         prediction_score_file))
# #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
raw_test_data_file = join(args.input_dir, args.raw_test_data)
test_answer_file = join(args.output_dir, args.exp_name,
                        args.test_answer_predict_name)
post_predict_dict = jd_adaptive_threshold_post_process(
    full_file=raw_test_data_file,
    score_dict_file=output_test_score_file,
    prediction_answer_file=test_answer_file,
    threshold_pred_dict_file=prediction_score_file)
post_predict_file = join(args.output_dir, args.exp_name,
                         args.post_test_prediction_name)
with open(post_predict_file, 'w') as fp:
    json.dump(post_predict_dict, fp)
print('Saving {} records into {}'.format(len(post_predict_dict),
                                         post_predict_file))

raw_dev_data_file = join(args.input_dir, args.raw_dev_data)
metrics = hotpot_eval(post_predict_file, raw_dev_data_file)
for key, value in metrics.items():
    print('{}:{}'.format(key, value))
print('-' * 75)
Exemplo n.º 4
0
def jd_eval_model(args, encoder, model, dataloader, example_dict, feature_dict,
                  prediction_file, eval_file, dev_gold_file):
    encoder.eval()
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}

    dataloader.refresh()

    thresholds = np.arange(0.1, 1.0, 0.025)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]
    ##++++++++++++++++++++++++++++++++++
    total_para_sp_dict = {}
    ##++++++++++++++++++++++++++++++++++
    best_sp_dict = {}
    threshold_inter_count = 0
    ##++++++++++++++++++++++++++++++++++

    for batch in tqdm(dataloader):
        with torch.no_grad():
            inputs = {
                'input_ids':
                batch['context_idxs'],
                'attention_mask':
                batch['context_mask'],
                'token_type_ids':
                batch['segment_idxs']
                if args.model_type in ['bert', 'xlnet', 'electra'] else None
            }  # XLM don't use segment_ids
            outputs = encoder(**inputs)

            batch['context_encoding'] = outputs[0]
            batch['context_mask'] = batch['context_mask'].float().to(
                args.device)
            start, end, q_type, paras, sent, ent, yp1, yp2 = model(
                batch, return_yp=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)
        para_mask = batch['para_mask']
        sent_mask = batch['sent_mask']
        # print(para_mask.shape, paras.shape)
        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)
        ##++++++++++++++++++++++++++++++++++++++++
        paras = paras[:, :, 1] - (1 - para_mask) * 1e30
        predict_para_support_np = torch.sigmoid(paras).data.cpu().numpy()
        # predict_para_support_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy()
        ##++++++++++++++++++++++++++++++++++++++++
        # print('sent shape {}'.format(sent.shape))
        sent = sent[:, :, 1] - (1 - sent_mask) * 1e30
        # predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()
        predict_support_np = torch.sigmoid(sent).data.cpu().numpy()
        # print('supp sent np shape {}'.format(predict_support_np.shape))
        for i in range(predict_support_np.shape[0]):
            cur_id = batch['ids'][i]
            predict_para_support_np_ith = predict_para_support_np[i]
            predict_support_np_ith = predict_support_np[i]
            # ################################################
            cur_para_sp_pred = supp_doc_prediction(
                predict_para_support_np_ith=predict_para_support_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id)
            total_para_sp_dict[cur_id] = cur_para_sp_pred
            # ################################################
            cur_sp_pred = supp_sent_prediction(
                predict_support_np_ith=predict_support_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id,
                thresholds=thresholds)
            # ###################################

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []

                total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        #####
        best_threshold_idx = -1
        #####
        for thresh_i in range(N_thresh):
            prediction = {
                'answer': ans_dict,
                'sp': total_sp_dict[thresh_i],
                'type': answer_type_dict,
                'type_prob': answer_type_prob_dict
            }
            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                #####
                best_threshold_idx = thresh_i
                #####
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)

        return best_metrics, best_threshold, best_threshold_idx

    best_metrics, best_threshold, best_threshold_idx = choose_best_threshold(
        answer_dict, prediction_file)
    ##############++++++++++++
    doc_recall_metric = doc_recall_eval(doc_prediction=total_para_sp_dict,
                                        gold_file=dev_gold_file)
    ##############++++++++++++
    json.dump(best_metrics, open(eval_file, 'w'))
    # -------------------------------------
    best_prediction = {
        'answer': answer_dict,
        'sp': best_sp_dict,
        'type': answer_type_dict,
        'type_prob': answer_type_prob_dict
    }
    print('Number of inter threshold = {}'.format(threshold_inter_count))
    best_tmp_file = os.path.join(os.path.dirname(prediction_file),
                                 'best_tmp.json')
    with open(best_tmp_file, 'w') as f:
        json.dump(best_prediction, f)
    best_th_metrics = hotpot_eval(best_tmp_file, dev_gold_file)
    for key, val in best_th_metrics.items():
        print("{} = {}".format(key, val))
    # -------------------------------------
    return best_metrics, best_threshold, doc_recall_metric
output_prediction_file = join(args.exp_name, 'prediction.json')

best_metrics, best_threshold = jd_unified_eval_model(
    args, model, test_data_loader, test_example_dict, test_feature_dict,
    output_pred_file, output_eval_file, args.dev_gold_file)
for key, val in best_metrics.items():
    print("{} = {}".format(key, val))
print('Best threshold = {}'.format(best_threshold))
threshold = best_threshold
predictions = jd_unified_test_model(args, model, test_data_loader,
                                    test_example_dict, test_feature_dict,
                                    threshold, output_test_score_file)
with open(output_prediction_file, 'w') as f:
    json.dump(predictions, f)
if args.dev_gold_file is not None:
    metrics = hotpot_eval(output_prediction_file, args.dev_gold_file)
    for key, value in metrics.items():
        print('{}:{}'.format(key, value))

# best_metrics, best_threshold = jd_postprocess_unified_eval_model(args, model, test_data_loader, test_example_dict, test_feature_dict,
#                                 output_pred_file, output_eval_file, args.dev_gold_file)
# for key, val in best_metrics.items():
#     print("{} = {}".format(key, val))
# print('Best threshold = {}'.format(best_threshold))
# threshold = best_threshold
# output_test_score_file = join(args.exp_name, 'test_score.json')
# output_prediction_file = join(args.exp_name, 'prediction.json')
# predictions = jd_postprecess_unified_test_model(args, model,
#                                 test_data_loader, test_example_dict, test_feature_dict,
#                                 threshold, output_test_score_file)
# with open(output_prediction_file, 'w') as f:
Exemplo n.º 6
0
                             "than this will be truncated, and sequences shorter than this will be padded.")

    args = parser.parse_args()

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    cached_examples_file = os.path.join(args.input_dir,
                                        get_cached_filename('examples', args))
    cached_features_file = os.path.join(args.input_dir,
                                        get_cached_filename('features',  args))
    cached_graphs_file = os.path.join(args.input_dir, 
                                     get_cached_filename('graphs', args))

    examples = pickle.load(gzip.open(cached_examples_file, 'rb'))
    features = pickle.load(gzip.open(cached_features_file, 'rb'))
    graph_dict = pickle.load(gzip.open(cached_graphs_file, 'rb'))

    example_dict = { example.qas_id: example for example in examples}
    feature_dict = { feature.qas_id: feature for feature in features}

    print("Loading examples from: {}".format(cached_examples_file))
    print("Loading features from: {}".format(cached_features_file))
    print("Loading graphs from: {}".format(cached_graphs_file))

    pred_file = join(args.output_dir, 'pred.json')
    predict(example_dict, feature_dict, pred_file, tokenizer, use_ent_ans=False)
    metrics = hotpot_eval(pred_file, args.raw_data)
    for key, val in metrics.items():
        print("{} = {}".format(key, val))
Exemplo n.º 7
0
def jd_at_eval_model(args, encoder, model, dataloader, example_dict,
                     feature_dict, prediction_file, eval_file, dev_gold_file):
    encoder.eval()
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}
    dataloader.refresh()
    ##++++++++++++++++++++++++++++++++++
    total_sent_sp_dict = {}
    ##++++++++++++++++++++++++++++++++++
    total_para_sp_dict = {}
    ##++++++++++++++++++++++++++++++++++
    for batch in tqdm(dataloader):
        with torch.no_grad():
            inputs = {
                'input_ids':
                batch['context_idxs'],
                'attention_mask':
                batch['context_mask'],
                'token_type_ids':
                batch['segment_idxs']
                if args.model_type in ['bert', 'xlnet'] else None
            }  # XLM don't use segment_ids
            outputs = encoder(**inputs)

            batch['context_encoding'] = outputs[0]
            batch['context_mask'] = batch['context_mask'].float().to(
                args.device)
            start, end, q_type, paras, sent, ent, yp1, yp2 = model(
                batch, return_yp=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)

        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)

        ##++++++++++++++++++++++++++++++++++++++++
        predict_para_support_logits = paras
        para_mask = batch['para_mask']
        batch_size = para_mask.shape[0]
        query_para_mask = torch.cat(
            [torch.ones(batch_size, 1).to(para_mask), para_mask], dim=-1)
        para_pred_out = adaptive_threshold_prediction(
            logits=predict_para_support_logits,
            number_labels=2,
            mask=query_para_mask,
            type='topk')
        para_pred_out_np = para_pred_out.data.cpu().numpy()
        ##++++++++++++++++++++++++++++++++++++++++
        predict_sent_support_logits = sent
        sent_mask = batch['sent_mask']
        query_sent_mask = torch.cat(
            [torch.ones(batch_size, 1).to(sent_mask), sent_mask], dim=-1)
        sent_pred_out = adaptive_threshold_prediction(
            logits=predict_sent_support_logits,
            number_labels=2,
            mask=query_sent_mask,
            type='or')
        sent_pred_out_np = sent_pred_out.data.cpu().numpy()
        ##++++++++++++++++++++++++++++++++++++++++
        for i in range(sent_pred_out_np.shape[0]):
            cur_id = batch['ids'][i]
            para_pred_out_np_ith = para_pred_out_np[i]
            sent_pred_out_np_ith = sent_pred_out_np[i]
            cur_para_sp_pred = supp_para_at_prediction(
                predict_para_support_np_ith=para_pred_out_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id)
            cur_sent_sp_pred = supp_sent_at_prediction(
                predict_sent_support_np_ith=sent_pred_out_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id)
            total_para_sp_dict[cur_id] = cur_para_sp_pred
            total_sent_sp_dict[cur_id] = cur_sent_sp_pred

    prediction = {
        'answer': answer_dict,
        'sp': total_sent_sp_dict,
        'type': answer_type_dict,
        'type_prob': answer_type_prob_dict
    }
    tmp_file = os.path.join(os.path.dirname(prediction_file), 'tmp.json')
    with open(tmp_file, 'w') as f:
        json.dump(prediction, f)
    eval_metrics = hotpot_eval(tmp_file, dev_gold_file)
    doc_recall_metric = doc_recall_eval(doc_prediction=total_para_sp_dict,
                                        gold_file=dev_gold_file)
    total_inconsistent_number = supp_doc_sent_consistent_checker(
        predict_para_dict=total_para_sp_dict,
        predicted_supp_sent_dict=total_sent_sp_dict,
        gold_file=dev_gold_file)
    ##++++++++++++++++++++++++++++++++++++++++
    json.dump(eval_metrics, open(eval_file, 'w'))
    return eval_metrics, doc_recall_metric, total_inconsistent_number