Esempio n. 1
0
def main():
    SOS = 0
    EOS = 1
    PAD = 2
    UNK = 3
    epoch = 20
    batches_in_epoch = 100
    batch_size = 64
    num_units = 200
    char_dim = 50  ## the size of char dimention
    char_hidden_units = 100  ##the length of the char_lstm (word embedding:2 * char_hidden)
    epoch_print = 2
    Bidirection = False
    Attention = False
    Embd_train = False
    model_type = 'testing'

    Dataset_name = "YAHOO"
    # Dataset_name1 = "wrongorder1"
    # Dataset_name1 = "back"

    Dataset_name1 = "back"
    Dataset_name1 = "threeopt"

    print "this is the " + Dataset_name1 + " char"

    logs_path = "/mnt/WDRed4T/ye/Qrefine/ckpt/" + Dataset_name + "/seq2seq_BIA_board"
    # FileName = "/mnt/WDRed4T/ye/DataR/" + Dataset_name + "/wrongword_Id_1"
    FileName = '/mnt/WDRed4T/ye/DataR/' + Dataset_name + '/' + Dataset_name1 + "_final"
    input = open(FileName, 'rb')
    data_com = pickle.load(input)

    train_noisy_Id = data_com["train_noisy_Id"]
    test_noisy_Id = data_com["test_noisy_Id"]
    eval_noisy_Id = data_com["eval_noisy_Id"]

    train_noisy_len = data_com["train_noisy_len"]
    test_noisy_len = data_com["test_noisy_len"]
    eval_noisy_len = data_com["eval_noisy_len"]

    train_noisy_char_Id = data_com["train_noisy_char_Id"]
    test_noisy_char_Id = data_com["test_noisy_char_Id"]
    eval_noisy_char_Id = data_com["eval_noisy_char_Id"]

    train_noisy_char_len = data_com["train_noisy_char_len"]
    test_noisy_char_len = data_com["test_noisy_char_len"]
    eval_noisy_char_len = data_com["eval_noisy_char_len"]

    train_target_Id = data_com["train_ground_truth"]
    test_target_Id = data_com["test_ground_truth"]
    eval_target_Id = data_com["eval_ground_truth"]

    train_input_Id = data_com["train_input_Id"]
    test_input_Id = data_com["test_input_Id"]
    eval_input_Id = data_com["eval_input_Id"]

    train_clean_len = data_com["train_clean_len"]
    test_clean_len = data_com["test_clean_len"]
    eval_clean_len = data_com["eval_clean_len"]

    max_char = data_com['max_char']
    char_num = 44
    max_word = data_com['max_word']

    FileName = '/mnt/WDRed4T/ye/DataR/' + Dataset_name + '/' + Dataset_name1 + '_vocab_embd'
    input = open(FileName, 'rb')
    vocab_embd = pickle.load(input)
    vocab = vocab_embd['vocab']
    embd = vocab_embd['embd']
    embd = np.asarray(embd)
    vocdic = zip(vocab, range(len(vocab)))
    index_voc = dict((index, vocab) for vocab, index in vocdic)
    voc_index = dict(
        (vocab, index) for vocab, index in vocdic)  ##dic[char]= index

    SOS = voc_index[u"<SOS>"]
    EOS = voc_index[u"<EOS>"]
    PAD = voc_index[u"<PAD>"]

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    print "the number of training question is:", len(train_noisy_Id), len(
        train_noisy_char_Id), len(train_noisy_char_len)

    print "the number of eval question is:", len(eval_noisy_Id), len(
        eval_noisy_char_Id), len(eval_noisy_char_len)

    vocab_size = len(vocab)

    Bleu_obj = Bleu()
    Rouge_obj = Rouge()
    Meteor_obj = Meteor()
    cIDER_obj = Cider()

    if model_type == 'training':
        Seq2Seq_model = Char_Seq2Seq(batch_size, vocab_size, num_units, embd,
                                     model_type, Bidirection, Embd_train,
                                     char_hidden_units, Attention, char_num,
                                     char_dim)

        # print "the number of evaluate question is:", len(eval_noisy_Id)

        patience_cnt = 0

        summary_writer = tf.summary.FileWriter(logs_path,
                                               graph=tf.get_default_graph())
        merge = tf.summary.merge_all()
        saver = tf.train.Saver(sharded=False)

        with tf.Session(config=config) as sess:
            # sess.run(tf.global_variables_initializer())
            saver.restore(
                sess, "/mnt/WDRed4T/ye/Qrefine/ckpt/CharS2S/" + Dataset_name +
                "/" + Dataset_name1 + "_" + str(Bidirection) + "_Att_" +
                str(Attention) + "_Emb_" + str(Embd_train) + "_hiddenunits_" +
                str(num_units) + "_bleu_0.491625950768")

            val_loss_epo = []
            # print [v for v in tf.trainable_variables()]
            for epo in range(epoch):
                idx = np.arange(0, len(train_noisy_Id))
                idx = list(np.random.permutation(idx))

                print " Epoch {}".format(epo)

                Bleu_score1 = []
                Bleu_score2 = []
                Bleu_score3 = []
                Bleu_score4 = []
                Rouge_score = []
                Meteor_score = []
                Cider_score = []

                for batch in range(len(train_noisy_Id) / batch_size):
                    source_shuffle, source_len, char_Id, char_len, train_shuffle, target_shuffle, target_len = next_batch(
                        train_noisy_Id, train_noisy_len, train_noisy_char_Id,
                        train_noisy_char_len, train_input_Id, train_target_Id,
                        train_clean_len, batch_size, batch, idx)
                    source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        source_shuffle,
                        maxlen=max_word,
                        padding='post',
                        value=EOS)
                    train_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        train_shuffle,
                        maxlen=max_word + 1,
                        padding='post',
                        value=EOS)

                    target_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        target_shuffle,
                        maxlen=max_word + 1,
                        padding='post',
                        value=EOS)

                    whole_noisy_char_Id = []
                    for sen_char in char_Id:
                        sen_len = len(sen_char)
                        for i in range(len(source_shuffle[0]) - sen_len):
                            sen_char.append(
                                [0] * max_char
                            )  ## fix the char with the length of max_word
                        whole_noisy_char_Id.append(sen_char)

                    whole_noisy_char_Id = np.asarray(whole_noisy_char_Id)
                    # print whole_noisy_char_Id.shape
                    #
                    # print len(char_len[0])

                    # target_len = np.tile(max(target_len) + 1, batch_size)

                    fd = {
                        Seq2Seq_model.encoder_inputs: source_shuffle,
                        Seq2Seq_model.encoder_inputs_length: source_len,
                        Seq2Seq_model.encoder_char_ids: whole_noisy_char_Id,
                        Seq2Seq_model.encoder_char_len: char_len,
                        Seq2Seq_model.decoder_inputs: train_shuffle,
                        Seq2Seq_model.decoder_length: target_len,
                        Seq2Seq_model.ground_truth: target_shuffle,
                        Seq2Seq_model.dropout_rate: 1
                    }

                    mask, logit_id, logit_id_pad, l_fun, pro_op = sess.run([
                        Seq2Seq_model.mask, Seq2Seq_model.ids,
                        Seq2Seq_model.logits_padded,
                        Seq2Seq_model.loss_seq2seq, Seq2Seq_model.train_op
                    ], fd)

                    # mask, logit_id, logit_id_pad= sess.run(
                    #     [Seq2Seq_model.mask, Seq2Seq_model.ids, Seq2Seq_model.logits_padded], fd)
                    # l_fun =0

                    for t in range(batch_size):
                        ref = []
                        hyp = []
                        for tar_id in target_shuffle[t]:
                            if tar_id != EOS and tar_id != PAD:
                                ref.append(vocab[tar_id])
                                # print vocab[tar_id]
                        for pre_id in logit_id[t]:
                            if pre_id != EOS and pre_id != PAD:
                                hyp.append(vocab[pre_id])

                        hyp_sen = u" ".join(hyp).encode('utf-8')
                        ref_sen = u" ".join(ref).encode('utf-8')
                        dic_hyp = {}
                        dic_hyp[0] = [hyp_sen]
                        dic_ref = {}
                        dic_ref[0] = [ref_sen]
                        sen_bleu, _ = Bleu_obj.compute_score(dic_ref, dic_hyp)
                        sen_rouge = Rouge_obj.compute_score(dic_ref, dic_hyp)
                        sen_meteor, _ = Meteor_obj.compute_score(
                            dic_ref, dic_hyp)
                        sen_cider, _ = cIDER_obj.compute_score(
                            dic_ref, dic_hyp)
                        Bleu_score1.append(sen_bleu[0])
                        Bleu_score2.append(sen_bleu[1])
                        Bleu_score3.append(sen_bleu[2])
                        Bleu_score4.append(sen_bleu[3])
                        Rouge_score.append(sen_rouge[0])
                        Meteor_score.append(sen_meteor)
                        Cider_score.append(sen_cider)

                    if batch == 0 or batch % batches_in_epoch == 0:
                        ##print the training
                        print('batch {}'.format(batch))
                        print('   minibatch loss: {}'.format(l_fun))

                        for t in xrange(3):
                            print 'Training Question {}'.format(t)
                            print "NQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    source_shuffle[t]))).strip().replace(
                                        "<EOS>", " ")
                            print "CQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    target_shuffle[t]))).strip().replace(
                                        "<EOS>", " ")
                            print "GQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    logit_id[t]))).strip().replace(
                                        "<EOS>", " ")
                        print(
                            "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge,  Cider\n"
                        )
                        print(
                            "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                            .format(
                                sum(Bleu_score1) / float(len(Bleu_score1)),
                                sum(Bleu_score2) / float(len(Bleu_score2)),
                                sum(Bleu_score3) / float(len(Bleu_score3)),
                                sum(Bleu_score4) / float(len(Bleu_score4)),
                                sum(Rouge_score) / float(len(Rouge_score)),
                                sum(Meteor_score) / float(len(Meteor_score)),
                                sum(Cider_score) / float(len(Cider_score))))

                        val_loss = []
                        val_Bleu_score1 = []
                        val_Bleu_score2 = []
                        val_Bleu_score3 = []
                        val_Bleu_score4 = []
                        val_Rouge_score = []
                        val_Meteor_score = []
                        val_Cider_score = []

                        for batch_val in range(
                                len(eval_noisy_Id) / batch_size):
                            idx_v = np.arange(0, len(eval_noisy_Id))
                            idx_v = list(np.random.permutation(idx_v))

                            val_source_shuffle, val_source_len, val_char_Id, val_char_len, val_train_shuffle, val_target_shuffle, val_target_len = next_batch(
                                eval_noisy_Id, eval_noisy_len,
                                eval_noisy_char_Id, eval_noisy_char_len,
                                eval_input_Id, eval_target_Id, eval_clean_len,
                                batch_size, batch_val, idx_v)
                            val_source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                                val_source_shuffle,
                                maxlen=max_word,
                                padding='post',
                                value=EOS)
                            val_target_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                                val_target_shuffle,
                                maxlen=max_word + 1,
                                padding='post',
                                value=EOS)
                            val_train_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                                val_train_shuffle,
                                maxlen=max_word + 1,
                                padding='post',
                                value=EOS)

                            val_batch_noisy_char_Id = []
                            for sen_char in val_char_Id:
                                sen_len = len(sen_char)
                                for i in range(
                                        len(val_source_shuffle[0]) - sen_len):
                                    sen_char.append(
                                        [0] * max_char
                                    )  ## fix the char with the length of max_word
                                val_batch_noisy_char_Id.append(sen_char)

                            val_batch_noisy_char_Id = np.asarray(
                                val_batch_noisy_char_Id)

                            # val_target_len = np.tile(max(val_target_len) + 1, batch_size)

                            fd_val = {
                                Seq2Seq_model.encoder_inputs:
                                val_source_shuffle,
                                Seq2Seq_model.encoder_inputs_length:
                                val_source_len,
                                Seq2Seq_model.encoder_char_ids:
                                val_batch_noisy_char_Id,
                                Seq2Seq_model.encoder_char_len: val_char_len,
                                Seq2Seq_model.decoder_length: val_target_len,
                                Seq2Seq_model.decoder_inputs:
                                val_train_shuffle,
                                Seq2Seq_model.ground_truth: val_target_shuffle,
                                Seq2Seq_model.dropout_rate: 1
                            }
                            val_ids, val_loss_seq = sess.run([
                                Seq2Seq_model.ids, Seq2Seq_model.loss_seq2seq
                            ], fd_val)
                            val_loss.append(val_loss_seq)

                            for t in range(batch_size):
                                ref = []
                                hyp = []
                                for tar_id in val_target_shuffle[t]:
                                    if tar_id != EOS and tar_id != PAD:
                                        ref.append(vocab[tar_id])
                                for pre_id in val_ids[t]:
                                    if pre_id != EOS and pre_id != PAD:
                                        hyp.append(vocab[pre_id])
                                        # print vocab[pre_id]
                                # sen_bleu1 = bleu([ref], hyp, weights=(1, 0, 0, 0))
                                # sen_bleu2 = bleu([ref], hyp, weights=(0.5, 0.5, 0, 0))
                                # val_Bleu_score1.append(sen_bleu1)
                                # val_Bleu_score2.append(sen_bleu2)
                                hyp_sen = u" ".join(hyp).encode('utf-8')
                                ref_sen = u" ".join(ref).encode('utf-8')
                                dic_hyp = {}
                                dic_hyp[0] = [hyp_sen]
                                dic_ref = {}
                                dic_ref[0] = [ref_sen]
                                sen_bleu, _ = Bleu_obj.compute_score(
                                    dic_ref, dic_hyp)
                                sen_rouge = Rouge_obj.compute_score(
                                    dic_ref, dic_hyp)
                                sen_meteor, _ = Meteor_obj.compute_score(
                                    dic_ref, dic_hyp)
                                sen_cider, _ = cIDER_obj.compute_score(
                                    dic_ref, dic_hyp)
                                val_Bleu_score1.append(sen_bleu[0])
                                val_Bleu_score2.append(sen_bleu[1])
                                val_Bleu_score3.append(sen_bleu[2])
                                val_Bleu_score4.append(sen_bleu[3])
                                val_Rouge_score.append(sen_rouge[0])
                                val_Meteor_score.append(sen_meteor)
                                val_Cider_score.append(sen_cider)

                        for t in xrange(3):
                            print 'Validation Question {}'.format(t)
                            print "NQ: " + " ".join(
                                map(lambda i: vocab[i],
                                    list(val_source_shuffle[t]))).strip(
                                    ).replace("<EOS>", " ")
                            print "CQ: " + " ".join(
                                map(lambda i: vocab[i],
                                    list(val_target_shuffle[t]))).strip(
                                    ).replace("<EOS>", " ")
                            print "GQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    val_ids[t]))).strip().replace(
                                        "<EOS>", " ")

                        avg_val_loss = sum(val_loss) / float(len(val_loss))

                        print('   minibatch loss of validation: {}'.format(
                            avg_val_loss))

                        print(
                            "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor Cider\n"
                        )
                        print(
                            "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                            .format(
                                sum(val_Bleu_score1) /
                                float(len(val_Bleu_score1)),
                                sum(val_Bleu_score2) /
                                float(len(val_Bleu_score2)),
                                sum(val_Bleu_score3) /
                                float(len(val_Bleu_score3)),
                                sum(val_Bleu_score4) /
                                float(len(val_Bleu_score4)),
                                sum(val_Rouge_score) /
                                float(len(val_Rouge_score)),
                                sum(val_Meteor_score) /
                                float(len(val_Meteor_score)),
                                sum(val_Cider_score) /
                                float(len(val_Cider_score))))

                val_loss_epo.append(avg_val_loss)

                if epo > 0:
                    for loss in val_loss_epo:
                        print "the val_loss_epo:", loss
                    print "the loss difference:", val_loss_epo[
                        -2] - val_loss_epo[-1]

                    if val_loss_epo[-2] - val_loss_epo[-1]:
                        patience_cnt = 0
                    else:
                        patience_cnt += 1

                    print patience_cnt

                    if patience_cnt > 5:
                        print("early stopping...")
                        saver.save(
                            sess, "/mnt/WDRed4T/ye/Qrefine/ckpt/CharS2S/" +
                            Dataset_name + "/" + Dataset_name1 + "_" +
                            str(Bidirection) + "_Att_" + str(Attention) +
                            "_Emb_" + str(Embd_train) + "_hiddenunits_" +
                            str(num_units))

                        break

                # if epo % epoch_print == 0:
                #     for t in range(10):
                #         print 'Question {}'.format(t)
                #         print " ".join(map(lambda i: vocab[i], list(target_shuffle[t]))).strip()
                #         print " ".join(map(lambda i: vocab[i], list(logit_id[t, :]))).strip()
                avg_bleu_val = sum(val_Bleu_score1) / float(
                    len(val_Bleu_score1))
                saver.save(
                    sess, "/mnt/WDRed4T/ye/Qrefine/ckpt/CharS2S/" +
                    Dataset_name + "/" + Dataset_name1 + "_" +
                    str(Bidirection) + "_Att_" + str(Attention) + "_Emb_" +
                    str(Embd_train) + "_hiddenunits_" + str(num_units) +
                    "_bleu_" + str(avg_bleu_val))

    elif model_type == 'testing':
        Seq2Seq_model = Char_Seq2Seq(batch_size, vocab_size, num_units, embd,
                                     model_type, Bidirection, Embd_train,
                                     char_hidden_units, Attention, char_num,
                                     char_dim)

        with tf.Session(config=config) as sess:
            saver_word_rw = tf.train.Saver()
            saver_word_rw.restore(
                sess, "/mnt/WDRed4T/ye/Qrefine/ckpt/CharS2S/" + Dataset_name +
                "/" + Dataset_name1 + str(Bidirection) + "_Att_" +
                str(Attention) + "_Emb_" + str(Embd_train) + "_hiddenunits_" +
                str(num_units) + "_bleu_0.286359797692")

            max_batches = len(test_noisy_Id) / batch_size

            idx = np.arange(0, len(test_noisy_Id))
            # idx = list(np.random.permutation(idx))
            test_Bleu_score1 = []
            test_Bleu_score2 = []
            test_Bleu_score3 = []
            test_Bleu_score4 = []
            test_Rouge_score = []
            test_Meteor_score = []
            test_Cider_score = []

            generated_test_sen = []
            test_answer_sen = []
            test_noisy_sen = []
            test_clean_sen = []

            for batch in range(max_batches):

                source_shuffle, source_len, test_char_Id, char_len, train_shuffle, target_shuffle, target_len = next_batch(
                    test_noisy_Id, test_noisy_len, test_noisy_char_Id,
                    test_noisy_char_len, test_input_Id, test_target_Id,
                    test_clean_len, batch_size, batch, idx)

                for no in source_shuffle:
                    test_noisy_sen.append(vocab[si] for si in no)
                for cl in target_shuffle:
                    test_clean_sen.append(vocab[ci] for ci in cl)

                source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                    source_shuffle, maxlen=max_word, padding='post', value=EOS)

                test_batch_noisy_char_Id = []
                for sen_char in test_char_Id:
                    sen_len = len(sen_char)
                    for i in range(len(source_shuffle[0]) - sen_len):
                        sen_char.append(
                            [0] * max_char
                        )  ## fix the char with the length of max_word
                    test_batch_noisy_char_Id.append(sen_char)
                test_batch_noisy_char_Id = np.asarray(test_batch_noisy_char_Id)

                fd = {
                    Seq2Seq_model.encoder_inputs: source_shuffle,
                    Seq2Seq_model.encoder_inputs_length: source_len,
                    Seq2Seq_model.encoder_char_ids: test_batch_noisy_char_Id,
                    Seq2Seq_model.encoder_char_len: char_len,
                    Seq2Seq_model.dropout_rate: 1
                }
                ids = sess.run([Seq2Seq_model.ids], fd)
                for t in range(batch_size):
                    ref = []
                    hyp = []
                    for tar_id in target_shuffle[t]:
                        if tar_id != 2:
                            ref.append(vocab[tar_id])
                    for pre_id in ids[0][t]:
                        if pre_id != 2 and pre_id != 0:
                            hyp.append(vocab[pre_id])
                    generated_test_sen.append(hyp)
                    # sen_bleu1 = bleu([ref], hyp, weights=(1, 0, 0, 0))
                    # sen_bleu2 = bleu([ref], hyp, weights=(0.5, 0.5, 0, 0))
                    # sen_bleu3 = bleu([ref], hyp, weights=(0.333, 0.333, 0.333, 0))
                    # sen_bleu4 = bleu([ref], hyp, weights=(0.25, 0.25, 0.25, 0.25))
                    # Bleu_score1.append(sen_bleu1)
                    # Bleu_score2.append(sen_bleu2)
                    # Bleu_score3.append(sen_bleu3)
                    # Bleu_score4.append(sen_bleu4)

                    hyp_sen = u" ".join(hyp).encode('utf-8')
                    ref_sen = u" ".join(ref).encode('utf-8')
                    dic_hyp = {}
                    dic_hyp[0] = [hyp_sen]
                    dic_ref = {}
                    dic_ref[0] = [ref_sen]
                    sen_bleu, _ = Bleu_obj.compute_score(dic_ref, dic_hyp)
                    sen_rouge = Rouge_obj.compute_score(dic_ref, dic_hyp)
                    sen_meteor, _ = Meteor_obj.compute_score(dic_ref, dic_hyp)
                    sen_cider, _ = cIDER_obj.compute_score(dic_ref, dic_hyp)
                    test_Bleu_score1.append(sen_bleu[0])
                    test_Bleu_score2.append(sen_bleu[1])
                    test_Bleu_score3.append(sen_bleu[2])
                    test_Bleu_score4.append(sen_bleu[3])
                    test_Rouge_score.append(sen_rouge[0])
                    test_Meteor_score.append(sen_meteor)
                    test_Cider_score.append(sen_cider)

                for t in xrange(5):
                    print 'Training Question {}'.format(t)
                    print "NQ: " + " ".join(
                        map(lambda i: vocab[i], list(
                            source_shuffle[t]))).strip().replace(
                                "<EOS>", "  ")
                    print "CQ: " + " ".join(
                        map(lambda i: vocab[i], list(
                            target_shuffle[t]))).strip().replace(
                                "<EOS>", "  ")
                    print "GQ: " + " ".join(
                        map(lambda i: vocab[i], list(
                            ids[0][t]))).strip().replace("<EOS>", "  ")

                bleu_score1 = sum(test_Bleu_score1) / float(
                    len(test_Bleu_score1))
                bleu_score2 = sum(test_Bleu_score2) / float(
                    len(test_Bleu_score2))
                bleu_score3 = sum(test_Bleu_score3) / float(
                    len(test_Bleu_score3))
                bleu_score4 = sum(test_Bleu_score4) / float(
                    len(test_Bleu_score4))

                print(
                    "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor  Cider\n"
                )
                print(
                    "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                    .format(
                        bleu_score1, bleu_score2, bleu_score3, bleu_score4,
                        sum(test_Rouge_score) / float(len(test_Rouge_score)),
                        sum(test_Meteor_score) / float(len(test_Meteor_score)),
                        sum(test_Cider_score) / float(len(test_Cider_score))))

            #  len(test_answer_sen)
            fname = "/mnt/WDRed4T/ye/Qrefine/ckpt/CharS2S/" + Dataset_name + "/" + Dataset_name1 + "_bleu1_" + str(
                bleu_score1)
            f = open(fname, "wb")
            print "the length test set is:", len(generated_test_sen), len(
                test_noisy_sen), len(test_clean_sen)
            f.write(
                "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor  Cider\n"
            )
            f.write(
                "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f} \n"
                .format(bleu_score1, bleu_score2, bleu_score3, bleu_score4,
                        sum(test_Rouge_score) / float(len(test_Rouge_score)),
                        sum(test_Meteor_score) / float(len(test_Meteor_score)),
                        sum(test_Cider_score) / float(len(test_Cider_score))))
            for i in range(len(generated_test_sen)):
                f.write("question" + str(i) + "\n")
                # f.write("answer: " + " ".join(test_answer_sen[i]) + "\n")
                f.write("noisy question: " + " ".join(test_noisy_sen[i]) +
                        "\n")
                f.write("clean question: " + " ".join(test_clean_sen[i]) +
                        "\n")
                f.write("generated question: " +
                        " ".join(generated_test_sen[i]) + "\n")

            f.close()
Esempio n. 2
0
def seq_run_BA(model_type, config):

    SOS = 0
    EOS = 1
    PAD = 2
    UNK = 3
    data_file = config.data_dir
    emd_file = [config.nemd_dir, config.cemd_dir]
    train_data, test_data, eval_data, vocab, embd = load_data(
        data_file, emd_file, "BA")
    batch_size = config.batch_size
    num_units = config.num_units
    Bidirection = config.Bidirection
    Attention = config.Attention
    Embd_train = config.Embd_train
    char_hidden_units = config.char_hidden_units
    char_num = config.char_num
    char_dim = config.char_dim
    batches_in_epoch = config.batches_in_epoch
    epoch = config.epoch
    S2S_ckp_dir = config.S2S_ckp_dir

    train_noisy_Id, train_noisy_len, train_noisy_char_Id, train_noisy_char_len, train_nemd, train_target_Id, train_input_Id, train_clean_Id, train_clean_len, train_answer_Id, train_answer_len, max_char, max_word = train_data
    test_noisy_Id, test_noisy_len, test_noisy_char_Id, test_noisy_char_len, test_nemd, test_target_Id, test_input_Id, test_clean_Id, test_clean_len, test_answer_Id, test_answer_len = test_data
    eval_noisy_Id, eval_noisy_len, eval_noisy_char_Id, eval_noisy_char_len, eval_nemd, eval_target_Id, eval_input_Id, eval_clean_Id, eval_clean_len, eval_answer_Id, eval_answer_len = eval_data

    # config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True

    print "the number of training question is:", len(train_noisy_Id), len(
        train_noisy_char_Id), len(train_noisy_char_len)

    print "the number of eval question is:", len(eval_noisy_Id), len(
        eval_noisy_char_Id), len(eval_noisy_char_len)

    vocab_size = len(vocab)

    Bleu_obj = Bleu()
    Rouge_obj = Rouge()
    Meteor_obj = Meteor()
    cIDER_obj = Cider()

    if model_type == 'training':
        Seq2Seq_model = Seq2Seq(batch_size, vocab_size, num_units, embd,
                                model_type, Bidirection, Embd_train,
                                char_hidden_units, Attention, char_num,
                                char_dim)

        # print "the number of evaluate question is:", len(eval_noisy_Id)

        patience_cnt = 0

        # summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
        # merge = tf.summary.merge_all()
        saver = tf.train.Saver(sharded=False)

        f = open(config.perf_path, "w")

        # config_gpu = tf.ConfigProto()
        # config_gpu.gpu_options.allow_growth = True
        # config=config_gpu
        with tf.Session() as sess:
            if os.path.isfile(S2S_ckp_dir):
                saver.restore(sess, S2S_ckp_dir)
            else:
                sess.run(tf.global_variables_initializer())
            # sess.run(tf.global_variables_initializer())
            # saver.restore(sess, config.S2S_ckp_dir)

            val_loss_epo = []
            # print [v for v in tf.trainable_variables()]
            for epo in range(epoch):
                idx = np.arange(0, len(train_noisy_Id))
                idx = list(np.random.permutation(idx))

                print " Epoch {}".format(epo)

                Bleu_score1 = []
                Bleu_score2 = []
                Bleu_score3 = []
                Bleu_score4 = []
                Rouge_score = []
                Meteor_score = []
                Cider_score = []

                for batch in range(len(train_noisy_Id) / batch_size):
                    source_shuffle, source_len, source_nemd, train_shuffle, target_shuffle, target_len = next_batch(
                        train_noisy_Id, train_noisy_len, train_nemd,
                        train_input_Id, train_target_Id, train_clean_len,
                        batch_size, batch, idx)

                    source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        source_shuffle,
                        maxlen=max_word,
                        padding='post',
                        value=EOS)
                    source_emd = []
                    for n in source_nemd:
                        source_emd.append(n[0:source_shuffle.shape[1]])
                    source_emd = np.asarray(source_emd)

                    train_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        train_shuffle,
                        maxlen=max_word + 1,
                        padding='post',
                        value=EOS)

                    # target_emd = []
                    # for c in train_cemd:
                    #     target_emd.append(c[0:train_shuffle.shape[1]])
                    # target_emd = np.asarray(target_emd)

                    target_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        target_shuffle,
                        maxlen=max_word + 1,
                        padding='post',
                        value=EOS)

                    # whole_noisy_char_Id =[]
                    # for sen_char in char_Id:
                    #     sen_len = len(sen_char)
                    #     for i in range(len(source_shuffle[0]) - sen_len):
                    #         sen_char.append([0] * max_char)  ## fix the char with the length of max_word
                    #     whole_noisy_char_Id.append(sen_char)
                    #
                    # whole_noisy_char_Id = np.asarray(whole_noisy_char_Id)
                    # print whole_noisy_char_Id.shape
                    #
                    # print len(char_len[0])

                    # target_len = np.tile(max(target_len) + 1, batch_size)

                    fd = {
                        Seq2Seq_model.encoder_inputs:
                        source_shuffle,
                        Seq2Seq_model.encoder_inputs_length:
                        source_len,
                        Seq2Seq_model.encoder_emb:
                        source_emd,
                        Seq2Seq_model.decoder_inputs:
                        train_shuffle,
                        Seq2Seq_model.decoder_length:
                        target_len,
                        # Seq2Seq_model.target_emb: target_emd,
                        Seq2Seq_model.ground_truth:
                        target_shuffle,
                        Seq2Seq_model.dropout_rate:
                        1
                    }

                    mask, logit_id, logit_id_pad, l_fun, pro_op = sess.run([
                        Seq2Seq_model.mask, Seq2Seq_model.ids,
                        Seq2Seq_model.logits_padded,
                        Seq2Seq_model.loss_seq2seq, Seq2Seq_model.train_op
                    ], fd)

                    # mask, logit_id, logit_id_pad= sess.run(
                    #     [Seq2Seq_model.mask, Seq2Seq_model.ids, Seq2Seq_model.logits_padded], fd)
                    # l_fun =0

                    for t in range(batch_size):
                        ref = []
                        hyp = []
                        for tar_id in target_shuffle[t]:
                            if tar_id != EOS and tar_id != PAD:
                                ref.append(vocab[tar_id])
                                # print vocab[tar_id]
                        for pre_id in logit_id[t]:
                            if pre_id != EOS and pre_id != PAD:
                                hyp.append(vocab[pre_id])

                        hyp_sen = u" ".join(hyp).encode('utf-8')
                        ref_sen = u" ".join(ref).encode('utf-8')
                        dic_hyp = {}
                        dic_hyp[0] = [hyp_sen]
                        dic_ref = {}
                        dic_ref[0] = [ref_sen]
                        sen_bleu, _ = Bleu_obj.compute_score(dic_ref, dic_hyp)
                        sen_rouge = Rouge_obj.compute_score(dic_ref, dic_hyp)
                        sen_meteor, _ = Meteor_obj.compute_score(
                            dic_ref, dic_hyp)
                        sen_cider, _ = cIDER_obj.compute_score(
                            dic_ref, dic_hyp)
                        Bleu_score1.append(sen_bleu[0])
                        Bleu_score2.append(sen_bleu[1])
                        Bleu_score3.append(sen_bleu[2])
                        Bleu_score4.append(sen_bleu[3])
                        Rouge_score.append(sen_rouge[0])
                        Meteor_score.append(sen_meteor)
                        Cider_score.append(sen_cider)

                    if batch == 0 or batch % batches_in_epoch == 0:
                        ##print the training
                        print('batch {}'.format(batch))
                        print('   minibatch loss: {}'.format(l_fun))
                        print(" the loss_epo:", epo)

                        for t in xrange(3):
                            print 'Training Question {}'.format(t)
                            print "NQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    source_shuffle[t]))).strip().replace(
                                        "<EOS>", " ")
                            print "CQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    target_shuffle[t]))).strip().replace(
                                        "<EOS>", " ")
                            print "GQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    logit_id[t]))).strip().replace(
                                        "<EOS>", " ")
                        print(
                            "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge,  Cider\n"
                        )
                        print(
                            "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                            .format(
                                sum(Bleu_score1) / float(len(Bleu_score1)),
                                sum(Bleu_score2) / float(len(Bleu_score2)),
                                sum(Bleu_score3) / float(len(Bleu_score3)),
                                sum(Bleu_score4) / float(len(Bleu_score4)),
                                sum(Rouge_score) / float(len(Rouge_score)),
                                sum(Meteor_score) / float(len(Meteor_score)),
                                sum(Cider_score) / float(len(Cider_score))))

                        val_loss = []
                        val_Bleu_score1 = []
                        val_Bleu_score2 = []
                        val_Bleu_score3 = []
                        val_Bleu_score4 = []
                        val_Rouge_score = []
                        val_Meteor_score = []
                        val_Cider_score = []

                        for batch_val in range(
                                len(eval_noisy_Id) / batch_size):
                            idx_v = np.arange(0, len(eval_noisy_Id))
                            idx_v = list(np.random.permutation(idx_v))

                            val_source_shuffle, val_source_len, val_nemd, val_train_shuffle, val_target_shuffle, val_target_len = next_batch(
                                eval_noisy_Id, eval_noisy_len, eval_nemd,
                                eval_input_Id, eval_target_Id, eval_clean_len,
                                batch_size, batch_val, idx_v)
                            val_source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                                val_source_shuffle,
                                maxlen=max_word,
                                padding='post',
                                value=EOS)
                            val_target_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                                val_target_shuffle,
                                maxlen=max_word + 1,
                                padding='post',
                                value=EOS)
                            val_train_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                                val_train_shuffle,
                                maxlen=max_word + 1,
                                padding='post',
                                value=EOS)
                            val_source_emd = []
                            for n in val_nemd:
                                val_source_emd.append(
                                    n[0:val_source_shuffle.shape[1]])
                            val_source_emd = np.asarray(val_source_emd)

                            # val_target_emd = []
                            # for c in val_train_cemd:
                            #     val_target_emd.append(c[0:val_train_shuffle.shape[1]])
                            # val_target_emd = np.asarray(val_target_emd)

                            # val_batch_noisy_char_Id = []
                            # for sen_char in val_char_Id:
                            #     sen_len = len(sen_char)
                            #     for i in range(len(val_source_shuffle[0]) - sen_len):
                            #         sen_char.append([0] * max_char) ## fix the char with the length of max_word
                            #     val_batch_noisy_char_Id.append(sen_char)
                            #
                            # val_batch_noisy_char_Id = np.asarray(val_batch_noisy_char_Id)

                            # val_target_len = np.tile(max(val_target_len) + 1, batch_size)

                            fd_val = {
                                Seq2Seq_model.encoder_inputs:
                                val_source_shuffle,
                                Seq2Seq_model.encoder_inputs_length:
                                val_source_len,
                                Seq2Seq_model.encoder_emb:
                                val_source_emd,
                                # Seq2Seq_model.target_emb: val_target_emd,
                                Seq2Seq_model.decoder_length:
                                val_target_len,
                                Seq2Seq_model.decoder_inputs:
                                val_train_shuffle,
                                Seq2Seq_model.ground_truth:
                                val_target_shuffle,
                                Seq2Seq_model.dropout_rate:
                                1
                            }
                            val_ids, val_loss_seq = sess.run([
                                Seq2Seq_model.ids, Seq2Seq_model.loss_seq2seq
                            ], fd_val)
                            val_loss.append(val_loss_seq)

                            for t in range(batch_size):
                                ref = []
                                hyp = []
                                for tar_id in val_target_shuffle[t]:
                                    if tar_id != EOS and tar_id != PAD:
                                        ref.append(vocab[tar_id])
                                for pre_id in val_ids[t]:
                                    if pre_id != EOS and pre_id != PAD:
                                        hyp.append(vocab[pre_id])
                                        # print vocab[pre_id]
                                # sen_bleu1 = bleu([ref], hyp, weights=(1, 0, 0, 0))
                                # sen_bleu2 = bleu([ref], hyp, weights=(0.5, 0.5, 0, 0))
                                # val_Bleu_score1.append(sen_bleu1)
                                # val_Bleu_score2.append(sen_bleu2)
                                hyp_sen = u" ".join(hyp).encode('utf-8')
                                ref_sen = u" ".join(ref).encode('utf-8')
                                dic_hyp = {}
                                dic_hyp[0] = [hyp_sen]
                                dic_ref = {}
                                dic_ref[0] = [ref_sen]
                                sen_bleu, _ = Bleu_obj.compute_score(
                                    dic_ref, dic_hyp)
                                sen_rouge = Rouge_obj.compute_score(
                                    dic_ref, dic_hyp)
                                sen_meteor, _ = Meteor_obj.compute_score(
                                    dic_ref, dic_hyp)
                                sen_cider, _ = cIDER_obj.compute_score(
                                    dic_ref, dic_hyp)
                                val_Bleu_score1.append(sen_bleu[0])
                                val_Bleu_score2.append(sen_bleu[1])
                                val_Bleu_score3.append(sen_bleu[2])
                                val_Bleu_score4.append(sen_bleu[3])
                                val_Rouge_score.append(sen_rouge[0])
                                val_Meteor_score.append(sen_meteor)
                                val_Cider_score.append(sen_cider)

                        for t in xrange(3):
                            print 'Validation Question {}'.format(t)
                            print "NQ: " + " ".join(
                                map(lambda i: vocab[i],
                                    list(val_source_shuffle[t]))).strip(
                                    ).replace("<EOS>", " ")
                            print "CQ: " + " ".join(
                                map(lambda i: vocab[i],
                                    list(val_target_shuffle[t]))).strip(
                                    ).replace("<EOS>", " ")
                            print "GQ: " + " ".join(
                                map(lambda i: vocab[i], list(
                                    val_ids[t]))).strip().replace(
                                        "<EOS>", " ")

                        avg_val_loss = sum(val_loss) / float(len(val_loss))

                        print('   minibatch loss of validation: {}'.format(
                            avg_val_loss))

                        print(
                            "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor Cider\n"
                        )
                        print(
                            "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                            .format(
                                sum(val_Bleu_score1) /
                                float(len(val_Bleu_score1)),
                                sum(val_Bleu_score2) /
                                float(len(val_Bleu_score2)),
                                sum(val_Bleu_score3) /
                                float(len(val_Bleu_score3)),
                                sum(val_Bleu_score4) /
                                float(len(val_Bleu_score4)),
                                sum(val_Rouge_score) /
                                float(len(val_Rouge_score)),
                                sum(val_Meteor_score) /
                                float(len(val_Meteor_score)),
                                sum(val_Cider_score) /
                                float(len(val_Cider_score))))

                        val_loss_epo.append(avg_val_loss)

                if epo > 0:
                    for loss in val_loss_epo:
                        print "the val_loss_epo:", loss
                    print "the loss difference:", val_loss_epo[
                        -2] - val_loss_epo[-1]

                    if val_loss_epo[-2] - val_loss_epo[-1]:
                        patience_cnt = 0
                    else:
                        patience_cnt += 1

                    print patience_cnt

                    if patience_cnt > 5:
                        print("early stopping...")
                        saver.save(sess, S2S_ckp_dir)

                        break

                # if epo % epoch_print == 0:
                #     for t in range(10):
                #         print 'Question {}'.format(t)
                #         print " ".join(map(lambda i: vocab[i], list(target_shuffle[t]))).strip()
                #         print " ".join(map(lambda i: vocab[i], list(logit_id[t, :]))).strip()
                avg_bleu_val = sum(val_Bleu_score1) / float(
                    len(val_Bleu_score1))
                if epo % 5 == 0:
                    saver.save(sess, S2S_ckp_dir)

                f.write("the performance: " + S2S_ckp_dir + "_epoch_" +
                        str(epo) + "avg_eval_bleu_" + str(avg_bleu_val))
            f.close()
            print "save file"

    elif model_type == 'testing':

        Seq2Seq_model = Seq2Seq(batch_size, vocab_size, num_units, embd,
                                model_type, Bidirection, Embd_train,
                                char_hidden_units, Attention, char_num,
                                char_dim)

        with tf.Session() as sess:
            saver_word_rw = tf.train.Saver()
            saver_word_rw.restore(sess, S2S_ckp_dir)

            max_batches = len(test_noisy_Id) / batch_size

            idx = np.arange(0, len(test_noisy_Id))
            # idx = list(np.random.permutation(idx))
            test_Bleu_score1 = []
            test_Bleu_score2 = []
            test_Bleu_score3 = []
            test_Bleu_score4 = []
            test_Rouge_score = []
            test_Meteor_score = []
            test_Cider_score = []

            generated_test_sen = []
            test_answer_sen = []
            test_noisy_sen = []
            test_clean_sen = []

            for batch in range(max_batches):

                test_source_shuffle, test_source_len, test_source_nemd, test_train_shuffle, test_target_shuffle, test_target_len = next_batch(
                    test_noisy_Id, test_noisy_len, test_nemd, test_input_Id,
                    test_target_Id, test_clean_len, batch_size, batch, idx)

                for no in test_source_shuffle:
                    test_noisy_sen.append(vocab[si] for si in no)
                for cl in test_target_shuffle:
                    test_clean_sen.append(vocab[ci] for ci in cl)

                test_source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                    test_source_shuffle,
                    maxlen=max_word,
                    padding='post',
                    value=EOS)

                test_source_emd = []
                for n in test_source_nemd:
                    test_source_emd.append(n[0:test_source_shuffle.shape[1]])
                test_source_emd = np.asarray(test_source_emd)

                # test_batch_noisy_char_Id = []
                # for sen_char in test_char_Id:
                #     sen_len = len(sen_char)
                #     for i in range(len(test_source_shuffle[0]) - sen_len):
                #         sen_char.append([0] * max_char)  ## fix the char with the length of max_word
                #     test_batch_noisy_char_Id.append(sen_char)
                # test_batch_noisy_char_Id = np.asarray(test_batch_noisy_char_Id)

                fd = {
                    Seq2Seq_model.encoder_inputs: test_source_shuffle,
                    Seq2Seq_model.encoder_inputs_length: test_source_len,
                    Seq2Seq_model.encoder_emb: test_source_emd,
                    Seq2Seq_model.dropout_rate: 1
                }
                ids = sess.run([Seq2Seq_model.ids], fd)
                for t in range(batch_size):
                    ref = []
                    hyp = []
                    for tar_id in test_target_shuffle[t]:
                        if tar_id != 2:
                            ref.append(vocab[tar_id])
                    for pre_id in ids[0][t]:
                        if pre_id != 2 and pre_id != 0:
                            hyp.append(vocab[pre_id])
                    generated_test_sen.append(hyp)

                    hyp_sen = u" ".join(hyp).encode('utf-8')
                    ref_sen = u" ".join(ref).encode('utf-8')
                    dic_hyp = {}
                    dic_hyp[0] = [hyp_sen]
                    dic_ref = {}
                    dic_ref[0] = [ref_sen]
                    sen_bleu, _ = Bleu_obj.compute_score(dic_ref, dic_hyp)
                    sen_rouge = Rouge_obj.compute_score(dic_ref, dic_hyp)
                    sen_meteor, _ = Meteor_obj.compute_score(dic_ref, dic_hyp)
                    sen_cider, _ = cIDER_obj.compute_score(dic_ref, dic_hyp)
                    test_Bleu_score1.append(sen_bleu[0])
                    test_Bleu_score2.append(sen_bleu[1])
                    test_Bleu_score3.append(sen_bleu[2])
                    test_Bleu_score4.append(sen_bleu[3])
                    test_Rouge_score.append(sen_rouge[0])
                    test_Meteor_score.append(sen_meteor)
                    test_Cider_score.append(sen_cider)

                for t in xrange(5):
                    print 'Training Question {}'.format(t)
                    print "NQ: " + " ".join(
                        map(lambda i: vocab[i], list(
                            test_source_shuffle[t]))).strip().replace(
                                "<EOS>", "  ")
                    print "CQ: " + " ".join(
                        map(lambda i: vocab[i], list(
                            test_target_shuffle[t]))).strip().replace(
                                "<EOS>", "  ")
                    print "GQ: " + " ".join(
                        map(lambda i: vocab[i], list(
                            ids[0][t]))).strip().replace("<EOS>", "  ")

                bleu_score1 = sum(test_Bleu_score1) / float(
                    len(test_Bleu_score1))
                bleu_score2 = sum(test_Bleu_score2) / float(
                    len(test_Bleu_score2))
                bleu_score3 = sum(test_Bleu_score3) / float(
                    len(test_Bleu_score3))
                bleu_score4 = sum(test_Bleu_score4) / float(
                    len(test_Bleu_score4))

                print(
                    "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor  Cider\n"
                )
                print(
                    "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                    .format(
                        bleu_score1, bleu_score2, bleu_score3, bleu_score4,
                        sum(test_Rouge_score) / float(len(test_Rouge_score)),
                        sum(test_Meteor_score) / float(len(test_Meteor_score)),
                        sum(test_Cider_score) / float(len(test_Cider_score))))

            #  len(test_answer_sen)
            fname = S2S_ckp_dir + "_bleu1_" + str(bleu_score1)
            f = open(fname, "wb")
            print "the length test set is:", len(generated_test_sen), len(
                test_noisy_sen), len(test_clean_sen)
            f.write(
                "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor  Cider\n"
            )
            f.write(
                "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f} \n"
                .format(bleu_score1, bleu_score2, bleu_score3, bleu_score4,
                        sum(test_Rouge_score) / float(len(test_Rouge_score)),
                        sum(test_Meteor_score) / float(len(test_Meteor_score)),
                        sum(test_Cider_score) / float(len(test_Cider_score))))
            for i in range(len(generated_test_sen)):
                f.write("question" + str(i) + "\n")
                # f.write("answer: " + " ".join(test_answer_sen[i]) + "\n")
                f.write("noisy question: " + " ".join(test_noisy_sen[i]) +
                        "\n")
                f.write("clean question: " + " ".join(test_clean_sen[i]) +
                        "\n")
                f.write("generated question: " +
                        " ".join(generated_test_sen[i]) + "\n")

            f.close()
Esempio n. 3
0
def RL_tuning_model(
    data_comb,
    epoch,
    batch_size=20,
    num_units=300,
    beam_width=5,
    discount_factor=0.1,
    sen_reward_rate=2,
    dropout=1,
    dec_L=1,
    Bidirection=False,
    Embd_train=False,
    Attention=True,
):
    max_noisy_len = data_comb[3]["max_noisy_len"]
    max_clean_len = data_comb[3]["max_clean_len"]
    max_answer_len = data_comb[3]["max_answer_len"]

    L = max_clean_len

    noisy_Id = data_comb[0]['noisy_Id']
    noisy_len = data_comb[0]['noisy_len']
    ground_truth = data_comb[0]['ground_truth']
    clean_len = data_comb[0]['clean_len']
    answer_Id = data_comb[0]['answer_Id']
    answer_len = data_comb[0]['answer_len']
    train_Id = data_comb[0]['train_Id']
    vocab = data_comb[0]['vocab']
    embd = data_comb[0]['embd']
    embd = np.array(embd)

    val_noisy_Id = data_comb[1]['noisy_Id']
    val_noisy_len = data_comb[1]['noisy_len']
    val_ground_truth = data_comb[1]['ground_truth']
    val_clean_len = data_comb[1]['clean_len']
    val_answer_Id = data_comb[1]['answer_Id']
    val_answer_len = data_comb[1]['answer_len']
    val_train_Id = data_comb[1]['train_Id']

    test_noisy_Id = data_comb[2]['noisy_Id']
    test_noisy_len = data_comb[2]['noisy_len']
    test_ground_truth = data_comb[2]['ground_truth']
    test_clean_len = data_comb[2]['clean_len']
    test_answer_Id = data_comb[2]['answer_Id']
    test_answer_len = data_comb[2]['answer_len']
    test_train_Id = data_comb[2]['train_Id']

    vocab_size = len(vocab)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    model_name = "BS_"
    values = {}
    checkpoint_path = "/home/ye/PycharmProjects/Qrefine/Seq_ckpt/Wiki/ALL_seq2seq_Bi_" + str(
        Bidirection) + "_Att_" + str(Attention) + "_Emb_" + str(
            Embd_train) + "_hiddenunits_" + str(num_units)
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        if 'loss_fun' not in key:
            values[model_name + key + ':0'] = reader.get_tensor(key)

    model_cond = "training"
    G_Seq2Seq = tf.Graph()

    sess_word_rw = tf.Session(config=config, graph=G_Seq2Seq)
    with G_Seq2Seq.as_default():
        Seq2Seq_model = BA_Seq2Seq.Bi_Att_Seq2Seq(batch_size, vocab_size,
                                                  num_units, embd, model_cond,
                                                  Bidirection, Embd_train,
                                                  Attention)
        saver_word_rw = tf.train.Saver()
        # saver_word_rw.restore(sess_word_rw,
        #                       "Seq_ckpt/pretrainALL-seq2seq_Bi_" + str(Bidirection) + "_Att_" + str(
        #                           Attention) + "_Emb_" + str(
        #                           Embd_train))
        saver_word_rw.restore(
            sess_word_rw,
            "/home/ye/PycharmProjects/Qrefine/Seq_ckpt/Wiki/ALL_seq2seq_Bi_" +
            str(Bidirection) + "_Att_" + str(Attention) + "_Emb_" +
            str(Embd_train) + "_hiddenunits_" + str(num_units))

    model_type = "testing"
    G_QA_similiarity = tf.Graph()
    sess_QA_rw = tf.Session(config=config, graph=G_QA_similiarity)
    with G_QA_similiarity.as_default():
        QA_simi_model = QA_similiarity.QA_similiarity(batch_size, num_units,
                                                      embd, model_type)
        saver_sen_rw = tf.train.Saver()
        # saver_sen_rw.restore(sess_QA_rw, "Seq_ckpt/Wiki_qa_similiarity")
        saver_sen_rw.restore(
            sess_QA_rw,
            "/home/ye/PycharmProjects/Qrefine/Seq_ckpt/Wiki/qa_similiarity")

    G_BeamSearch = tf.Graph()
    with G_BeamSearch.as_default():
        BeamSearch_seq2seq = seq_last.BeamSearch_Seq2seq(
            vocab_size=vocab_size,
            num_units=num_units,
            beam_width=beam_width,
            model_name=model_name,
            embd=embd,
            Bidirection=Bidirection,
            Embd_train=Embd_train,
            Attention=Attention,
            max_target_length=L)

    seq2seq_len = L
    RL_len = 0

    Bleu_obj = Bleu()
    Rouge_obj = Rouge()
    Meteor_obj = Meteor()
    cIDER_obj = Cider()

    with tf.Session(config=config, graph=G_BeamSearch) as sess_beamsearch:
        sess_beamsearch.run(tf.global_variables_initializer())
        for v in tf.trainable_variables():
            if v.name in values.keys():
                v.load(values[v.name], sess_beamsearch)

        val_loss_epo = []
        patience_cnt = 0
        for epo in range(epoch):
            print(" epoch: {}".format(epo))
            # RL_len = epo
            RL_len = L - seq2seq_len

            idx = np.arange(0, len(noisy_Id))
            idx = list(np.random.permutation(idx))

            Bleu_score1 = []

            for batch in range(len(noisy_Id) / batch_size):

                if seq2seq_len < 0:
                    seq2seq_len = 0
                    source_shuffle, source_len, train_shuffle, target_shuffle, target_len, answer_shuffle, ans_len = next_batch(
                        noisy_Id, noisy_len, train_Id, ground_truth, clean_len,
                        answer_Id, answer_len, batch_size, batch, idx)

                    source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        source_shuffle,
                        maxlen=None,
                        padding='post',
                        truncating="post",
                        value=EOS)

                    target_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        target_shuffle,
                        maxlen=seq2seq_len,
                        padding='post',
                        truncating="post",
                        value=EOS)

                    train_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        train_shuffle,
                        maxlen=seq2seq_len,
                        padding='post',
                        truncating="post",
                        value=EOS)

                    initial_input_in = [[EOS] for i in range(batch_size)]

                    target_len = np.tile(0, batch_size)

                    fd = {
                        BeamSearch_seq2seq.encoder_inputs: source_shuffle,
                        BeamSearch_seq2seq.encoder_inputs_length: source_len,
                        BeamSearch_seq2seq.decoder_length: target_len,
                        BeamSearch_seq2seq.decoder_inputs: train_shuffle,
                        BeamSearch_seq2seq.decoder_targets: target_shuffle,
                        BeamSearch_seq2seq.initial_input: initial_input_in
                    }

                    generated_que, policy = sess_beamsearch.run([
                        BeamSearch_seq2seq.RL_ids,
                        BeamSearch_seq2seq.max_policy
                    ], fd)

                    generated_que_input = np.insert(generated_que,
                                                    0,
                                                    SOS,
                                                    axis=1)[:, 0:-1]
                    generated_target_len = np.tile(
                        generated_que_input.shape[1], batch_size)

                    fd_seq = {
                        Seq2Seq_model.encoder_inputs: source_shuffle,
                        Seq2Seq_model.encoder_inputs_length: source_len,
                        Seq2Seq_model.decoder_inputs: generated_que_input,
                        Seq2Seq_model.decoder_length: generated_target_len,
                        Seq2Seq_model.ground_truth: generated_que,
                        Seq2Seq_model.dropout_rate: dropout
                    }
                    logits_pro = sess_word_rw.run(Seq2Seq_model.softmax_logits,
                                                  fd_seq)

                    answer_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                        answer_shuffle,
                        maxlen=None,
                        padding='post',
                        truncating="post",
                        value=EOS)

                    generated_len = np.tile(generated_que.shape[1], batch_size)

                    fd_qa = {
                        QA_simi_model.answer_inputs: answer_shuffle,
                        QA_simi_model.answer_inputs_length: ans_len,
                        QA_simi_model.question1_inputs: source_shuffle,
                        QA_simi_model.question1_inputs_length: source_len,
                        QA_simi_model.question2_inputs: generated_que,
                        QA_simi_model.question2_inputs_length: generated_len
                    }
                    QA_similiarity_rd = sess_QA_rw.run(
                        [QA_simi_model.two_distance], fd_qa)

                    beam_QA_similar.append(QA_similiarity_rd[0])
                    # print QA_similiarity_rd

                    reward = np.zeros((batch_size, RL_len))
                    for i in range(generated_que.shape[0]):
                        for j in range(seq2seq_len, generated_que.shape[1]):
                            max_index = generated_que[i][j]
                            reward[i][L - 1 - j] = logits_pro[i, j, max_index]
                            if j == generated_que.shape[1] - 1:
                                reward[i][L - 1 - j] = reward[i][
                                    L - 1 - j] + sen_reward_rate * (
                                        QA_similiarity_rd[0][i])

                    discounted_rewards = discounted_rewards_cal(
                        reward, discount_factor)

                    RL_rewards = discounted_rewards

                    fd = {
                        BeamSearch_seq2seq.encoder_inputs: source_shuffle,
                        BeamSearch_seq2seq.encoder_inputs_length: source_len,
                        BeamSearch_seq2seq.decoder_length: target_len,
                        BeamSearch_seq2seq.decoder_inputs: train_shuffle,
                        BeamSearch_seq2seq.decoder_targets: target_shuffle,
                        BeamSearch_seq2seq.initial_input: initial_input_in,
                        BeamSearch_seq2seq.dis_rewards: RL_rewards
                    }

                    _, rl_reward = sess_beamsearch.run([
                        BeamSearch_seq2seq.RL_train_op,
                        BeamSearch_seq2seq.rl_reward
                    ], fd)
                    for t in range(batch_size):
                        ref = []
                        hyp = []
                        bleu_s = []
                        for tar_id in target_shuffle[t]:
                            if tar_id != 2:
                                ref.append(vocab[tar_id])
                        for pre_id in generated_que[t, :]:
                            if pre_id != 0 and pre_id != -1 and pre_id != 2:
                                hyp.append(vocab[pre_id])
                        sen_bleu1 = bleu([ref], hyp, weights=(1, 0, 0, 0))
                        bleu_s.append(sen_bleu1)
                        Bleu_score1.append(bleu_s)

                    top1_bleu = []
                    for i in range(len(Bleu_score1)):
                        top1_bleu.append(Bleu_score1[i][0])

                    if batch == 0 or batch % 100 == 0:
                        print(' batch {}'.format(batch))
                        print('   minibatch reward of training: {}'.format(
                            -rl_reward))
                        print("the current whole bleu_score1 is:",
                              sum(top1_bleu) / float(len(top1_bleu)))

                beam_QA_similar = []
                # beam_cnQA_similar = []
                source_shuffle, source_len, train_shuffle, target_shuffle, target_len, answer_shuffle, ans_len = next_batch(
                    noisy_Id, noisy_len, train_Id, ground_truth, clean_len,
                    answer_Id, answer_len, batch_size, batch, idx)

                source_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                    source_shuffle,
                    maxlen=None,
                    padding='post',
                    truncating="post",
                    value=EOS)

                target_shuffle_in = tf.keras.preprocessing.sequence.pad_sequences(
                    target_shuffle,
                    maxlen=seq2seq_len,
                    padding='post',
                    truncating="post",
                    value=EOS)

                train_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                    train_shuffle,
                    maxlen=seq2seq_len,
                    padding='post',
                    truncating="post",
                    value=EOS)

                initial_input_in = [
                    target_shuffle_in[i][-1] for i in range(batch_size)
                ]

                target_len = np.tile(seq2seq_len, batch_size)

                fd = {
                    BeamSearch_seq2seq.encoder_inputs: source_shuffle,
                    BeamSearch_seq2seq.encoder_inputs_length: source_len,
                    BeamSearch_seq2seq.decoder_length: target_len,
                    BeamSearch_seq2seq.decoder_inputs: train_shuffle,
                    BeamSearch_seq2seq.decoder_targets: target_shuffle_in,
                    BeamSearch_seq2seq.initial_input: initial_input_in
                }

                cl_loss, _, S2S_ids = sess_beamsearch.run([
                    BeamSearch_seq2seq.loss_seq2seq,
                    BeamSearch_seq2seq.train_op, BeamSearch_seq2seq.ids
                ], fd)

                # tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto')

                if batch == 0 or batch % batches_in_epoch == 0:
                    print(' batch {}'.format(batch))
                    print('   minibatch loss of training: {}'.format(cl_loss))
                    val_loss = []
                    for batch_val in range(len(val_noisy_Id) / batch_size):
                        idx_v = np.arange(0, len(val_noisy_Id))
                        idx_v = list(np.random.permutation(idx_v))

                        val_source_shuffle, val_source_len, val_train_shuffle, val_target_shuffle, val_target_len, val_answer_shuffle, val_ans_len = next_batch(
                            val_noisy_Id, val_noisy_len, val_train_Id,
                            val_ground_truth, val_clean_len, val_answer_Id,
                            val_answer_len, batch_size, batch_val, idx_v)
                        val_source_shuffle_in = tf.keras.preprocessing.sequence.pad_sequences(
                            val_source_shuffle,
                            maxlen=None,
                            padding='post',
                            value=EOS)
                        val_target_shuffle_in = tf.keras.preprocessing.sequence.pad_sequences(
                            val_target_shuffle,
                            maxlen=seq2seq_len,
                            padding='post',
                            value=EOS)
                        val_train_shuffle_in = tf.keras.preprocessing.sequence.pad_sequences(
                            val_train_shuffle,
                            maxlen=seq2seq_len,
                            padding='post',
                            value=EOS)

                        val_initial_input = [
                            val_target_shuffle_in[i][-1]
                            for i in range(batch_size)
                        ]

                        val_target_len = np.tile(seq2seq_len, batch_size)

                        fd_val = {
                            BeamSearch_seq2seq.encoder_inputs:
                            val_source_shuffle_in,
                            BeamSearch_seq2seq.encoder_inputs_length:
                            val_source_len,
                            BeamSearch_seq2seq.decoder_length: val_target_len,
                            BeamSearch_seq2seq.decoder_inputs:
                            val_train_shuffle_in,
                            BeamSearch_seq2seq.decoder_targets:
                            val_target_shuffle_in,
                            BeamSearch_seq2seq.initial_input: val_initial_input
                        }
                        val_loss.append(
                            sess_beamsearch.run(
                                BeamSearch_seq2seq.loss_seq2seq, fd_val))
                        avg_val_loss = sum(val_loss) / float(len(val_loss))

                    print('   minibatch loss of validation: {}'.format(
                        avg_val_loss))

                # val_loss_epo.append(avg_val_loss)

                gc.collect()

                if L - seq2seq_len == 0: continue

                RL_logits, policy = sess_beamsearch.run(
                    [BeamSearch_seq2seq.RL_ids, BeamSearch_seq2seq.max_policy],
                    fd)

                # print "The size of Rl:", RL_logits.shape[0], RL_logits.shape[1]
                # print "The size of Policy:", policy.shape[0], policy.shape[1]

                max_target_length = RL_logits.shape[1]

                # for batch in range(batch_size):
                #     label = 0
                #     for i in range(RL_logits.shape[1]):
                #         if RL_logits[batch][i]==2 and label ==0:
                #             label = 1
                #             for id in range(i+1, max_target_length):
                #                 RL_logits[batch][id]=2
                #             continue

                Sequnce_len = seq2seq_len + max_target_length
                # print "the max_target_length is:", max_target_length

                RL_rewards = np.zeros((batch_size, max_target_length))

                generated_que = []

                for i in range(batch_size):
                    generated_que.append(
                        list(np.append(S2S_ids[i], RL_logits[i])))

                generated_que = np.asarray(generated_que)

                logits_batch = []
                # for i in range(generated_que.shape[0]):
                #     for j in range(generated_que.shape[1]):
                #         if generated_que[i][j] == 2 or j == generated_que.shape[1] - 1:
                #             logits_batch.append(j)
                #             continue
                #
                # reward = np.zeros((batch_size, max_target_length))
                # # generated_que = SOS + generated_que - PAD
                generated_que_input = np.insert(generated_que, 0, SOS,
                                                axis=1)[:, 0:-1]
                # target_shuffle = tf.keras.preprocessing.sequence.pad_sequences(target_shuffle, maxlen=None,
                #                                                                padding='post',truncating="post",  value=EOS)

                generated_target_len = np.tile(generated_que_input.shape[1],
                                               batch_size)

                fd_seq = {
                    Seq2Seq_model.encoder_inputs: source_shuffle,
                    Seq2Seq_model.encoder_inputs_length: source_len,
                    Seq2Seq_model.decoder_inputs: generated_que_input,
                    Seq2Seq_model.decoder_length: generated_target_len,
                    Seq2Seq_model.ground_truth: generated_que,
                    Seq2Seq_model.dropout_rate: dropout
                }
                logits_pro = sess_word_rw.run(Seq2Seq_model.softmax_logits,
                                              fd_seq)

                # logits_pro = logits_pro[:, :, :]
                # print "the logits_pro is:", logits_pro.shape[1]

                answer_shuffle = tf.keras.preprocessing.sequence.pad_sequences(
                    answer_shuffle,
                    maxlen=None,
                    padding='post',
                    truncating="post",
                    value=EOS)

                generated_len = np.tile(generated_que.shape[1], batch_size)

                fd_qa = {
                    QA_simi_model.answer_inputs: answer_shuffle,
                    QA_simi_model.answer_inputs_length: ans_len,
                    QA_simi_model.question1_inputs: source_shuffle,
                    QA_simi_model.question1_inputs_length: source_len,
                    QA_simi_model.question2_inputs: generated_que,
                    QA_simi_model.question2_inputs_length: generated_len
                }
                QA_similiarity_rd = sess_QA_rw.run(
                    [QA_simi_model.two_distance], fd_qa)

                beam_QA_similar.append(QA_similiarity_rd[0])
                # print QA_similiarity_rd

                # reward = np.zeros((batch_size, RL_len))

                # only use word reward
                for i in range(generated_que.shape[0]):
                    label = 0
                    for j in range(seq2seq_len, generated_que.shape[1]):
                        max_index = generated_que[i][j]
                        RL_rewards[i][j - seq2seq_len] = logits_pro[i, j,
                                                                    max_index]

                # only use QA similiarity
                # for i in range(generated_que.shape[0]):
                #     label = 0
                #     for j in range(seq2seq_len, generated_que.shape[1]):
                #         max_index = generated_que[i][j]
                #         RL_rewards[i][j - seq2seq_len] = 0
                #         if max_index == 2 and label == 0:
                #             label = 1
                #             RL_rewards[i][j - seq2seq_len] = 0 + sen_reward_rate * (QA_similiarity_rd[0][i])
                #             break

                # # use both QA similiarity and Word reward
                for i in range(generated_que.shape[0]):
                    label = 0
                    for j in range(seq2seq_len, generated_que.shape[1]):
                        max_index = generated_que[i][j]
                        RL_rewards[i][j - seq2seq_len] = logits_pro[i, j,
                                                                    max_index]
                        if max_index == 2 and label == 0:
                            label = 1
                            RL_rewards[i][j - seq2seq_len] = logits_pro[
                                i, j, max_index] + sen_reward_rate * (
                                    QA_similiarity_rd[0][i])
                            break

                # fd = {an_Lstm.answer_inputs: answer_shuffle, an_Lstm.answer_inputs_length: ans_len}
                # reward_similiarity = sess_sen_rw.run([an_Lstm.answer_state], fd)

                discounted_rewards = discounted_rewards_cal(
                    RL_rewards, discount_factor)

                RL_rewards = discounted_rewards

                fd = {
                    BeamSearch_seq2seq.encoder_inputs: source_shuffle,
                    BeamSearch_seq2seq.encoder_inputs_length: source_len,
                    BeamSearch_seq2seq.decoder_length: target_len,
                    BeamSearch_seq2seq.decoder_inputs: train_shuffle,
                    BeamSearch_seq2seq.decoder_targets: target_shuffle_in,
                    BeamSearch_seq2seq.initial_input: initial_input_in,
                    BeamSearch_seq2seq.dis_rewards: RL_rewards
                }

                _, rl_reward, word_policy, policy = sess_beamsearch.run([
                    BeamSearch_seq2seq.RL_train_op,
                    BeamSearch_seq2seq.rl_reward,
                    BeamSearch_seq2seq.word_log_prob,
                    BeamSearch_seq2seq.softmax_policy
                ], fd)

                for t in range(batch_size):
                    ref = []
                    hyp = []
                    bleu_s = []
                    for tar_id in target_shuffle[t]:
                        if tar_id != 2:
                            ref.append(vocab[tar_id])
                    for pre_id in generated_que[t, :]:
                        if pre_id != 0 and pre_id != -1 and pre_id != 2:
                            hyp.append(vocab[pre_id])
                    sen_bleu1 = bleu([ref], hyp, weights=(1, 0, 0, 0))
                    bleu_s.append(sen_bleu1)
                    Bleu_score1.append(bleu_s)

                top1_bleu = []
                for i in range(len(Bleu_score1)):
                    top1_bleu.append(Bleu_score1[i][0])

                if batch == 0 or batch % 100 == 0:
                    print(' batch {}'.format(batch))
                    print('   minibatch reward of training: {}'.format(
                        -rl_reward))
                    # print the result
                    for t in range(5):  ## five sentences
                        print('Question {}'.format(t))
                        print("noisy question:")
                        print(" ".join(
                            map(lambda i: vocab[i],
                                list(source_shuffle[t]))).strip().replace(
                                    "<EOS>", " "))
                        print("clean question:")
                        print(" ".join(
                            map(lambda i: vocab[i],
                                list(target_shuffle[t]))).strip().replace(
                                    "<EOS>", " "))
                        print("the generated question:")
                        pre_sen = []
                        for log_id in generated_que[t, :]:
                            if log_id != -1 and log_id != 2 and log_id != 0:
                                pre_sen.append(log_id)
                        print(" ".join(map(lambda i: vocab[i],
                                           pre_sen)).strip())
                        print("\n")
                    print("the current whole bleu_score1 is:",
                          sum(top1_bleu) / float(len(top1_bleu)))
                gc.collect()

            seq2seq_len = seq2seq_len - dec_L

            ##testing set result:
            test_Bleu_score1 = []
            test_Bleu_score2 = []
            test_Bleu_score3 = []
            test_Bleu_score4 = []
            test_Rouge_score = []
            test_Meteor_score = []
            test_Cider_score = []

            generated_test_sen = []
            test_answer_sen = []
            test_noisy_sen = []
            test_clean_sen = []

            for batch_test in range(len(test_noisy_Id) / batch_size):
                idx_t = np.arange(0, len(test_noisy_Id))
                # idx_t = list(np.random.permutation(idx_t))

                val_source_shuffle, val_source_len, val_train_shuffle, val_target_shuffle, val_target_len, val_answer_shuffle, val_ans_len = next_batch(
                    test_noisy_Id, test_noisy_len, test_train_Id,
                    test_ground_truth, test_clean_len, test_answer_Id,
                    test_answer_len, batch_size, batch_test, idx_t)

                for an in val_answer_shuffle:
                    test_answer_sen.append(vocab[anum] for anum in an)
                for no in val_source_shuffle:
                    test_noisy_sen.append(vocab[si] for si in no)
                for cl in val_target_shuffle:
                    test_clean_sen.append(vocab[ci] for ci in cl)

                val_source_shuffle_in = tf.keras.preprocessing.sequence.pad_sequences(
                    val_source_shuffle,
                    maxlen=None,
                    padding='post',
                    truncating="post",
                    value=EOS)

                val_target_shuffle_in = []
                for val in val_target_shuffle:
                    val_target_shuffle_in.append([val[0]])

                val_train_shuffle_in = []
                for tra in val_train_shuffle:
                    val_train_shuffle_in.append([tra[0]])

                initial_input_in = [SOS for i in range(batch_size)]
                # [SOS for i in range(batch_size)]

                val_target_len_in = np.tile(1, batch_size)

                fd = {
                    BeamSearch_seq2seq.encoder_inputs: val_source_shuffle_in,
                    BeamSearch_seq2seq.encoder_inputs_length: val_source_len,
                    BeamSearch_seq2seq.decoder_length: val_target_len_in,
                    BeamSearch_seq2seq.decoder_inputs: val_train_shuffle_in,
                    BeamSearch_seq2seq.decoder_targets: val_target_shuffle_in,
                    BeamSearch_seq2seq.initial_input: initial_input_in
                }

                val_id = sess_beamsearch.run([BeamSearch_seq2seq.RL_ids], fd)

                final_id = val_id[0]

                for t in range(batch_size):
                    ref = []
                    hyp = []
                    for tar_id in val_target_shuffle[t]:
                        if tar_id != 2:
                            ref.append(vocab[tar_id])
                    for pre_id in final_id[t, :]:
                        if pre_id != 2 and pre_id != 0 and pre_id != -1:
                            hyp.append(vocab[pre_id])
                    generated_test_sen.append(hyp)

                    hyp_sen = " ".join(hyp)
                    ref_sen = " ".join(ref)
                    dic_hyp = {}
                    dic_hyp[0] = [hyp_sen]
                    dic_ref = {}
                    dic_ref[0] = [ref_sen]
                    sen_bleu, _ = Bleu_obj.compute_score(dic_ref, dic_hyp)
                    sen_rouge = Rouge_obj.compute_score(dic_ref, dic_hyp)
                    sen_meteor, _ = Meteor_obj.compute_score(dic_ref, dic_hyp)
                    sen_cider, _ = cIDER_obj.compute_score(dic_ref, dic_hyp)
                    test_Bleu_score1.append(sen_bleu[0])
                    test_Bleu_score2.append(sen_bleu[1])
                    test_Bleu_score3.append(sen_bleu[2])
                    test_Bleu_score4.append(sen_bleu[3])
                    test_Rouge_score.append(sen_rouge[0])
                    test_Meteor_score.append(sen_meteor)
                    test_Cider_score.append(sen_cider)

                bleu_score1 = sum(test_Bleu_score1) / float(
                    len(test_Bleu_score1))
                bleu_score2 = sum(test_Bleu_score2) / float(
                    len(test_Bleu_score2))
                bleu_score3 = sum(test_Bleu_score3) / float(
                    len(test_Bleu_score3))
                bleu_score4 = sum(test_Bleu_score4) / float(
                    len(test_Bleu_score4))

                print(
                    "\n Bleu_score1, Bleu_score2, Bleu_score3, Bleu_score4, Rouge, Meteor  Cider\n"
                )
                print(
                    "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f}"
                    .format(
                        bleu_score1, bleu_score2, bleu_score3, bleu_score4,
                        sum(test_Rouge_score) / float(len(test_Rouge_score)),
                        sum(test_Meteor_score) / float(len(test_Meteor_score)),
                        sum(test_Cider_score) / float(len(test_Cider_score))))

                print "the bleu score on test is:", bleu_score1

                if bleu_score1 > 0.55:
                    fname = "Wiki_result_whole/wiki_bleu1_" + str(bleu_score1)
                    f = open(fname, "wb")
                    f.write(
                        "   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f},   {:.4f} \n"
                        .format(
                            bleu_score1, bleu_score2, bleu_score3, bleu_score4,
                            sum(test_Rouge_score) /
                            float(len(test_Rouge_score)),
                            sum(test_Meteor_score) /
                            float(len(test_Meteor_score)),
                            sum(test_Cider_score) /
                            float(len(test_Cider_score))))

                    f.write("\n" + "reward_rate " + str(sen_reward_rate) +
                            " discount_factor " + str(discount_factor) +
                            " dec_L " + str(dec_L) + " batch_size " +
                            str(batch_size) + " epoch " + str(epoch) + "\n")

                    for i in range(len(generated_test_sen)):
                        f.write("question " + str(i) + "\n")
                        f.write("answer: " + " ".join(test_answer_sen[i]) +
                                "\n")
                        f.write("noisy question: " +
                                " ".join(test_noisy_sen[i]) + "\n")
                        f.write("clean question: " +
                                " ".join(test_clean_sen[i]) + "\n")
                        f.write("generated question: " +
                                " ".join(generated_test_sen[i]) + "\n")