예제 #1
0
파일: eval.py 프로젝트: siat-nlp/TTOS
def eval_entity_f1_camrest(eval_fp, entity_fp):
    test_data = []
    with open(eval_fp, 'r') as fr:
        for line in fr:
            dialog = json.loads(line.strip())
            test_data.append(dialog)

    print("test data: ", len(test_data))

    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]]
    global_entity_list = list(set(global_entity_list))

    F1_pred, F1_count = 0, 0
    for dialog in test_data:
        pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]

        gold_ents = dialog["gold_entity"]
        if len(gold_ents) > 0:
            gold_ents = ' '.join(gold_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys)
        F1_pred += single_f1
        F1_count += count

    F1_score = F1_pred / float(F1_count)

    return F1_score
예제 #2
0
def eval_entity_f1_camrest_online(hyps, tasks, gold_entity, kb_word):

    entity_fp = "./data/CamRest/camrest676-entities.json"
    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            global_entity_list += [
                item.lower().replace(' ', '_') for item in global_entity[key]
            ]
    global_entity_list = list(set(global_entity_list))

    test_data = []
    for hyp, tk, ent, kb in zip(hyps, tasks, gold_entity, kb_word):
        dialog = {"result": hyp, "task": tk, "gold_entity": ent, "kb": kb}
        test_data.append(dialog)

    F1_pred, F1_count = 0, 0
    for dialog in test_data:
        # TODO: change the result
        pred_tokens = dialog["result"].split()
        # pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]

        gold_ents = dialog["gold_entity"]
        # if len(gold_ents) > 0:
        #     gold_ents = ' '.join(gold_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_ents, pred_tokens,
                                       global_entity_list, kb_arrys)
        F1_pred += single_f1
        F1_count += count

    F1_score = F1_pred / float(F1_count)

    return F1_score
예제 #3
0
파일: eval.py 프로젝트: siat-nlp/DDMN
def eval_entity_f1_camrest(eval_fp, entity_fp, average="micro"):
    test_data = []
    with open(eval_fp, 'r') as fr:
        for line in fr:
            dialog = json.loads(line.strip())
            if len(dialog["gold_entity"]) > 0:
                dialog["gold_entity"] = ' '.join(
                    dialog["gold_entity"]).replace('_', ' ').split()
            test_data.append(dialog)

    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            global_entity_list += [
                item.lower().replace(' ', '_') for item in global_entity[key]
            ]
    global_entity_list = list(set(global_entity_list))

    F1_pred, F1_count = 0, 0
    TP_all, FP_all, FN_all = 0, 0, 0
    for dialog in test_data:
        pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]
        gold_ents = dialog["gold_entity"]
        tp, fp, fn, f1, count = compute_prf(gold_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        F1_pred += f1
        TP_all += tp
        FP_all += fp
        FN_all += fn
        F1_count += count
    if average == "micro":
        F1_score = compute_f1(TP_all, FP_all, FN_all)
    else:
        F1_score = F1_pred / float(F1_count)

    return F1_score
예제 #4
0
파일: eval.py 프로젝트: siat-nlp/TTOS
def eval_entity_f1_kvr_online(hyps, tasks, gold_entity, kb_word):

    entity_fp = "./data/KVR/kvret_entities.json"

    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            if key != 'poi':
                global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]]
            else:
                for item in global_entity['poi']:
                    global_entity_list += [item[k].lower().replace(' ', '_') for k in item.keys()]
    global_entity_list = list(set(global_entity_list))

    test_data = []
    for hyp, tk, ent, kb in zip(hyps, tasks, gold_entity, kb_word):
        dialog = {"result": hyp,
                  "task": tk,
                  "gold_entity": ent,
                  "kb": kb}
        test_data.append(dialog)

    F1_pred, F1_count = 0, 0
    for dialog in test_data:
        pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]

        gold_ents = dialog["gold_entity"]
        if len(gold_ents) > 0:
            gold_ents = ' '.join(gold_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_ents, pred_tokens, global_entity_list, kb_arrys)
        F1_pred += single_f1
        F1_count += count

    F1_score = F1_pred / float(F1_count)

    return F1_score
예제 #5
0
def eval_entity_f1_kvr(eval_fp, entity_fp):
    test_data = []
    with open(eval_fp, 'r') as fr:
        for line in fr:
            ent_idx_sch, ent_idx_wet, ent_idx_nav = [], [], []
            dialog = json.loads(line.strip())
            if dialog["task"] == "schedule":
                ent_idx_sch = dialog["gold_entity"]
            elif dialog["task"] == "weather":
                ent_idx_wet = dialog["gold_entity"]
            elif dialog["task"] == "navigate":
                ent_idx_nav = dialog["gold_entity"]
            ent_index = list(set(ent_idx_sch + ent_idx_wet + ent_idx_nav))
            dialog["ent_index"] = ent_index
            dialog["ent_idx_sch"] = list(set(ent_idx_sch))
            dialog["ent_idx_wet"] = list(set(ent_idx_wet))
            dialog["ent_idx_nav"] = list(set(ent_idx_nav))
            test_data.append(dialog)

    print("test data: ", len(test_data))

    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            if key != 'poi':
                global_entity_list += [
                    item.lower().replace(' ', '_')
                    for item in global_entity[key]
                ]
            else:
                for item in global_entity['poi']:
                    global_entity_list += [
                        item[k].lower().replace(' ', '_') for k in item.keys()
                    ]
    global_entity_list = list(set(global_entity_list))

    F1_pred, F1_sch_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
    F1_count, F1_sch_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0
    for dialog in test_data:
        # TODO: change the result
        pred_tokens = dialog["result"].split()
        # pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]

        gold_ents = dialog["ent_index"]
        # if len(gold_ents) > 0:
        #     gold_ents = ' '.join(gold_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_ents, pred_tokens,
                                       global_entity_list, kb_arrys)
        F1_pred += single_f1
        F1_count += count

        gold_sch_ents = dialog["ent_idx_sch"]
        # if len(gold_sch_ents) > 0:
        #     gold_sch_ents = ' '.join(gold_sch_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_sch_ents, pred_tokens,
                                       global_entity_list, kb_arrys)
        F1_sch_pred += single_f1
        F1_sch_count += count

        gold_wet_ents = dialog["ent_idx_wet"]
        # if len(gold_wet_ents) > 0:
        #     gold_wet_ents = ' '.join(gold_wet_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_wet_ents, pred_tokens,
                                       global_entity_list, kb_arrys)
        F1_wet_pred += single_f1
        F1_wet_count += count

        gold_nav_ents = dialog["ent_idx_nav"]
        # if len(gold_nav_ents) > 0:
        #     gold_nav_ents = ' '.join(gold_nav_ents).replace('_', ' ').split()
        single_f1, count = compute_prf(gold_nav_ents, pred_tokens, [],
                                       kb_arrys)
        F1_nav_pred += single_f1
        F1_nav_count += count

    F1_score = F1_pred / float(F1_count)
    F1_sch_score = F1_sch_pred / float(F1_sch_count)
    F1_wet_score = F1_wet_pred / float(F1_wet_count)
    F1_nav_score = F1_nav_pred / float(F1_nav_count)
    return F1_score, F1_sch_score, F1_wet_score, F1_nav_score
예제 #6
0
파일: rewards.py 프로젝트: leehamw/CRMN
def reward_fn1(self, preds, targets, gold_ents, ptr_index, task_label):
    """
    reward_fn1
    General reward
    """
    # parameters
    alpha1 = 1.0
    alpha2 = 0.3

    # acc reward
    '''
    # get the weighted mask
    no_padding_mask = preds.ne(self.padding_idx).float()
    trues = (preds == targets).float()
    if self.padding_idx is not None:
        weights = no_padding_mask
        acc = (weights * trues).sum(dim=1) / weights.sum(dim=1)
    else:
        acc = trues.mean(dim=1)
    '''

    pred_text = self.tgt_field.denumericalize(preds)
    tgt_text = self.tgt_field.denumericalize(targets)
    batch_size = targets.size(0)
    batch_kb_inputs = self.kbs[:batch_size, :, :]
    kb_plain = self.kb_field.denumericalize(batch_kb_inputs)

    result = Pack()
    result.add(pred_text=pred_text, tgt_text=tgt_text, gold_ents=gold_ents, kb_plain=kb_plain)
    result_list = result.flatten()

    # bleu reward
    bleu_score = []
    for res in result_list:
        hyp_toks = res.pred_text.split()
        ref_toks = res.tgt_text.split()
        try:
            bleu_1 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks,
                                   smoothing_function=SmoothingFunction().method7,
                                   weights=[1, 0, 0, 0])
        except:
            bleu_1 = 0
        try:
            bleu_2 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks,
                                   smoothing_function=SmoothingFunction().method7,
                                   weights=[0.5, 0.5, 0, 0])
        except:
            bleu_2 = 0
        bleu = (bleu_1 + bleu_2) / 2
        bleu_score.append(bleu)
    bleu_score = torch.tensor(bleu_score, dtype=torch.float)

    # entity f1 reward
    f1_score = []
    report_f1 = []
    for res in result_list:
        if len(res.gold_ents) == 0:
            f1_pred = 1.0
        else:
            # TODO: change the way
            #gold_entity = ' '.join(res.gold_ents).replace('_', ' ').split()
            #pred_sent = res.pred_text.replace('_', ' ')
            gold_entity = res.gold_ents
            pred_sent = res.pred_text
            f1_pred, _ = compute_prf(gold_entity, pred_sent,
                                     global_entity_list=[], kb_plain=res.kb_plain)
            report_f1.append(f1_pred)
        f1_score.append(f1_pred)
    if len(report_f1) == 0:
        report_f1.append(0.0)
    f1_score = torch.tensor(f1_score, dtype=torch.float)
    report_f1 = torch.tensor(report_f1, dtype=torch.float)

    if self.use_gpu:
        bleu_score = bleu_score.cuda()
        f1_score = f1_score.cuda()
        report_f1 = report_f1.cuda()

    # compound reward
    #reward = alpha1 * bleu_score.unsqueeze(-1) + alpha2 * f1_score.unsqueeze(-1)
    reward = alpha1 * bleu_score.unsqueeze(-1)

    return reward, bleu_score, report_f1
예제 #7
0
파일: eval.py 프로젝트: siat-nlp/DDMN
def eval_entity_f1_kvr(eval_fp, entity_fp, average="micro"):
    test_data = []
    with open(eval_fp, 'r') as fr:
        for line in fr:
            ent_idx_sch, ent_idx_wet, ent_idx_nav = [], [], []
            dialog = json.loads(line.strip())
            if len(dialog["gold_entity"]) > 0:
                dialog["gold_entity"] = ' '.join(
                    dialog["gold_entity"]).replace('_', ' ').split()
            if dialog["task"] == "schedule":
                ent_idx_sch = dialog["gold_entity"]
            elif dialog["task"] == "weather":
                ent_idx_wet = dialog["gold_entity"]
            elif dialog["task"] == "navigate":
                ent_idx_nav = dialog["gold_entity"]
            ent_index = list(set(ent_idx_sch + ent_idx_wet + ent_idx_nav))
            dialog["ent_index"] = ent_index
            dialog["ent_idx_sch"] = list(set(ent_idx_sch))
            dialog["ent_idx_wet"] = list(set(ent_idx_wet))
            dialog["ent_idx_nav"] = list(set(ent_idx_nav))
            test_data.append(dialog)

    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            if key != 'poi':
                global_entity_list += [
                    item.lower().replace(' ', '_')
                    for item in global_entity[key]
                ]
            else:
                for item in global_entity['poi']:
                    global_entity_list += [
                        item[k].lower().replace(' ', '_') for k in item.keys()
                    ]
    global_entity_list = list(set(global_entity_list))

    F1_pred, F1_sch_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
    F1_count, F1_sch_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0
    TP_all, FP_all, FN_all = 0, 0, 0
    TP_sch, FP_sch, FN_sch = 0, 0, 0
    TP_wet, FP_wet, FN_wet = 0, 0, 0
    TP_nav, FP_nav, FN_nav = 0, 0, 0

    for dialog in test_data:
        pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]

        gold_ents = dialog["ent_index"]
        tp, fp, fn, f1, count = compute_prf(gold_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_all += tp
        FP_all += fp
        FN_all += fn
        F1_pred += f1
        F1_count += count

        gold_sch_ents = dialog["ent_idx_sch"]
        tp, fp, fn, f1, count = compute_prf(gold_sch_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_sch += tp
        FP_sch += fp
        FN_sch += fn
        F1_sch_pred += f1
        F1_sch_count += count

        gold_wet_ents = dialog["ent_idx_wet"]
        tp, fp, fn, f1, count = compute_prf(gold_wet_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_wet += tp
        FP_wet += fp
        FN_wet += fn
        F1_wet_pred += f1
        F1_wet_count += count

        gold_nav_ents = dialog["ent_idx_nav"]
        tp, fp, fn, f1, count = compute_prf(gold_nav_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_nav += tp
        FP_nav += fp
        FN_nav += fn
        F1_nav_pred += f1
        F1_nav_count += count

    if average == "micro":
        F1_score = compute_f1(TP_all, FP_all, FN_all)
        F1_sch_score = compute_f1(TP_sch, FP_sch, FN_sch)
        F1_wet_score = compute_f1(TP_wet, FP_wet, FN_wet)
        F1_nav_score = compute_f1(TP_nav, FP_nav, FN_nav)
    else:
        F1_score = F1_pred / float(F1_count)
        F1_sch_score = F1_sch_pred / float(F1_sch_count)
        F1_wet_score = F1_wet_pred / float(F1_wet_count)
        F1_nav_score = F1_nav_pred / float(F1_nav_count)

    return F1_score, F1_sch_score, F1_wet_score, F1_nav_score
예제 #8
0
파일: eval.py 프로젝트: siat-nlp/DDMN
def eval_entity_f1_multiwoz(eval_fp, entity_fp, average="micro"):
    test_data = []
    with open(eval_fp, 'r') as fr:
        for line in fr:
            ent_idx_res, ent_idx_att, ent_idx_hotel = [], [], []
            dialog = json.loads(line.strip())
            if len(dialog["gold_entity"]) > 0:
                dialog["gold_entity"] = ' '.join(
                    dialog["gold_entity"]).replace('_', ' ').split()
            if dialog["task"] == "restaurant":
                ent_idx_res = dialog["gold_entity"]
            elif dialog["task"] == "attraction":
                ent_idx_att = dialog["gold_entity"]
            elif dialog["task"] == "hotel":
                ent_idx_hotel = dialog["gold_entity"]
            ent_index = list(set(ent_idx_res + ent_idx_att + ent_idx_hotel))
            dialog["ent_index"] = ent_index
            dialog["ent_idx_res"] = list(set(ent_idx_res))
            dialog["ent_idx_att"] = list(set(ent_idx_att))
            dialog["ent_idx_hotel"] = list(set(ent_idx_hotel))
            test_data.append(dialog)

    with open(entity_fp, 'r') as fr:
        global_entity = json.load(fr)
        global_entity_list = []
        for key in global_entity.keys():
            global_entity_list += [
                item.lower().replace(' ', '_') for item in global_entity[key]
            ]
    global_entity_list = list(set(global_entity_list))

    F1_pred, F1_res_pred, F1_att_pred, F1_hotel_pred = 0, 0, 0, 0
    F1_count, F1_res_count, F1_att_count, F1_hotel_count = 0, 0, 0, 0
    TP_all, FP_all, FN_all = 0, 0, 0
    TP_res, FP_res, FN_res = 0, 0, 0
    TP_att, FP_att, FN_att = 0, 0, 0
    TP_hotel, FP_hotel, FN_hotel = 0, 0, 0

    for dialog in test_data:
        pred_tokens = dialog["result"].replace('_', ' ').split()
        kb_arrys = dialog["kb"]

        gold_ents = dialog["ent_index"]
        tp, fp, fn, f1, count = compute_prf(gold_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_all += tp
        FP_all += fp
        FN_all += fn
        F1_pred += f1
        F1_count += count

        gold_res_ents = dialog["ent_idx_res"]
        tp, fp, fn, f1, count = compute_prf(gold_res_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_res += tp
        FP_res += fp
        FN_res += fn
        F1_res_pred += f1
        F1_res_count += count

        gold_att_ents = dialog["ent_idx_att"]
        tp, fp, fn, f1, count = compute_prf(gold_att_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_att += tp
        FP_att += fp
        FN_att += fn
        F1_att_pred += f1
        F1_att_count += count

        gold_hotel_ents = dialog["ent_idx_hotel"]
        tp, fp, fn, f1, count = compute_prf(gold_hotel_ents, pred_tokens,
                                            global_entity_list, kb_arrys)
        TP_hotel += tp
        FP_hotel += fp
        FN_hotel += fn
        F1_hotel_pred += f1
        F1_hotel_count += count

    if average == "micro":
        F1_score = compute_f1(TP_all, FP_all, FN_all)
        F1_res_score = compute_f1(TP_res, FP_res, FN_res)
        F1_att_score = compute_f1(TP_att, FP_att, FN_att)
        F1_hotel_score = compute_f1(TP_hotel, FP_hotel, FN_hotel)
    else:
        F1_score = F1_pred / float(F1_count)
        F1_res_score = F1_res_pred / float(F1_res_count)
        F1_att_score = F1_att_pred / float(F1_att_count)
        F1_hotel_score = F1_hotel_pred / float(F1_hotel_count)

    return F1_score, F1_res_score, F1_att_score, F1_hotel_score