Пример #1
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph=tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size_a,
                      dst_vocab_size=config.src_vocab_size_b)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_variational_train_model()

        generator.init_and_restore(modelFile=config.generator.modelFile)

        dis_filter_sizes = [i for i in range(1, config.discriminator.dis_max_len, 4)]
        dis_num_filters = [(100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4)]

        discriminator = text_DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=3,
            vocab_size_s=config.dst_vocab_size_a,
            vocab_size_t=config.dst_vocab_size_b,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            s_domain_data=config.discriminator.s_domain_data,
            t_domain_data=config.discriminator.t_domain_data,
            s_domain_generated_data=config.discriminator.s_domain_generated_data,
            t_domain_generated_data=config.discriminator.t_domain_generated_data,
            dev_s_domain_data=config.discriminator.dev_s_domain_data,
            dev_t_domain_data=config.discriminator.dev_t_domain_data,
            dev_s_domain_generated_data=config.discriminator.dev_s_domain_generated_data,
            dev_t_domain_generated_data=config.discriminator.dev_t_domain_generated_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope
        )

        batch_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length
        )

        for epoch in range(1, config.gan_iter_num + 1):
            for gen_iter in range(config.gan_gen_iter_num):
                batch = next(batch_iter)
                x, y = batch[0], batch[1]
                generate_ab, generate_ba = generator.generate_step(x, y)

                logging.info("generate the samples")
                generate_ab_dealed, generate_ab_mask = deal_generated_samples(generate_ab, du.dst2idx)
                generate_ba_dealed, generate_ba_mask = deal_generated_samples(generate_ba, du.src2idx)

                
                ## for debug
                #print('the generate_ba_dealed is ')
                #sample_str=du.indices_to_words(generate_ba_dealed, 'src')
                #print(sample_str)

                #print('the generate_ab_dealed is ')
                #sample_str=du.indices_to_words(generate_ab_dealed, 'dst')
                #print(sample_str)
                

                x_to_maxlen = extend_sentence_to_maxlen(x)
                y_to_maxlen = extend_sentence_to_maxlen(y)

                logging.info("calculate the reward")
                rewards_ab = generator.get_reward(x=x,
                                               x_to_maxlen=x_to_maxlen,
                                               y_sample=generate_ab_dealed,
                                               y_sample_mask=generate_ab_mask,
                                               rollnum=config.rollnum,
                                               disc=discriminator,
                                               max_len=config.discriminator.dis_max_len,
                                               bias_num=config.bias_num,
                                               data_util=du,
                                               direction='ab')

                rewards_ba = generator.get_reward(x=y,
                                               x_to_maxlen=y_to_maxlen,
                                               y_sample=generate_ba_dealed,
                                               y_sample_mask=generate_ba_mask,
                                               rollnum=config.rollnum,
                                               disc=discriminator,
                                               max_len=config.discriminator.dis_max_len,
                                               bias_num=config.bias_num,
                                               data_util=du,
                                               direction='ba')
                

                loss_ab = generator.generate_step_and_update(x, generate_ab_dealed, rewards_ab)

                loss_ba = generator.generate_step_and_update(y, generate_ba_dealed, rewards_ba)

                print("the reward for ab and ba is ", rewards_ab, rewards_ba)
                print("the loss is for ab and ba is", loss_ab, loss_ba)

                logging.info("save the model into %s" % config.generator.modelFile)
                generator.saver.save(generator.sess, config.generator.modelFile)


            ####  modified to here, next starts from here

            logging.info("prepare the gan_dis_data begin")
            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.s_domain_data,
                gan_dis_positive_data=config.discriminator.t_domain_data,
                num=config.generate_num,
                reshuf=True
            )
            
            s_domain_data_half = config.discriminator.s_domain_data+'.half'
            t_domain_data_half = config.discriminator.t_domain_data+'.half'

            os.popen('head -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.s_domain_data + ' > ' + s_domain_data_half)
            os.popen('tail -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.t_domain_data + ' > ' + t_domain_data_half)
            
            logging.info("generate and the save t_domain_generated_data in to %s." %config.discriminator.s_domain_generated_data)

            generator.generate_and_save(data_util=du,
                                        infile=s_domain_data_half,
                                        generate_batch=config.discriminator.dis_batch_size,
                                        outfile=config.discriminator.t_domain_generated_data,
                                        direction='ab'
                                      )

            logging.info("generate and the save s_domain_generated_data in to %s." %config.discriminator.t_domain_generated_data)

            generator.generate_and_save(data_util=du,
                                        infile=t_domain_data_half,
                                        generate_batch=config.discriminator.dis_batch_size,
                                        outfile=config.discriminator.s_domain_generated_data,
                                        direction='ba'
                                      )
            
            logging.info("prepare %d gan_dis_data done!" %data_num)
            logging.info("finetuen the discriminator begin")

            discriminator.train(max_epoch=config.gan_dis_iter_num,
                                s_domain_data=config.discriminator.s_domain_data,
                                t_domain_data=config.discriminator.t_domain_data,
                                s_domain_generated_data=config.discriminator.s_domain_generated_data,
                                t_domain_generated_data=config.discriminator.t_domain_generated_data
                                )
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("finetune the discrimiantor done!")

        logging.info('reinforcement training done!')
Пример #2
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph = tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size,
                      dst_vocab_size=config.dst_vocab_size)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_train_model()
        generator.build_generate(max_len=config.generator.max_length,
                                 generate_devices=config.generator.devices,
                                 optimizer=config.generator.optimizer)

        generator.build_rollout_generate(
            max_len=config.generator.max_length,
            roll_generate_devices=config.generator.devices)

        generator.init_and_restore(modelFile=config.generator.modelFile)

        dis_filter_sizes = [
            i for i in range(1, config.discriminator.dis_max_len, 4)
        ]
        dis_num_filters = [
            (100 + i * 10)
            for i in range(1, config.discriminator.dis_max_len, 4)
        ]

        discriminator = DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=2,
            vocab_size=config.dst_vocab_size,
            vocab_size_s=config.src_vocab_size,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            positive_data=config.discriminator.dis_positive_data,
            negative_data=config.discriminator.dis_negative_data,
            source_data=config.discriminator.dis_source_data,
            dev_positive_data=config.discriminator.dis_dev_positive_data,
            dev_negative_data=config.discriminator.dis_dev_negative_data,
            dev_source_data=config.discriminator.dis_dev_source_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope)

        batch_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length)

        for epoch in range(1, config.gan_iter_num + 1):
            for gen_iter in range(config.gan_gen_iter_num):
                batch = next(batch_iter)
                x, y_ground = batch[0], batch[1]
                y_sample = generator.generate_step(x)
                logging.info("generate the samples")
                y_sample_dealed, y_sample_mask = deal_generated_samples(
                    y_sample, du.dst2idx)
                #
                #### for debug
                ##print('the sample is ')
                ##sample_str=du.indices_to_words(y_sample_dealed, 'dst')
                ##print(sample_str)
                #
                x_to_maxlen = extend_sentence_to_maxlen(
                    x, config.generator.max_length)
                logging.info("calculate the reward")
                rewards = generator.get_reward(
                    x=x,
                    x_to_maxlen=x_to_maxlen,
                    y_sample=y_sample_dealed,
                    y_sample_mask=y_sample_mask,
                    rollnum=config.rollnum,
                    disc=discriminator,
                    max_len=config.discriminator.dis_max_len,
                    bias_num=config.bias_num,
                    data_util=du)

                loss = generator.generate_step_and_update(
                    x, y_sample_dealed, rewards)

                print("the reward is ", rewards)
                print("the loss is ", loss)

                logging.info("save the model into %s" %
                             config.generator.modelFile)
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

                if config.generator.teacher_forcing:

                    logging.info("doiong the teacher forcing begin!")
                    y_ground, y_ground_mask = deal_generated_samples_to_maxlen(
                        y_sample=y_ground,
                        dicts=du.dst2idx,
                        maxlen=config.discriminator.dis_max_len)

                    rewards_ground = np.ones_like(y_ground)
                    rewards_ground = rewards_ground * y_ground_mask
                    loss = generator.generate_step_and_update(
                        x, y_ground, rewards_ground)
                    print("the teacher forcing reward is ", rewards_ground)
                    print("the teacher forcing loss is ", loss)

            generator.saver.save(generator.sess, config.generator.modelFile)

            logging.info("prepare the gan_dis_data begin")
            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.dis_source_data,
                gan_dis_positive_data=config.discriminator.dis_positive_data,
                num=config.generate_num,
                reshuf=True)

            logging.info("generate and the save in to %s." %
                         config.discriminator.dis_negative_data)
            generator.generate_and_save(
                data_util=du,
                infile=config.discriminator.dis_source_data,
                generate_batch=config.discriminator.dis_batch_size,
                outfile=config.discriminator.dis_negative_data)

            logging.info("prepare %d gan_dis_data done!" % data_num)
            logging.info("finetuen the discriminator begin")

            discriminator.train(
                max_epoch=config.gan_dis_iter_num,
                positive_data=config.discriminator.dis_positive_data,
                negative_data=config.discriminator.dis_negative_data,
                source_data=config.discriminator.dis_source_data)
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("finetune the discrimiantor done!")

        logging.info('reinforcement training done!')
Пример #3
0
def gan_train(config):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    default_graph = tf.Graph()
    with default_graph.as_default():
        sess = tf.Session(config=sess_config, graph=default_graph)

        logger = logging.getLogger('')
        du = DataUtil(config=config)
        du.load_vocab(src_vocab=config.generator.src_vocab,
                      dst_vocab=config.generator.dst_vocab,
                      src_vocab_size=config.src_vocab_size,
                      dst_vocab_size=config.dst_vocab_size)

        generator = Model(config=config, graph=default_graph, sess=sess)
        generator.build_train_model()
        generator.build_generate(max_len=config.generator.max_length,
                                 generate_devices=config.generator.devices,
                                 optimizer=config.generator.optimizer)
        generator.build_rollout_generate(
            max_len=config.generator.max_length,
            roll_generate_devices=config.generator.devices)
        generator.init_and_restore(modelFile=config.generator.modelFile)

        #这个变量是什么?!filters?
        dis_filter_sizes = [
            i for i in range(1, config.discriminator.dis_max_len, 4)
        ]
        dis_num_filters = [
            (100 + i * 10)
            for i in range(1, config.discriminator.dis_max_len, 4)
        ]

        discriminator = DisCNN(
            sess=sess,
            max_len=config.discriminator.dis_max_len,
            num_classes=2,
            vocab_size=config.dst_vocab_size,
            vocab_size_s=config.src_vocab_size,
            batch_size=config.discriminator.dis_batch_size,
            dim_word=config.discriminator.dis_dim_word,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            source_dict=config.discriminator.dis_src_vocab,
            target_dict=config.discriminator.dis_dst_vocab,
            gpu_device=config.discriminator.dis_gpu_devices,
            positive_data=config.discriminator.dis_positive_data,
            negative_data=config.discriminator.dis_negative_data,
            source_data=config.discriminator.dis_source_data,
            dev_positive_data=config.discriminator.dis_dev_positive_data,
            dev_negative_data=config.discriminator.dis_dev_negative_data,
            dev_source_data=config.discriminator.dis_dev_source_data,
            max_epoches=config.discriminator.dis_max_epoches,
            dispFreq=config.discriminator.dis_dispFreq,
            saveFreq=config.discriminator.dis_saveFreq,
            saveto=config.discriminator.dis_saveto,
            reload=config.discriminator.dis_reload,
            clip_c=config.discriminator.dis_clip_c,
            optimizer=config.discriminator.dis_optimizer,
            reshuffle=config.discriminator.dis_reshuffle,
            scope=config.discriminator.dis_scope)

        batch_train_iter = du.get_training_batches(
            set_train_src_path=config.generator.src_path,
            set_train_dst_path=config.generator.dst_path,
            set_batch_size=config.generator.batch_size,
            set_max_length=config.generator.max_length)

        max_SARI_results = 0.32  #!!
        max_BLEU_results = 0.77  #!!

        def evaluation_result(generator, config):
            nonlocal max_SARI_results
            nonlocal max_BLEU_results

            # 在test dataset 上开始验证
            logging.info("Max_SARI_results: {}".format(max_SARI_results))
            logging.info("Max_BLEU_results: {}".format(max_BLEU_results))
            output_t = "prepare_data/test.8turkers.clean.out.gan"

            # Beam Search 8.turkers dataset
            evaluator = Evaluator(config=config, out_file=output_t)
            #logging.info("Evaluate on BLEU and SARI")
            SARI_results, BLEU_results = evaluator.translate()
            logging.info(" Current_SARI is {} \n Current_BLEU is {}".format(
                SARI_results, BLEU_results))

            if SARI_results >= max_SARI_results or BLEU_results >= max_BLEU_results:
                if SARI_results >= max_SARI_results:
                    max_SARI_results = SARI_results
                    logging.info("SARI Update Successfully !!!")
                if BLEU_results >= max_BLEU_results:
                    logging.info("BLEU Update Successfully !!!")
                    max_BLEU_results = BLEU_results
                return True
            else:
                return False

        for epoch in range(1, config.gan_iter_num + 1):  #10000
            for gen_iter in range(config.gan_gen_iter_num):  #1
                batch_train = next(batch_train_iter)
                x, y_ground = batch_train[0], batch_train[1]
                y_sample = generator.generate_step(x)
                logging.info("1. Policy Gradient Training !!!")
                y_sample_dealed, y_sample_mask = deal_generated_samples(
                    y_sample, du.dst2idx)  #将y_sample数字矩阵用0补齐长度
                x_to_maxlen = extend_sentence_to_maxlen(
                    x, config.generator.max_length)  #将x数字矩阵用0补齐长度
                x_str = du.indices_to_words(x, 'dst')
                ground_str = du.indices_to_words(y_ground, 'dst')
                sample_str = du.indices_to_words(y_sample, 'dst')

                # Rewards = D(Discriminator) + Q(BLEU socres)
                logging.info("2. Calculate the Reward !!!")
                rewards = generator.get_reward(
                    x=x,
                    x_to_maxlen=x_to_maxlen,
                    y_sample=y_sample_dealed,
                    y_sample_mask=y_sample_mask,
                    rollnum=config.rollnum,
                    disc=discriminator,
                    max_len=config.discriminator.dis_max_len,
                    bias_num=config.bias_num,
                    data_util=du)
                # Police Gradient 更新Generator模型
                logging.info("3. Update the Generator Model !!!")
                loss = generator.generate_step_and_update(
                    x, y_sample_dealed, rewards)
                #logging.info("The reward is ",rewards)
                #logging.info("The loss is ",loss)

                #update_or_not_update=evaluation_result(generator,config)
                #if update_or_not_update:
                # 保存Generator模型
                logging.info("4. Save the Generator model into %s" %
                             config.generator.modelFile)
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

                if config.generator.teacher_forcing:

                    logging.info("5. Doing the Teacher Forcing begin!")
                    y_ground, y_ground_mask = deal_generated_samples_to_maxlen(
                        y_sample=y_ground,
                        dicts=du.dst2idx,
                        maxlen=config.discriminator.dis_max_len)

                    rewards_ground = np.ones_like(y_ground)
                    rewards_ground = rewards_ground * y_ground_mask
                    loss = generator.generate_step_and_update(
                        x, y_ground, rewards_ground)
                    #logging.info("The teacher forcing reward is ", rewards_ground)
                    #logging.info("The teacher forcing loss is ", loss)

            logging.info("5. Evaluation SARI and BLEU")
            update_or_not_update = evaluation_result(generator, config)
            if update_or_not_update:
                #保存Generator模型
                generator.saver.save(generator.sess,
                                     config.generator.modelFile)

            data_num = prepare_gan_dis_data(
                train_data_source=config.generator.src_path,
                train_data_target=config.generator.dst_path,
                gan_dis_source_data=config.discriminator.dis_source_data,
                gan_dis_positive_data=config.discriminator.dis_positive_data,
                num=config.generate_num,
                reshuf=True)

            logging.info("8.Generate  Negative Dataset for Discriminator !!!")
            # 生成negative数据集
            generator.generate_and_save(
                data_util=du,
                infile=config.discriminator.dis_source_data,
                generate_batch=config.discriminator.dis_batch_size,
                outfile=config.discriminator.dis_negative_data)

            logging.info("9.Negative Dataset was save in to %s." %
                         config.discriminator.dis_negative_data)
            logging.info("10.Finetuen the discriminator begin !!!!!")

            discriminator.train(
                max_epoch=config.gan_dis_iter_num,
                positive_data=config.discriminator.dis_positive_data,
                negative_data=config.discriminator.dis_negative_data,
                source_data=config.discriminator.dis_source_data)
            discriminator.saver.save(discriminator.sess, discriminator.saveto)
            logging.info("11.Finetune the discrimiantor done !!!!")

        logging.info('Reinforcement training done!')