def gan_train(config): sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True default_graph = tf.Graph() with default_graph.as_default(): sess = tf.Session(config=sess_config, graph=default_graph) logger = logging.getLogger('') du = DataUtil(config=config) du.load_vocab(src_vocab=config.generator.src_vocab, dst_vocab=config.generator.dst_vocab, src_vocab_size=config.src_vocab_size, dst_vocab_size=config.dst_vocab_size) generator = Model(config=config, graph=default_graph, sess=sess) generator.build_train_model() generator.build_generate(max_len=config.generator.max_length, generate_devices=config.generator.devices, optimizer=config.generator.optimizer) generator.build_rollout_generate( max_len=config.generator.max_length, roll_generate_devices=config.generator.devices) generator.init_and_restore(modelFile=config.generator.modelFile) dis_filter_sizes = [ i for i in range(1, config.discriminator.dis_max_len, 4) ] dis_num_filters = [ (100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4) ] discriminator = DisCNN( sess=sess, max_len=config.discriminator.dis_max_len, num_classes=2, vocab_size=config.dst_vocab_size, vocab_size_s=config.src_vocab_size, batch_size=config.discriminator.dis_batch_size, dim_word=config.discriminator.dis_dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=config.discriminator.dis_src_vocab, target_dict=config.discriminator.dis_dst_vocab, gpu_device=config.discriminator.dis_gpu_devices, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data, dev_positive_data=config.discriminator.dis_dev_positive_data, dev_negative_data=config.discriminator.dis_dev_negative_data, dev_source_data=config.discriminator.dis_dev_source_data, max_epoches=config.discriminator.dis_max_epoches, dispFreq=config.discriminator.dis_dispFreq, saveFreq=config.discriminator.dis_saveFreq, saveto=config.discriminator.dis_saveto, reload=config.discriminator.dis_reload, clip_c=config.discriminator.dis_clip_c, optimizer=config.discriminator.dis_optimizer, reshuffle=config.discriminator.dis_reshuffle, scope=config.discriminator.dis_scope) batch_iter = du.get_training_batches( set_train_src_path=config.generator.src_path, set_train_dst_path=config.generator.dst_path, set_batch_size=config.generator.batch_size, set_max_length=config.generator.max_length) for epoch in range(1, config.gan_iter_num + 1): for gen_iter in range(config.gan_gen_iter_num): batch = next(batch_iter) x, y_ground = batch[0], batch[1] y_sample = generator.generate_step(x) logging.info("generate the samples") y_sample_dealed, y_sample_mask = deal_generated_samples( y_sample, du.dst2idx) # #### for debug ##print('the sample is ') ##sample_str=du.indices_to_words(y_sample_dealed, 'dst') ##print(sample_str) # x_to_maxlen = extend_sentence_to_maxlen( x, config.generator.max_length) logging.info("calculate the reward") rewards = generator.get_reward( x=x, x_to_maxlen=x_to_maxlen, y_sample=y_sample_dealed, y_sample_mask=y_sample_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du) loss = generator.generate_step_and_update( x, y_sample_dealed, rewards) print("the reward is ", rewards) print("the loss is ", loss) logging.info("save the model into %s" % config.generator.modelFile) generator.saver.save(generator.sess, config.generator.modelFile) if config.generator.teacher_forcing: logging.info("doiong the teacher forcing begin!") y_ground, y_ground_mask = deal_generated_samples_to_maxlen( y_sample=y_ground, dicts=du.dst2idx, maxlen=config.discriminator.dis_max_len) rewards_ground = np.ones_like(y_ground) rewards_ground = rewards_ground * y_ground_mask loss = generator.generate_step_and_update( x, y_ground, rewards_ground) print("the teacher forcing reward is ", rewards_ground) print("the teacher forcing loss is ", loss) generator.saver.save(generator.sess, config.generator.modelFile) logging.info("prepare the gan_dis_data begin") data_num = prepare_gan_dis_data( train_data_source=config.generator.src_path, train_data_target=config.generator.dst_path, gan_dis_source_data=config.discriminator.dis_source_data, gan_dis_positive_data=config.discriminator.dis_positive_data, num=config.generate_num, reshuf=True) logging.info("generate and the save in to %s." % config.discriminator.dis_negative_data) generator.generate_and_save( data_util=du, infile=config.discriminator.dis_source_data, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.dis_negative_data) logging.info("prepare %d gan_dis_data done!" % data_num) logging.info("finetuen the discriminator begin") discriminator.train( max_epoch=config.gan_dis_iter_num, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data) discriminator.saver.save(discriminator.sess, discriminator.saveto) logging.info("finetune the discrimiantor done!") logging.info('reinforcement training done!')
def main(argv): ##################################### create the session ################################################################## config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 1.0 config.allow_soft_placement = True is_gan_train = FLAGS.is_gan_train is_decode = FLAGS.is_decode is_generator_train = FLAGS.is_generator_train is_discriminator_train = FLAGS.is_discriminator_train ###################################### pretraining the generator ###################################################################### batch_size = FLAGS.batch_size source_dict = FLAGS.source_dict target_dict = FLAGS.target_dict train_data_source = FLAGS.train_data_source train_data_target = FLAGS.train_data_target n_words_src = FLAGS.n_words_src n_words_trg = FLAGS.n_words_trg gpu_device = FLAGS.gpu_device dim_word = FLAGS.dim_word dim = FLAGS.dim max_len = FLAGS.max_len optimizer = FLAGS.optimizer precision = FLAGS.precision clip_c = FLAGS.clip_c max_epoches = FLAGS.max_epoches reshuffle = FLAGS.reshuffle saveto = FLAGS.saveto saveFreq = FLAGS.saveFreq dispFreq = FLAGS.dispFreq sampleFreq = FLAGS.sampleFreq gen_reload = FLAGS.gen_reload gan_gen_batch_size = FLAGS.gan_gen_batch_size sess = tf.Session(config=config) with tf.variable_scope('generate'): generator = GenNmt(sess=sess, batch_size=batch_size, source_dict=source_dict, target_dict=target_dict, train_data_source=train_data_source, train_data_target=train_data_target, n_words_src=n_words_src, n_words_trg=n_words_trg, gpu_device=gpu_device, dim_word=dim_word, dim=dim, max_len=max_len, clip_c=clip_c, max_epoches=max_epoches, reshuffle=reshuffle, saveto=saveto, saveFreq=saveFreq, dispFreq=dispFreq, sampleFreq=sampleFreq, optimizer=optimizer, precision=precision, gen_reload=gen_reload) if is_decode: decode_file = FLAGS.decode_file decode_result_file = FLAGS.decode_result_file decode_gpu = FLAGS.decode_gpu decode_is_print = FLAGS.decode_is_print #print('decoding the file %s on %s' % (decode_file, decode_gpu)) generator.gen_sample(decode_file, decode_result_file, 10, is_print=decode_is_print, gpu_device=decode_gpu) return 0 elif is_generator_train: print('train the model and build the generate') generator.build_train_model() generator.gen_train() generator.build_generate(maxlen=max_len, generate_batch=gan_gen_batch_size, optimizer='adam') generator.rollout_generate(generate_batch=gan_gen_batch_size) print('done') else: print('build the generate without training') generator.build_train_model() generator.build_generate(maxlen=max_len, generate_batch=gan_gen_batch_size, optimizer='adam') generator.rollout_generate(generate_batch=gan_gen_batch_size) generator.init_and_reload() #print('building testing ') #generator.build_test() #print('done') ## #################################################### pretraining the discriminator ################################################################## if is_discriminator_train or is_gan_train: dis_max_epoches = FLAGS.dis_epoches dis_dispFreq = FLAGS.dis_dispFreq dis_saveFreq = FLAGS.dis_saveFreq dis_devFreq = FLAGS.dis_devFreq dis_batch_size = FLAGS.dis_batch_size dis_saveto = FLAGS.dis_saveto dis_reshuffle = FLAGS.dis_reshuffle dis_gpu_device = FLAGS.dis_gpu_device dis_max_len = FLAGS.dis_max_len positive_data = FLAGS.dis_positive_data negative_data = FLAGS.dis_negative_data dis_dev_positive_data = FLAGS.dis_dev_positive_data dis_dev_negative_data = FLAGS.dis_dev_negative_data dis_reload = FLAGS.dis_reload dis_filter_sizes = [i for i in range(1, dis_max_len, 4)] dis_num_filters = [(100 + i * 10) for i in range(1, dis_max_len, 4)] discriminator = DisCNN(sess=sess, max_len=dis_max_len, num_classes=2, vocab_size=n_words_trg, batch_size=dis_batch_size, dim_word=dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=source_dict, target_dict=target_dict, gpu_device=dis_gpu_device, positive_data=positive_data, negative_data=negative_data, dev_positive_data=dis_dev_positive_data, dev_negative_data=dis_dev_negative_data, max_epoches=dis_max_epoches, dispFreq=dis_dispFreq, saveFreq=dis_saveFreq, devFreq=dis_devFreq, saveto=dis_saveto, reload=dis_reload, clip_c=clip_c, optimizer=optimizer, reshuffle=dis_reshuffle, scope='discnn') if is_discriminator_train: print('train the discriminator') discriminator.train() print('done') else: print('building the discriminator without training done') print('done') # ####################################### Start Reinforcement Training ####################################################################################### if is_gan_train: gan_total_iter_num = FLAGS.gan_total_iter_num gan_gen_iter_num = FLAGS.gan_gen_iter_num gan_dis_iter_num = FLAGS.gan_dis_iter_num gan_gen_reshuffle = FLAGS.gan_gen_reshuffle gan_gen_source_data = FLAGS.gan_gen_source_data gan_dis_source_data = FLAGS.gan_dis_source_data gan_dis_positive_data = FLAGS.gan_dis_positive_data gan_dis_negative_data = FLAGS.gan_dis_negative_data gan_dis_reshuffle = FLAGS.gan_dis_reshuffle gan_dis_batch_size = FLAGS.gan_dis_batch_size gan_dispFreq = FLAGS.gan_dispFreq gan_saveFreq = FLAGS.gan_saveFreq roll_num = FLAGS.rollnum generate_num = FLAGS.generate_num print('reinforcement training begin...') for gan_iter in range(gan_total_iter_num): print('reinforcement training for %d epoch' % gan_iter) gen_train_it = gen_train_iter(gan_gen_source_data, gan_gen_reshuffle, generator.dictionaries[0], n_words_src, gan_gen_batch_size, max_len) print('finetune the generator begin...') for gen_iter in range(gan_gen_iter_num): x, _ = next(gen_train_it) x, x_mask = prepare_multiple_sentence(x, maxlen=max_len) y_sample_out = generator.generate_step(x, x_mask) #print(y_sample_out) #tmp_str=print_string('y', y_sample_out[0], generator.worddicts_r) #tmp_str_2=print_string('y', y_sample_out[1], generator.worddicts_r) #print tmp_str #print tmp_str_2 y_input, y_input_mask = deal_generated_y_sentence( y_sample_out, generator.worddicts, precision=precision) rewards = generator.get_reward(x, x_mask, y_input, y_input_mask, roll_num, discriminator) print('the reward is ', rewards) loss = generator.generate_step_and_update( x, x_mask, y_input, rewards) if gen_iter % gan_dispFreq == 0: print('the %d iter, seen %d examples, loss is %f ' % (gen_iter, ((gan_iter) * gan_gen_iter_num + gen_iter + 1), loss)) if gen_iter % gan_saveFreq == 0: generator.saver.save(generator.sess, generator.saveto) print('save the parameters when seen %d examples ' % ((gan_iter) * gan_gen_iter_num + gan_iter + 1)) generator.saver.save(generator.sess, generator.saveto) print('finetune the generator done!') #print('self testing') #generator.self_test(gan_dis_source_data, gan_dis_negative_data) #print('self testing done!') print('prepare the gan_dis_data begin ') data_num = prepare_gan_dis_data(train_data_source, train_data_target, gan_dis_source_data, gan_dis_positive_data, num=generate_num, reshuf=True) print( 'prepare the gan_dis_data done, the num of the gan_dis_data is %d' % data_num) print('generator generate and save to %s' % gan_dis_negative_data) generator.generate_and_save(gan_dis_source_data, gan_dis_negative_data, generate_batch=gan_gen_batch_size) print('done!') print('finetune the discriminator begin...') discriminator.train(max_epoch=gan_dis_iter_num, positive_data=gan_dis_positive_data, negative_data=gan_dis_negative_data) print('finetune the discriminator done!') print('reinforcement training done')
def gan_train(config): sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True default_graph=tf.Graph() with default_graph.as_default(): sess = tf.Session(config=sess_config, graph=default_graph) logger = logging.getLogger('') du = DataUtil(config=config) du.load_vocab(src_vocab=config.generator.src_vocab, dst_vocab=config.generator.dst_vocab, src_vocab_size=config.src_vocab_size_a, dst_vocab_size=config.src_vocab_size_b) generator = Model(config=config, graph=default_graph, sess=sess) generator.build_variational_train_model() generator.init_and_restore(modelFile=config.generator.modelFile) dis_filter_sizes = [i for i in range(1, config.discriminator.dis_max_len, 4)] dis_num_filters = [(100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4)] discriminator = text_DisCNN( sess=sess, max_len=config.discriminator.dis_max_len, num_classes=3, vocab_size_s=config.dst_vocab_size_a, vocab_size_t=config.dst_vocab_size_b, batch_size=config.discriminator.dis_batch_size, dim_word=config.discriminator.dis_dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=config.discriminator.dis_src_vocab, target_dict=config.discriminator.dis_dst_vocab, gpu_device=config.discriminator.dis_gpu_devices, s_domain_data=config.discriminator.s_domain_data, t_domain_data=config.discriminator.t_domain_data, s_domain_generated_data=config.discriminator.s_domain_generated_data, t_domain_generated_data=config.discriminator.t_domain_generated_data, dev_s_domain_data=config.discriminator.dev_s_domain_data, dev_t_domain_data=config.discriminator.dev_t_domain_data, dev_s_domain_generated_data=config.discriminator.dev_s_domain_generated_data, dev_t_domain_generated_data=config.discriminator.dev_t_domain_generated_data, max_epoches=config.discriminator.dis_max_epoches, dispFreq=config.discriminator.dis_dispFreq, saveFreq=config.discriminator.dis_saveFreq, saveto=config.discriminator.dis_saveto, reload=config.discriminator.dis_reload, clip_c=config.discriminator.dis_clip_c, optimizer=config.discriminator.dis_optimizer, reshuffle=config.discriminator.dis_reshuffle, scope=config.discriminator.dis_scope ) batch_iter = du.get_training_batches( set_train_src_path=config.generator.src_path, set_train_dst_path=config.generator.dst_path, set_batch_size=config.generator.batch_size, set_max_length=config.generator.max_length ) for epoch in range(1, config.gan_iter_num + 1): for gen_iter in range(config.gan_gen_iter_num): batch = next(batch_iter) x, y = batch[0], batch[1] generate_ab, generate_ba = generator.generate_step(x, y) logging.info("generate the samples") generate_ab_dealed, generate_ab_mask = deal_generated_samples(generate_ab, du.dst2idx) generate_ba_dealed, generate_ba_mask = deal_generated_samples(generate_ba, du.src2idx) ## for debug #print('the generate_ba_dealed is ') #sample_str=du.indices_to_words(generate_ba_dealed, 'src') #print(sample_str) #print('the generate_ab_dealed is ') #sample_str=du.indices_to_words(generate_ab_dealed, 'dst') #print(sample_str) x_to_maxlen = extend_sentence_to_maxlen(x) y_to_maxlen = extend_sentence_to_maxlen(y) logging.info("calculate the reward") rewards_ab = generator.get_reward(x=x, x_to_maxlen=x_to_maxlen, y_sample=generate_ab_dealed, y_sample_mask=generate_ab_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du, direction='ab') rewards_ba = generator.get_reward(x=y, x_to_maxlen=y_to_maxlen, y_sample=generate_ba_dealed, y_sample_mask=generate_ba_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du, direction='ba') loss_ab = generator.generate_step_and_update(x, generate_ab_dealed, rewards_ab) loss_ba = generator.generate_step_and_update(y, generate_ba_dealed, rewards_ba) print("the reward for ab and ba is ", rewards_ab, rewards_ba) print("the loss is for ab and ba is", loss_ab, loss_ba) logging.info("save the model into %s" % config.generator.modelFile) generator.saver.save(generator.sess, config.generator.modelFile) #### modified to here, next starts from here logging.info("prepare the gan_dis_data begin") data_num = prepare_gan_dis_data( train_data_source=config.generator.src_path, train_data_target=config.generator.dst_path, gan_dis_source_data=config.discriminator.s_domain_data, gan_dis_positive_data=config.discriminator.t_domain_data, num=config.generate_num, reshuf=True ) s_domain_data_half = config.discriminator.s_domain_data+'.half' t_domain_data_half = config.discriminator.t_domain_data+'.half' os.popen('head -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.s_domain_data + ' > ' + s_domain_data_half) os.popen('tail -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.t_domain_data + ' > ' + t_domain_data_half) logging.info("generate and the save t_domain_generated_data in to %s." %config.discriminator.s_domain_generated_data) generator.generate_and_save(data_util=du, infile=s_domain_data_half, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.t_domain_generated_data, direction='ab' ) logging.info("generate and the save s_domain_generated_data in to %s." %config.discriminator.t_domain_generated_data) generator.generate_and_save(data_util=du, infile=t_domain_data_half, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.s_domain_generated_data, direction='ba' ) logging.info("prepare %d gan_dis_data done!" %data_num) logging.info("finetuen the discriminator begin") discriminator.train(max_epoch=config.gan_dis_iter_num, s_domain_data=config.discriminator.s_domain_data, t_domain_data=config.discriminator.t_domain_data, s_domain_generated_data=config.discriminator.s_domain_generated_data, t_domain_generated_data=config.discriminator.t_domain_generated_data ) discriminator.saver.save(discriminator.sess, discriminator.saveto) logging.info("finetune the discrimiantor done!") logging.info('reinforcement training done!')
def gan_train(config): sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True default_graph = tf.Graph() with default_graph.as_default(): sess = tf.Session(config=sess_config, graph=default_graph) logger = logging.getLogger('') du = DataUtil(config=config) du.load_vocab(src_vocab=config.generator.src_vocab, dst_vocab=config.generator.dst_vocab, src_vocab_size=config.src_vocab_size, dst_vocab_size=config.dst_vocab_size) generator = Model(config=config, graph=default_graph, sess=sess) generator.build_train_model() generator.build_generate(max_len=config.generator.max_length, generate_devices=config.generator.devices, optimizer=config.generator.optimizer) generator.build_rollout_generate( max_len=config.generator.max_length, roll_generate_devices=config.generator.devices) generator.init_and_restore(modelFile=config.generator.modelFile) #这个变量是什么?!filters? dis_filter_sizes = [ i for i in range(1, config.discriminator.dis_max_len, 4) ] dis_num_filters = [ (100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4) ] discriminator = DisCNN( sess=sess, max_len=config.discriminator.dis_max_len, num_classes=2, vocab_size=config.dst_vocab_size, vocab_size_s=config.src_vocab_size, batch_size=config.discriminator.dis_batch_size, dim_word=config.discriminator.dis_dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=config.discriminator.dis_src_vocab, target_dict=config.discriminator.dis_dst_vocab, gpu_device=config.discriminator.dis_gpu_devices, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data, dev_positive_data=config.discriminator.dis_dev_positive_data, dev_negative_data=config.discriminator.dis_dev_negative_data, dev_source_data=config.discriminator.dis_dev_source_data, max_epoches=config.discriminator.dis_max_epoches, dispFreq=config.discriminator.dis_dispFreq, saveFreq=config.discriminator.dis_saveFreq, saveto=config.discriminator.dis_saveto, reload=config.discriminator.dis_reload, clip_c=config.discriminator.dis_clip_c, optimizer=config.discriminator.dis_optimizer, reshuffle=config.discriminator.dis_reshuffle, scope=config.discriminator.dis_scope) batch_train_iter = du.get_training_batches( set_train_src_path=config.generator.src_path, set_train_dst_path=config.generator.dst_path, set_batch_size=config.generator.batch_size, set_max_length=config.generator.max_length) max_SARI_results = 0.32 #!! max_BLEU_results = 0.77 #!! def evaluation_result(generator, config): nonlocal max_SARI_results nonlocal max_BLEU_results # 在test dataset 上开始验证 logging.info("Max_SARI_results: {}".format(max_SARI_results)) logging.info("Max_BLEU_results: {}".format(max_BLEU_results)) output_t = "prepare_data/test.8turkers.clean.out.gan" # Beam Search 8.turkers dataset evaluator = Evaluator(config=config, out_file=output_t) #logging.info("Evaluate on BLEU and SARI") SARI_results, BLEU_results = evaluator.translate() logging.info(" Current_SARI is {} \n Current_BLEU is {}".format( SARI_results, BLEU_results)) if SARI_results >= max_SARI_results or BLEU_results >= max_BLEU_results: if SARI_results >= max_SARI_results: max_SARI_results = SARI_results logging.info("SARI Update Successfully !!!") if BLEU_results >= max_BLEU_results: logging.info("BLEU Update Successfully !!!") max_BLEU_results = BLEU_results return True else: return False for epoch in range(1, config.gan_iter_num + 1): #10000 for gen_iter in range(config.gan_gen_iter_num): #1 batch_train = next(batch_train_iter) x, y_ground = batch_train[0], batch_train[1] y_sample = generator.generate_step(x) logging.info("1. Policy Gradient Training !!!") y_sample_dealed, y_sample_mask = deal_generated_samples( y_sample, du.dst2idx) #将y_sample数字矩阵用0补齐长度 x_to_maxlen = extend_sentence_to_maxlen( x, config.generator.max_length) #将x数字矩阵用0补齐长度 x_str = du.indices_to_words(x, 'dst') ground_str = du.indices_to_words(y_ground, 'dst') sample_str = du.indices_to_words(y_sample, 'dst') # Rewards = D(Discriminator) + Q(BLEU socres) logging.info("2. Calculate the Reward !!!") rewards = generator.get_reward( x=x, x_to_maxlen=x_to_maxlen, y_sample=y_sample_dealed, y_sample_mask=y_sample_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du) # Police Gradient 更新Generator模型 logging.info("3. Update the Generator Model !!!") loss = generator.generate_step_and_update( x, y_sample_dealed, rewards) #logging.info("The reward is ",rewards) #logging.info("The loss is ",loss) #update_or_not_update=evaluation_result(generator,config) #if update_or_not_update: # 保存Generator模型 logging.info("4. Save the Generator model into %s" % config.generator.modelFile) generator.saver.save(generator.sess, config.generator.modelFile) if config.generator.teacher_forcing: logging.info("5. Doing the Teacher Forcing begin!") y_ground, y_ground_mask = deal_generated_samples_to_maxlen( y_sample=y_ground, dicts=du.dst2idx, maxlen=config.discriminator.dis_max_len) rewards_ground = np.ones_like(y_ground) rewards_ground = rewards_ground * y_ground_mask loss = generator.generate_step_and_update( x, y_ground, rewards_ground) #logging.info("The teacher forcing reward is ", rewards_ground) #logging.info("The teacher forcing loss is ", loss) logging.info("5. Evaluation SARI and BLEU") update_or_not_update = evaluation_result(generator, config) if update_or_not_update: #保存Generator模型 generator.saver.save(generator.sess, config.generator.modelFile) data_num = prepare_gan_dis_data( train_data_source=config.generator.src_path, train_data_target=config.generator.dst_path, gan_dis_source_data=config.discriminator.dis_source_data, gan_dis_positive_data=config.discriminator.dis_positive_data, num=config.generate_num, reshuf=True) logging.info("8.Generate Negative Dataset for Discriminator !!!") # 生成negative数据集 generator.generate_and_save( data_util=du, infile=config.discriminator.dis_source_data, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.dis_negative_data) logging.info("9.Negative Dataset was save in to %s." % config.discriminator.dis_negative_data) logging.info("10.Finetuen the discriminator begin !!!!!") discriminator.train( max_epoch=config.gan_dis_iter_num, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data) discriminator.saver.save(discriminator.sess, discriminator.saveto) logging.info("11.Finetune the discrimiantor done !!!!") logging.info('Reinforcement training done!')