Exemplo n.º 1
0
def prepare_data_seq(task, batch_size=100, shuffle=True):
    file_train = 'data/dialog-bAbI-tasks/dialog-babi-task{}trn.txt'.format(
        task)
    file_dev = 'data/dialog-bAbI-tasks/dialog-babi-task{}dev.txt'.format(task)
    file_test = 'data/dialog-bAbI-tasks/dialog-babi-task{}tst.txt'.format(task)
    if (int(task) != 6):
        file_test_OOV = 'data/dialog-bAbI-tasks/dialog-babi-task{}tst-OOV.txt'.format(
            task)

    if int(task) != 6:
        ent = entityList('data/dialog-bAbI-tasks/dialog-babi-kb-all.txt',
                         int(task))
    else:
        ent = entityList(
            'data/dialog-bAbI-tasks/dialog-babi-task6-dstc2-kb.txt', int(task))

    pair_train, max_len_train, max_r_train = read_langs(file_train,
                                                        ent,
                                                        max_line=None)
    pair_dev, max_len_dev, max_r_dev = read_langs(file_dev, ent, max_line=None)
    pair_test, max_len_test, max_r_test = read_langs(file_test,
                                                     ent,
                                                     max_line=None)

    max_r_test_OOV = 0
    max_len_test_OOV = 0
    if (int(task) != 6):
        pair_test_OOV, max_len_test_OOV, max_r_test_OOV = read_langs(
            file_test_OOV, ent, max_line=None)

    max_len = max(max_len_train, max_len_dev, max_len_test,
                  max_len_test_OOV) + 1
    max_r = max(max_r_train, max_r_dev, max_r_test, max_r_test_OOV) + 1
    lang = Lang()

    train = get_seq(pair_train, lang, batch_size, True, max_len)
    dev = get_seq(pair_dev, lang, batch_size, False, max_len)
    test = get_seq(pair_test, lang, batch_size, False, max_len)
    if (int(task) != 6):
        testOOV = get_seq(pair_test_OOV, lang, batch_size, False, max_len)
    else:
        testOOV = []

    logging.info("Read %s sentence pairs train" % len(pair_train))
    logging.info("Read %s sentence pairs dev" % len(pair_dev))
    logging.info("Read %s sentence pairs test" % len(pair_test))
    if (int(task) != 6):
        logging.info("Read %s sentence pairs test" % len(pair_test_OOV))
    logging.info("Max len Input %s " % max_len)
    logging.info("Vocab_size %s " % lang.n_words)
    logging.info("USE_CUDA={}".format(USE_CUDA))

    return train, dev, test, testOOV, lang, max_len, max_r
Exemplo n.º 2
0
    return lang, max_len, max_r

if __name__=='__main__':
    '''-lr=0.001 -layer=1 -hdd=12 -dr=0.0 -dec=Mem2Seq -bsz=2 -ds=babi -t=1 '''
    # args['task'] = 5
    # args['hidden'] = 12
    # args['learn'] = 0.001
    # args['layer'] = 1
    # args['drop'] = 0.0
    # args['decoder'] = 'Men2Seq'
    # args['batch'] = 2
    # args['dataset'] = 'babi'
    task = 5
    file_train = '../data/dialog-bAbI-tasks/dialog-babi-task{}trn.txt'.format(task)
    ent = entityList('../data/dialog-bAbI-tasks/dialog-babi-kb-all.txt', int(task))

    pair_train,max_len_train, max_r_train = read_langs(file_train, ent, max_line=None)

    max_len = max_len_train + 1
    batch_size = 100
    lang = Lang()

    train = get_seq(pair_train, lang, batch_size, True, max_len)

    epoch = 1
    logging.info("Epoch:{}".format(epoch))
    # Run the train function
    import tqdm
    pbar = tqdm(enumerate(train), total=len(train))
    for i, data in pbar:
Exemplo n.º 3
0
    def evaluate(self, dev, avg_best, BLEU=False):
        logging.info("STARTING EVALUATION")
        acc_avg = 0.0
        wer_avg = 0.0
        bleu_avg = 0.0
        acc_P = 0.0
        acc_V = 0.0
        microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = 0, 0, 0, 0
        microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = 0, 0, 0, 0
        ref = []
        hyp = []
        ref_s = ""
        hyp_s = ""
        dialog_acc_dict = {}

        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [
                            item.lower().replace(' ', '_')
                            for item in global_entity[key]
                        ]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [
                                item[k].lower().replace(' ', '_')
                                for k in item.keys()
                            ]
                global_entity_list = list(set(global_entity_list))
        else:
            if int(args["task"]) != 6:
                global_entity_list = entityList(
                    'data/dialog-bAbI-tasks/dialog-babi-kb-all.txt',
                    int(args["task"]))
            else:
                global_entity_list = entityList(
                    'data/dialog-bAbI-tasks/dialog-babi-task6-dstc2-kb.txt',
                    int(args["task"]))

        pbar = tqdm(enumerate(dev), total=len(dev))
        for j, data_dev in pbar:
            if args['dataset'] == 'kvr':
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6])
            else:
                words = self.evaluate_batch(len(data_dev[1]), data_dev[0],
                                            data_dev[1], data_dev[2],
                                            data_dev[3], data_dev[4],
                                            data_dev[5], data_dev[6])

            acc = 0
            w = 0
            temp_gen = []

            for i, row in enumerate(np.transpose(words)):
                st = ''
                for e in row:
                    if e == '<EOS>': break
                    else: st += e + ' '
                temp_gen.append(st)
                correct = data_dev[7][i]
                ### compute F1 SCORE
                st = st.lstrip().rstrip()
                correct = correct.lstrip().rstrip()
                if args['dataset'] == 'kvr':
                    f1_true, count = self.compute_prf(data_dev[8][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count
                    f1_true, count = self.compute_prf(data_dev[9][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_cal += f1_true
                    microF1_PRED_cal += count
                    f1_true, count = self.compute_prf(data_dev[10][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_nav += f1_true
                    microF1_PRED_nav += count
                    f1_true, count = self.compute_prf(data_dev[11][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[14][i])
                    microF1_TRUE_wet += f1_true
                    microF1_PRED_wet += count
                elif args['dataset'] == 'babi' and int(args["task"]) == 6:
                    f1_true, count = self.compute_prf(data_dev[10][i],
                                                      st.split(),
                                                      global_entity_list,
                                                      data_dev[12][i])
                    microF1_TRUE += f1_true
                    microF1_PRED += count

                if args['dataset'] == 'babi':
                    if data_dev[11][i] not in dialog_acc_dict.keys():
                        dialog_acc_dict[data_dev[11][i]] = []
                    if (correct == st):
                        acc += 1
                        dialog_acc_dict[data_dev[11][i]].append(1)
                    else:
                        dialog_acc_dict[data_dev[11][i]].append(0)
                else:
                    if (correct == st):
                        acc += 1
                #    print("Correct:"+str(correct))
                #    print("\tPredict:"+str(st))
                #    print("\tFrom:"+str(self.from_whichs[:,i]))

                w += wer(correct, st)
                ref.append(str(correct))
                hyp.append(str(st))
                ref_s += str(correct) + "\n"
                hyp_s += str(st) + "\n"

            acc_avg += acc / float(len(data_dev[1]))
            wer_avg += w / float(len(data_dev[1]))
            pbar.set_description("R:{:.4f},W:{:.4f}".format(
                acc_avg / float(len(dev)), wer_avg / float(len(dev))))

        # dialog accuracy
        if args['dataset'] == 'babi':
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]):
                    dia_acc += 1
            logging.info("Dialog Accuracy:\t" +
                         str(dia_acc * 1.0 / len(dialog_acc_dict.keys())))

        if args['dataset'] == 'kvr':
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE /
                                                float(microF1_PRED)))
            logging.info("\tCAL F1:\t{}".format(microF1_TRUE_cal /
                                                float(microF1_PRED_cal)))
            logging.info("\tWET F1:\t{}".format(microF1_TRUE_wet /
                                                float(microF1_PRED_wet)))
            logging.info("\tNAV F1:\t{}".format(microF1_TRUE_nav /
                                                float(microF1_PRED_nav)))
        elif args['dataset'] == 'babi' and int(args["task"]) == 6:
            logging.info("F1 SCORE:\t{}".format(microF1_TRUE /
                                                float(microF1_PRED)))

        bleu_score = moses_multi_bleu(np.array(hyp),
                                      np.array(ref),
                                      lowercase=True)
        logging.info("BLEU SCORE:" + str(bleu_score))
        if (BLEU):
            if (bleu_score >= avg_best):
                self.save_model(str(self.name) + str(bleu_score))
                logging.info("MODEL SAVED")
            return bleu_score
        else:
            acc_avg = acc_avg / float(len(dev))
            if (acc_avg >= avg_best):
                self.save_model(str(self.name) + str(acc_avg))
                logging.info("MODEL SAVED")
            return acc_avg