Пример #1
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1], data_dev[2])
            acc = 0
            w = 0
            temp_gen = []
            # print(words)
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]

                if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                    acc += 1
                # else:
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(acc_avg / float(len(dev)),
                                                            wer_avg / float(len(dev))))

        if (BLEU):
            bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
            logging.info("BLEU SCORE:" + str(bleu_score))

            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
Пример #2
0
    def evaluate(self, out_f, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        ref = []
        hyp = []
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            words = self.evaluate_batch(batch_size=len(data_dev[1]),
                                        input_batches=data_dev[0],
                                        input_lengths=data_dev[1],
                                        target_batches=data_dev[2],
                                        target_lengths=data_dev[3],
                                        target_index=data_dev[4],
                                        src_plain=data_dev[5])
            acc = 0
            w = 0
            print("data_dev", data_dev[1])
            temp_gen = []
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = " ".join(data_dev[6][i])
                ### IMPORTANT
                ### WE NEED TO COMPARE THE PLAIN STRING, BECAUSE WE COPY THE WORDS FROM THE INPUT
                ### ====>> the index in the output gold can be UNK
                if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                    acc += 1
                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))


#                print('correct',correct)
#               print('hyp',st)

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        out_f.writelines(str(correct) + "===" + str(hyp) + "\n")
        out_f.writelines("---------" + "\n")
        logging.info("BLEU SCORE:" + str(bleu_score))
        out_f.writelines("bleu_score" + str(bleu_score) + "\n")

        if (bleu_score >= avg_best):
            self.save_model(str(self.name) + str(bleu_score))
            logging.info("MODEL SAVED")
        return bleu_score
Пример #3
0
def eval_bleu(data):
    hyps = []
    refs = []
    for dialog in data:
        pred_str = dialog["result"]
        gold_str = dialog["target"]
        hyps.append(pred_str)
        refs.append(gold_str)
    assert len(hyps) == len(refs)
    hyp_arrys = np.array(hyps)
    ref_arrys = np.array(refs)

    bleu_score = moses_multi_bleu(hyp_arrys, ref_arrys, lowercase=True)
    return bleu_score
Пример #4
0
    def evaluate(self, dev, matric_best, early_stop=None):
        print("STARTING EVALUATION")
        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.extKnow.train(False)
        self.decoder.train(False)  
        
        ref, hyp = [], []
        acc, total = 0, 0
        dialog_acc_dict = {}
        F1_pred, F1_cal_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
        F1_count, F1_cal_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0
        pbar = tqdm(enumerate(dev),total=len(dev))
        new_precision, new_recall, new_f1_score = 0, 0, 0

        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [item[k].lower().replace(' ', '_') for k in item.keys()]
                global_entity_list = list(set(global_entity_list))

        for j, data_dev in pbar: 
            # Encode and Decode
            _, _, decoded_fine, decoded_coarse, global_pointer = self.encode_and_decode(data_dev, self.max_resp_len, False, True)
            decoded_coarse = np.transpose(decoded_coarse)
            decoded_fine = np.transpose(decoded_fine)
            for bi, row in enumerate(decoded_fine):
                st = ''
                for e in row:
                    if e == 'EOS': break
                    else: st += e + ' '
                st_c = ''
                for e in decoded_coarse[bi]:
                    if e == 'EOS': break
                    else: st_c += e + ' '
                pred_sent = st.lstrip().rstrip()
                pred_sent_coarse = st_c.lstrip().rstrip()
                gold_sent = data_dev['response_plain'][bi].lstrip().rstrip()
                ref.append(gold_sent)
                hyp.append(pred_sent)
                
                if args['dataset'] == 'kvr': 
                    # compute F1 SCORE
                    single_f1, count = self.compute_prf(data_dev['ent_index'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_cal'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_cal_pred += single_f1
                    F1_cal_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_nav'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_nav_pred += single_f1
                    F1_nav_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_wet'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_wet_pred += single_f1
                    F1_wet_count += count
                else:
                    # compute Dialogue Accuracy Score
                    current_id = data_dev['ID'][bi]
                    if current_id not in dialog_acc_dict.keys():
                        dialog_acc_dict[current_id] = []
                    if gold_sent == pred_sent:
                        dialog_acc_dict[current_id].append(1)
                    else:
                        dialog_acc_dict[current_id].append(0)

                # compute Per-response Accuracy Score
                total += 1
                if (gold_sent == pred_sent):
                    acc += 1

                if args['genSample']:
                    self.print_examples(bi, data_dev, pred_sent, pred_sent_coarse, gold_sent)

        # Set back to training mode
        self.encoder.train(True)
        self.extKnow.train(True)
        self.decoder.train(True)

        bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
        acc_score = acc / float(total)
        print("ACC SCORE:\t"+str(acc_score))

        if args['dataset'] == 'kvr':
            F1_score = F1_pred / float(F1_count)
            print("F1 SCORE:\t{}".format(F1_pred/float(F1_count)))
            print("\tCAL F1:\t{}".format(F1_cal_pred/float(F1_cal_count))) 
            print("\tWET F1:\t{}".format(F1_wet_pred/float(F1_wet_count))) 
            print("\tNAV F1:\t{}".format(F1_nav_pred/float(F1_nav_count))) 
            print("BLEU SCORE:\t"+str(bleu_score))
        else:
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k])==sum(dialog_acc_dict[k]):
                    dia_acc += 1
            print("Dialog Accuracy:\t"+str(dia_acc*1.0/len(dialog_acc_dict.keys())))
        
        if (early_stop == 'BLEU'):
            if (bleu_score >= matric_best):
                self.save_model('BLEU-'+str(bleu_score))
                print("MODEL SAVED")
            return bleu_score
        elif (early_stop == 'ENTF1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")  
            return F1_score
        else:
            if (acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(acc_score))
                print("MODEL SAVED")
            return acc_score
Пример #5
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                        data_dev[1], data_dev[2])
            acc = 0
            w = 0
            temp_gen = []
            #print(words)
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                if (len(data_dev) > 10):
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred

                if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                    acc += 1
                #else:
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))
        if (len(data_dev) > 10):
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))
            logging.info("F1 CAL:\t" + str(
                f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')))
            logging.info("F1 WET:\t" + str(
                f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')))
            logging.info("F1 NAV:\t" + str(
                f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')))

        if (BLEU):
            bleu_score = moses_multi_bleu(np.array(hyp),
                                          np.array(ref),
                                          lowercase=True)
            logging.info("BLEU SCORE:" + str(bleu_score))

            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
Пример #6
0
    def evaluate(self, dev, matric_best, output=False, early_stop=None):
        print("STARTING EVALUATION")
        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.extKnow.train(False)
        self.decoder.train(False)

        ref, hyp = [], []
        ids = []
        acc, total = 0, 0
        if args['dataset'] == 'kvr':
            F1_pred, F1_cal_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
            F1_count, F1_cal_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0
            TP_all, FP_all, FN_all = 0, 0, 0

            TP_sche, FP_sche, FN_sche = 0, 0, 0
            TP_wea, FP_wea, FN_wea = 0, 0, 0
            TP_nav, FP_nav, FN_nav = 0, 0, 0
        elif args['dataset'] == 'woz':
            F1_pred, F1_police_pred, F1_restaurant_pred, F1_hospital_pred, F1_attraction_pred, F1_hotel_pred = 0, 0, 0, 0, 0, 0
            F1_count, F1_police_count, F1_restaurant_count, F1_hospital_count, F1_attraction_count, F1_hotel_count = 0, 0, 0, 0, 0, 0
            TP_all, FP_all, FN_all = 0, 0, 0

            TP_restaurant, FP_restaurant, FN_restaurant = 0, 0, 0
            TP_attraction, FP_attraction, FN_attraction = 0, 0, 0
            TP_hotel, FP_hotel, FN_hotel = 0, 0, 0
        elif args['dataset'] == 'cam':
            F1_pred, F1_count = 0, 0

            TP_all, FP_all, FN_all = 0, 0, 0

        pbar = tqdm(enumerate(dev), total=len(dev))

        if args['dataset'] == 'kvr':
            entity_path = '../dataset/KVR/kvret_entities.json'
        elif args['dataset'] == 'woz':
            entity_path = '../dataset/MULTIWOZ2.1/global_entities.json'
        elif args['dataset'] == 'cam':
            entity_path = '../dataset/CamRest676/CamRest676_entities.json'

        with open(entity_path) as f:
            global_entity = json.load(f)
            global_entity_type = {}
            global_entity_list = []
            for key in global_entity.keys():
                if key != 'poi':
                    entity_arr = [
                        item.lower().replace(' ', '_')
                        for item in global_entity[key]
                    ]
                    global_entity_list += entity_arr
                    for entity in entity_arr:
                        global_entity_type[entity] = key
                else:
                    for item in global_entity['poi']:
                        entity_arr = [
                            item[k].lower().replace(' ', '_')
                            for k in item.keys()
                        ]
                        global_entity_list += entity_arr
                        for key in item:
                            global_entity_type[item[key].lower().replace(
                                ' ', '_')] = key
            global_entity_list = list(set(global_entity_list))

        for j, data_dev in pbar:
            ids.extend(data_dev['id'])
            # Encode and Decode
            _, _, decoded_fine, decoded_coarse, global_pointer, _, _, _, _ = self.encode_and_decode(
                data_dev, self.max_resp_len, False, True, global_entity_type)
            decoded_coarse = np.transpose(decoded_coarse)
            decoded_fine = np.transpose(decoded_fine)
            for bi, row in enumerate(decoded_fine):
                st = ''
                for e in row:
                    if e == 'EOS':
                        break
                    else:
                        st += e + ' '
                st_c = ''
                for e in decoded_coarse[bi]:
                    if e == 'EOS':
                        break
                    else:
                        st_c += e + ' '
                pred_sent = st.lstrip().rstrip()
                pred_sent_coarse = st_c.lstrip().rstrip()
                gold_sent = data_dev['response_plain'][bi].lstrip().rstrip()
                ref.append(gold_sent)
                hyp.append(pred_sent)

                if args['dataset'] == 'kvr':
                    # compute F1 SCORE
                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_index'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    TP_all += single_tp
                    FP_all += single_fp
                    FN_all += single_fn

                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_idx_cal'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_cal_pred += single_f1
                    F1_cal_count += count
                    TP_sche += single_tp
                    FP_sche += single_fp
                    FN_sche += single_fn

                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_idx_nav'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_nav_pred += single_f1
                    F1_nav_count += count
                    TP_nav += single_tp
                    FP_nav += single_fp
                    FN_nav += single_fn

                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_idx_wet'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_wet_pred += single_f1
                    F1_wet_count += count
                    TP_wea += single_tp
                    FP_wea += single_fp
                    FN_wea += single_fn

                elif args['dataset'] == 'woz':
                    # coimpute F1 SCORE
                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_index'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    TP_all += single_tp
                    FP_all += single_fp
                    FN_all += single_fn

                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_idx_restaurant'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_restaurant_pred += single_f1
                    F1_restaurant_count += count
                    TP_restaurant += single_tp
                    FP_restaurant += single_fp
                    FN_restaurant += single_fn

                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_idx_attraction'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_attraction_pred += single_f1
                    F1_attraction_count += count
                    TP_attraction += single_tp
                    FP_attraction += single_fp
                    FN_attraction += single_fn

                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_idx_hotel'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_hotel_pred += single_f1
                    F1_hotel_count += count
                    TP_hotel += single_tp
                    FP_hotel += single_fp
                    FN_hotel += single_fn

                elif args['dataset'] == 'cam':
                    # compute F1 SCORE
                    single_tp, single_fp, single_fn, single_f1, count = self.compute_prf(
                        data_dev['ent_index'][bi], pred_sent.split(),
                        global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    TP_all += single_tp
                    FP_all += single_fp
                    FN_all += single_fn

                # compute Per-response Accuracy Score
                total += 1
                if (gold_sent == pred_sent):
                    acc += 1

                if args['genSample']:
                    self.print_examples(bi, data_dev, pred_sent,
                                        pred_sent_coarse, gold_sent)

        # Set back to training mode
        self.encoder.train(True)
        self.extKnow.train(True)
        self.decoder.train(True)

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        acc_score = acc / float(total)
        print("ACC SCORE:\t" + str(acc_score))

        if args['dataset'] == 'kvr':
            print("BLEU SCORE:\t" + str(bleu_score))
            print("F1-macro SCORE:\t{}".format(F1_pred / float(F1_count)))
            print("F1-macro-sche SCORE:\t{}".format(F1_cal_pred /
                                                    float(F1_cal_count)))
            print("F1-macro-wea SCORE:\t{}".format(F1_wet_pred /
                                                   float(F1_wet_count)))
            print("F1-macro-nav SCORE:\t{}".format(F1_nav_pred /
                                                   float(F1_nav_count)))

            P_score = TP_all / float(TP_all + FP_all) if (TP_all +
                                                          FP_all) != 0 else 0
            R_score = TP_all / float(TP_all + FN_all) if (TP_all +
                                                          FN_all) != 0 else 0
            P_nav_score = TP_nav / float(TP_nav + FP_nav) if (
                TP_nav + FP_nav) != 0 else 0
            P_sche_score = TP_sche / float(TP_sche + FP_sche) if (
                TP_sche + FP_sche) != 0 else 0
            P_wea_score = TP_wea / float(TP_wea + FP_wea) if (
                TP_wea + FP_wea) != 0 else 0
            R_nav_score = TP_nav / float(TP_nav + FN_nav) if (
                TP_nav + FN_nav) != 0 else 0
            R_sche_score = TP_sche / float(TP_sche + FN_sche) if (
                TP_sche + FN_sche) != 0 else 0
            R_wea_score = TP_wea / float(TP_wea + FN_wea) if (
                TP_wea + FN_wea) != 0 else 0

            F1_score = self.compute_F1(P_score, R_score)
            print("F1-micro SCORE:\t{}".format(F1_score))
            print("F1-micro-sche SCORE:\t{}".format(
                self.compute_F1(P_sche_score, R_sche_score)))
            print("F1-micro-wea SCORE:\t{}".format(
                self.compute_F1(P_wea_score, R_wea_score)))
            print("F1-micro-nav SCORE:\t{}".format(
                self.compute_F1(P_nav_score, R_nav_score)))

            print("BLEU SCORE:" + str(bleu_score))
            print("F1-macro SCORE:\t{}".format(F1_pred / float(F1_count)))
            print("F1-micro SCORE:\t{}".format(F1_score))

        elif args['dataset'] == 'woz':
            print("BLEU SCORE:\t" + str(bleu_score))
            print("F1-macro SCORE:\t{}".format(F1_pred / float(F1_count)))
            print("F1-macro-restaurant SCORE:\t{}".format(
                F1_restaurant_pred / float(F1_restaurant_count)))
            print("F1-macro-attraction SCORE:\t{}".format(
                F1_attraction_pred / float(F1_attraction_count)))
            print("F1-macro-hotel SCORE:\t{}".format(F1_hotel_pred /
                                                     float(F1_hotel_count)))

            P_score = TP_all / float(TP_all + FP_all) if (TP_all +
                                                          FP_all) != 0 else 0
            R_score = TP_all / float(TP_all + FN_all) if (TP_all +
                                                          FN_all) != 0 else 0
            P_restaurant_score = TP_restaurant / float(
                TP_restaurant + FP_restaurant) if (TP_restaurant +
                                                   FP_restaurant) != 0 else 0
            P_attraction_score = TP_attraction / float(
                TP_attraction + FP_attraction) if (TP_attraction +
                                                   FP_attraction) != 0 else 0
            P_hotel_score = TP_hotel / float(TP_hotel + FP_hotel) if (
                TP_hotel + FP_hotel) != 0 else 0

            R_restaurant_score = TP_restaurant / float(
                TP_restaurant + FN_restaurant) if (TP_restaurant +
                                                   FN_restaurant) != 0 else 0
            R_attraction_score = TP_attraction / float(
                TP_attraction + FN_attraction) if (TP_attraction +
                                                   FN_attraction) != 0 else 0
            R_hotel_score = TP_hotel / float(TP_hotel + FN_hotel) if (
                TP_hotel + FN_hotel) != 0 else 0

            F1_score = self.compute_F1(P_score, R_score)
            print("F1-micro SCORE:\t{}".format(F1_score))
            print("F1-micro-restaurant SCORE:\t{}".format(
                self.compute_F1(P_restaurant_score, R_restaurant_score)))
            print("F1-micro-attraction SCORE:\t{}".format(
                self.compute_F1(P_attraction_score, R_attraction_score)))
            print("F1-micro-hotel SCORE:\t{}".format(
                self.compute_F1(P_hotel_score, R_hotel_score)))

        elif args['dataset'] == 'cam':
            print("BLEU SCORE:\t" + str(bleu_score))
            print("F1-macro SCORE:\t{}".format(F1_pred / float(F1_count)))
            P_score = TP_all / float(TP_all + FP_all) if (TP_all +
                                                          FP_all) != 0 else 0
            R_score = TP_all / float(TP_all + FN_all) if (TP_all +
                                                          FN_all) != 0 else 0
            F1_score = self.compute_F1(P_score, R_score)
            print("F1-micro SCORE:\t{}".format(F1_score))

        if output:
            print('Test Finish!')
            with open(args['output'], 'w+') as f:
                if args['dataset'] == 'kvr':
                    print("ACC SCORE:\t" + str(acc_score), file=f)
                    print("BLEU SCORE:\t" + str(bleu_score), file=f)
                    print("F1-macro SCORE:\t{}".format(F1_pred /
                                                       float(F1_count)),
                          file=f)
                    print("F1-micro SCORE:\t{}".format(
                        self.compute_F1(P_score, R_score)),
                          file=f)
                    print("F1-macro-sche SCORE:\t{}".format(
                        F1_cal_pred / float(F1_cal_count)),
                          file=f)
                    print("F1-macro-wea SCORE:\t{}".format(
                        F1_wet_pred / float(F1_wet_count)),
                          file=f)
                    print("F1-macro-nav SCORE:\t{}".format(
                        F1_nav_pred / float(F1_nav_count)),
                          file=f)
                    print("F1-micro-sche SCORE:\t{}".format(
                        self.compute_F1(P_sche_score, R_sche_score)),
                          file=f)
                    print("F1-micro-wea SCORE:\t{}".format(
                        self.compute_F1(P_wea_score, R_wea_score)),
                          file=f)
                    print("F1-micro-nav SCORE:\t{}".format(
                        self.compute_F1(P_nav_score, R_nav_score)),
                          file=f)
                elif args['dataset'] == 'woz':
                    print("ACC SCORE:\t" + str(acc_score), file=f)
                    print("BLEU SCORE:\t" + str(bleu_score), file=f)
                    print("F1-macro SCORE:\t{}".format(F1_pred /
                                                       float(F1_count)),
                          file=f)
                    print("F1-micro SCORE:\t{}".format(
                        self.compute_F1(P_score, R_score)),
                          file=f)
                    print("F1-macro-restaurant SCORE:\t{}".format(
                        F1_restaurant_pred / float(F1_restaurant_count)),
                          file=f)
                    print("F1-macro-attraction SCORE:\t{}".format(
                        F1_attraction_pred / float(F1_attraction_count)),
                          file=f)
                    print("F1-macro-hotel SCORE:\t{}".format(
                        F1_hotel_pred / float(F1_hotel_count)),
                          file=f)

                    print("F1-micro SCORE:\t{}".format(F1_score), file=f)
                    print("F1-micro-restaurant SCORE:\t{}".format(
                        self.compute_F1(P_restaurant_score,
                                        R_restaurant_score)),
                          file=f)
                    print("F1-micro-attraction SCORE:\t{}".format(
                        self.compute_F1(P_attraction_score,
                                        R_attraction_score)),
                          file=f)
                    print("F1-micro-hotel SCORE:\t{}".format(
                        self.compute_F1(P_hotel_score, R_hotel_score)),
                          file=f)
                elif args['dataset'] == 'cam':
                    print("ACC SCORE:\t" + str(acc_score), file=f)
                    print("BLEU SCORE:\t" + str(bleu_score), file=f)
                    print("F1-macro SCORE:\t{}".format(F1_pred /
                                                       float(F1_count)),
                          file=f)
                    print("F1-micro SCORE:\t{}".format(
                        self.compute_F1(P_score, R_score)),
                          file=f)

        if (early_stop == 'BLEU'):
            if (bleu_score >= matric_best):
                self.save_model('BLEU-' + str(bleu_score) + 'F1-' +
                                str(F1_score))
                print("MODEL SAVED")
            return bleu_score
        elif (early_stop == 'ENTF1'):
            if (F1_score >= matric_best):
                # self.save_model('ENTF1-{:.4f}'.format(F1_score))
                self.save_model('ENTF1')
                print("MODEL SAVED")
            return F1_score
        else:
            if (acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(acc_score))
                print("MODEL SAVED")
            return acc_score
Пример #7
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = 0, 0, 0, 0
        microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = 0, 0, 0, 0
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}

        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [
                            item.lower().replace(' ', '_')
                            for item in global_entity[key]
                        ]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [
                                item[k].lower().replace(' ', '_')
                                for k in item.keys()
                            ]
                global_entity_list = list(set(global_entity_list))
        else:
            if int(args["task"]) != 6:
                global_entity_list = entityList(
                    'data/dialog-bAbI-tasks/dialog-babi-kb-all.txt',
                    int(args["task"]))
            else:
                global_entity_list = entityList(
                    'data/dialog-bAbI-tasks/dialog-babi-task6-dstc2-kb.txt',
                    int(args["task"]))

        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6])

            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                st = st.lstrip().rstrip()
                correct = correct.lstrip().rstrip()
                if args['dataset'] == 'kvr':
                    f1_true, count = self.compute_prf(data_dev[8][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count
                    f1_true, count = self.compute_prf(data_dev[9][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += count
                    f1_true, count = self.compute_prf(data_dev[10][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += count
                    f1_true, count = self.compute_prf(data_dev[11][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += count
                elif args['dataset'] == 'babi' and int(args["task"]) == 6:
                    f1_true, count = self.compute_prf(data_dev[10][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[12][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count

                if args['dataset'] == 'babi':
                    if data_dev[11][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[11][i]] = []
                    if (correct == st):
                        acc += 1
                        dialog_acc_dict[data_dev[11][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[11][i]].append(0)
                else:
                    if (correct == st):
                        acc += 1
                #    print("Correct:"+str(correct))
                #    print("\tPredict:"+str(st))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct, st)
                ref.append(str(correct))
                hyp.append(str(st))
                ref_s += str(correct) + "\n"
                hyp_s += str(st) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE /
                                                float(microF1_PRED)))
            logging.info("\tCAL F1:\t{}".format(microF1_TRUE_cal /
                                                float(microF1_PRED_cal)))
            logging.info("\tWET F1:\t{}".format(microF1_TRUE_wet /
                                                float(microF1_PRED_wet)))
            logging.info("\tNAV F1:\t{}".format(microF1_TRUE_nav /
                                                float(microF1_PRED_nav)))
        elif args['dataset'] == 'babi' and int(args["task"]) == 6:
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE /
                                                float(microF1_PRED)))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
Пример #8
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-2], data_dev[-1])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-4], data_dev[-3])
            # acc_P += acc_ptr
            # acc_V += acc_vac
            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                if args['dataset'] == 'kvr':
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred
                elif args['dataset'] == 'babi' and int(self.task) == 6:
                    f1_true, f1_pred = computeF1(data_dev[-2][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                if args['dataset'] == 'babi':
                    if data_dev[-1][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[-1][i]] = []
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                        dialog_acc_dict[data_dev[-1][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[-1][i]].append(0)
                else:
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))
            logging.info("F1 CAL:\t" + str(
                f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')))
            logging.info("F1 WET:\t" + str(
                f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')))
            logging.info("F1 NAV:\t" + str(
                f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')))
        elif args['dataset'] == 'babi' and int(self.task) == 6:
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
Пример #9
0
    def evaluate(self, dev, metric_best, early_stop=None):
        print('\nSTARTING EVALUATING...')
        self.encoder.train(False)
        self.ext_know.train(False)
        self.decoder.train(False)

        label, pred = [], []
        acc, total = 0, 0
        # kvr数据集的评价指标
        F1_pred, F1_cal_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
        F1_count, F1_cal_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0

        dialogue_acc_dict = defaultdict(list)

        if args['dataset'] == 'kvr':
            with open('../data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity:
                    if key != 'poi':
                        global_entity_list += [
                            item.lower().replace(' ', '_')
                            for item in global_entity[key]
                        ]
                    else:
                        for item in global_entity[key]:
                            global_entity_list += [
                                item[x].lower().replace(' ', '_') for x in item
                            ]
                global_entity_list = list(set(global_entity_list))

        for i, data_item in tqdm(enumerate(dev), total=len(dev)):
            max_target_length = max(data_item['response_lengths'])
            _, _, decoded_fine, decoded_coarse, global_ptr = self.encode_and_decode(
                data_item, max_target_length, evaluating=True)
            # decoded_fine是以一个batch的一个单词组成的列表为最内维度,所以倒置转化成行为一个完整的句子的预测疏输出
            decoded_fine, decoded_coarse = map(
                lambda x: np.transpose(np.array(x)),
                (decoded_fine, decoded_coarse))
            for bi, word_fine in enumerate(decoded_fine):
                response_fine = ''
                for e in word_fine:
                    if e == 'EOS':
                        break
                    response_fine += (e + ' ')
                st_c = ''
                for e in decoded_coarse[bi]:
                    if e == 'EOS':
                        break
                    else:
                        st_c += e + ' '
                pred_sentence = response_fine.strip()
                pred_sentence_coarse = st_c.strip()
                pred.append(pred_sentence)
                label_sentence = data_item['response_plain'][bi].strip(
                )  # 有一次bi会越界
                label.append(label_sentence)

                # 打印输出样例
                # print('Context:')
                # print(data_item['context_arr_plain'][bi])
                # print('Predictive response:')
                # print(pred_sentence)
                # print('Label sentence:')
                # print(label_sentence)

                if args['dataset'] == 'kvr':
                    # compute F1 SCORE
                    single_f1, count = self.compute_prf(
                        data_item['ent_index'][bi], pred_sentence.split(),
                        global_entity_list, data_item['kb_info_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    single_f1, count = self.compute_prf(
                        data_item['ent_idx_cal'][bi], pred_sentence.split(),
                        global_entity_list, data_item['kb_info_plain'][bi])
                    F1_cal_pred += single_f1
                    F1_cal_count += count
                    single_f1, count = self.compute_prf(
                        data_item['ent_idx_nav'][bi], pred_sentence.split(),
                        global_entity_list, data_item['kb_info_plain'][bi])
                    F1_nav_pred += single_f1
                    F1_nav_count += count
                    single_f1, count = self.compute_prf(
                        data_item['ent_idx_wet'][bi], pred_sentence.split(),
                        global_entity_list, data_item['kb_info_plain'][bi])
                    F1_wet_pred += single_f1
                    F1_wet_count += count
                else:
                    if pred_sentence == label_sentence:
                        acc += 1
                        dialogue_acc_dict[data_item['ID'][bi]].append(1)
                    else:
                        dialogue_acc_dict[data_item['ID'][bi]].append(0)
                total += 1
                if args['genSample']:
                    self.print_examples(bi, data_item, pred_sentence,
                                        pred_sentence_coarse, label_sentence)

        self.encoder.train(True)
        self.ext_know.train(True)
        self.decoder.train(True)

        acc_score = acc / float(total)
        print('TRAIN ACC SCORE:\t{}'.format(acc_score))
        bleu_score = moses_multi_bleu(np.array(pred),
                                      np.array(label),
                                      lowercase=True)  # 暂时无法使用

        if args['dataset'] == 'kvr':
            F1_score = F1_pred / float(F1_count)
            print("F1 SCORE:\t{}".format(F1_pred / float(F1_count)))
            print("\tCAL F1:\t{}".format(F1_cal_pred / float(F1_cal_count)))
            print("\tWET F1:\t{}".format(F1_wet_pred / float(F1_wet_count)))
            print("\tNAV F1:\t{}".format(F1_nav_pred / float(F1_nav_count)))
            print("BLEU SCORE:\t" + str(bleu_score))
        else:
            dialogue_acc = 0
            for key in dialogue_acc_dict:
                if len(dialogue_acc_dict[key]) == sum(dialogue_acc_dict[key]):
                    dialogue_acc += 1
            print("Dialog Accuracy:\t{}".format(dialogue_acc * 1.0 /
                                                len(dialogue_acc_dict)))

        if early_stop == 'BLEU':
            if bleu_score >= metric_best:
                self.save_model('BLEU-' + str(bleu_score))
                print('MODEL SAVED')
                return bleu_score
        else:
            if acc_score >= metric_best:
                self.save_model('ACC-' + str(acc_score))
                print('MODEL SAVED')
                return acc_score
Пример #10
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        # 在whole eval_dataset上计算的
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}  # 在whole eval_dataset上计算的
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                '''
                batch_size,
                input_batches, input_lengths,
                target_batches, target_lengths,
                target_index,target_gate,
                src_plain,
                conv_seqs, conv_lengths'''
                # output shape (T,B)
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-2], data_dev[-1])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-4], data_dev[-3])
            # acc_P += acc_ptr
            # acc_V += acc_vac
            acc = 0  # 在one batch里计算的
            w = 0
            temp_gen = []

            # Permute the dimensions of an array
            for i, row in enumerate(np.transpose(words)):  # shape (B,T)
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                # data_dev[7] may be the correct sentences; shape(B,T)
                correct = data_dev[7][i]  # this is response sentences
                # compute F1 SCORE
                if args['dataset'] == 'kvr':
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred  # 全部是1,估计用来做分母的,多余的??
                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred
                elif args['dataset'] == 'babi' and int(self.task) == 6:
                    f1_true, f1_pred = computeF1(data_dev[-2][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                if args['dataset'] == 'babi':
                    # ID
                    if data_dev[-1][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[-1][i]] = []
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1  # 在one batch里计算的
                        dialog_acc_dict[data_dev[-1][i]].append(1)
                    else:  # 在whole eval_dataset上计算的
                        dialog_acc_dict[data_dev[-1][i]].append(0)
                else:
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(
                data_dev[1]))  # len(data_dev[1]) = batch_size
            wer_avg += w / float(len(data_dev[1]))  # len(dev) = num of batches
            # TODO: 有点不合理啊; 除以j应该比较合理;
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))
            logging.info("F1 CAL:\t" + str(
                f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')))
            logging.info("F1 WET:\t" + str(
                f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')))
            logging.info("F1 NAV:\t" + str(
                f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')))
        elif args['dataset'] == 'babi' and int(self.task) == 6:
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
Пример #11
0
    def evaluate(self,
                 dev,
                 avg_best,
                 epoch,
                 BLEU=False,
                 Analyse=False,
                 type='dev'):
        assert type == 'dev' or type == 'test'
        logging.info("STARTING EVALUATION:{}".format(type))
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}
        pbar = tqdm(enumerate(dev), total=len(dev))

        if Analyse == True:
            write_fp = write_to_disk('./vanilla-seq-generate.txt')

        for j, data_dev in pbar:
            words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                        data_dev[1], data_dev[2])
            acc = 0
            w = 0
            temp_gen = []
            #print(words)
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                if args['dataset'] == 'kvr':
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred
                elif args['dataset'] == 'babi' and int(self.task) == 6:

                    f1_true, f1_pred = computeF1(data_dev[-2][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                if args['dataset'] == 'babi':
                    if data_dev[-1][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[-1][i]] = []
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                        dialog_acc_dict[data_dev[-1][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[-1][i]].append(0)
                else:
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                #else:
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                # w += wer(correct.lstrip().rstrip(),st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            # write batch data to disk
            if Analyse == True:
                for gen, gold in zip(temp_gen, data_dev[7]):
                    write_fp.write(gen + '\t' + gold + '\n')

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))
        if Analyse == True:
            write_fp.close()

        if args['dataset'] == 'babi':
            # TODO:计算平均的对话准确度
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))
            self._send_metrics(epoch,
                               type,
                               acc=dia_acc * 1.0 / len(dialog_acc_dict.keys()))
            self._send_metrics(epoch,
                               type,
                               acc_response=acc_avg / float(len(dev)))

        if args['dataset'] == 'kvr':
            f1 = f1_score(microF1_TRUE, microF1_PRED, average='micro')
            f1_cal = f1_score(microF1_TRUE_cal,
                              microF1_PRED_cal,
                              average='micro')
            f1_wet = f1_score(microF1_TRUE_wet,
                              microF1_PRED_wet,
                              average='micro')
            f1_nav = f1_score(microF1_TRUE_nav,
                              microF1_PRED_nav,
                              average='micro')

            logging.info("F1 SCORE:\t" + str(f1))
            logging.info("F1 CAL:\t" + str(f1_cal))
            logging.info("F1 WET:\t" + str(f1_wet))
            logging.info("F1 NAV:\t" + str(f1_nav))
            self._send_metrics(epoch,
                               type,
                               f1=f1,
                               f1_cal=f1_cal,
                               f1_wet=f1_wet,
                               f1_nav=f1_nav)

        elif args['dataset'] == 'babi' and int(self.task) == 6:
            f1 = f1_score(microF1_TRUE, microF1_PRED, average='micro')
            logging.info("F1 SCORE:\t" + str(f1))
            self._send_metrics(epoch, type, babi_6_f1=f1)

        self._send_metrics(epoch, type, total_loss=self.print_loss_avg)

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        self._send_metrics(epoch, type, bleu=bleu_score)

        if (BLEU):
            if (bleu_score >= avg_best):
                if type == 'dev':
                    self.save_model(str(self.name) + str(bleu_score))
                    logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                if type == 'dev':
                    self.save_model(str(self.name) + str(acc_avg))
                    logging.info("MODEL SAVED")
            return acc_avg
Пример #12
0
    def evaluate(self, dev, avg_best, epoch, BLEU=False, Analyse=False, type='dev'):
        # Analyse 是在分析Mem中每个attn的score使用的
        assert type == 'dev' or type == 'test'
        logging.info("STARTING EVALUATION:{}".format(type))
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = [], [], [], []
        microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = [], [], [], []
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}
        pbar = tqdm(enumerate(dev), total=len(dev))
        if Analyse == True:
            write_fp = write_to_disk('./multi-mem-generate.txt')
            # 统计有多少数据是从memory中复制出来的

        for j, data_dev in pbar:
            words = self.evaluate_batch(batch_size=len(data_dev[1]),
                                        input_batches=data_dev[0],
                                        input_lengths=data_dev[1],
                                        target_batches=data_dev[2],
                                        target_lengths=data_dev[3],
                                        target_index=data_dev[4],
                                        target_gate=data_dev[5],
                                        src_plain=data_dev[6],
                                        conv_seqs=data_dev[-6],
                                        conv_lengths=data_dev[-5],
                                        kb_seqs=data_dev[-4],
                                        kb_lengths=data_dev[-3],
                                        kb_target_index=data_dev[-2],
                                        kb_plain=data_dev[-1],
                                        step=j,
                                        Analyse=Analyse)
            # acc_P += acc_ptr
            # acc_V += acc_vac
            acc = 0
            w = 0
            temp_gen = []
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                if args['dataset'] == 'kvr':
                    # TODO:Check this
                    f1_true, f1_pred = computeF1(data_dev[8][i], st.lstrip().rstrip(), correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[9][i], st.lstrip().rstrip(), correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[10][i], st.lstrip().rstrip(), correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[11][i], st.lstrip().rstrip(), correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred
                elif args['dataset'] == 'babi' and int(self.task) == 6:
                    f1_true, f1_pred = computeF1(data_dev[-6][i], st.lstrip().rstrip(), correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred
                if args['dataset'] == 'babi':
                    if data_dev[-5][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[-5][i]] = []
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                        dialog_acc_dict[data_dev[-5][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[-5][i]].append(0)
                else:
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                # print("Correct:"+str(correct.lstrip().rstrip()))
                # print("\tPredict:"+str(st.lstrip().rstrip()))
                # print("\tFrom:"+str(self.from_whichs[:,i]))
                # w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            # write batch data to disk
            if Analyse == True:
                for gen, gold in zip(temp_gen, data_dev[7]):
                    write_fp.write(gen + '\t' + gold + '\n')

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(acc_avg / float(len(dev)),
                                                            wer_avg / float(len(dev))))

        if Analyse == True:
            write_fp.close()

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" + str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))
            self._send_metrics(epoch, type, acc=dia_acc * 1.0 / len(dialog_acc_dict.keys()))
        if args['dataset'] == 'kvr':
            f1 = f1_score(microF1_TRUE, microF1_PRED, average='micro')
            f1_cal = f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')
            f1_wet = f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')
            f1_nav = f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')
            logging.info("F1 SCORE:\t" + str(f1))
            logging.info("F1 CAL:\t" + str(f1_cal))
            logging.info("F1 WET:\t" + str(f1_wet))
            logging.info("F1 NAV:\t" + str(f1_nav))
            self._send_metrics(epoch, type, f1=f1, f1_cal=f1_cal, f1_wet=f1_wet, f1_nav=f1_nav)
        elif args['dataset'] == 'babi' and int(self.task) == 6:
            f1 = f1_score(microF1_TRUE, microF1_PRED, average='micro')
            logging.info("F1 SCORE:\t" + str(f1))
            self._send_metrics(epoch, type, babi_6_f1=f1)
        # Report Bleu score
        bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
        # Report Acc per response
        self._send_metrics(epoch, type, acc_response=acc_avg / float(len(dev)))
        logging.info("BLEU SCORE:" + str(bleu_score))

        if Analyse == False:
            # Send loss
            self._send_metrics(epoch, type, total_loss=self.print_loss_avg,
                               ptr_loss=self.print_loss_kb,
                               vocab_loss=self.print_loss_vocabulary,
                               his_loss=self.print_loss_memory,
                               bleu_score=bleu_score)

            if (BLEU):
                if (bleu_score >= avg_best and bleu_score != 0):
                    if type == 'dev':
                        directory = self.save_model(str(self.name) + str(bleu_score))
                        locals_var = locals()

                        logging.info("MODEL SAVED")
                return bleu_score
            else:
                acc_avg = acc_avg / float(len(dev))
                if (acc_avg >= avg_best):
                    if type == 'dev':
                        locals_var = locals()
                        directory = self.save_model(str(self.name) + str(acc_avg))
                        logging.info("MODEL SAVED")
                return acc_avg
        else:
            if (BLEU):
                return bleu_score
            else:
                return acc_avg
Пример #13
0
    def evaluate(self, dev, dev_length, matric_best, early_stop=None):
        print('STARTING EVALUATION:')
        ref, hyp = [], []
        acc, total = 0, 0
        dialog_acc_dict = {}
        F1_pred, F1_cal_pred, F1_nav_pred, F1_wet_pred, F1_restaurant_pred, F1_hotel_pred, F1_attraction_pred, F1_train_pred, F1_hospital_pred = 0, 0, 0, 0, 0, 0, 0, 0, 0
        F1_count, F1_cal_count, F1_nav_count, F1_wet_count, F1_restaurant_count, F1_hotel_count, F1_attraction_count, F1_train_count, F1_hospital_count = 0, 0, 0, 0, 0, 0, 0, 0, 0
        pbar = tqdm(enumerate(dev.take(-1)), total=(dev_length))
        new_precision, new_recall, new_f1_score = 0, 0, 0

        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [
                            item.lower().replace(' ', '_')
                            for item in global_entity[key]
                        ]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [
                                item[k].lower().replace(' ', '_')
                                for k in item.keys()
                            ]
                global_entity_list = list(set(global_entity_list))
        elif args['dataset'] == 'multiwoz':
            with open('data/MULTIWOZ2.1/multiwoz_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    global_entity_list += [
                        item.lower().replace(' ', '_')
                        for item in global_entity[key]
                    ]
                global_entity_list = list(set(global_entity_list))

        for j, data_dev in pbar:
            # Encode and Decode
            _, _, decoded_fine, decoded_coarse, global_pointer, global_pointer_logits = self.encode_and_decode(
                data_dev, self.max_resp_len, False, True, False)
            decoded_coarse = np.transpose(decoded_coarse)
            decoded_fine = np.transpose(decoded_fine)
            for bi, row in enumerate(decoded_fine):
                st = ''
                for e in row:
                    if e == 'EOS':
                        break
                    else:
                        st += e + ' '
                st_c = ''
                for e in decoded_coarse[bi]:
                    if e == 'EOS':
                        break
                    else:
                        st_c += e + ' '
                pred_sent = st.lstrip().rstrip()
                pred_sent_coarse = st_c.lstrip().rstrip()
                gold_sent = data_dev[8][bi][0].numpy().decode().lstrip(
                ).rstrip()  # data[8]: response_plain.
                ref.append(gold_sent)
                hyp.append(pred_sent)

                if args['dataset'] == 'kvr':
                    # compute F1 SCORE
                    single_f1, count = self.compute_prf(
                        data_dev[14][bi], pred_sent.split(),
                        global_entity_list, data_dev[9]
                        [bi])  # data[14]: ent_index, data[9]: kb_arr_plain.
                    F1_pred += single_f1
                    F1_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[16][bi], pred_sent.split(),
                        global_entity_list, data_dev[9]
                        [bi])  # data[16]: ent_idx_cal, data[9]: kb_arr_plain.
                    F1_cal_pred += single_f1
                    F1_cal_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[17][bi], pred_sent.split(),
                        global_entity_list, data_dev[9]
                        [bi])  # data[17]: ent_idx_nav, data[9]: kb_arr_plain.
                    F1_nav_pred += single_f1
                    F1_nav_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[18][bi], pred_sent.split(),
                        global_entity_list, data_dev[9]
                        [bi])  # data[18]: ent_idx_wet, data[9]: kb_arr_plain.
                    F1_wet_pred += single_f1
                    F1_wet_count += count
                elif args['dataset'] == 'multiwoz':
                    # compute F1 SCORE
                    single_f1, count = self.compute_prf(
                        data_dev[14][bi], pred_sent.split(),
                        global_entity_list, data_dev[9]
                        [bi])  # data[14]: ent_index, data[9]: kb_arr_plain.
                    F1_pred += single_f1
                    F1_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[28][bi], pred_sent.split(),
                        global_entity_list, data_dev[9][bi]
                    )  # data[28]: ent_idx_restaurant, data[9]: kb_arr_plain.
                    F1_restaurant_pred += single_f1
                    F1_restaurant_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[29][bi], pred_sent.split(),
                        global_entity_list, data_dev[9][bi]
                    )  # data[29]: ent_idx_hotel, data[9]: kb_arr_plain.
                    F1_hotel_pred += single_f1
                    F1_hotel_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[30][bi], pred_sent.split(),
                        global_entity_list, data_dev[9][bi]
                    )  # data[30]: ent_idx_attraction, data[9]: kb_arr_plain.
                    F1_attraction_pred += single_f1
                    F1_attraction_count += count
                    single_f1, count = self.compute_prf(
                        data_dev[31][bi], pred_sent.split(),
                        global_entity_list, data_dev[9][bi]
                    )  # data[31]: ent_idx_train, data[9]: kb_arr_plain.
                    F1_train_pred += single_f1
                    F1_train_count += count
                else:
                    # compute Dialogue Accuracy Score
                    current_id = data_dev[22][bi]
                    if current_id not in dialog_acc_dict.keys():
                        dialog_acc_dict[current_id] = []
                    if gold_sent == pred_sent:
                        dialog_acc_dict[current_id].append(1)
                    else:
                        dialog_acc_dict[current_id].append(0)

                # compute Per-response Accuracy Score
                total += 1
                if (gold_sent == pred_sent):
                    acc += 1

                if args['genSample']:
                    self.print_examples(bi, data_dev, pred_sent,
                                        pred_sent_coarse, gold_sent)

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        acc_score = acc / float(total)

        if args['dataset'] == 'kvr':
            F1_score = F1_pred / float(F1_count)
            print("F1 SCORE:\t{}".format(F1_pred / float(F1_count)))
            print("\tCAL F1:\t{}".format(F1_cal_pred / float(F1_cal_count)))
            print("\tWET F1:\t{}".format(F1_wet_pred / float(F1_wet_count)))
            print("\tNAV F1:\t{}".format(F1_nav_pred / float(F1_nav_count)))
            print("BLEU SCORE:\t" + str(bleu_score))
        elif args['dataset'] == 'multiwoz':
            F1_score = F1_pred / float(F1_count)
            print("F1 SCORE:\t{}".format(F1_pred / float(F1_count)))
            print("\tRES F1:\t{}".format(F1_restaurant_pred /
                                         float(F1_restaurant_count)))
            print("\tHOT F1:\t{}".format(F1_hotel_pred /
                                         float(F1_hotel_count)))
            print("\tATT F1:\t{}".format(F1_attraction_pred /
                                         float(F1_attraction_count)))
            print("\tTRA F1:\t{}".format(F1_train_pred /
                                         float(F1_train_count)))
            print("BLEU SCORE:\t" + str(bleu_score))
        else:
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            print("Dialog Accuracy:\t" +
                  str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if (early_stop == 'BLEU'):
            if (bleu_score >= matric_best):
                self.save_model('BLEU-' + str(bleu_score))
                print("MODEL SAVED")
            return bleu_score
        elif (early_stop == 'ENTF1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")
            return F1_score
        else:
            if (acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(acc_score))
                print("MODEL SAVED")
            return acc_score
Пример #14
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = 0, 0, 0, 0
        microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = 0, 0, 0, 0
        query = []
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        query_s = ""
        dialog_acc_dict = {}

        # only use for chitchat data
        if args['dataset'] == 'chitchat':
            global_entity_list = []
        else:
            raise ValueError("Must use chitchat data.")

        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            # ???
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1],
                                            data_dev[2], data_dev[3], data_dev[4], data_dev[5], data_dev[6])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1],
                                            data_dev[2], data_dev[3], data_dev[4], data_dev[5], data_dev[6])

            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                # correct = " ".join([self.lang.index2word[item] for item in data_dev[2]])
                correct = " ".join(data_dev[7][i])
                user_query = " ".join(data_dev[6][i])
                ### compute F1 SCOR
                st = st.lstrip().rstrip()
                correct = correct.lstrip().rstrip()
                user_query = user_query.strip()
                '''
                if args['dataset'] == 'kvr':
                    f1_true, count = self.compute_prf(data_dev[8][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE += f1_true`
                    microF1_PRED += count
                    f1_true, count = self.compute_prf(data_dev[9][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += count
                    f1_true, count = self.compute_prf(data_dev[10][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += count
                    f1_true, count = self.compute_prf(data_dev[11][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += count
                elif args['dataset'] == 'babi' and int(args["task"]) == 6:
                    f1_true, count = self.compute_prf(data_dev[10][i], st.split(), global_entity_list, data_dev[12][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count

                if args['dataset'] == 'babi':
                    if data_dev[11][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[11][i]] = []
                    if (correct == st):
                        acc += 1
                        dialog_acc_dict[data_dev[11][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[11][i]].append(0)
                else:
                
                    if (correct == st):
                        acc += 1
                #    print("Correct:"+str(correct))
                #    print("\tPredict:"+str(st))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))
                '''

                w += wer(correct, st)
                ref.append(str(correct))
                hyp.append(str(st))
                query.append(str(user_query))
                ref_s += str(correct) + "\n"
                hyp_s += str(st) + "\n"
                query_s += str(user_query) + "\n"

            # acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("W:{:.4f}".format(wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" + str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        '''
        if args['dataset'] == 'kvr':
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE / float(microF1_PRED)))
            logging.info("\tCAL F1:\t{}".format(microF1_TRUE_cal / float(microF1_PRED_cal)))
            logging.info("\tWET F1:\t{}".format(microF1_TRUE_wet / float(microF1_PRED_wet)))
            logging.info("\tNAV F1:\t{}".format(microF1_TRUE_nav / float(microF1_PRED_nav)))
        elif args['dataset'] == 'babi' and int(args["task"]) == 6:
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE / float(microF1_PRED)))
        '''

        # save decoding results

        bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        # always return with bleu score.
        BLEU = True
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            self.save_decode_results(query, ref, hyp, bleu_score)

            return bleu_score
        '''