示例#1
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1], data_dev[2])
            acc = 0
            w = 0
            temp_gen = []
            # print(words)
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]

                if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                    acc += 1
                # else:
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(acc_avg / float(len(dev)),
                                                            wer_avg / float(len(dev))))

        if (BLEU):
            bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
            logging.info("BLEU SCORE:" + str(bleu_score))

            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
示例#2
0
    def evaluate(self, out_f, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        ref = []
        hyp = []
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            words = self.evaluate_batch(batch_size=len(data_dev[1]),
                                        input_batches=data_dev[0],
                                        input_lengths=data_dev[1],
                                        target_batches=data_dev[2],
                                        target_lengths=data_dev[3],
                                        target_index=data_dev[4],
                                        src_plain=data_dev[5])
            acc = 0
            w = 0
            print("data_dev", data_dev[1])
            temp_gen = []
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = " ".join(data_dev[6][i])
                ### IMPORTANT
                ### WE NEED TO COMPARE THE PLAIN STRING, BECAUSE WE COPY THE WORDS FROM THE INPUT
                ### ====>> the index in the output gold can be UNK
                if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                    acc += 1
                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))


#                print('correct',correct)
#               print('hyp',st)

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        out_f.writelines(str(correct) + "===" + str(hyp) + "\n")
        out_f.writelines("---------" + "\n")
        logging.info("BLEU SCORE:" + str(bleu_score))
        out_f.writelines("bleu_score" + str(bleu_score) + "\n")

        if (bleu_score >= avg_best):
            self.save_model(str(self.name) + str(bleu_score))
            logging.info("MODEL SAVED")
        return bleu_score
示例#3
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                        data_dev[1], data_dev[2])
            acc = 0
            w = 0
            temp_gen = []
            #print(words)
            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                if (len(data_dev) > 10):
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred

                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred

                if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                    acc += 1
                #else:
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))
        if (len(data_dev) > 10):
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))
            logging.info("F1 CAL:\t" + str(
                f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')))
            logging.info("F1 WET:\t" + str(
                f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')))
            logging.info("F1 NAV:\t" + str(
                f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')))

        if (BLEU):
            bleu_score = moses_multi_bleu(np.array(hyp),
                                          np.array(ref),
                                          lowercase=True)
            logging.info("BLEU SCORE:" + str(bleu_score))

            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
示例#4
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = 0, 0, 0, 0
        microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = 0, 0, 0, 0
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}

        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [
                            item.lower().replace(' ', '_')
                            for item in global_entity[key]
                        ]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [
                                item[k].lower().replace(' ', '_')
                                for k in item.keys()
                            ]
                global_entity_list = list(set(global_entity_list))
        else:
            if int(args["task"]) != 6:
                global_entity_list = entityList(
                    'data/dialog-bAbI-tasks/dialog-babi-kb-all.txt',
                    int(args["task"]))
            else:
                global_entity_list = entityList(
                    'data/dialog-bAbI-tasks/dialog-babi-task6-dstc2-kb.txt',
                    int(args["task"]))

        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6])

            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                st = st.lstrip().rstrip()
                correct = correct.lstrip().rstrip()
                if args['dataset'] == 'kvr':
                    f1_true, count = self.compute_prf(data_dev[8][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count
                    f1_true, count = self.compute_prf(data_dev[9][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += count
                    f1_true, count = self.compute_prf(data_dev[10][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += count
                    f1_true, count = self.compute_prf(data_dev[11][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += count
                elif args['dataset'] == 'babi' and int(args["task"]) == 6:
                    f1_true, count = self.compute_prf(data_dev[10][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[12][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count

                if args['dataset'] == 'babi':
                    if data_dev[11][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[11][i]] = []
                    if (correct == st):
                        acc += 1
                        dialog_acc_dict[data_dev[11][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[11][i]].append(0)
                else:
                    if (correct == st):
                        acc += 1
                #    print("Correct:"+str(correct))
                #    print("\tPredict:"+str(st))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct, st)
                ref.append(str(correct))
                hyp.append(str(st))
                ref_s += str(correct) + "\n"
                hyp_s += str(st) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE /
                                                float(microF1_PRED)))
            logging.info("\tCAL F1:\t{}".format(microF1_TRUE_cal /
                                                float(microF1_PRED_cal)))
            logging.info("\tWET F1:\t{}".format(microF1_TRUE_wet /
                                                float(microF1_PRED_wet)))
            logging.info("\tNAV F1:\t{}".format(microF1_TRUE_nav /
                                                float(microF1_PRED_nav)))
        elif args['dataset'] == 'babi' and int(args["task"]) == 6:
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE /
                                                float(microF1_PRED)))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
示例#5
0
文件: mem2seq.py 项目: zxsted/mem2seq
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-2], data_dev[-1])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-4], data_dev[-3])
            # acc_P += acc_ptr
            # acc_V += acc_vac
            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                if args['dataset'] == 'kvr':
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred
                elif args['dataset'] == 'babi' and int(self.task) == 6:
                    f1_true, f1_pred = computeF1(data_dev[-2][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                if args['dataset'] == 'babi':
                    if data_dev[-1][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[-1][i]] = []
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                        dialog_acc_dict[data_dev[-1][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[-1][i]].append(0)
                else:
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))
            logging.info("F1 CAL:\t" + str(
                f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')))
            logging.info("F1 WET:\t" + str(
                f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')))
            logging.info("F1 NAV:\t" + str(
                f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')))
        elif args['dataset'] == 'babi' and int(self.task) == 6:
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
示例#6
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED,microF1_PRED_cal,microF1_PRED_nav,microF1_PRED_wet = [],[],[],[]
        microF1_TRUE,microF1_TRUE_cal,microF1_TRUE_nav,microF1_TRUE_wet = [],[],[],[]
        # 在whole eval_dataset上计算的
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}  # 在whole eval_dataset上计算的
        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                '''
                batch_size,
                input_batches, input_lengths,
                target_batches, target_lengths,
                target_index,target_gate,
                src_plain,
                conv_seqs, conv_lengths'''
                # output shape (T,B)
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-2], data_dev[-1])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6],
                                            data_dev[-4], data_dev[-3])
            # acc_P += acc_ptr
            # acc_V += acc_vac
            acc = 0  # 在one batch里计算的
            w = 0
            temp_gen = []

            # Permute the dimensions of an array
            for i, row in enumerate(np.transpose(words)):  # shape (B,T)
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                # data_dev[7] may be the correct sentences; shape(B,T)
                correct = data_dev[7][i]  # this is response sentences
                # compute F1 SCORE
                if args['dataset'] == 'kvr':
                    f1_true, f1_pred = computeF1(data_dev[8][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred  # 全部是1,估计用来做分母的,多余的??
                    f1_true, f1_pred = computeF1(data_dev[9][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[10][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += f1_pred
                    f1_true, f1_pred = computeF1(data_dev[11][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += f1_pred
                elif args['dataset'] == 'babi' and int(self.task) == 6:
                    f1_true, f1_pred = computeF1(data_dev[-2][i],
                                                 st.lstrip().rstrip(),
                                                 correct.lstrip().rstrip())
                    microF1_TRUE += f1_true
                    microF1_PRED += f1_pred

                if args['dataset'] == 'babi':
                    # ID
                    if data_dev[-1][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[-1][i]] = []
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1  # 在one batch里计算的
                        dialog_acc_dict[data_dev[-1][i]].append(1)
                    else:  # 在whole eval_dataset上计算的
                        dialog_acc_dict[data_dev[-1][i]].append(0)
                else:
                    if (correct.lstrip().rstrip() == st.lstrip().rstrip()):
                        acc += 1
                #    print("Correct:"+str(correct.lstrip().rstrip()))
                #    print("\tPredict:"+str(st.lstrip().rstrip()))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct.lstrip().rstrip(), st.lstrip().rstrip())
                ref.append(str(correct.lstrip().rstrip()))
                hyp.append(str(st.lstrip().rstrip()))
                ref_s += str(correct.lstrip().rstrip()) + "\n"
                hyp_s += str(st.lstrip().rstrip()) + "\n"

            acc_avg += acc / float(len(
                data_dev[1]))  # len(data_dev[1]) = batch_size
            wer_avg += w / float(len(data_dev[1]))  # len(dev) = num of batches
            # TODO: 有点不合理啊; 除以j应该比较合理;
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))
            logging.info("F1 CAL:\t" + str(
                f1_score(microF1_TRUE_cal, microF1_PRED_cal, average='micro')))
            logging.info("F1 WET:\t" + str(
                f1_score(microF1_TRUE_wet, microF1_PRED_wet, average='micro')))
            logging.info("F1 NAV:\t" + str(
                f1_score(microF1_TRUE_nav, microF1_PRED_nav, average='micro')))
        elif args['dataset'] == 'babi' and int(self.task) == 6:
            logging.info(
                "F1 SCORE:\t" +
                str(f1_score(microF1_TRUE, microF1_PRED, average='micro')))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg
示例#7
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = 0, 0, 0, 0
        microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = 0, 0, 0, 0
        query = []
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        query_s = ""
        dialog_acc_dict = {}

        # only use for chitchat data
        if args['dataset'] == 'chitchat':
            global_entity_list = []
        else:
            raise ValueError("Must use chitchat data.")

        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            # ???
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1],
                                            data_dev[2], data_dev[3], data_dev[4], data_dev[5], data_dev[6])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1],
                                            data_dev[2], data_dev[3], data_dev[4], data_dev[5], data_dev[6])

            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>':
                        break
                    else:
                        st += e + ' '
                temp_gen.append(st)
                # correct = " ".join([self.lang.index2word[item] for item in data_dev[2]])
                correct = " ".join(data_dev[7][i])
                user_query = " ".join(data_dev[6][i])
                ### compute F1 SCOR
                st = st.lstrip().rstrip()
                correct = correct.lstrip().rstrip()
                user_query = user_query.strip()
                '''
                if args['dataset'] == 'kvr':
                    f1_true, count = self.compute_prf(data_dev[8][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE += f1_true`
                    microF1_PRED += count
                    f1_true, count = self.compute_prf(data_dev[9][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += count
                    f1_true, count = self.compute_prf(data_dev[10][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += count
                    f1_true, count = self.compute_prf(data_dev[11][i], st.split(), global_entity_list, data_dev[14][i])
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += count
                elif args['dataset'] == 'babi' and int(args["task"]) == 6:
                    f1_true, count = self.compute_prf(data_dev[10][i], st.split(), global_entity_list, data_dev[12][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count

                if args['dataset'] == 'babi':
                    if data_dev[11][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[11][i]] = []
                    if (correct == st):
                        acc += 1
                        dialog_acc_dict[data_dev[11][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[11][i]].append(0)
                else:
                
                    if (correct == st):
                        acc += 1
                #    print("Correct:"+str(correct))
                #    print("\tPredict:"+str(st))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))
                '''

                w += wer(correct, st)
                ref.append(str(correct))
                hyp.append(str(st))
                query.append(str(user_query))
                ref_s += str(correct) + "\n"
                hyp_s += str(st) + "\n"
                query_s += str(user_query) + "\n"

            # acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("W:{:.4f}".format(wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" + str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        '''
        if args['dataset'] == 'kvr':
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE / float(microF1_PRED)))
            logging.info("\tCAL F1:\t{}".format(microF1_TRUE_cal / float(microF1_PRED_cal)))
            logging.info("\tWET F1:\t{}".format(microF1_TRUE_wet / float(microF1_PRED_wet)))
            logging.info("\tNAV F1:\t{}".format(microF1_TRUE_nav / float(microF1_PRED_nav)))
        elif args['dataset'] == 'babi' and int(args["task"]) == 6:
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE / float(microF1_PRED)))
        '''

        # save decoding results

        bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        # always return with bleu score.
        BLEU = True
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            self.save_decode_results(query, ref, hyp, bleu_score)

            return bleu_score
        '''