Exemple #1
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph = tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size,
                      dst_vocab_size=config.dst_vocab_size)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_train_model()
        generator.build_generate(max_len=config.generator.max_length,
                                 generate_devices=config.generator.devices,
                                 optimizer=config.generator.optimizer)

        generator.build_rollout_generate(
            max_len=config.generator.max_length,
            roll_generate_devices=config.generator.devices)

        generator.init_and_restore(modelFile=config.generator.modelFile)

        dis_filter_sizes = [
            i for i in range(1, config.discriminator.dis_max_len, 4)
        ]
        dis_num_filters = [
            (100 + i * 10)
            for i in range(1, config.discriminator.dis_max_len, 4)
        ]

        discriminator = DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=2,
            vocab_size=config.dst_vocab_size,
            vocab_size_s=config.src_vocab_size,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            positive_data=config.discriminator.dis_positive_data,
            negative_data=config.discriminator.dis_negative_data,
            source_data=config.discriminator.dis_source_data,
            dev_positive_data=config.discriminator.dis_dev_positive_data,
            dev_negative_data=config.discriminator.dis_dev_negative_data,
            dev_source_data=config.discriminator.dis_dev_source_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope)

        batch_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length)

        for epoch in range(1, config.gan_iter_num + 1):
            for gen_iter in range(config.gan_gen_iter_num):
                batch = next(batch_iter)
                x, y_ground = batch[0], batch[1]
                y_sample = generator.generate_step(x)
                logging.info("generate the samples")
                y_sample_dealed, y_sample_mask = deal_generated_samples(
                    y_sample, du.dst2idx)
                #
                #### for debug
                ##print('the sample is ')
                ##sample_str=du.indices_to_words(y_sample_dealed, 'dst')
                ##print(sample_str)
                #
                x_to_maxlen = extend_sentence_to_maxlen(
                    x, config.generator.max_length)
                logging.info("calculate the reward")
                rewards = generator.get_reward(
                    x=x,
                    x_to_maxlen=x_to_maxlen,
                    y_sample=y_sample_dealed,
                    y_sample_mask=y_sample_mask,
                    rollnum=config.rollnum,
                    disc=discriminator,
                    max_len=config.discriminator.dis_max_len,
                    bias_num=config.bias_num,
                    data_util=du)

                loss = generator.generate_step_and_update(
                    x, y_sample_dealed, rewards)

                print("the reward is ", rewards)
                print("the loss is ", loss)

                logging.info("save the model into %s" %
                             config.generator.modelFile)
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

                if config.generator.teacher_forcing:

                    logging.info("doiong the teacher forcing begin!")
                    y_ground, y_ground_mask = deal_generated_samples_to_maxlen(
                        y_sample=y_ground,
                        dicts=du.dst2idx,
                        maxlen=config.discriminator.dis_max_len)

                    rewards_ground = np.ones_like(y_ground)
                    rewards_ground = rewards_ground * y_ground_mask
                    loss = generator.generate_step_and_update(
                        x, y_ground, rewards_ground)
                    print("the teacher forcing reward is ", rewards_ground)
                    print("the teacher forcing loss is ", loss)

            generator.saver.save(generator.sess, config.generator.modelFile)

            logging.info("prepare the gan_dis_data begin")
            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.dis_source_data,
                gan_dis_positive_data=config.discriminator.dis_positive_data,
                num=config.generate_num,
                reshuf=True)

            logging.info("generate and the save in to %s." %
                         config.discriminator.dis_negative_data)
            generator.generate_and_save(
                data_util=du,
                infile=config.discriminator.dis_source_data,
                generate_batch=config.discriminator.dis_batch_size,
                outfile=config.discriminator.dis_negative_data)

            logging.info("prepare %d gan_dis_data done!" % data_num)
            logging.info("finetuen the discriminator begin")

            discriminator.train(
                max_epoch=config.gan_dis_iter_num,
                positive_data=config.discriminator.dis_positive_data,
                negative_data=config.discriminator.dis_negative_data,
                source_data=config.discriminator.dis_source_data)
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("finetune the discrimiantor done!")

        logging.info('reinforcement training done!')
Exemple #2
0
def main(argv):
    #####################################   create the session  ##################################################################

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1.0
    config.allow_soft_placement = True

    is_gan_train = FLAGS.is_gan_train
    is_decode = FLAGS.is_decode
    is_generator_train = FLAGS.is_generator_train
    is_discriminator_train = FLAGS.is_discriminator_train

    ######################################  pretraining  the generator ######################################################################
    batch_size = FLAGS.batch_size
    source_dict = FLAGS.source_dict
    target_dict = FLAGS.target_dict
    train_data_source = FLAGS.train_data_source
    train_data_target = FLAGS.train_data_target
    n_words_src = FLAGS.n_words_src
    n_words_trg = FLAGS.n_words_trg
    gpu_device = FLAGS.gpu_device
    dim_word = FLAGS.dim_word
    dim = FLAGS.dim
    max_len = FLAGS.max_len
    optimizer = FLAGS.optimizer
    precision = FLAGS.precision
    clip_c = FLAGS.clip_c
    max_epoches = FLAGS.max_epoches
    reshuffle = FLAGS.reshuffle
    saveto = FLAGS.saveto
    saveFreq = FLAGS.saveFreq
    dispFreq = FLAGS.dispFreq
    sampleFreq = FLAGS.sampleFreq
    gen_reload = FLAGS.gen_reload

    gan_gen_batch_size = FLAGS.gan_gen_batch_size

    sess = tf.Session(config=config)
    with tf.variable_scope('generate'):
        generator = GenNmt(sess=sess,
                           batch_size=batch_size,
                           source_dict=source_dict,
                           target_dict=target_dict,
                           train_data_source=train_data_source,
                           train_data_target=train_data_target,
                           n_words_src=n_words_src,
                           n_words_trg=n_words_trg,
                           gpu_device=gpu_device,
                           dim_word=dim_word,
                           dim=dim,
                           max_len=max_len,
                           clip_c=clip_c,
                           max_epoches=max_epoches,
                           reshuffle=reshuffle,
                           saveto=saveto,
                           saveFreq=saveFreq,
                           dispFreq=dispFreq,
                           sampleFreq=sampleFreq,
                           optimizer=optimizer,
                           precision=precision,
                           gen_reload=gen_reload)

        if is_decode:
            decode_file = FLAGS.decode_file
            decode_result_file = FLAGS.decode_result_file
            decode_gpu = FLAGS.decode_gpu
            decode_is_print = FLAGS.decode_is_print
            #print('decoding the file %s on %s' % (decode_file, decode_gpu))
            generator.gen_sample(decode_file,
                                 decode_result_file,
                                 10,
                                 is_print=decode_is_print,
                                 gpu_device=decode_gpu)

            return 0

        elif is_generator_train:
            print('train the model and build the generate')
            generator.build_train_model()
            generator.gen_train()
            generator.build_generate(maxlen=max_len,
                                     generate_batch=gan_gen_batch_size,
                                     optimizer='adam')
            generator.rollout_generate(generate_batch=gan_gen_batch_size)
            print('done')

        else:
            print('build the generate without training')
            generator.build_train_model()
            generator.build_generate(maxlen=max_len,
                                     generate_batch=gan_gen_batch_size,
                                     optimizer='adam')
            generator.rollout_generate(generate_batch=gan_gen_batch_size)
            generator.init_and_reload()

            #print('building testing ')
            #generator.build_test()
            #print('done')

## #################################################### pretraining the discriminator ##################################################################

    if is_discriminator_train or is_gan_train:

        dis_max_epoches = FLAGS.dis_epoches
        dis_dispFreq = FLAGS.dis_dispFreq
        dis_saveFreq = FLAGS.dis_saveFreq
        dis_devFreq = FLAGS.dis_devFreq
        dis_batch_size = FLAGS.dis_batch_size
        dis_saveto = FLAGS.dis_saveto
        dis_reshuffle = FLAGS.dis_reshuffle
        dis_gpu_device = FLAGS.dis_gpu_device
        dis_max_len = FLAGS.dis_max_len
        positive_data = FLAGS.dis_positive_data
        negative_data = FLAGS.dis_negative_data
        dis_dev_positive_data = FLAGS.dis_dev_positive_data
        dis_dev_negative_data = FLAGS.dis_dev_negative_data
        dis_reload = FLAGS.dis_reload

        dis_filter_sizes = [i for i in range(1, dis_max_len, 4)]
        dis_num_filters = [(100 + i * 10) for i in range(1, dis_max_len, 4)]

        discriminator = DisCNN(sess=sess,
                               max_len=dis_max_len,
                               num_classes=2,
                               vocab_size=n_words_trg,
                               batch_size=dis_batch_size,
                               dim_word=dim_word,
                               filter_sizes=dis_filter_sizes,
                               num_filters=dis_num_filters,
                               source_dict=source_dict,
                               target_dict=target_dict,
                               gpu_device=dis_gpu_device,
                               positive_data=positive_data,
                               negative_data=negative_data,
                               dev_positive_data=dis_dev_positive_data,
                               dev_negative_data=dis_dev_negative_data,
                               max_epoches=dis_max_epoches,
                               dispFreq=dis_dispFreq,
                               saveFreq=dis_saveFreq,
                               devFreq=dis_devFreq,
                               saveto=dis_saveto,
                               reload=dis_reload,
                               clip_c=clip_c,
                               optimizer=optimizer,
                               reshuffle=dis_reshuffle,
                               scope='discnn')

        if is_discriminator_train:
            print('train the discriminator')
            discriminator.train()
            print('done')

        else:
            print('building the discriminator without training done')
            print('done')

#   ####################################### Start Reinforcement Training #######################################################################################
        if is_gan_train:

            gan_total_iter_num = FLAGS.gan_total_iter_num
            gan_gen_iter_num = FLAGS.gan_gen_iter_num
            gan_dis_iter_num = FLAGS.gan_dis_iter_num

            gan_gen_reshuffle = FLAGS.gan_gen_reshuffle
            gan_gen_source_data = FLAGS.gan_gen_source_data

            gan_dis_source_data = FLAGS.gan_dis_source_data
            gan_dis_positive_data = FLAGS.gan_dis_positive_data
            gan_dis_negative_data = FLAGS.gan_dis_negative_data
            gan_dis_reshuffle = FLAGS.gan_dis_reshuffle
            gan_dis_batch_size = FLAGS.gan_dis_batch_size
            gan_dispFreq = FLAGS.gan_dispFreq
            gan_saveFreq = FLAGS.gan_saveFreq
            roll_num = FLAGS.rollnum
            generate_num = FLAGS.generate_num

            print('reinforcement training begin...')

            for gan_iter in range(gan_total_iter_num):

                print('reinforcement training for %d epoch' % gan_iter)
                gen_train_it = gen_train_iter(gan_gen_source_data,
                                              gan_gen_reshuffle,
                                              generator.dictionaries[0],
                                              n_words_src, gan_gen_batch_size,
                                              max_len)

                print('finetune the generator begin...')
                for gen_iter in range(gan_gen_iter_num):
                    x, _ = next(gen_train_it)
                    x, x_mask = prepare_multiple_sentence(x, maxlen=max_len)
                    y_sample_out = generator.generate_step(x, x_mask)
                    #print(y_sample_out)
                    #tmp_str=print_string('y', y_sample_out[0], generator.worddicts_r)
                    #tmp_str_2=print_string('y', y_sample_out[1], generator.worddicts_r)
                    #print tmp_str
                    #print tmp_str_2

                    y_input, y_input_mask = deal_generated_y_sentence(
                        y_sample_out, generator.worddicts, precision=precision)
                    rewards = generator.get_reward(x, x_mask, y_input,
                                                   y_input_mask, roll_num,
                                                   discriminator)
                    print('the reward is ', rewards)
                    loss = generator.generate_step_and_update(
                        x, x_mask, y_input, rewards)
                    if gen_iter % gan_dispFreq == 0:
                        print('the %d iter, seen %d examples, loss is %f ' %
                              (gen_iter,
                               ((gan_iter) * gan_gen_iter_num + gen_iter + 1),
                               loss))
                    if gen_iter % gan_saveFreq == 0:
                        generator.saver.save(generator.sess, generator.saveto)
                        print('save the parameters when seen %d examples ' %
                              ((gan_iter) * gan_gen_iter_num + gan_iter + 1))

                generator.saver.save(generator.sess, generator.saveto)
                print('finetune the generator done!')

                #print('self testing')
                #generator.self_test(gan_dis_source_data, gan_dis_negative_data)
                #print('self testing done!')

                print('prepare the gan_dis_data begin ')
                data_num = prepare_gan_dis_data(train_data_source,
                                                train_data_target,
                                                gan_dis_source_data,
                                                gan_dis_positive_data,
                                                num=generate_num,
                                                reshuf=True)
                print(
                    'prepare the gan_dis_data done, the num of the gan_dis_data is %d'
                    % data_num)

                print('generator generate and save to %s' %
                      gan_dis_negative_data)
                generator.generate_and_save(gan_dis_source_data,
                                            gan_dis_negative_data,
                                            generate_batch=gan_gen_batch_size)
                print('done!')

                print('finetune the discriminator begin...')
                discriminator.train(max_epoch=gan_dis_iter_num,
                                    positive_data=gan_dis_positive_data,
                                    negative_data=gan_dis_negative_data)
                print('finetune the discriminator done!')

            print('reinforcement training done')
Exemple #3
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph=tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size_a,
                      dst_vocab_size=config.src_vocab_size_b)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_variational_train_model()

        generator.init_and_restore(modelFile=config.generator.modelFile)

        dis_filter_sizes = [i for i in range(1, config.discriminator.dis_max_len, 4)]
        dis_num_filters = [(100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4)]

        discriminator = text_DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=3,
            vocab_size_s=config.dst_vocab_size_a,
            vocab_size_t=config.dst_vocab_size_b,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            s_domain_data=config.discriminator.s_domain_data,
            t_domain_data=config.discriminator.t_domain_data,
            s_domain_generated_data=config.discriminator.s_domain_generated_data,
            t_domain_generated_data=config.discriminator.t_domain_generated_data,
            dev_s_domain_data=config.discriminator.dev_s_domain_data,
            dev_t_domain_data=config.discriminator.dev_t_domain_data,
            dev_s_domain_generated_data=config.discriminator.dev_s_domain_generated_data,
            dev_t_domain_generated_data=config.discriminator.dev_t_domain_generated_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope
        )

        batch_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length
        )

        for epoch in range(1, config.gan_iter_num + 1):
            for gen_iter in range(config.gan_gen_iter_num):
                batch = next(batch_iter)
                x, y = batch[0], batch[1]
                generate_ab, generate_ba = generator.generate_step(x, y)

                logging.info("generate the samples")
                generate_ab_dealed, generate_ab_mask = deal_generated_samples(generate_ab, du.dst2idx)
                generate_ba_dealed, generate_ba_mask = deal_generated_samples(generate_ba, du.src2idx)

                
                ## for debug
                #print('the generate_ba_dealed is ')
                #sample_str=du.indices_to_words(generate_ba_dealed, 'src')
                #print(sample_str)

                #print('the generate_ab_dealed is ')
                #sample_str=du.indices_to_words(generate_ab_dealed, 'dst')
                #print(sample_str)
                

                x_to_maxlen = extend_sentence_to_maxlen(x)
                y_to_maxlen = extend_sentence_to_maxlen(y)

                logging.info("calculate the reward")
                rewards_ab = generator.get_reward(x=x,
                                               x_to_maxlen=x_to_maxlen,
                                               y_sample=generate_ab_dealed,
                                               y_sample_mask=generate_ab_mask,
                                               rollnum=config.rollnum,
                                               disc=discriminator,
                                               max_len=config.discriminator.dis_max_len,
                                               bias_num=config.bias_num,
                                               data_util=du,
                                               direction='ab')

                rewards_ba = generator.get_reward(x=y,
                                               x_to_maxlen=y_to_maxlen,
                                               y_sample=generate_ba_dealed,
                                               y_sample_mask=generate_ba_mask,
                                               rollnum=config.rollnum,
                                               disc=discriminator,
                                               max_len=config.discriminator.dis_max_len,
                                               bias_num=config.bias_num,
                                               data_util=du,
                                               direction='ba')
                

                loss_ab = generator.generate_step_and_update(x, generate_ab_dealed, rewards_ab)

                loss_ba = generator.generate_step_and_update(y, generate_ba_dealed, rewards_ba)

                print("the reward for ab and ba is ", rewards_ab, rewards_ba)
                print("the loss is for ab and ba is", loss_ab, loss_ba)

                logging.info("save the model into %s" % config.generator.modelFile)
                generator.saver.save(generator.sess, config.generator.modelFile)


            ####  modified to here, next starts from here

            logging.info("prepare the gan_dis_data begin")
            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.s_domain_data,
                gan_dis_positive_data=config.discriminator.t_domain_data,
                num=config.generate_num,
                reshuf=True
            )
            
            s_domain_data_half = config.discriminator.s_domain_data+'.half'
            t_domain_data_half = config.discriminator.t_domain_data+'.half'

            os.popen('head -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.s_domain_data + ' > ' + s_domain_data_half)
            os.popen('tail -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.t_domain_data + ' > ' + t_domain_data_half)
            
            logging.info("generate and the save t_domain_generated_data in to %s." %config.discriminator.s_domain_generated_data)

            generator.generate_and_save(data_util=du,
                                        infile=s_domain_data_half,
                                        generate_batch=config.discriminator.dis_batch_size,
                                        outfile=config.discriminator.t_domain_generated_data,
                                        direction='ab'
                                      )

            logging.info("generate and the save s_domain_generated_data in to %s." %config.discriminator.t_domain_generated_data)

            generator.generate_and_save(data_util=du,
                                        infile=t_domain_data_half,
                                        generate_batch=config.discriminator.dis_batch_size,
                                        outfile=config.discriminator.s_domain_generated_data,
                                        direction='ba'
                                      )
            
            logging.info("prepare %d gan_dis_data done!" %data_num)
            logging.info("finetuen the discriminator begin")

            discriminator.train(max_epoch=config.gan_dis_iter_num,
                                s_domain_data=config.discriminator.s_domain_data,
                                t_domain_data=config.discriminator.t_domain_data,
                                s_domain_generated_data=config.discriminator.s_domain_generated_data,
                                t_domain_generated_data=config.discriminator.t_domain_generated_data
                                )
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("finetune the discrimiantor done!")

        logging.info('reinforcement training done!')
Exemple #4
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph = tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size,
                      dst_vocab_size=config.dst_vocab_size)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_train_model()
        generator.build_generate(max_len=config.generator.max_length,
                                 generate_devices=config.generator.devices,
                                 optimizer=config.generator.optimizer)
        generator.build_rollout_generate(
            max_len=config.generator.max_length,
            roll_generate_devices=config.generator.devices)
        generator.init_and_restore(modelFile=config.generator.modelFile)

        #这个变量是什么?!filters?
        dis_filter_sizes = [
            i for i in range(1, config.discriminator.dis_max_len, 4)
        ]
        dis_num_filters = [
            (100 + i * 10)
            for i in range(1, config.discriminator.dis_max_len, 4)
        ]

        discriminator = DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=2,
            vocab_size=config.dst_vocab_size,
            vocab_size_s=config.src_vocab_size,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            positive_data=config.discriminator.dis_positive_data,
            negative_data=config.discriminator.dis_negative_data,
            source_data=config.discriminator.dis_source_data,
            dev_positive_data=config.discriminator.dis_dev_positive_data,
            dev_negative_data=config.discriminator.dis_dev_negative_data,
            dev_source_data=config.discriminator.dis_dev_source_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope)

        batch_train_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length)

        max_SARI_results = 0.32  #!!
        max_BLEU_results = 0.77  #!!

        def evaluation_result(generator, config):
            nonlocal max_SARI_results
            nonlocal max_BLEU_results

            # 在test dataset 上开始验证
            logging.info("Max_SARI_results: {}".format(max_SARI_results))
            logging.info("Max_BLEU_results: {}".format(max_BLEU_results))
            output_t = "prepare_data/test.8turkers.clean.out.gan"

            # Beam Search 8.turkers dataset
            evaluator = Evaluator(config=config, out_file=output_t)
            #logging.info("Evaluate on BLEU and SARI")
            SARI_results, BLEU_results = evaluator.translate()
            logging.info(" Current_SARI is {} \n Current_BLEU is {}".format(
                SARI_results, BLEU_results))

            if SARI_results >= max_SARI_results or BLEU_results >= max_BLEU_results:
                if SARI_results >= max_SARI_results:
                    max_SARI_results = SARI_results
                    logging.info("SARI Update Successfully !!!")
                if BLEU_results >= max_BLEU_results:
                    logging.info("BLEU Update Successfully !!!")
                    max_BLEU_results = BLEU_results
                return True
            else:
                return False

        for epoch in range(1, config.gan_iter_num + 1):  #10000
            for gen_iter in range(config.gan_gen_iter_num):  #1
                batch_train = next(batch_train_iter)
                x, y_ground = batch_train[0], batch_train[1]
                y_sample = generator.generate_step(x)
                logging.info("1. Policy Gradient Training !!!")
                y_sample_dealed, y_sample_mask = deal_generated_samples(
                    y_sample, du.dst2idx)  #将y_sample数字矩阵用0补齐长度
                x_to_maxlen = extend_sentence_to_maxlen(
                    x, config.generator.max_length)  #将x数字矩阵用0补齐长度
                x_str = du.indices_to_words(x, 'dst')
                ground_str = du.indices_to_words(y_ground, 'dst')
                sample_str = du.indices_to_words(y_sample, 'dst')

                # Rewards = D(Discriminator) + Q(BLEU socres)
                logging.info("2. Calculate the Reward !!!")
                rewards = generator.get_reward(
                    x=x,
                    x_to_maxlen=x_to_maxlen,
                    y_sample=y_sample_dealed,
                    y_sample_mask=y_sample_mask,
                    rollnum=config.rollnum,
                    disc=discriminator,
                    max_len=config.discriminator.dis_max_len,
                    bias_num=config.bias_num,
                    data_util=du)
                # Police Gradient 更新Generator模型
                logging.info("3. Update the Generator Model !!!")
                loss = generator.generate_step_and_update(
                    x, y_sample_dealed, rewards)
                #logging.info("The reward is ",rewards)
                #logging.info("The loss is ",loss)

                #update_or_not_update=evaluation_result(generator,config)
                #if update_or_not_update:
                # 保存Generator模型
                logging.info("4. Save the Generator model into %s" %
                             config.generator.modelFile)
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

                if config.generator.teacher_forcing:

                    logging.info("5. Doing the Teacher Forcing begin!")
                    y_ground, y_ground_mask = deal_generated_samples_to_maxlen(
                        y_sample=y_ground,
                        dicts=du.dst2idx,
                        maxlen=config.discriminator.dis_max_len)

                    rewards_ground = np.ones_like(y_ground)
                    rewards_ground = rewards_ground * y_ground_mask
                    loss = generator.generate_step_and_update(
                        x, y_ground, rewards_ground)
                    #logging.info("The teacher forcing reward is ", rewards_ground)
                    #logging.info("The teacher forcing loss is ", loss)

            logging.info("5. Evaluation SARI and BLEU")
            update_or_not_update = evaluation_result(generator, config)
            if update_or_not_update:
                #保存Generator模型
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.dis_source_data,
                gan_dis_positive_data=config.discriminator.dis_positive_data,
                num=config.generate_num,
                reshuf=True)

            logging.info("8.Generate  Negative Dataset for Discriminator !!!")
            # 生成negative数据集
            generator.generate_and_save(
                data_util=du,
                infile=config.discriminator.dis_source_data,
                generate_batch=config.discriminator.dis_batch_size,
                outfile=config.discriminator.dis_negative_data)

            logging.info("9.Negative Dataset was save in to %s." %
                         config.discriminator.dis_negative_data)
            logging.info("10.Finetuen the discriminator begin !!!!!")

            discriminator.train(
                max_epoch=config.gan_dis_iter_num,
                positive_data=config.discriminator.dis_positive_data,
                negative_data=config.discriminator.dis_negative_data,
                source_data=config.discriminator.dis_source_data)
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("11.Finetune the discrimiantor done !!!!")

        logging.info('Reinforcement training done!')