Пример #1
0
    def generate_and_save(self, data_util, infile, generate_batch, outfile):
        outfile = codecs.open(outfile, 'w', 'utf-8')
        for batch in data_util.get_test_batches(infile, generate_batch):
            feed = {self.generate_x: batch}
            out_generate = self.sess.run(self.generate_sample, feed_dict=feed)
            out_generate_dealed, _ = deal_generated_samples(
                out_generate, data_util.dst2idx)

            y_strs = data_util.indices_to_words_del_pad(
                out_generate_dealed, 'dst')
            for y_str in y_strs:
                outfile.write(y_str + '\n')
        outfile.close()
Пример #2
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph = tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size,
                      dst_vocab_size=config.dst_vocab_size)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_train_model()
        generator.build_generate(max_len=config.generator.max_length,
                                 generate_devices=config.generator.devices,
                                 optimizer=config.generator.optimizer)

        generator.build_rollout_generate(
            max_len=config.generator.max_length,
            roll_generate_devices=config.generator.devices)

        generator.init_and_restore(modelFile=config.generator.modelFile)

        dis_filter_sizes = [
            i for i in range(1, config.discriminator.dis_max_len, 4)
        ]
        dis_num_filters = [
            (100 + i * 10)
            for i in range(1, config.discriminator.dis_max_len, 4)
        ]

        discriminator = DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=2,
            vocab_size=config.dst_vocab_size,
            vocab_size_s=config.src_vocab_size,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            positive_data=config.discriminator.dis_positive_data,
            negative_data=config.discriminator.dis_negative_data,
            source_data=config.discriminator.dis_source_data,
            dev_positive_data=config.discriminator.dis_dev_positive_data,
            dev_negative_data=config.discriminator.dis_dev_negative_data,
            dev_source_data=config.discriminator.dis_dev_source_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope)

        batch_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length)

        for epoch in range(1, config.gan_iter_num + 1):
            for gen_iter in range(config.gan_gen_iter_num):
                batch = next(batch_iter)
                x, y_ground = batch[0], batch[1]
                y_sample = generator.generate_step(x)
                logging.info("generate the samples")
                y_sample_dealed, y_sample_mask = deal_generated_samples(
                    y_sample, du.dst2idx)
                #
                #### for debug
                ##print('the sample is ')
                ##sample_str=du.indices_to_words(y_sample_dealed, 'dst')
                ##print(sample_str)
                #
                x_to_maxlen = extend_sentence_to_maxlen(
                    x, config.generator.max_length)
                logging.info("calculate the reward")
                rewards = generator.get_reward(
                    x=x,
                    x_to_maxlen=x_to_maxlen,
                    y_sample=y_sample_dealed,
                    y_sample_mask=y_sample_mask,
                    rollnum=config.rollnum,
                    disc=discriminator,
                    max_len=config.discriminator.dis_max_len,
                    bias_num=config.bias_num,
                    data_util=du)

                loss = generator.generate_step_and_update(
                    x, y_sample_dealed, rewards)

                print("the reward is ", rewards)
                print("the loss is ", loss)

                logging.info("save the model into %s" %
                             config.generator.modelFile)
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

                if config.generator.teacher_forcing:

                    logging.info("doiong the teacher forcing begin!")
                    y_ground, y_ground_mask = deal_generated_samples_to_maxlen(
                        y_sample=y_ground,
                        dicts=du.dst2idx,
                        maxlen=config.discriminator.dis_max_len)

                    rewards_ground = np.ones_like(y_ground)
                    rewards_ground = rewards_ground * y_ground_mask
                    loss = generator.generate_step_and_update(
                        x, y_ground, rewards_ground)
                    print("the teacher forcing reward is ", rewards_ground)
                    print("the teacher forcing loss is ", loss)

            generator.saver.save(generator.sess, config.generator.modelFile)

            logging.info("prepare the gan_dis_data begin")
            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.dis_source_data,
                gan_dis_positive_data=config.discriminator.dis_positive_data,
                num=config.generate_num,
                reshuf=True)

            logging.info("generate and the save in to %s." %
                         config.discriminator.dis_negative_data)
            generator.generate_and_save(
                data_util=du,
                infile=config.discriminator.dis_source_data,
                generate_batch=config.discriminator.dis_batch_size,
                outfile=config.discriminator.dis_negative_data)

            logging.info("prepare %d gan_dis_data done!" % data_num)
            logging.info("finetuen the discriminator begin")

            discriminator.train(
                max_epoch=config.gan_dis_iter_num,
                positive_data=config.discriminator.dis_positive_data,
                negative_data=config.discriminator.dis_negative_data,
                source_data=config.discriminator.dis_source_data)
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("finetune the discrimiantor done!")

        logging.info('reinforcement training done!')
Пример #3
0
def train(config):
    logger = logging.getLogger('')

    """Train a model with a config file."""
    du = DataUtil(config=config)
    du.load_vocab(src_vocab=config.src_vocab,
                  dst_vocab=config.dst_vocab,
                  src_vocab_size=config.src_vocab_size_a,
                  dst_vocab_size=config.src_vocab_size_b)

    model = Model(config=config)
    model.build_variational_train_model()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    with model.graph.as_default():
        saver = tf.train.Saver(var_list=tf.global_variables())
        summary_writer = tf.summary.FileWriter(config.train.logdir, graph=model.graph)
        # saver_partial = tf.train.Saver(var_list=[v for v in tf.trainable_variables() if 'Adam' not in v.name])

        with tf.Session(config=sess_config) as sess:
            # Initialize all variables.
            sess.run(tf.global_variables_initializer())
            reload_pretrain_embedding=False
            try:
                # saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir))
                # print('Restore partial model from %s.' % config.train.logdir)
                saver.restore(sess, tf.train.latest_checkpoint(config.train.logdir))
            except:
                logger.info('Failed to reload model.')
                reload_pretrain_embedding=True

            if reload_pretrain_embedding:
                logger.info('reload the pretrained embeddings for the encoders')
                src_pretrained_embedding={}
                dst_pretrained_embedding={}
                try:

                    for l in codecs.open(config.train.src_pretrain_wordemb_path, 'r', 'utf-8'):
                        word_emb=l.strip().split()
                        # print(word_emb)
                        if len(word_emb)== config.hidden_units + 1:
                            word, emb = word_emb[0], np.array(map(float, word_emb[1:]))
                            src_pretrained_embedding[word]=emb

                    for l in codecs.open(config.train.dst_pretrain_wordemb_path, 'r', 'utf-8'):
                        word_emb=l.strip().split()
                        if len(word_emb)==config.hidden_units + 1:
                            word, emb = word_emb[0], np.array(map(float, word_emb[1:]))
                            dst_pretrained_embedding[word]=emb

                    logger.info('reload the word embedding done')

                    tf.get_variable_scope().reuse_variables()
                    src_embed_a=tf.get_variable('enc_aembedding/src_embedding/kernel')
                    src_embed_b=tf.get_variable('enc_bembedding/src_embedding/kernel')

                    dst_embed_a=tf.get_variable('dec_aembedding/dst_embedding/kernel')
                    dst_embed_b=tf.get_variable('dec_bembedding/dst_embedding/kernel')

                    count_a=0
                    src_value_a=sess.run(src_embed_a)
                    dst_value_a=sess.run(dst_embed_a)
                    # print(src_value_a)
                    for word in src_pretrained_embedding:
                        if word in du.src2idx:
                            id = du.src2idx[word]
                            # print(id)
                            src_value_a[id] = src_pretrained_embedding[word]
                            dst_value_a[id] = src_pretrained_embedding[word]
                            count_a += 1
                    sess.run(src_embed_a.assign(src_value_a))
                    sess.run(dst_embed_a.assign(dst_value_a))
                    # print(sess.run(src_embed_a))


                    count_b=0
                    src_value_b = sess.run(src_embed_b)
                    dst_value_b = sess.run(dst_embed_b)
                    for word in dst_pretrained_embedding:
                        if word in du.dst2idx:
                            id = du.dst2idx[word]
                            # print(id)
                            src_value_b[id] = dst_pretrained_embedding[word]
                            dst_value_b[id] = dst_pretrained_embedding[word]
                            count_b += 1
                    sess.run(src_embed_b.assign(src_value_b))
                    sess.run(dst_embed_b.assign(dst_value_b))

                    logger.info('restore %d src_embedding and %d dst_embedding done' %(count_a, count_b))

                except:
                    logger.info('Failed to load the pretriaed embeddings')

            # tmp_writer = codecs.open('tmp_test', 'w', 'utf-8')

            for epoch in range(1, config.train.num_epochs+1):
                for batch in du.get_training_batches_with_buckets():
                    # swap the batch[0] and batch[1] accroding to whether the length of the sequence is odd or even
                    # batch_swap=[]
                    # swap_0 = np.arange(batch[0].shape[1])
                    # swap_1 = np.arange(batch[1].shape[1])
                    #
                    # if len(swap_0) % 2 == 0:
                    #     swap_0[0::2]+=1
                    #     swap_0[1::2]-=1
                    # else:
                    #     swap_0[0:-1:2]+=1
                    #     swap_0[1::2]-=1
                    #
                    # if len(swap_1) % 2 == 0:
                    #     swap_1[0::2]+=1
                    #     swap_1[1::2]-=1
                    # else:
                    #     swap_1[0:-1:2] += 1
                    #     swap_1[1::2] -= 1
                    #
                    # batch_swap.append(batch[0].transpose()[swap_0].transpose())
                    # batch_swap.append(batch[1].transpose()[swap_1].transpose())

                    # print(batch[0])
                    # print(batch_swap[0])

                    # randomly shuffle the batch[0] and batch[1]
                    #batch_shuffle=[]
                    #shuffle_0_indices = np.random.permutation(np.arange(batch[0].shape[1]))
                    #shuffle_1_indices = np.random.permutation(np.arange(batch[1].shape[1]))
                    #batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose())
                    #batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose())


                    def get_shuffle_k_indices(length, shuffle_k):
                        shuffle_k_indices = []
                        rand_start = np.random.randint(shuffle_k)

                        indices_list_start = list(np.random.permutation(np.arange(0, rand_start)))
                        shuffle_k_indices.extend(indices_list_start)

                        for i in range(rand_start, length, shuffle_k):
                            if i + shuffle_k > length:
                                indices_list_i = list(np.random.permutation(np.arange(i, length)))
                            else:
                                indices_list_i = list(np.random.permutation(np.arange(i, i + shuffle_k)))

                            shuffle_k_indices.extend(indices_list_i)

                        return np.array(shuffle_k_indices)

                    batch_shuffle=[]
                    shuffle_0_indices = get_shuffle_k_indices(batch[0].shape[1], config.train.shuffle_k)
                    shuffle_1_indices = get_shuffle_k_indices(batch[1].shape[1], config.train.shuffle_k)
                    #print(shuffle_0_indices)
                    batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose())
                    batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose())

                    start_time = time.time()
                    step = sess.run(model.global_step)

                    step, lr, gnorm_aa, loss_aa, acc_aa, _ = sess.run(
                        [model.global_step, model.learning_rate, model.grads_norm_aa,
                         model.loss_aa, model.acc_aa, model.train_op_aa],
                        feed_dict={model.src_a_pl: batch_shuffle[0], model.dst_a_pl: batch[0]})

                    step, lr, gnorm_bb, loss_bb, acc_bb, _ = sess.run(
                        [model.global_step, model.learning_rate, model.grads_norm_bb,
                         model.loss_bb, model.acc_bb, model.train_op_bb],
                        feed_dict={model.src_b_pl: batch_shuffle[1], model.dst_b_pl: batch[1]})


                    # this step takes too much time
                    generate_ab, generate_ba = sess.run(
                        [model.generate_ab, model.generate_ba],
                        feed_dict={model.src_a_pl: batch[0], model.src_b_pl: batch[1]})

                    generate_ab_dealed, _ = deal_generated_samples(generate_ab, du.dst2idx)
                    generate_ba_dealed, _ = deal_generated_samples(generate_ba, du.src2idx)

                    #for sent in du.indices_to_words(batch[0], o='src'):
                    #    print(sent, file=tmp_writer)
                    #for sent in du.indices_to_words(generate_ab_dealed, o='dst'):
                    #    print(sent, file=tmp_writer)

                    step, acc_ab, loss_ab, _ = sess.run(
                        [model.global_step, model.acc_ab, model.loss_ab, model.train_op_ab],
                        feed_dict={model.src_a_pl:generate_ba_dealed, model.dst_b_pl: batch[1]})

                    step, acc_ba, loss_ba, _ = sess.run(
                        [model.global_step, model.acc_ba, model.loss_ba, model.train_op_ba],
                        feed_dict={model.src_b_pl:generate_ab_dealed, model.dst_a_pl: batch[0]})

                    if step % config.train.disp_freq == 0:
                        logger.info('epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}'
                                    '\tacc: {5:.4f}\tcross_loss: {6:.4f}\tcross_acc: {7:.4f}\ttime: {8:.4f}'
                                    .format(epoch, step, lr, gnorm_aa, loss_aa, acc_aa, loss_ab, acc_ab,
                                            time.time() - start_time))

                    # Save model
                    if step % config.train.save_freq == 0:
                        mp = config.train.logdir + '/model_epoch_%d_step_%d' % (epoch, step)
                        saver.save(sess, mp)
                        logger.info('Save model in %s.' % mp)

            logger.info("Finish training.")
Пример #4
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph=tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size_a,
                      dst_vocab_size=config.src_vocab_size_b)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_variational_train_model()

        generator.init_and_restore(modelFile=config.generator.modelFile)

        dis_filter_sizes = [i for i in range(1, config.discriminator.dis_max_len, 4)]
        dis_num_filters = [(100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4)]

        discriminator = text_DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=3,
            vocab_size_s=config.dst_vocab_size_a,
            vocab_size_t=config.dst_vocab_size_b,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            s_domain_data=config.discriminator.s_domain_data,
            t_domain_data=config.discriminator.t_domain_data,
            s_domain_generated_data=config.discriminator.s_domain_generated_data,
            t_domain_generated_data=config.discriminator.t_domain_generated_data,
            dev_s_domain_data=config.discriminator.dev_s_domain_data,
            dev_t_domain_data=config.discriminator.dev_t_domain_data,
            dev_s_domain_generated_data=config.discriminator.dev_s_domain_generated_data,
            dev_t_domain_generated_data=config.discriminator.dev_t_domain_generated_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope
        )

        batch_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length
        )

        for epoch in range(1, config.gan_iter_num + 1):
            for gen_iter in range(config.gan_gen_iter_num):
                batch = next(batch_iter)
                x, y = batch[0], batch[1]
                generate_ab, generate_ba = generator.generate_step(x, y)

                logging.info("generate the samples")
                generate_ab_dealed, generate_ab_mask = deal_generated_samples(generate_ab, du.dst2idx)
                generate_ba_dealed, generate_ba_mask = deal_generated_samples(generate_ba, du.src2idx)

                
                ## for debug
                #print('the generate_ba_dealed is ')
                #sample_str=du.indices_to_words(generate_ba_dealed, 'src')
                #print(sample_str)

                #print('the generate_ab_dealed is ')
                #sample_str=du.indices_to_words(generate_ab_dealed, 'dst')
                #print(sample_str)
                

                x_to_maxlen = extend_sentence_to_maxlen(x)
                y_to_maxlen = extend_sentence_to_maxlen(y)

                logging.info("calculate the reward")
                rewards_ab = generator.get_reward(x=x,
                                               x_to_maxlen=x_to_maxlen,
                                               y_sample=generate_ab_dealed,
                                               y_sample_mask=generate_ab_mask,
                                               rollnum=config.rollnum,
                                               disc=discriminator,
                                               max_len=config.discriminator.dis_max_len,
                                               bias_num=config.bias_num,
                                               data_util=du,
                                               direction='ab')

                rewards_ba = generator.get_reward(x=y,
                                               x_to_maxlen=y_to_maxlen,
                                               y_sample=generate_ba_dealed,
                                               y_sample_mask=generate_ba_mask,
                                               rollnum=config.rollnum,
                                               disc=discriminator,
                                               max_len=config.discriminator.dis_max_len,
                                               bias_num=config.bias_num,
                                               data_util=du,
                                               direction='ba')
                

                loss_ab = generator.generate_step_and_update(x, generate_ab_dealed, rewards_ab)

                loss_ba = generator.generate_step_and_update(y, generate_ba_dealed, rewards_ba)

                print("the reward for ab and ba is ", rewards_ab, rewards_ba)
                print("the loss is for ab and ba is", loss_ab, loss_ba)

                logging.info("save the model into %s" % config.generator.modelFile)
                generator.saver.save(generator.sess, config.generator.modelFile)


            ####  modified to here, next starts from here

            logging.info("prepare the gan_dis_data begin")
            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.s_domain_data,
                gan_dis_positive_data=config.discriminator.t_domain_data,
                num=config.generate_num,
                reshuf=True
            )
            
            s_domain_data_half = config.discriminator.s_domain_data+'.half'
            t_domain_data_half = config.discriminator.t_domain_data+'.half'

            os.popen('head -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.s_domain_data + ' > ' + s_domain_data_half)
            os.popen('tail -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.t_domain_data + ' > ' + t_domain_data_half)
            
            logging.info("generate and the save t_domain_generated_data in to %s." %config.discriminator.s_domain_generated_data)

            generator.generate_and_save(data_util=du,
                                        infile=s_domain_data_half,
                                        generate_batch=config.discriminator.dis_batch_size,
                                        outfile=config.discriminator.t_domain_generated_data,
                                        direction='ab'
                                      )

            logging.info("generate and the save s_domain_generated_data in to %s." %config.discriminator.t_domain_generated_data)

            generator.generate_and_save(data_util=du,
                                        infile=t_domain_data_half,
                                        generate_batch=config.discriminator.dis_batch_size,
                                        outfile=config.discriminator.s_domain_generated_data,
                                        direction='ba'
                                      )
            
            logging.info("prepare %d gan_dis_data done!" %data_num)
            logging.info("finetuen the discriminator begin")

            discriminator.train(max_epoch=config.gan_dis_iter_num,
                                s_domain_data=config.discriminator.s_domain_data,
                                t_domain_data=config.discriminator.t_domain_data,
                                s_domain_generated_data=config.discriminator.s_domain_generated_data,
                                t_domain_generated_data=config.discriminator.t_domain_generated_data
                                )
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("finetune the discrimiantor done!")

        logging.info('reinforcement training done!')
Пример #5
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph = tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size,
                      dst_vocab_size=config.dst_vocab_size)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_train_model()
        generator.build_generate(max_len=config.generator.max_length,
                                 generate_devices=config.generator.devices,
                                 optimizer=config.generator.optimizer)
        generator.build_rollout_generate(
            max_len=config.generator.max_length,
            roll_generate_devices=config.generator.devices)
        generator.init_and_restore(modelFile=config.generator.modelFile)

        #这个变量是什么?!filters?
        dis_filter_sizes = [
            i for i in range(1, config.discriminator.dis_max_len, 4)
        ]
        dis_num_filters = [
            (100 + i * 10)
            for i in range(1, config.discriminator.dis_max_len, 4)
        ]

        discriminator = DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=2,
            vocab_size=config.dst_vocab_size,
            vocab_size_s=config.src_vocab_size,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            positive_data=config.discriminator.dis_positive_data,
            negative_data=config.discriminator.dis_negative_data,
            source_data=config.discriminator.dis_source_data,
            dev_positive_data=config.discriminator.dis_dev_positive_data,
            dev_negative_data=config.discriminator.dis_dev_negative_data,
            dev_source_data=config.discriminator.dis_dev_source_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope)

        batch_train_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length)

        max_SARI_results = 0.32  #!!
        max_BLEU_results = 0.77  #!!

        def evaluation_result(generator, config):
            nonlocal max_SARI_results
            nonlocal max_BLEU_results

            # 在test dataset 上开始验证
            logging.info("Max_SARI_results: {}".format(max_SARI_results))
            logging.info("Max_BLEU_results: {}".format(max_BLEU_results))
            output_t = "prepare_data/test.8turkers.clean.out.gan"

            # Beam Search 8.turkers dataset
            evaluator = Evaluator(config=config, out_file=output_t)
            #logging.info("Evaluate on BLEU and SARI")
            SARI_results, BLEU_results = evaluator.translate()
            logging.info(" Current_SARI is {} \n Current_BLEU is {}".format(
                SARI_results, BLEU_results))

            if SARI_results >= max_SARI_results or BLEU_results >= max_BLEU_results:
                if SARI_results >= max_SARI_results:
                    max_SARI_results = SARI_results
                    logging.info("SARI Update Successfully !!!")
                if BLEU_results >= max_BLEU_results:
                    logging.info("BLEU Update Successfully !!!")
                    max_BLEU_results = BLEU_results
                return True
            else:
                return False

        for epoch in range(1, config.gan_iter_num + 1):  #10000
            for gen_iter in range(config.gan_gen_iter_num):  #1
                batch_train = next(batch_train_iter)
                x, y_ground = batch_train[0], batch_train[1]
                y_sample = generator.generate_step(x)
                logging.info("1. Policy Gradient Training !!!")
                y_sample_dealed, y_sample_mask = deal_generated_samples(
                    y_sample, du.dst2idx)  #将y_sample数字矩阵用0补齐长度
                x_to_maxlen = extend_sentence_to_maxlen(
                    x, config.generator.max_length)  #将x数字矩阵用0补齐长度
                x_str = du.indices_to_words(x, 'dst')
                ground_str = du.indices_to_words(y_ground, 'dst')
                sample_str = du.indices_to_words(y_sample, 'dst')

                # Rewards = D(Discriminator) + Q(BLEU socres)
                logging.info("2. Calculate the Reward !!!")
                rewards = generator.get_reward(
                    x=x,
                    x_to_maxlen=x_to_maxlen,
                    y_sample=y_sample_dealed,
                    y_sample_mask=y_sample_mask,
                    rollnum=config.rollnum,
                    disc=discriminator,
                    max_len=config.discriminator.dis_max_len,
                    bias_num=config.bias_num,
                    data_util=du)
                # Police Gradient 更新Generator模型
                logging.info("3. Update the Generator Model !!!")
                loss = generator.generate_step_and_update(
                    x, y_sample_dealed, rewards)
                #logging.info("The reward is ",rewards)
                #logging.info("The loss is ",loss)

                #update_or_not_update=evaluation_result(generator,config)
                #if update_or_not_update:
                # 保存Generator模型
                logging.info("4. Save the Generator model into %s" %
                             config.generator.modelFile)
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

                if config.generator.teacher_forcing:

                    logging.info("5. Doing the Teacher Forcing begin!")
                    y_ground, y_ground_mask = deal_generated_samples_to_maxlen(
                        y_sample=y_ground,
                        dicts=du.dst2idx,
                        maxlen=config.discriminator.dis_max_len)

                    rewards_ground = np.ones_like(y_ground)
                    rewards_ground = rewards_ground * y_ground_mask
                    loss = generator.generate_step_and_update(
                        x, y_ground, rewards_ground)
                    #logging.info("The teacher forcing reward is ", rewards_ground)
                    #logging.info("The teacher forcing loss is ", loss)

            logging.info("5. Evaluation SARI and BLEU")
            update_or_not_update = evaluation_result(generator, config)
            if update_or_not_update:
                #保存Generator模型
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.dis_source_data,
                gan_dis_positive_data=config.discriminator.dis_positive_data,
                num=config.generate_num,
                reshuf=True)

            logging.info("8.Generate  Negative Dataset for Discriminator !!!")
            # 生成negative数据集
            generator.generate_and_save(
                data_util=du,
                infile=config.discriminator.dis_source_data,
                generate_batch=config.discriminator.dis_batch_size,
                outfile=config.discriminator.dis_negative_data)

            logging.info("9.Negative Dataset was save in to %s." %
                         config.discriminator.dis_negative_data)
            logging.info("10.Finetuen the discriminator begin !!!!!")

            discriminator.train(
                max_epoch=config.gan_dis_iter_num,
                positive_data=config.discriminator.dis_positive_data,
                negative_data=config.discriminator.dis_negative_data,
                source_data=config.discriminator.dis_source_data)
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("11.Finetune the discrimiantor done !!!!")

        logging.info('Reinforcement training done!')