def validation_step(self, batch, batch_idx): start, end, q_type, paras, sents, ents, yp1, yp2 = self.forward(batch=batch) loss_list = compute_loss(self.hparams, batch, start, end, paras, sents, ents, q_type) loss, loss_span, loss_type, loss_sup, loss_ent, loss_para = loss_list dict_for_log = {'span_loss': loss_span, 'type_loss': loss_type, 'sent_loss': loss_sup, 'ent_loss': loss_ent, 'para_loss': loss_para, 'step': batch_idx + 1} ####################################################################### type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(self.dev_example_dict, self.dev_feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) predict_support_np = torch.sigmoid(sents[:, :, 1]).data.cpu().numpy() valid_dict = {'answer': answer_dict_, 'ans_type': answer_type_dict_, 'ids': batch['ids'], 'ans_type_pro': answer_type_prob_dict_, 'supp_np': predict_support_np} ####################################################################### output = {'valid_loss': loss, 'log': dict_for_log, 'valid_dict_output': valid_dict} # output = {'valid_dict_output': valid_dict} return output
def jd_train_eval_model(args, encoder, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, train_gold_file, train_type, output_score_file=None): encoder.eval() model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} ##++++++ prediction_res_score_dict = {} ##++++++ # dataloader.refresh() #++++++ cut_sentence_count = 0 #++++++ thresholds = np.arange(0.1, 1.0, 0.05) N_thresh = len(thresholds) total_sp_dict = [{} for _ in range(N_thresh)] for batch in tqdm(dataloader): #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for key, value in batch.items(): if key not in {'ids'}: batch[key] = value.to(args.device) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ with torch.no_grad(): inputs = { 'input_ids': batch['context_idxs'], 'attention_mask': batch['context_mask'], 'token_type_ids': batch['segment_idxs'] if args.model_type in ['bert', 'xlnet', 'electra'] else None } # XLM don't use segment_ids outputs = encoder(**inputs) ####++++++++++++++++++++++++++++++++++++++ if args.model_type == 'electra': batch['context_encoding'] = outputs.last_hidden_state else: batch['context_encoding'] = outputs[0] ####++++++++++++++++++++++++++++++++++++++ batch['context_mask'] = batch['context_mask'].float().to( args.device) start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model( batch, return_yp=True, return_cls=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens( example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # print('ent_prediction', ent.shape) # print('ent_mask', batch['ans_cand_mask']) # print('gold_ent', batch['is_gold_ent']) ent_pre_prob = torch.sigmoid(ent).data.cpu().numpy() ent_mask_np = batch['ent_mask'].data.cpu().numpy() ans_cand_mask_np = batch['ans_cand_mask'].data.cpu().numpy() is_gold_ent_np = batch['is_gold_ent'].data.cpu().numpy() _, _, answer_sent_name_dict_ = convert_answer_to_sent_names( example_dict, feature_dict, batch, yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob, ent_pre_prob) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy() #################################################################### support_sent_mask_np = batch['sent_mask'].data.cpu().numpy() predict_support_para_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy() support_para_mask_np = batch['para_mask'].data.cpu().numpy() cls_emb_np = cls_emb.data.cpu().numpy() #################################################################### predict_support_logit_np = sent[:, :, 1].data.cpu().numpy() predict_support_para_logit_np = paras[:, :, 1].data.cpu().numpy() ent_pre_logit_np = ent.data.cpu().numpy() #################################################################### for i in range(predict_support_np.shape[0]): cur_sp_pred = [[] for _ in range(N_thresh)] cur_id = batch['ids'][i] ##+++++++++++++++++++++++++ orig_supp_fact_id = example_dict[cur_id].sup_fact_id prune_supp_fact_id = feature_dict[cur_id].sup_fact_ids # print('origi supp fact id {}'.format(orig_supp_fact_id)) # print('prune supp fact id {}'.format(prune_supp_fact_id)) ##+++++++++++++++++++++++++ topk_score_ref, cut_sent_flag, topk_pred_sent_names, diff_para_sent_names, topk_pred_paras = \ post_process_sent_para(cur_id=cur_id, example_dict=example_dict, feature_dict=feature_dict, sent_scores_np_i=predict_support_np[i], sent_mask_np_i=support_sent_mask_np[i], para_scores_np_i=predict_support_para_np[i], para_mask_np_i=support_para_mask_np[i]) ans_sent_name = answer_sent_name_dict_[cur_id] if cut_sent_flag: cut_sentence_count += 1 ##+++++++++++++++++++++++++ # sent_pred_ = {'sp_score': predict_support_np[i].tolist(), 'sp_mask': support_sent_mask_np[i].tolist(), 'sp_names': example_dict[cur_id].sent_names} # para_pred_ = {'para_score': predict_support_para_np[i].tolist(), 'para_mask': support_para_mask_np[i].tolist(), 'para_names': example_dict[cur_id].para_names} # ans_pred_ = {'ans_type': type_prob[i].tolist(), 'ent_score': ent_pre_prob[i].tolist(), 'ent_mask': ent_mask_np[i].tolist(), # 'query_entity': example_dict[cur_id].ques_entities_text, 'ctx_entity': example_dict[cur_id].ctx_entities_text, # 'ans_ent_mask': ans_cand_mask_np[i].tolist(), 'is_gold_ent': is_gold_ent_np[i].tolist(), 'answer': answer_dict[cur_id]} sent_pred_ = { 'sp_score': predict_support_logit_np[i].tolist(), 'sp_mask': support_sent_mask_np[i].tolist(), 'sp_names': example_dict[cur_id].sent_names, 'sup_fact_id': orig_supp_fact_id, 'trim_sup_fact_id': prune_supp_fact_id } para_pred_ = { 'para_score': predict_support_para_logit_np[i].tolist(), 'para_mask': support_para_mask_np[i].tolist(), 'para_names': example_dict[cur_id].para_names } ans_pred_ = { 'ans_type': type_prob[i].tolist(), 'ent_score': ent_pre_logit_np[i].tolist(), 'ent_mask': ent_mask_np[i].tolist(), 'query_entity': example_dict[cur_id].ques_entities_text, 'ctx_entity': example_dict[cur_id].ctx_entities_text, 'ans_ent_mask': ans_cand_mask_np[i].tolist(), 'is_gold_ent': is_gold_ent_np[i].tolist(), 'answer': answer_dict[cur_id] } cls_emb_ = {'cls_emb': cls_emb_np[i].tolist()} res_pred = {**sent_pred_, **para_pred_, **ans_pred_, **cls_emb_} prediction_res_score_dict[cur_id] = res_pred ##+++++++++++++++++++++++++ for j in range(predict_support_np.shape[1]): if j >= len(example_dict[cur_id].sent_names): break for thresh_i in range(N_thresh): # if predict_support_np[i, j] > thresholds[thresh_i]: if predict_support_np[ i, j] > thresholds[thresh_i] * topk_score_ref: cur_sp_pred[thresh_i].append( example_dict[cur_id].sent_names[j]) for thresh_i in range(N_thresh): if cur_id not in total_sp_dict[thresh_i]: total_sp_dict[thresh_i][cur_id] = [] ##+++++ # +++++++++++++++++++++++++++ post_process_thresh_i_sp_pred = post_process_technique( cur_sp_pred=cur_sp_pred[thresh_i], topk_pred_paras=topk_pred_paras, topk_pred_sent_names=topk_pred_sent_names, diff_para_sent_names=diff_para_sent_names, ans_sent_name=ans_sent_name) total_sp_dict[thresh_i][cur_id].extend( post_process_thresh_i_sp_pred) # # +++++++++++++++++++++++++++ # total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i]) def choose_best_threshold(ans_dict, pred_file): best_joint_f1 = 0 best_metrics = None best_threshold = 0 for thresh_i in range(N_thresh): prediction = { 'answer': ans_dict, 'sp': total_sp_dict[thresh_i], 'type': answer_type_dict, 'type_prob': answer_type_prob_dict } tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp_train.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) metrics = train_eval(tmp_file, train_gold_file, train_type) if metrics['joint_f1'] >= best_joint_f1: best_joint_f1 = metrics['joint_f1'] best_threshold = thresholds[thresh_i] best_metrics = metrics shutil.move(tmp_file, pred_file) return best_metrics, best_threshold best_metrics, best_threshold = choose_best_threshold( answer_dict, prediction_file) json.dump(best_metrics, open(eval_file, 'w')) if output_score_file is not None: with open(output_score_file, 'w') as f: json.dump(prediction_res_score_dict, f) #####+++++++++++ with open(train_gold_file) as f: gold = json.load(f) for row in gold: key = row['_id'] print('suppo = {}'.format(row['supporting_facts'])) if key in prediction_res_score_dict: score_case = prediction_res_score_dict[key] sp_names = score_case['sp_names'] sup_fact_id = score_case['sup_fact_id'] trim_sup_fact_id = score_case['trim_sup_fact_id'] print('orig', [sp_names[_] for _ in sup_fact_id]) print('trim', [sp_names[_] for _ in trim_sup_fact_id]) #####+++++++++++ print('Number of examples with cutted sentences = {}'.format( cut_sentence_count)) return best_metrics, best_threshold
def jd_unified_post_feature_collection_model(args, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, dev_gold_file=None, output_score_file=None): model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} ##++++++ prediction_res_score_dict = {} ##++++++ thresholds = np.arange(0.1, 1.0, 0.05) N_thresh = len(thresholds) total_sp_dict = [{} for _ in range(N_thresh)] for batch in tqdm(dataloader): #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for key, value in batch.items(): if key not in {'ids'}: batch[key] = value.to(args.device) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ with torch.no_grad(): start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model( batch, return_yp=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens( example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy() # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ predict_support_logit_np = sent[:, :, 1].data.cpu().numpy() support_sent_mask_np = batch['sent_mask'].data.cpu().numpy() cls_emb_np = cls_emb.data.cpu().numpy() # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for i in range(predict_support_np.shape[0]): cur_sp_pred = [[] for _ in range(N_thresh)] cur_id = batch['ids'][i] # print(cur_id) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ orig_supp_fact_id = example_dict[cur_id].sup_fact_id prune_supp_fact_id = feature_dict[cur_id].sup_fact_ids # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_pred_ = { 'sp_score': predict_support_logit_np[i].tolist(), 'sp_mask': support_sent_mask_np[i].tolist(), 'sp_names': example_dict[cur_id].sent_names, 'sup_fact_id': orig_supp_fact_id, 'trim_sup_fact_id': prune_supp_fact_id } cls_emb_ = {'cls_emb': cls_emb_np[i].tolist()} res_score = {**sent_pred_, **cls_emb_} prediction_res_score_dict[cur_id] = res_score # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for j in range(predict_support_np.shape[1]): if j >= len(example_dict[cur_id].sent_names): break for thresh_i in range(N_thresh): if predict_support_np[i, j] > thresholds[thresh_i]: cur_sp_pred[thresh_i].append( example_dict[cur_id].sent_names[j]) for thresh_i in range(N_thresh): if cur_id not in total_sp_dict[thresh_i]: total_sp_dict[thresh_i][cur_id] = [] total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i]) def choose_best_threshold(ans_dict, pred_file): best_joint_f1 = 0 best_metrics = None best_threshold = 0 for thresh_i in range(N_thresh): prediction = { 'answer': ans_dict, 'sp': total_sp_dict[thresh_i], 'type': answer_type_dict, 'type_prob': answer_type_prob_dict } tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) metrics = hotpot_eval(tmp_file, dev_gold_file) if metrics['joint_f1'] >= best_joint_f1: best_joint_f1 = metrics['joint_f1'] best_threshold = thresholds[thresh_i] best_metrics = metrics shutil.move(tmp_file, pred_file) return best_metrics, best_threshold best_metrics, best_threshold = choose_best_threshold( answer_dict, prediction_file) json.dump(best_metrics, open(eval_file, 'w')) if output_score_file is not None: with open(output_score_file, 'w') as f: json.dump(prediction_res_score_dict, f) print('Saving {} score records into {}'.format( len(prediction_res_score_dict), output_score_file)) return best_metrics, best_threshold
def jd_eval_model(args, encoder, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, dev_gold_file): encoder.eval() model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} dataloader.refresh() thresholds = np.arange(0.1, 1.0, 0.025) N_thresh = len(thresholds) total_sp_dict = [{} for _ in range(N_thresh)] ##++++++++++++++++++++++++++++++++++ total_para_sp_dict = {} ##++++++++++++++++++++++++++++++++++ best_sp_dict = {} threshold_inter_count = 0 ##++++++++++++++++++++++++++++++++++ for batch in tqdm(dataloader): with torch.no_grad(): inputs = { 'input_ids': batch['context_idxs'], 'attention_mask': batch['context_mask'], 'token_type_ids': batch['segment_idxs'] if args.model_type in ['bert', 'xlnet', 'electra'] else None } # XLM don't use segment_ids outputs = encoder(**inputs) batch['context_encoding'] = outputs[0] batch['context_mask'] = batch['context_mask'].float().to( args.device) start, end, q_type, paras, sent, ent, yp1, yp2 = model( batch, return_yp=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens( example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) para_mask = batch['para_mask'] sent_mask = batch['sent_mask'] # print(para_mask.shape, paras.shape) answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) ##++++++++++++++++++++++++++++++++++++++++ paras = paras[:, :, 1] - (1 - para_mask) * 1e30 predict_para_support_np = torch.sigmoid(paras).data.cpu().numpy() # predict_para_support_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy() ##++++++++++++++++++++++++++++++++++++++++ # print('sent shape {}'.format(sent.shape)) sent = sent[:, :, 1] - (1 - sent_mask) * 1e30 # predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy() predict_support_np = torch.sigmoid(sent).data.cpu().numpy() # print('supp sent np shape {}'.format(predict_support_np.shape)) for i in range(predict_support_np.shape[0]): cur_id = batch['ids'][i] predict_para_support_np_ith = predict_para_support_np[i] predict_support_np_ith = predict_support_np[i] # ################################################ cur_para_sp_pred = supp_doc_prediction( predict_para_support_np_ith=predict_para_support_np_ith, example_dict=example_dict, batch_ids_ith=cur_id) total_para_sp_dict[cur_id] = cur_para_sp_pred # ################################################ cur_sp_pred = supp_sent_prediction( predict_support_np_ith=predict_support_np_ith, example_dict=example_dict, batch_ids_ith=cur_id, thresholds=thresholds) # ################################### for thresh_i in range(N_thresh): if cur_id not in total_sp_dict[thresh_i]: total_sp_dict[thresh_i][cur_id] = [] total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i]) def choose_best_threshold(ans_dict, pred_file): best_joint_f1 = 0 best_metrics = None best_threshold = 0 ##### best_threshold_idx = -1 ##### for thresh_i in range(N_thresh): prediction = { 'answer': ans_dict, 'sp': total_sp_dict[thresh_i], 'type': answer_type_dict, 'type_prob': answer_type_prob_dict } tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) metrics = hotpot_eval(tmp_file, dev_gold_file) if metrics['joint_f1'] >= best_joint_f1: best_joint_f1 = metrics['joint_f1'] best_threshold = thresholds[thresh_i] ##### best_threshold_idx = thresh_i ##### best_metrics = metrics shutil.move(tmp_file, pred_file) return best_metrics, best_threshold, best_threshold_idx best_metrics, best_threshold, best_threshold_idx = choose_best_threshold( answer_dict, prediction_file) ##############++++++++++++ doc_recall_metric = doc_recall_eval(doc_prediction=total_para_sp_dict, gold_file=dev_gold_file) ##############++++++++++++ json.dump(best_metrics, open(eval_file, 'w')) # ------------------------------------- best_prediction = { 'answer': answer_dict, 'sp': best_sp_dict, 'type': answer_type_dict, 'type_prob': answer_type_prob_dict } print('Number of inter threshold = {}'.format(threshold_inter_count)) best_tmp_file = os.path.join(os.path.dirname(prediction_file), 'best_tmp.json') with open(best_tmp_file, 'w') as f: json.dump(best_prediction, f) best_th_metrics = hotpot_eval(best_tmp_file, dev_gold_file) for key, val in best_th_metrics.items(): print("{} = {}".format(key, val)) # ------------------------------------- return best_metrics, best_threshold, doc_recall_metric
def jd_unified_eval_model(args, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, dev_gold_file=None): model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} # dataloader.refresh() thresholds = np.arange(0.1, 1.0, 0.05) N_thresh = len(thresholds) total_sp_dict = [{} for _ in range(N_thresh)] for batch in tqdm(dataloader): #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for key, value in batch.items(): if key not in {'ids'}: batch[key] = value.to(args.device) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ with torch.no_grad(): start, end, q_type, paras, sent, ent, yp1, yp2 = model( batch, return_yp=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens( example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy() for i in range(predict_support_np.shape[0]): cur_sp_pred = [[] for _ in range(N_thresh)] cur_id = batch['ids'][i] for j in range(predict_support_np.shape[1]): if j >= len(example_dict[cur_id].sent_names): break for thresh_i in range(N_thresh): if predict_support_np[i, j] > thresholds[thresh_i]: cur_sp_pred[thresh_i].append( example_dict[cur_id].sent_names[j]) for thresh_i in range(N_thresh): if cur_id not in total_sp_dict[thresh_i]: total_sp_dict[thresh_i][cur_id] = [] total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i]) def choose_best_threshold(ans_dict, pred_file): best_joint_f1 = 0 best_metrics = None best_threshold = 0 for thresh_i in range(N_thresh): prediction = { 'answer': ans_dict, 'sp': total_sp_dict[thresh_i], 'type': answer_type_dict, 'type_prob': answer_type_prob_dict } tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) metrics = hotpot_eval(tmp_file, dev_gold_file) if metrics['joint_f1'] >= best_joint_f1: best_joint_f1 = metrics['joint_f1'] best_threshold = thresholds[thresh_i] best_metrics = metrics shutil.move(tmp_file, pred_file) return best_metrics, best_threshold best_metrics, best_threshold = choose_best_threshold( answer_dict, prediction_file) json.dump(best_metrics, open(eval_file, 'w')) return best_metrics, best_threshold
def jd_at_eval_model(args, encoder, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, dev_gold_file): encoder.eval() model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} dataloader.refresh() ##++++++++++++++++++++++++++++++++++ total_sent_sp_dict = {} ##++++++++++++++++++++++++++++++++++ total_para_sp_dict = {} ##++++++++++++++++++++++++++++++++++ for batch in tqdm(dataloader): with torch.no_grad(): inputs = { 'input_ids': batch['context_idxs'], 'attention_mask': batch['context_mask'], 'token_type_ids': batch['segment_idxs'] if args.model_type in ['bert', 'xlnet'] else None } # XLM don't use segment_ids outputs = encoder(**inputs) batch['context_encoding'] = outputs[0] batch['context_mask'] = batch['context_mask'].float().to( args.device) start, end, q_type, paras, sent, ent, yp1, yp2 = model( batch, return_yp=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens( example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) ##++++++++++++++++++++++++++++++++++++++++ predict_para_support_logits = paras para_mask = batch['para_mask'] batch_size = para_mask.shape[0] query_para_mask = torch.cat( [torch.ones(batch_size, 1).to(para_mask), para_mask], dim=-1) para_pred_out = adaptive_threshold_prediction( logits=predict_para_support_logits, number_labels=2, mask=query_para_mask, type='topk') para_pred_out_np = para_pred_out.data.cpu().numpy() ##++++++++++++++++++++++++++++++++++++++++ predict_sent_support_logits = sent sent_mask = batch['sent_mask'] query_sent_mask = torch.cat( [torch.ones(batch_size, 1).to(sent_mask), sent_mask], dim=-1) sent_pred_out = adaptive_threshold_prediction( logits=predict_sent_support_logits, number_labels=2, mask=query_sent_mask, type='or') sent_pred_out_np = sent_pred_out.data.cpu().numpy() ##++++++++++++++++++++++++++++++++++++++++ for i in range(sent_pred_out_np.shape[0]): cur_id = batch['ids'][i] para_pred_out_np_ith = para_pred_out_np[i] sent_pred_out_np_ith = sent_pred_out_np[i] cur_para_sp_pred = supp_para_at_prediction( predict_para_support_np_ith=para_pred_out_np_ith, example_dict=example_dict, batch_ids_ith=cur_id) cur_sent_sp_pred = supp_sent_at_prediction( predict_sent_support_np_ith=sent_pred_out_np_ith, example_dict=example_dict, batch_ids_ith=cur_id) total_para_sp_dict[cur_id] = cur_para_sp_pred total_sent_sp_dict[cur_id] = cur_sent_sp_pred prediction = { 'answer': answer_dict, 'sp': total_sent_sp_dict, 'type': answer_type_dict, 'type_prob': answer_type_prob_dict } tmp_file = os.path.join(os.path.dirname(prediction_file), 'tmp.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) eval_metrics = hotpot_eval(tmp_file, dev_gold_file) doc_recall_metric = doc_recall_eval(doc_prediction=total_para_sp_dict, gold_file=dev_gold_file) total_inconsistent_number = supp_doc_sent_consistent_checker( predict_para_dict=total_para_sp_dict, predicted_supp_sent_dict=total_sent_sp_dict, gold_file=dev_gold_file) ##++++++++++++++++++++++++++++++++++++++++ json.dump(eval_metrics, open(eval_file, 'w')) return eval_metrics, doc_recall_metric, total_inconsistent_number
def lightnHGN_test_procedure(model, test_data_loader, dev_feature_dict, dev_example_dict, args, device): model.freeze() out_puts = [] start_time = time() total_steps = len(test_data_loader) with torch.no_grad(): for batch_idx, batch in tqdm(enumerate(test_data_loader)): batch = batch2device(batch=batch, device=device) start, end, q_type, paras, sents, ents, yp1, yp2 = model.forward(batch=batch) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(dev_example_dict, dev_feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) predict_support_np = torch.sigmoid(sents[:, :, 1]).data.cpu().numpy() valid_dict = {'answer': answer_dict_, 'ans_type': answer_type_dict_, 'ids': batch['ids'], 'ans_type_pro': answer_type_prob_dict_, 'supp_np': predict_support_np} # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if (batch_idx + 1) % args.eval_batch_size == 0: print('Evaluating the model... {}/{} in {:.4f} seconds'.format(batch_idx + 1, total_steps, time()-start_time)) out_puts.append(valid_dict) del batch # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} thresholds = np.arange(0.1, 1.0, 0.025) N_thresh = len(thresholds) total_sp_dict = [{} for _ in range(N_thresh)] # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for batch_idx, valid_dict in tqdm(enumerate(out_puts)): answer_dict_, answer_type_dict_, answer_type_prob_dict_ = valid_dict['answer'], valid_dict['ans_type'], \ valid_dict['ans_type_pro'] answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) predict_support_np = valid_dict['supp_np'] batch_ids = valid_dict['ids'] for i in range(predict_support_np.shape[0]): cur_sp_pred = [[] for _ in range(N_thresh)] cur_id = batch_ids[i] for j in range(predict_support_np.shape[1]): if j >= len(dev_example_dict[cur_id].sent_names): break for thresh_i in range(N_thresh): if predict_support_np[i, j] > thresholds[thresh_i]: cur_sp_pred[thresh_i].append(dev_example_dict[cur_id].sent_names[j]) for thresh_i in range(N_thresh): if cur_id not in total_sp_dict[thresh_i]: total_sp_dict[thresh_i][cur_id] = [] total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i]) def choose_best_threshold(ans_dict, pred_file): best_joint_f1 = 0 best_metrics = None best_threshold = 0 ################# metric_dict = {} ################# for thresh_i in range(N_thresh): prediction = {'answer': ans_dict, 'sp': total_sp_dict[thresh_i], 'type': answer_type_dict, 'type_prob': answer_type_prob_dict} tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) metrics = hotpot_eval(tmp_file, args.dev_gold_file) if metrics['joint_f1'] >= best_joint_f1: best_joint_f1 = metrics['joint_f1'] best_threshold = thresholds[thresh_i] best_metrics = metrics shutil.move(tmp_file, pred_file) ####### metric_dict[thresh_i] = ( metrics['em'], metrics['f1'], metrics['sp_em'], metrics['sp_f1'], metrics['joint_em'], metrics['joint_f1']) ####### return best_metrics, best_threshold, metric_dict output_pred_file = os.path.join(args.exp_name, f'pred.json') output_eval_file = os.path.join(args.exp_name, f'eval.txt') ####+++++ best_metrics, best_threshold, metric_dict = choose_best_threshold(answer_dict, output_pred_file) ####++++++ logging.info('Leader board evaluation completed with threshold = {:.4f}'.format(best_threshold)) log_metrics(mode='Evaluation', metrics=best_metrics) logging.info('*' * 75) ####++++++ for key, value in metric_dict.items(): str_value = ['{:.4f}'.format(_) for _ in value] logging.info('threshold {:.4f}: \t metrics: {}'.format(thresholds[key], str_value)) ####++++++ json.dump(best_metrics, open(output_eval_file, 'w')) ############################################################################# return best_metrics, best_threshold
def jd_postprecess_unified_test_model(args, model, dataloader, example_dict, feature_dict, threshold=0.35, output_score_file=None): model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} sp_dict = {} ##++++++ prediction_res_score_dict = {} ##++++++ for batch in tqdm(dataloader): #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for key, value in batch.items(): if key not in {'ids'}: batch[key] = value.to(args.device) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ with torch.no_grad(): start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model(batch, return_yp=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ _, _, answer_sent_name_dict_ = convert_answer_to_sent_names(example_dict, feature_dict, batch, yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy() predict_support_para_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy() # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ predict_support_logit_np = sent[:, :, 1].data.cpu().numpy() support_sent_mask_np = batch['sent_mask'].data.cpu().numpy() predict_support_para_logit_np = paras[:, :, 1].data.cpu().numpy() support_para_mask_np = batch['para_mask'].data.cpu().numpy() cls_emb_np = cls_emb.data.cpu().numpy() # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for i in range(predict_support_np.shape[0]): cur_id = batch['ids'][i] cur_sp_pred = [] # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_pred_ = {'sp_score': predict_support_logit_np[i].tolist(), 'sp_mask': support_sent_mask_np[i].tolist(), 'sp_names': example_dict[cur_id].sent_names} para_pred_ = {'sp_para_score': predict_support_para_logit_np[i].tolist(), 'sp_para_mask': support_para_mask_np[i].tolist(), 'sp_para_names': example_dict[cur_id].para_names} cls_emb_ = {'cls_emb': cls_emb_np[i].tolist()} res_score = {**sent_pred_, **cls_emb_, **para_pred_} prediction_res_score_dict[cur_id] = res_score # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ topk_score_ref, cut_sent_flag, topk_pred_sent_names, diff_para_sent_names, topk_pred_paras = \ post_process_sent_para(cur_id=cur_id, example_dict=example_dict, feature_dict=feature_dict, sent_scores_np_i=predict_support_np[i], sent_mask_np_i=support_sent_mask_np[i], para_scores_np_i=predict_support_para_np[i], para_mask_np_i=support_para_mask_np[i]) ans_sent_name = answer_sent_name_dict_[cur_id] for j in range(predict_support_np.shape[1]): if j >= len(example_dict[cur_id].sent_names): break if predict_support_np[i, j] > topk_score_ref * threshold: cur_sp_pred.append(example_dict[cur_id].sent_names[j]) post_process_sp_pred = post_process_technique(cur_sp_pred=cur_sp_pred, topk_pred_paras=topk_pred_paras, topk_pred_sent_names=topk_pred_sent_names, diff_para_sent_names=diff_para_sent_names, ans_sent_name=ans_sent_name) sp_dict[cur_id] = post_process_sp_pred prediction = {'answer': answer_dict, 'sp': sp_dict, 'type': answer_type_dict, 'type_prob': answer_type_prob_dict} if output_score_file is not None: with open(output_score_file, 'w') as f: json.dump(prediction_res_score_dict, f) print('Saving {} score records into {}'.format(len(prediction_res_score_dict), output_score_file)) return prediction
def jd_postprocess_unified_eval_model(args, model, dataloader, example_dict, feature_dict, prediction_file, eval_file, dev_gold_file=None): model.eval() answer_dict = {} answer_type_dict = {} answer_type_prob_dict = {} thresholds = np.arange(0.1, 1.0, 0.025) N_thresh = len(thresholds) total_sp_dict = [{} for _ in range(N_thresh)] for batch in tqdm(dataloader): # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for key, value in batch.items(): if key not in {'ids'}: batch[key] = value.to(args.device) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ with torch.no_grad(): start, end, q_type, paras, sent, ent, yp1, yp2, cls_emb = model(batch, return_yp=True) type_prob = F.softmax(q_type, dim=1).data.cpu().numpy() answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(example_dict, feature_dict, batch['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ _, _, answer_sent_name_dict_ = convert_answer_to_sent_names(example_dict, feature_dict, batch, yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), type_prob) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ answer_type_dict.update(answer_type_dict_) answer_type_prob_dict.update(answer_type_prob_dict_) answer_dict.update(answer_dict_) predict_support_np = torch.sigmoid(sent[:, :, 1]).data.cpu().numpy() #################################################################### support_sent_mask_np = batch['sent_mask'].data.cpu().numpy() predict_support_para_np = torch.sigmoid(paras[:, :, 1]).data.cpu().numpy() support_para_mask_np = batch['para_mask'].data.cpu().numpy() #################################################################### for i in range(predict_support_np.shape[0]): cur_sp_pred = [[] for _ in range(N_thresh)] cur_id = batch['ids'][i] ##+++++++++++++++++++++++++ topk_score_ref, cut_sent_flag, topk_pred_sent_names, diff_para_sent_names, topk_pred_paras = \ post_process_sent_para(cur_id=cur_id, example_dict=example_dict, feature_dict=feature_dict, sent_scores_np_i=predict_support_np[i], sent_mask_np_i=support_sent_mask_np[i], para_scores_np_i=predict_support_para_np[i], para_mask_np_i=support_para_mask_np[i]) ans_sent_name = answer_sent_name_dict_[cur_id] ##+++++++++++++++++++++++++ for j in range(predict_support_np.shape[1]): if j >= len(example_dict[cur_id].sent_names): break for thresh_i in range(N_thresh): if predict_support_np[i, j] > thresholds[thresh_i] * topk_score_ref: cur_sp_pred[thresh_i].append(example_dict[cur_id].sent_names[j]) for thresh_i in range(N_thresh): if cur_id not in total_sp_dict[thresh_i]: total_sp_dict[thresh_i][cur_id] = [] ##+++++ post_process_thresh_i_sp_pred = post_process_technique(cur_sp_pred=cur_sp_pred[thresh_i], topk_pred_paras=topk_pred_paras, topk_pred_sent_names=topk_pred_sent_names, diff_para_sent_names=diff_para_sent_names, ans_sent_name=ans_sent_name) total_sp_dict[thresh_i][cur_id].extend(post_process_thresh_i_sp_pred) # # +++++++++++++++++++++++++++ def choose_best_threshold(ans_dict, pred_file): best_joint_f1 = 0 best_metrics = None best_threshold = 0 for thresh_i in range(N_thresh): prediction = {'answer': ans_dict, 'sp': total_sp_dict[thresh_i], 'type': answer_type_dict, 'type_prob': answer_type_prob_dict} tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json') with open(tmp_file, 'w') as f: json.dump(prediction, f) metrics = hotpot_eval(tmp_file, dev_gold_file) if metrics['joint_f1'] >= best_joint_f1: best_joint_f1 = metrics['joint_f1'] best_threshold = thresholds[thresh_i] best_metrics = metrics shutil.move(tmp_file, pred_file) return best_metrics, best_threshold best_metrics, best_threshold = choose_best_threshold(answer_dict, prediction_file) json.dump(best_metrics, open(eval_file, 'w')) return best_metrics, best_threshold