예제 #1
0
 def validation_step(self, batch, batch_idx):
     start, end, q_type, paras, sents, ents, yp1, yp2 = self.forward(batch=batch)
     loss_list = compute_loss(self.hparams, batch, start, end, paras, sents, ents, q_type)
     loss, loss_span, loss_type, loss_sup, loss_ent, loss_para = loss_list
     dict_for_log = {'span_loss': loss_span, 'type_loss': loss_type,
                              'sent_loss': loss_sup, 'ent_loss': loss_ent,
                              'para_loss': loss_para,
                     'step': batch_idx + 1}
     #######################################################################
     type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
     answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(self.dev_example_dict,
                                                                                 self.dev_feature_dict,
                                                                                 batch['ids'],
                                                                                 yp1.data.cpu().numpy().tolist(),
                                                                                 yp2.data.cpu().numpy().tolist(),
                                                                                 type_prob)
     predict_support_np = torch.sigmoid(sents[:, :, 1]).data.cpu().numpy()
     valid_dict = {'answer': answer_dict_, 'ans_type': answer_type_dict_, 'ids': batch['ids'],
                   'ans_type_pro': answer_type_prob_dict_, 'supp_np': predict_support_np}
     #######################################################################
     output = {'valid_loss': loss, 'log': dict_for_log, 'valid_dict_output': valid_dict}
     # output = {'valid_dict_output': valid_dict}
     return output
예제 #2
0
def jd_train_eval_model(args,
                        encoder,
                        model,
                        dataloader,
                        example_dict,
                        feature_dict,
                        prediction_file,
                        eval_file,
                        train_gold_file,
                        train_type,
                        output_score_file=None):
    encoder.eval()
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}
    ##++++++
    prediction_res_score_dict = {}
    ##++++++
    # dataloader.refresh()
    #++++++
    cut_sentence_count = 0
    #++++++

    thresholds = np.arange(0.1, 1.0, 0.05)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]

    for batch in tqdm(dataloader):
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for key, value in batch.items():
            if key not in {'ids'}:
                batch[key] = value.to(args.device)
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        with torch.no_grad():
            inputs = {
                'input_ids':
                batch['context_idxs'],
                'attention_mask':
                batch['context_mask'],
                'token_type_ids':
                batch['segment_idxs']
                if args.model_type in ['bert', 'xlnet', 'electra'] else None
            }  # XLM don't use segment_ids
            outputs = encoder(**inputs)
            ####++++++++++++++++++++++++++++++++++++++
            if args.model_type == 'electra':
                batch['context_encoding'] = outputs.last_hidden_state
            else:
                batch['context_encoding'] = outputs[0]
            ####++++++++++++++++++++++++++++++++++++++
            batch['context_mask'] = batch['context_mask'].float().to(
                args.device)
            start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model(
                batch, return_yp=True, return_cls=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # print('ent_prediction', ent.shape)
        # print('ent_mask', batch['ans_cand_mask'])
        # print('gold_ent', batch['is_gold_ent'])
        ent_pre_prob = torch.sigmoid(ent).data.cpu().numpy()
        ent_mask_np = batch['ent_mask'].data.cpu().numpy()
        ans_cand_mask_np = batch['ans_cand_mask'].data.cpu().numpy()
        is_gold_ent_np = batch['is_gold_ent'].data.cpu().numpy()

        _, _, answer_sent_name_dict_ = convert_answer_to_sent_names(
            example_dict, feature_dict, batch,
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob, ent_pre_prob)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)

        predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()
        ####################################################################
        support_sent_mask_np = batch['sent_mask'].data.cpu().numpy()
        predict_support_para_np = torch.sigmoid(paras[:, :,
                                                      1]).data.cpu().numpy()
        support_para_mask_np = batch['para_mask'].data.cpu().numpy()
        cls_emb_np = cls_emb.data.cpu().numpy()
        ####################################################################
        predict_support_logit_np = sent[:, :, 1].data.cpu().numpy()
        predict_support_para_logit_np = paras[:, :, 1].data.cpu().numpy()
        ent_pre_logit_np = ent.data.cpu().numpy()
        ####################################################################

        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = [[] for _ in range(N_thresh)]
            cur_id = batch['ids'][i]
            ##+++++++++++++++++++++++++
            orig_supp_fact_id = example_dict[cur_id].sup_fact_id
            prune_supp_fact_id = feature_dict[cur_id].sup_fact_ids
            # print('origi supp fact id {}'.format(orig_supp_fact_id))
            # print('prune supp fact id {}'.format(prune_supp_fact_id))
            ##+++++++++++++++++++++++++
            topk_score_ref, cut_sent_flag, topk_pred_sent_names, diff_para_sent_names, topk_pred_paras = \
                post_process_sent_para(cur_id=cur_id, example_dict=example_dict, feature_dict=feature_dict,
                                       sent_scores_np_i=predict_support_np[i], sent_mask_np_i=support_sent_mask_np[i],
                                       para_scores_np_i=predict_support_para_np[i], para_mask_np_i=support_para_mask_np[i])
            ans_sent_name = answer_sent_name_dict_[cur_id]
            if cut_sent_flag:
                cut_sentence_count += 1
            ##+++++++++++++++++++++++++
            # sent_pred_ = {'sp_score': predict_support_np[i].tolist(), 'sp_mask': support_sent_mask_np[i].tolist(), 'sp_names': example_dict[cur_id].sent_names}
            # para_pred_ = {'para_score': predict_support_para_np[i].tolist(), 'para_mask': support_para_mask_np[i].tolist(), 'para_names': example_dict[cur_id].para_names}
            # ans_pred_ = {'ans_type': type_prob[i].tolist(), 'ent_score': ent_pre_prob[i].tolist(), 'ent_mask': ent_mask_np[i].tolist(),
            #              'query_entity': example_dict[cur_id].ques_entities_text, 'ctx_entity': example_dict[cur_id].ctx_entities_text,
            #              'ans_ent_mask': ans_cand_mask_np[i].tolist(), 'is_gold_ent': is_gold_ent_np[i].tolist(), 'answer': answer_dict[cur_id]}
            sent_pred_ = {
                'sp_score': predict_support_logit_np[i].tolist(),
                'sp_mask': support_sent_mask_np[i].tolist(),
                'sp_names': example_dict[cur_id].sent_names,
                'sup_fact_id': orig_supp_fact_id,
                'trim_sup_fact_id': prune_supp_fact_id
            }
            para_pred_ = {
                'para_score': predict_support_para_logit_np[i].tolist(),
                'para_mask': support_para_mask_np[i].tolist(),
                'para_names': example_dict[cur_id].para_names
            }
            ans_pred_ = {
                'ans_type': type_prob[i].tolist(),
                'ent_score': ent_pre_logit_np[i].tolist(),
                'ent_mask': ent_mask_np[i].tolist(),
                'query_entity': example_dict[cur_id].ques_entities_text,
                'ctx_entity': example_dict[cur_id].ctx_entities_text,
                'ans_ent_mask': ans_cand_mask_np[i].tolist(),
                'is_gold_ent': is_gold_ent_np[i].tolist(),
                'answer': answer_dict[cur_id]
            }
            cls_emb_ = {'cls_emb': cls_emb_np[i].tolist()}
            res_pred = {**sent_pred_, **para_pred_, **ans_pred_, **cls_emb_}
            prediction_res_score_dict[cur_id] = res_pred
            ##+++++++++++++++++++++++++

            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break

                for thresh_i in range(N_thresh):
                    # if predict_support_np[i, j] > thresholds[thresh_i]:
                    if predict_support_np[
                            i, j] > thresholds[thresh_i] * topk_score_ref:
                        cur_sp_pred[thresh_i].append(
                            example_dict[cur_id].sent_names[j])

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []
                ##+++++
                # +++++++++++++++++++++++++++
                post_process_thresh_i_sp_pred = post_process_technique(
                    cur_sp_pred=cur_sp_pred[thresh_i],
                    topk_pred_paras=topk_pred_paras,
                    topk_pred_sent_names=topk_pred_sent_names,
                    diff_para_sent_names=diff_para_sent_names,
                    ans_sent_name=ans_sent_name)
                total_sp_dict[thresh_i][cur_id].extend(
                    post_process_thresh_i_sp_pred)
                # # +++++++++++++++++++++++++++
                # total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        for thresh_i in range(N_thresh):
            prediction = {
                'answer': ans_dict,
                'sp': total_sp_dict[thresh_i],
                'type': answer_type_dict,
                'type_prob': answer_type_prob_dict
            }
            tmp_file = os.path.join(os.path.dirname(pred_file),
                                    'tmp_train.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = train_eval(tmp_file, train_gold_file, train_type)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)

        return best_metrics, best_threshold

    best_metrics, best_threshold = choose_best_threshold(
        answer_dict, prediction_file)
    json.dump(best_metrics, open(eval_file, 'w'))

    if output_score_file is not None:
        with open(output_score_file, 'w') as f:
            json.dump(prediction_res_score_dict, f)

    #####+++++++++++
    with open(train_gold_file) as f:
        gold = json.load(f)
    for row in gold:
        key = row['_id']
        print('suppo = {}'.format(row['supporting_facts']))
        if key in prediction_res_score_dict:
            score_case = prediction_res_score_dict[key]
            sp_names = score_case['sp_names']
            sup_fact_id = score_case['sup_fact_id']
            trim_sup_fact_id = score_case['trim_sup_fact_id']
            print('orig', [sp_names[_] for _ in sup_fact_id])
            print('trim', [sp_names[_] for _ in trim_sup_fact_id])
    #####+++++++++++

    print('Number of examples with cutted sentences = {}'.format(
        cut_sentence_count))
    return best_metrics, best_threshold
def jd_unified_post_feature_collection_model(args,
                                             model,
                                             dataloader,
                                             example_dict,
                                             feature_dict,
                                             prediction_file,
                                             eval_file,
                                             dev_gold_file=None,
                                             output_score_file=None):
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}

    ##++++++
    prediction_res_score_dict = {}
    ##++++++

    thresholds = np.arange(0.1, 1.0, 0.05)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]

    for batch in tqdm(dataloader):
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for key, value in batch.items():
            if key not in {'ids'}:
                batch[key] = value.to(args.device)
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        with torch.no_grad():
            start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model(
                batch, return_yp=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)

        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)

        predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()

        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        predict_support_logit_np = sent[:, :, 1].data.cpu().numpy()
        support_sent_mask_np = batch['sent_mask'].data.cpu().numpy()
        cls_emb_np = cls_emb.data.cpu().numpy()
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = [[] for _ in range(N_thresh)]
            cur_id = batch['ids'][i]
            # print(cur_id)
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            orig_supp_fact_id = example_dict[cur_id].sup_fact_id
            prune_supp_fact_id = feature_dict[cur_id].sup_fact_ids
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

            sent_pred_ = {
                'sp_score': predict_support_logit_np[i].tolist(),
                'sp_mask': support_sent_mask_np[i].tolist(),
                'sp_names': example_dict[cur_id].sent_names,
                'sup_fact_id': orig_supp_fact_id,
                'trim_sup_fact_id': prune_supp_fact_id
            }
            cls_emb_ = {'cls_emb': cls_emb_np[i].tolist()}
            res_score = {**sent_pred_, **cls_emb_}
            prediction_res_score_dict[cur_id] = res_score
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break

                for thresh_i in range(N_thresh):
                    if predict_support_np[i, j] > thresholds[thresh_i]:
                        cur_sp_pred[thresh_i].append(
                            example_dict[cur_id].sent_names[j])

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []

                total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        for thresh_i in range(N_thresh):
            prediction = {
                'answer': ans_dict,
                'sp': total_sp_dict[thresh_i],
                'type': answer_type_dict,
                'type_prob': answer_type_prob_dict
            }

            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)

        return best_metrics, best_threshold

    best_metrics, best_threshold = choose_best_threshold(
        answer_dict, prediction_file)
    json.dump(best_metrics, open(eval_file, 'w'))
    if output_score_file is not None:
        with open(output_score_file, 'w') as f:
            json.dump(prediction_res_score_dict, f)
        print('Saving {} score records into {}'.format(
            len(prediction_res_score_dict), output_score_file))

    return best_metrics, best_threshold
예제 #4
0
def jd_eval_model(args, encoder, model, dataloader, example_dict, feature_dict,
                  prediction_file, eval_file, dev_gold_file):
    encoder.eval()
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}

    dataloader.refresh()

    thresholds = np.arange(0.1, 1.0, 0.025)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]
    ##++++++++++++++++++++++++++++++++++
    total_para_sp_dict = {}
    ##++++++++++++++++++++++++++++++++++
    best_sp_dict = {}
    threshold_inter_count = 0
    ##++++++++++++++++++++++++++++++++++

    for batch in tqdm(dataloader):
        with torch.no_grad():
            inputs = {
                'input_ids':
                batch['context_idxs'],
                'attention_mask':
                batch['context_mask'],
                'token_type_ids':
                batch['segment_idxs']
                if args.model_type in ['bert', 'xlnet', 'electra'] else None
            }  # XLM don't use segment_ids
            outputs = encoder(**inputs)

            batch['context_encoding'] = outputs[0]
            batch['context_mask'] = batch['context_mask'].float().to(
                args.device)
            start, end, q_type, paras, sent, ent, yp1, yp2 = model(
                batch, return_yp=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)
        para_mask = batch['para_mask']
        sent_mask = batch['sent_mask']
        # print(para_mask.shape, paras.shape)
        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)
        ##++++++++++++++++++++++++++++++++++++++++
        paras = paras[:, :, 1] - (1 - para_mask) * 1e30
        predict_para_support_np = torch.sigmoid(paras).data.cpu().numpy()
        # predict_para_support_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy()
        ##++++++++++++++++++++++++++++++++++++++++
        # print('sent shape {}'.format(sent.shape))
        sent = sent[:, :, 1] - (1 - sent_mask) * 1e30
        # predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()
        predict_support_np = torch.sigmoid(sent).data.cpu().numpy()
        # print('supp sent np shape {}'.format(predict_support_np.shape))
        for i in range(predict_support_np.shape[0]):
            cur_id = batch['ids'][i]
            predict_para_support_np_ith = predict_para_support_np[i]
            predict_support_np_ith = predict_support_np[i]
            # ################################################
            cur_para_sp_pred = supp_doc_prediction(
                predict_para_support_np_ith=predict_para_support_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id)
            total_para_sp_dict[cur_id] = cur_para_sp_pred
            # ################################################
            cur_sp_pred = supp_sent_prediction(
                predict_support_np_ith=predict_support_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id,
                thresholds=thresholds)
            # ###################################

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []

                total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        #####
        best_threshold_idx = -1
        #####
        for thresh_i in range(N_thresh):
            prediction = {
                'answer': ans_dict,
                'sp': total_sp_dict[thresh_i],
                'type': answer_type_dict,
                'type_prob': answer_type_prob_dict
            }
            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                #####
                best_threshold_idx = thresh_i
                #####
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)

        return best_metrics, best_threshold, best_threshold_idx

    best_metrics, best_threshold, best_threshold_idx = choose_best_threshold(
        answer_dict, prediction_file)
    ##############++++++++++++
    doc_recall_metric = doc_recall_eval(doc_prediction=total_para_sp_dict,
                                        gold_file=dev_gold_file)
    ##############++++++++++++
    json.dump(best_metrics, open(eval_file, 'w'))
    # -------------------------------------
    best_prediction = {
        'answer': answer_dict,
        'sp': best_sp_dict,
        'type': answer_type_dict,
        'type_prob': answer_type_prob_dict
    }
    print('Number of inter threshold = {}'.format(threshold_inter_count))
    best_tmp_file = os.path.join(os.path.dirname(prediction_file),
                                 'best_tmp.json')
    with open(best_tmp_file, 'w') as f:
        json.dump(best_prediction, f)
    best_th_metrics = hotpot_eval(best_tmp_file, dev_gold_file)
    for key, val in best_th_metrics.items():
        print("{} = {}".format(key, val))
    # -------------------------------------
    return best_metrics, best_threshold, doc_recall_metric
예제 #5
0
def jd_unified_eval_model(args,
                          model,
                          dataloader,
                          example_dict,
                          feature_dict,
                          prediction_file,
                          eval_file,
                          dev_gold_file=None):
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}

    # dataloader.refresh()

    thresholds = np.arange(0.1, 1.0, 0.05)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]

    for batch in tqdm(dataloader):
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for key, value in batch.items():
            if key not in {'ids'}:
                batch[key] = value.to(args.device)
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        with torch.no_grad():
            start, end, q_type, paras, sent, ent, yp1, yp2 = model(
                batch, return_yp=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)

        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)

        predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()

        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = [[] for _ in range(N_thresh)]
            cur_id = batch['ids'][i]

            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break

                for thresh_i in range(N_thresh):
                    if predict_support_np[i, j] > thresholds[thresh_i]:
                        cur_sp_pred[thresh_i].append(
                            example_dict[cur_id].sent_names[j])

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []

                total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        for thresh_i in range(N_thresh):
            prediction = {
                'answer': ans_dict,
                'sp': total_sp_dict[thresh_i],
                'type': answer_type_dict,
                'type_prob': answer_type_prob_dict
            }
            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)

        return best_metrics, best_threshold

    best_metrics, best_threshold = choose_best_threshold(
        answer_dict, prediction_file)
    json.dump(best_metrics, open(eval_file, 'w'))

    return best_metrics, best_threshold
예제 #6
0
def jd_at_eval_model(args, encoder, model, dataloader, example_dict,
                     feature_dict, prediction_file, eval_file, dev_gold_file):
    encoder.eval()
    model.eval()

    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}
    dataloader.refresh()
    ##++++++++++++++++++++++++++++++++++
    total_sent_sp_dict = {}
    ##++++++++++++++++++++++++++++++++++
    total_para_sp_dict = {}
    ##++++++++++++++++++++++++++++++++++
    for batch in tqdm(dataloader):
        with torch.no_grad():
            inputs = {
                'input_ids':
                batch['context_idxs'],
                'attention_mask':
                batch['context_mask'],
                'token_type_ids':
                batch['segment_idxs']
                if args.model_type in ['bert', 'xlnet'] else None
            }  # XLM don't use segment_ids
            outputs = encoder(**inputs)

            batch['context_encoding'] = outputs[0]
            batch['context_mask'] = batch['context_mask'].float().to(
                args.device)
            start, end, q_type, paras, sent, ent, yp1, yp2 = model(
                batch, return_yp=True)

        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            yp1.data.cpu().numpy().tolist(),
            yp2.data.cpu().numpy().tolist(), type_prob)

        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)

        ##++++++++++++++++++++++++++++++++++++++++
        predict_para_support_logits = paras
        para_mask = batch['para_mask']
        batch_size = para_mask.shape[0]
        query_para_mask = torch.cat(
            [torch.ones(batch_size, 1).to(para_mask), para_mask], dim=-1)
        para_pred_out = adaptive_threshold_prediction(
            logits=predict_para_support_logits,
            number_labels=2,
            mask=query_para_mask,
            type='topk')
        para_pred_out_np = para_pred_out.data.cpu().numpy()
        ##++++++++++++++++++++++++++++++++++++++++
        predict_sent_support_logits = sent
        sent_mask = batch['sent_mask']
        query_sent_mask = torch.cat(
            [torch.ones(batch_size, 1).to(sent_mask), sent_mask], dim=-1)
        sent_pred_out = adaptive_threshold_prediction(
            logits=predict_sent_support_logits,
            number_labels=2,
            mask=query_sent_mask,
            type='or')
        sent_pred_out_np = sent_pred_out.data.cpu().numpy()
        ##++++++++++++++++++++++++++++++++++++++++
        for i in range(sent_pred_out_np.shape[0]):
            cur_id = batch['ids'][i]
            para_pred_out_np_ith = para_pred_out_np[i]
            sent_pred_out_np_ith = sent_pred_out_np[i]
            cur_para_sp_pred = supp_para_at_prediction(
                predict_para_support_np_ith=para_pred_out_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id)
            cur_sent_sp_pred = supp_sent_at_prediction(
                predict_sent_support_np_ith=sent_pred_out_np_ith,
                example_dict=example_dict,
                batch_ids_ith=cur_id)
            total_para_sp_dict[cur_id] = cur_para_sp_pred
            total_sent_sp_dict[cur_id] = cur_sent_sp_pred

    prediction = {
        'answer': answer_dict,
        'sp': total_sent_sp_dict,
        'type': answer_type_dict,
        'type_prob': answer_type_prob_dict
    }
    tmp_file = os.path.join(os.path.dirname(prediction_file), 'tmp.json')
    with open(tmp_file, 'w') as f:
        json.dump(prediction, f)
    eval_metrics = hotpot_eval(tmp_file, dev_gold_file)
    doc_recall_metric = doc_recall_eval(doc_prediction=total_para_sp_dict,
                                        gold_file=dev_gold_file)
    total_inconsistent_number = supp_doc_sent_consistent_checker(
        predict_para_dict=total_para_sp_dict,
        predicted_supp_sent_dict=total_sent_sp_dict,
        gold_file=dev_gold_file)
    ##++++++++++++++++++++++++++++++++++++++++
    json.dump(eval_metrics, open(eval_file, 'w'))
    return eval_metrics, doc_recall_metric, total_inconsistent_number
예제 #7
0
def lightnHGN_test_procedure(model, test_data_loader, dev_feature_dict, dev_example_dict, args, device):
    model.freeze()
    out_puts = []
    start_time = time()
    total_steps = len(test_data_loader)
    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(test_data_loader)):
            batch = batch2device(batch=batch, device=device)
            start, end, q_type, paras, sents, ents, yp1, yp2 = model.forward(batch=batch)
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
            answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(dev_example_dict,
                                                                                        dev_feature_dict,
                                                                                        batch['ids'],
                                                                                        yp1.data.cpu().numpy().tolist(),
                                                                                        yp2.data.cpu().numpy().tolist(),
                                                                                        type_prob)
            predict_support_np = torch.sigmoid(sents[:, :, 1]).data.cpu().numpy()
            valid_dict = {'answer': answer_dict_, 'ans_type': answer_type_dict_, 'ids': batch['ids'],
                          'ans_type_pro': answer_type_prob_dict_, 'supp_np': predict_support_np}
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            if (batch_idx + 1) % args.eval_batch_size == 0:
                print('Evaluating the model... {}/{} in {:.4f} seconds'.format(batch_idx + 1, total_steps, time()-start_time))
            out_puts.append(valid_dict)
            del batch
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}

    thresholds = np.arange(0.1, 1.0, 0.025)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    for batch_idx, valid_dict in tqdm(enumerate(out_puts)):
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = valid_dict['answer'], valid_dict['ans_type'], \
                                                                  valid_dict['ans_type_pro']
        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)

        predict_support_np = valid_dict['supp_np']
        batch_ids = valid_dict['ids']
        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = [[] for _ in range(N_thresh)]
            cur_id = batch_ids[i]

            for j in range(predict_support_np.shape[1]):
                if j >= len(dev_example_dict[cur_id].sent_names):
                    break
                for thresh_i in range(N_thresh):
                    if predict_support_np[i, j] > thresholds[thresh_i]:
                        cur_sp_pred[thresh_i].append(dev_example_dict[cur_id].sent_names[j])

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []
                total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        #################
        metric_dict = {}
        #################
        for thresh_i in range(N_thresh):
            prediction = {'answer': ans_dict,
                          'sp': total_sp_dict[thresh_i],
                          'type': answer_type_dict,
                          'type_prob': answer_type_prob_dict}
            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, args.dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)
            #######
            metric_dict[thresh_i] = (
                metrics['em'], metrics['f1'],
                metrics['sp_em'], metrics['sp_f1'],
                metrics['joint_em'], metrics['joint_f1'])
            #######
        return best_metrics, best_threshold, metric_dict

    output_pred_file = os.path.join(args.exp_name, f'pred.json')
    output_eval_file = os.path.join(args.exp_name, f'eval.txt')
    ####+++++
    best_metrics, best_threshold, metric_dict = choose_best_threshold(answer_dict, output_pred_file)
    ####++++++
    logging.info('Leader board evaluation completed with threshold = {:.4f}'.format(best_threshold))
    log_metrics(mode='Evaluation', metrics=best_metrics)
    logging.info('*' * 75)
    ####++++++
    for key, value in metric_dict.items():
        str_value = ['{:.4f}'.format(_) for _ in value]
        logging.info('threshold {:.4f}: \t metrics: {}'.format(thresholds[key], str_value))
    ####++++++
    json.dump(best_metrics, open(output_eval_file, 'w'))
    #############################################################################
    return best_metrics, best_threshold
예제 #8
0
def jd_postprecess_unified_test_model(args, model, dataloader, example_dict, feature_dict, threshold=0.35,
                          output_score_file=None):
    model.eval()
    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}
    sp_dict = {}
    ##++++++
    prediction_res_score_dict = {}
    ##++++++
    for batch in tqdm(dataloader):
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for key, value in batch.items():
            if key not in {'ids'}:
                batch[key] = value.to(args.device)
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        with torch.no_grad():
            start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model(batch, return_yp=True)
        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(example_dict, feature_dict, batch['ids'],
                                                                                    yp1.data.cpu().numpy().tolist(),
                                                                                    yp2.data.cpu().numpy().tolist(),
                                                                                    type_prob)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        _, _, answer_sent_name_dict_ = convert_answer_to_sent_names(example_dict, feature_dict, batch,
                                                                    yp1.data.cpu().numpy().tolist(),
                                                                    yp2.data.cpu().numpy().tolist(),
                                                                    type_prob)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)
        predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()
        predict_support_para_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy()
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        predict_support_logit_np = sent[:, :, 1].data.cpu().numpy()
        support_sent_mask_np = batch['sent_mask'].data.cpu().numpy()
        predict_support_para_logit_np = paras[:, :, 1].data.cpu().numpy()
        support_para_mask_np = batch['para_mask'].data.cpu().numpy()
        cls_emb_np = cls_emb.data.cpu().numpy()
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        for i in range(predict_support_np.shape[0]):
            cur_id = batch['ids'][i]
            cur_sp_pred = []
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            sent_pred_ = {'sp_score': predict_support_logit_np[i].tolist(), 'sp_mask': support_sent_mask_np[i].tolist(),
                          'sp_names': example_dict[cur_id].sent_names}
            para_pred_ = {'sp_para_score': predict_support_para_logit_np[i].tolist(), 'sp_para_mask': support_para_mask_np[i].tolist(),
                          'sp_para_names': example_dict[cur_id].para_names}

            cls_emb_ = {'cls_emb': cls_emb_np[i].tolist()}
            res_score = {**sent_pred_, **cls_emb_, **para_pred_}
            prediction_res_score_dict[cur_id] = res_score
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            topk_score_ref, cut_sent_flag, topk_pred_sent_names, diff_para_sent_names, topk_pred_paras = \
                post_process_sent_para(cur_id=cur_id, example_dict=example_dict, feature_dict=feature_dict,
                                       sent_scores_np_i=predict_support_np[i], sent_mask_np_i=support_sent_mask_np[i],
                                       para_scores_np_i=predict_support_para_np[i], para_mask_np_i=support_para_mask_np[i])
            ans_sent_name = answer_sent_name_dict_[cur_id]

            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break
                if predict_support_np[i, j] > topk_score_ref * threshold:
                    cur_sp_pred.append(example_dict[cur_id].sent_names[j])

            post_process_sp_pred = post_process_technique(cur_sp_pred=cur_sp_pred,
                                                                   topk_pred_paras=topk_pred_paras,
                                                                   topk_pred_sent_names=topk_pred_sent_names,
                                                                   diff_para_sent_names=diff_para_sent_names,
                                                                   ans_sent_name=ans_sent_name)
            sp_dict[cur_id] = post_process_sp_pred

    prediction = {'answer': answer_dict,
                  'sp': sp_dict,
                  'type': answer_type_dict,
                  'type_prob': answer_type_prob_dict}
    if output_score_file is not None:
        with open(output_score_file, 'w') as f:
            json.dump(prediction_res_score_dict, f)
        print('Saving {} score records into {}'.format(len(prediction_res_score_dict), output_score_file))
    return prediction
예제 #9
0
def jd_postprocess_unified_eval_model(args, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, dev_gold_file=None):
    model.eval()
    answer_dict = {}
    answer_type_dict = {}
    answer_type_prob_dict = {}

    thresholds = np.arange(0.1, 1.0, 0.025)
    N_thresh = len(thresholds)
    total_sp_dict = [{} for _ in range(N_thresh)]

    for batch in tqdm(dataloader):
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for key, value in batch.items():
            if key not in {'ids'}:
                batch[key] = value.to(args.device)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        with torch.no_grad():
            start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model(batch, return_yp=True)
        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(example_dict, feature_dict, batch['ids'],
                                                                                    yp1.data.cpu().numpy().tolist(),
                                                                                    yp2.data.cpu().numpy().tolist(),
                                                                                    type_prob)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        _, _, answer_sent_name_dict_ = convert_answer_to_sent_names(example_dict, feature_dict, batch,
                                                                                    yp1.data.cpu().numpy().tolist(),
                                                                                    yp2.data.cpu().numpy().tolist(),
                                                                                    type_prob)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        answer_type_dict.update(answer_type_dict_)
        answer_type_prob_dict.update(answer_type_prob_dict_)
        answer_dict.update(answer_dict_)
        predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy()
        ####################################################################
        support_sent_mask_np = batch['sent_mask'].data.cpu().numpy()
        predict_support_para_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy()
        support_para_mask_np = batch['para_mask'].data.cpu().numpy()
        ####################################################################
        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = [[] for _ in range(N_thresh)]
            cur_id = batch['ids'][i]
            ##+++++++++++++++++++++++++
            topk_score_ref, cut_sent_flag, topk_pred_sent_names, diff_para_sent_names, topk_pred_paras = \
                post_process_sent_para(cur_id=cur_id, example_dict=example_dict, feature_dict=feature_dict,
                                       sent_scores_np_i=predict_support_np[i], sent_mask_np_i=support_sent_mask_np[i],
                                       para_scores_np_i=predict_support_para_np[i], para_mask_np_i=support_para_mask_np[i])
            ans_sent_name = answer_sent_name_dict_[cur_id]
            ##+++++++++++++++++++++++++
            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break
                for thresh_i in range(N_thresh):
                    if predict_support_np[i, j] > thresholds[thresh_i] * topk_score_ref:
                        cur_sp_pred[thresh_i].append(example_dict[cur_id].sent_names[j])

            for thresh_i in range(N_thresh):
                if cur_id not in total_sp_dict[thresh_i]:
                    total_sp_dict[thresh_i][cur_id] = []
                ##+++++
                post_process_thresh_i_sp_pred = post_process_technique(cur_sp_pred=cur_sp_pred[thresh_i],
                                                               topk_pred_paras=topk_pred_paras,
                                                               topk_pred_sent_names=topk_pred_sent_names,
                                                               diff_para_sent_names=diff_para_sent_names,
                                                               ans_sent_name=ans_sent_name)
                total_sp_dict[thresh_i][cur_id].extend(post_process_thresh_i_sp_pred)
                # # +++++++++++++++++++++++++++

    def choose_best_threshold(ans_dict, pred_file):
        best_joint_f1 = 0
        best_metrics = None
        best_threshold = 0
        for thresh_i in range(N_thresh):
            prediction = {'answer': ans_dict,
                          'sp': total_sp_dict[thresh_i],
                          'type': answer_type_dict,
                          'type_prob': answer_type_prob_dict}
            tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
            with open(tmp_file, 'w') as f:
                json.dump(prediction, f)
            metrics = hotpot_eval(tmp_file, dev_gold_file)
            if metrics['joint_f1'] >= best_joint_f1:
                best_joint_f1 = metrics['joint_f1']
                best_threshold = thresholds[thresh_i]
                best_metrics = metrics
                shutil.move(tmp_file, pred_file)
        return best_metrics, best_threshold

    best_metrics, best_threshold = choose_best_threshold(answer_dict, prediction_file)
    json.dump(best_metrics, open(eval_file, 'w'))
    return best_metrics, best_threshold