def eval_entity_f1_camrest(eval_fp, entity_fp): test_data = [] with open(eval_fp, 'r') as fr: for line in fr: dialog = json.loads(line.strip()) test_data.append(dialog) print("test data: ", len(test_data)) with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]] global_entity_list = list(set(global_entity_list)) F1_pred, F1_count = 0, 0 for dialog in test_data: pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["gold_entity"] if len(gold_ents) > 0: gold_ents = ' '.join(gold_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) F1_pred += single_f1 F1_count += count F1_score = F1_pred / float(F1_count) return F1_score
def eval_entity_f1_camrest_online(hyps, tasks, gold_entity, kb_word): entity_fp = "./data/CamRest/camrest676-entities.json" with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): global_entity_list += [ item.lower().replace(' ', '_') for item in global_entity[key] ] global_entity_list = list(set(global_entity_list)) test_data = [] for hyp, tk, ent, kb in zip(hyps, tasks, gold_entity, kb_word): dialog = {"result": hyp, "task": tk, "gold_entity": ent, "kb": kb} test_data.append(dialog) F1_pred, F1_count = 0, 0 for dialog in test_data: # TODO: change the result pred_tokens = dialog["result"].split() # pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["gold_entity"] # if len(gold_ents) > 0: # gold_ents = ' '.join(gold_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) F1_pred += single_f1 F1_count += count F1_score = F1_pred / float(F1_count) return F1_score
def eval_entity_f1_camrest(eval_fp, entity_fp, average="micro"): test_data = [] with open(eval_fp, 'r') as fr: for line in fr: dialog = json.loads(line.strip()) if len(dialog["gold_entity"]) > 0: dialog["gold_entity"] = ' '.join( dialog["gold_entity"]).replace('_', ' ').split() test_data.append(dialog) with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): global_entity_list += [ item.lower().replace(' ', '_') for item in global_entity[key] ] global_entity_list = list(set(global_entity_list)) F1_pred, F1_count = 0, 0 TP_all, FP_all, FN_all = 0, 0, 0 for dialog in test_data: pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["gold_entity"] tp, fp, fn, f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) F1_pred += f1 TP_all += tp FP_all += fp FN_all += fn F1_count += count if average == "micro": F1_score = compute_f1(TP_all, FP_all, FN_all) else: F1_score = F1_pred / float(F1_count) return F1_score
def eval_entity_f1_kvr_online(hyps, tasks, gold_entity, kb_word): entity_fp = "./data/KVR/kvret_entities.json" with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): if key != 'poi': global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]] else: for item in global_entity['poi']: global_entity_list += [item[k].lower().replace(' ', '_') for k in item.keys()] global_entity_list = list(set(global_entity_list)) test_data = [] for hyp, tk, ent, kb in zip(hyps, tasks, gold_entity, kb_word): dialog = {"result": hyp, "task": tk, "gold_entity": ent, "kb": kb} test_data.append(dialog) F1_pred, F1_count = 0, 0 for dialog in test_data: pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["gold_entity"] if len(gold_ents) > 0: gold_ents = ' '.join(gold_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) F1_pred += single_f1 F1_count += count F1_score = F1_pred / float(F1_count) return F1_score
def eval_entity_f1_kvr(eval_fp, entity_fp): test_data = [] with open(eval_fp, 'r') as fr: for line in fr: ent_idx_sch, ent_idx_wet, ent_idx_nav = [], [], [] dialog = json.loads(line.strip()) if dialog["task"] == "schedule": ent_idx_sch = dialog["gold_entity"] elif dialog["task"] == "weather": ent_idx_wet = dialog["gold_entity"] elif dialog["task"] == "navigate": ent_idx_nav = dialog["gold_entity"] ent_index = list(set(ent_idx_sch + ent_idx_wet + ent_idx_nav)) dialog["ent_index"] = ent_index dialog["ent_idx_sch"] = list(set(ent_idx_sch)) dialog["ent_idx_wet"] = list(set(ent_idx_wet)) dialog["ent_idx_nav"] = list(set(ent_idx_nav)) test_data.append(dialog) print("test data: ", len(test_data)) with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): if key != 'poi': global_entity_list += [ item.lower().replace(' ', '_') for item in global_entity[key] ] else: for item in global_entity['poi']: global_entity_list += [ item[k].lower().replace(' ', '_') for k in item.keys() ] global_entity_list = list(set(global_entity_list)) F1_pred, F1_sch_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0 F1_count, F1_sch_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0 for dialog in test_data: # TODO: change the result pred_tokens = dialog["result"].split() # pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["ent_index"] # if len(gold_ents) > 0: # gold_ents = ' '.join(gold_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) F1_pred += single_f1 F1_count += count gold_sch_ents = dialog["ent_idx_sch"] # if len(gold_sch_ents) > 0: # gold_sch_ents = ' '.join(gold_sch_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_sch_ents, pred_tokens, global_entity_list, kb_arrys) F1_sch_pred += single_f1 F1_sch_count += count gold_wet_ents = dialog["ent_idx_wet"] # if len(gold_wet_ents) > 0: # gold_wet_ents = ' '.join(gold_wet_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_wet_ents, pred_tokens, global_entity_list, kb_arrys) F1_wet_pred += single_f1 F1_wet_count += count gold_nav_ents = dialog["ent_idx_nav"] # if len(gold_nav_ents) > 0: # gold_nav_ents = ' '.join(gold_nav_ents).replace('_', ' ').split() single_f1, count = compute_prf(gold_nav_ents, pred_tokens, [], kb_arrys) F1_nav_pred += single_f1 F1_nav_count += count F1_score = F1_pred / float(F1_count) F1_sch_score = F1_sch_pred / float(F1_sch_count) F1_wet_score = F1_wet_pred / float(F1_wet_count) F1_nav_score = F1_nav_pred / float(F1_nav_count) return F1_score, F1_sch_score, F1_wet_score, F1_nav_score
def reward_fn1(self, preds, targets, gold_ents, ptr_index, task_label): """ reward_fn1 General reward """ # parameters alpha1 = 1.0 alpha2 = 0.3 # acc reward ''' # get the weighted mask no_padding_mask = preds.ne(self.padding_idx).float() trues = (preds == targets).float() if self.padding_idx is not None: weights = no_padding_mask acc = (weights * trues).sum(dim=1) / weights.sum(dim=1) else: acc = trues.mean(dim=1) ''' pred_text = self.tgt_field.denumericalize(preds) tgt_text = self.tgt_field.denumericalize(targets) batch_size = targets.size(0) batch_kb_inputs = self.kbs[:batch_size, :, :] kb_plain = self.kb_field.denumericalize(batch_kb_inputs) result = Pack() result.add(pred_text=pred_text, tgt_text=tgt_text, gold_ents=gold_ents, kb_plain=kb_plain) result_list = result.flatten() # bleu reward bleu_score = [] for res in result_list: hyp_toks = res.pred_text.split() ref_toks = res.tgt_text.split() try: bleu_1 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks, smoothing_function=SmoothingFunction().method7, weights=[1, 0, 0, 0]) except: bleu_1 = 0 try: bleu_2 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks, smoothing_function=SmoothingFunction().method7, weights=[0.5, 0.5, 0, 0]) except: bleu_2 = 0 bleu = (bleu_1 + bleu_2) / 2 bleu_score.append(bleu) bleu_score = torch.tensor(bleu_score, dtype=torch.float) # entity f1 reward f1_score = [] report_f1 = [] for res in result_list: if len(res.gold_ents) == 0: f1_pred = 1.0 else: # TODO: change the way #gold_entity = ' '.join(res.gold_ents).replace('_', ' ').split() #pred_sent = res.pred_text.replace('_', ' ') gold_entity = res.gold_ents pred_sent = res.pred_text f1_pred, _ = compute_prf(gold_entity, pred_sent, global_entity_list=[], kb_plain=res.kb_plain) report_f1.append(f1_pred) f1_score.append(f1_pred) if len(report_f1) == 0: report_f1.append(0.0) f1_score = torch.tensor(f1_score, dtype=torch.float) report_f1 = torch.tensor(report_f1, dtype=torch.float) if self.use_gpu: bleu_score = bleu_score.cuda() f1_score = f1_score.cuda() report_f1 = report_f1.cuda() # compound reward #reward = alpha1 * bleu_score.unsqueeze(-1) + alpha2 * f1_score.unsqueeze(-1) reward = alpha1 * bleu_score.unsqueeze(-1) return reward, bleu_score, report_f1
def eval_entity_f1_kvr(eval_fp, entity_fp, average="micro"): test_data = [] with open(eval_fp, 'r') as fr: for line in fr: ent_idx_sch, ent_idx_wet, ent_idx_nav = [], [], [] dialog = json.loads(line.strip()) if len(dialog["gold_entity"]) > 0: dialog["gold_entity"] = ' '.join( dialog["gold_entity"]).replace('_', ' ').split() if dialog["task"] == "schedule": ent_idx_sch = dialog["gold_entity"] elif dialog["task"] == "weather": ent_idx_wet = dialog["gold_entity"] elif dialog["task"] == "navigate": ent_idx_nav = dialog["gold_entity"] ent_index = list(set(ent_idx_sch + ent_idx_wet + ent_idx_nav)) dialog["ent_index"] = ent_index dialog["ent_idx_sch"] = list(set(ent_idx_sch)) dialog["ent_idx_wet"] = list(set(ent_idx_wet)) dialog["ent_idx_nav"] = list(set(ent_idx_nav)) test_data.append(dialog) with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): if key != 'poi': global_entity_list += [ item.lower().replace(' ', '_') for item in global_entity[key] ] else: for item in global_entity['poi']: global_entity_list += [ item[k].lower().replace(' ', '_') for k in item.keys() ] global_entity_list = list(set(global_entity_list)) F1_pred, F1_sch_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0 F1_count, F1_sch_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0 TP_all, FP_all, FN_all = 0, 0, 0 TP_sch, FP_sch, FN_sch = 0, 0, 0 TP_wet, FP_wet, FN_wet = 0, 0, 0 TP_nav, FP_nav, FN_nav = 0, 0, 0 for dialog in test_data: pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["ent_index"] tp, fp, fn, f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) TP_all += tp FP_all += fp FN_all += fn F1_pred += f1 F1_count += count gold_sch_ents = dialog["ent_idx_sch"] tp, fp, fn, f1, count = compute_prf(gold_sch_ents, pred_tokens, global_entity_list, kb_arrys) TP_sch += tp FP_sch += fp FN_sch += fn F1_sch_pred += f1 F1_sch_count += count gold_wet_ents = dialog["ent_idx_wet"] tp, fp, fn, f1, count = compute_prf(gold_wet_ents, pred_tokens, global_entity_list, kb_arrys) TP_wet += tp FP_wet += fp FN_wet += fn F1_wet_pred += f1 F1_wet_count += count gold_nav_ents = dialog["ent_idx_nav"] tp, fp, fn, f1, count = compute_prf(gold_nav_ents, pred_tokens, global_entity_list, kb_arrys) TP_nav += tp FP_nav += fp FN_nav += fn F1_nav_pred += f1 F1_nav_count += count if average == "micro": F1_score = compute_f1(TP_all, FP_all, FN_all) F1_sch_score = compute_f1(TP_sch, FP_sch, FN_sch) F1_wet_score = compute_f1(TP_wet, FP_wet, FN_wet) F1_nav_score = compute_f1(TP_nav, FP_nav, FN_nav) else: F1_score = F1_pred / float(F1_count) F1_sch_score = F1_sch_pred / float(F1_sch_count) F1_wet_score = F1_wet_pred / float(F1_wet_count) F1_nav_score = F1_nav_pred / float(F1_nav_count) return F1_score, F1_sch_score, F1_wet_score, F1_nav_score
def eval_entity_f1_multiwoz(eval_fp, entity_fp, average="micro"): test_data = [] with open(eval_fp, 'r') as fr: for line in fr: ent_idx_res, ent_idx_att, ent_idx_hotel = [], [], [] dialog = json.loads(line.strip()) if len(dialog["gold_entity"]) > 0: dialog["gold_entity"] = ' '.join( dialog["gold_entity"]).replace('_', ' ').split() if dialog["task"] == "restaurant": ent_idx_res = dialog["gold_entity"] elif dialog["task"] == "attraction": ent_idx_att = dialog["gold_entity"] elif dialog["task"] == "hotel": ent_idx_hotel = dialog["gold_entity"] ent_index = list(set(ent_idx_res + ent_idx_att + ent_idx_hotel)) dialog["ent_index"] = ent_index dialog["ent_idx_res"] = list(set(ent_idx_res)) dialog["ent_idx_att"] = list(set(ent_idx_att)) dialog["ent_idx_hotel"] = list(set(ent_idx_hotel)) test_data.append(dialog) with open(entity_fp, 'r') as fr: global_entity = json.load(fr) global_entity_list = [] for key in global_entity.keys(): global_entity_list += [ item.lower().replace(' ', '_') for item in global_entity[key] ] global_entity_list = list(set(global_entity_list)) F1_pred, F1_res_pred, F1_att_pred, F1_hotel_pred = 0, 0, 0, 0 F1_count, F1_res_count, F1_att_count, F1_hotel_count = 0, 0, 0, 0 TP_all, FP_all, FN_all = 0, 0, 0 TP_res, FP_res, FN_res = 0, 0, 0 TP_att, FP_att, FN_att = 0, 0, 0 TP_hotel, FP_hotel, FN_hotel = 0, 0, 0 for dialog in test_data: pred_tokens = dialog["result"].replace('_', ' ').split() kb_arrys = dialog["kb"] gold_ents = dialog["ent_index"] tp, fp, fn, f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys) TP_all += tp FP_all += fp FN_all += fn F1_pred += f1 F1_count += count gold_res_ents = dialog["ent_idx_res"] tp, fp, fn, f1, count = compute_prf(gold_res_ents, pred_tokens, global_entity_list, kb_arrys) TP_res += tp FP_res += fp FN_res += fn F1_res_pred += f1 F1_res_count += count gold_att_ents = dialog["ent_idx_att"] tp, fp, fn, f1, count = compute_prf(gold_att_ents, pred_tokens, global_entity_list, kb_arrys) TP_att += tp FP_att += fp FN_att += fn F1_att_pred += f1 F1_att_count += count gold_hotel_ents = dialog["ent_idx_hotel"] tp, fp, fn, f1, count = compute_prf(gold_hotel_ents, pred_tokens, global_entity_list, kb_arrys) TP_hotel += tp FP_hotel += fp FN_hotel += fn F1_hotel_pred += f1 F1_hotel_count += count if average == "micro": F1_score = compute_f1(TP_all, FP_all, FN_all) F1_res_score = compute_f1(TP_res, FP_res, FN_res) F1_att_score = compute_f1(TP_att, FP_att, FN_att) F1_hotel_score = compute_f1(TP_hotel, FP_hotel, FN_hotel) else: F1_score = F1_pred / float(F1_count) F1_res_score = F1_res_pred / float(F1_res_count) F1_att_score = F1_att_pred / float(F1_att_count) F1_hotel_score = F1_hotel_pred / float(F1_hotel_count) return F1_score, F1_res_score, F1_att_score, F1_hotel_score