def generate_and_save(self, data_util, infile, generate_batch, outfile): outfile = codecs.open(outfile, 'w', 'utf-8') for batch in data_util.get_test_batches(infile, generate_batch): feed = {self.generate_x: batch} out_generate = self.sess.run(self.generate_sample, feed_dict=feed) out_generate_dealed, _ = deal_generated_samples( out_generate, data_util.dst2idx) y_strs = data_util.indices_to_words_del_pad( out_generate_dealed, 'dst') for y_str in y_strs: outfile.write(y_str + '\n') outfile.close()
def gan_train(config): sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True default_graph = tf.Graph() with default_graph.as_default(): sess = tf.Session(config=sess_config, graph=default_graph) logger = logging.getLogger('') du = DataUtil(config=config) du.load_vocab(src_vocab=config.generator.src_vocab, dst_vocab=config.generator.dst_vocab, src_vocab_size=config.src_vocab_size, dst_vocab_size=config.dst_vocab_size) generator = Model(config=config, graph=default_graph, sess=sess) generator.build_train_model() generator.build_generate(max_len=config.generator.max_length, generate_devices=config.generator.devices, optimizer=config.generator.optimizer) generator.build_rollout_generate( max_len=config.generator.max_length, roll_generate_devices=config.generator.devices) generator.init_and_restore(modelFile=config.generator.modelFile) dis_filter_sizes = [ i for i in range(1, config.discriminator.dis_max_len, 4) ] dis_num_filters = [ (100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4) ] discriminator = DisCNN( sess=sess, max_len=config.discriminator.dis_max_len, num_classes=2, vocab_size=config.dst_vocab_size, vocab_size_s=config.src_vocab_size, batch_size=config.discriminator.dis_batch_size, dim_word=config.discriminator.dis_dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=config.discriminator.dis_src_vocab, target_dict=config.discriminator.dis_dst_vocab, gpu_device=config.discriminator.dis_gpu_devices, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data, dev_positive_data=config.discriminator.dis_dev_positive_data, dev_negative_data=config.discriminator.dis_dev_negative_data, dev_source_data=config.discriminator.dis_dev_source_data, max_epoches=config.discriminator.dis_max_epoches, dispFreq=config.discriminator.dis_dispFreq, saveFreq=config.discriminator.dis_saveFreq, saveto=config.discriminator.dis_saveto, reload=config.discriminator.dis_reload, clip_c=config.discriminator.dis_clip_c, optimizer=config.discriminator.dis_optimizer, reshuffle=config.discriminator.dis_reshuffle, scope=config.discriminator.dis_scope) batch_iter = du.get_training_batches( set_train_src_path=config.generator.src_path, set_train_dst_path=config.generator.dst_path, set_batch_size=config.generator.batch_size, set_max_length=config.generator.max_length) for epoch in range(1, config.gan_iter_num + 1): for gen_iter in range(config.gan_gen_iter_num): batch = next(batch_iter) x, y_ground = batch[0], batch[1] y_sample = generator.generate_step(x) logging.info("generate the samples") y_sample_dealed, y_sample_mask = deal_generated_samples( y_sample, du.dst2idx) # #### for debug ##print('the sample is ') ##sample_str=du.indices_to_words(y_sample_dealed, 'dst') ##print(sample_str) # x_to_maxlen = extend_sentence_to_maxlen( x, config.generator.max_length) logging.info("calculate the reward") rewards = generator.get_reward( x=x, x_to_maxlen=x_to_maxlen, y_sample=y_sample_dealed, y_sample_mask=y_sample_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du) loss = generator.generate_step_and_update( x, y_sample_dealed, rewards) print("the reward is ", rewards) print("the loss is ", loss) logging.info("save the model into %s" % config.generator.modelFile) generator.saver.save(generator.sess, config.generator.modelFile) if config.generator.teacher_forcing: logging.info("doiong the teacher forcing begin!") y_ground, y_ground_mask = deal_generated_samples_to_maxlen( y_sample=y_ground, dicts=du.dst2idx, maxlen=config.discriminator.dis_max_len) rewards_ground = np.ones_like(y_ground) rewards_ground = rewards_ground * y_ground_mask loss = generator.generate_step_and_update( x, y_ground, rewards_ground) print("the teacher forcing reward is ", rewards_ground) print("the teacher forcing loss is ", loss) generator.saver.save(generator.sess, config.generator.modelFile) logging.info("prepare the gan_dis_data begin") data_num = prepare_gan_dis_data( train_data_source=config.generator.src_path, train_data_target=config.generator.dst_path, gan_dis_source_data=config.discriminator.dis_source_data, gan_dis_positive_data=config.discriminator.dis_positive_data, num=config.generate_num, reshuf=True) logging.info("generate and the save in to %s." % config.discriminator.dis_negative_data) generator.generate_and_save( data_util=du, infile=config.discriminator.dis_source_data, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.dis_negative_data) logging.info("prepare %d gan_dis_data done!" % data_num) logging.info("finetuen the discriminator begin") discriminator.train( max_epoch=config.gan_dis_iter_num, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data) discriminator.saver.save(discriminator.sess, discriminator.saveto) logging.info("finetune the discrimiantor done!") logging.info('reinforcement training done!')
def train(config): logger = logging.getLogger('') """Train a model with a config file.""" du = DataUtil(config=config) du.load_vocab(src_vocab=config.src_vocab, dst_vocab=config.dst_vocab, src_vocab_size=config.src_vocab_size_a, dst_vocab_size=config.src_vocab_size_b) model = Model(config=config) model.build_variational_train_model() sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True with model.graph.as_default(): saver = tf.train.Saver(var_list=tf.global_variables()) summary_writer = tf.summary.FileWriter(config.train.logdir, graph=model.graph) # saver_partial = tf.train.Saver(var_list=[v for v in tf.trainable_variables() if 'Adam' not in v.name]) with tf.Session(config=sess_config) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) reload_pretrain_embedding=False try: # saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) # print('Restore partial model from %s.' % config.train.logdir) saver.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) except: logger.info('Failed to reload model.') reload_pretrain_embedding=True if reload_pretrain_embedding: logger.info('reload the pretrained embeddings for the encoders') src_pretrained_embedding={} dst_pretrained_embedding={} try: for l in codecs.open(config.train.src_pretrain_wordemb_path, 'r', 'utf-8'): word_emb=l.strip().split() # print(word_emb) if len(word_emb)== config.hidden_units + 1: word, emb = word_emb[0], np.array(map(float, word_emb[1:])) src_pretrained_embedding[word]=emb for l in codecs.open(config.train.dst_pretrain_wordemb_path, 'r', 'utf-8'): word_emb=l.strip().split() if len(word_emb)==config.hidden_units + 1: word, emb = word_emb[0], np.array(map(float, word_emb[1:])) dst_pretrained_embedding[word]=emb logger.info('reload the word embedding done') tf.get_variable_scope().reuse_variables() src_embed_a=tf.get_variable('enc_aembedding/src_embedding/kernel') src_embed_b=tf.get_variable('enc_bembedding/src_embedding/kernel') dst_embed_a=tf.get_variable('dec_aembedding/dst_embedding/kernel') dst_embed_b=tf.get_variable('dec_bembedding/dst_embedding/kernel') count_a=0 src_value_a=sess.run(src_embed_a) dst_value_a=sess.run(dst_embed_a) # print(src_value_a) for word in src_pretrained_embedding: if word in du.src2idx: id = du.src2idx[word] # print(id) src_value_a[id] = src_pretrained_embedding[word] dst_value_a[id] = src_pretrained_embedding[word] count_a += 1 sess.run(src_embed_a.assign(src_value_a)) sess.run(dst_embed_a.assign(dst_value_a)) # print(sess.run(src_embed_a)) count_b=0 src_value_b = sess.run(src_embed_b) dst_value_b = sess.run(dst_embed_b) for word in dst_pretrained_embedding: if word in du.dst2idx: id = du.dst2idx[word] # print(id) src_value_b[id] = dst_pretrained_embedding[word] dst_value_b[id] = dst_pretrained_embedding[word] count_b += 1 sess.run(src_embed_b.assign(src_value_b)) sess.run(dst_embed_b.assign(dst_value_b)) logger.info('restore %d src_embedding and %d dst_embedding done' %(count_a, count_b)) except: logger.info('Failed to load the pretriaed embeddings') # tmp_writer = codecs.open('tmp_test', 'w', 'utf-8') for epoch in range(1, config.train.num_epochs+1): for batch in du.get_training_batches_with_buckets(): # swap the batch[0] and batch[1] accroding to whether the length of the sequence is odd or even # batch_swap=[] # swap_0 = np.arange(batch[0].shape[1]) # swap_1 = np.arange(batch[1].shape[1]) # # if len(swap_0) % 2 == 0: # swap_0[0::2]+=1 # swap_0[1::2]-=1 # else: # swap_0[0:-1:2]+=1 # swap_0[1::2]-=1 # # if len(swap_1) % 2 == 0: # swap_1[0::2]+=1 # swap_1[1::2]-=1 # else: # swap_1[0:-1:2] += 1 # swap_1[1::2] -= 1 # # batch_swap.append(batch[0].transpose()[swap_0].transpose()) # batch_swap.append(batch[1].transpose()[swap_1].transpose()) # print(batch[0]) # print(batch_swap[0]) # randomly shuffle the batch[0] and batch[1] #batch_shuffle=[] #shuffle_0_indices = np.random.permutation(np.arange(batch[0].shape[1])) #shuffle_1_indices = np.random.permutation(np.arange(batch[1].shape[1])) #batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose()) #batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose()) def get_shuffle_k_indices(length, shuffle_k): shuffle_k_indices = [] rand_start = np.random.randint(shuffle_k) indices_list_start = list(np.random.permutation(np.arange(0, rand_start))) shuffle_k_indices.extend(indices_list_start) for i in range(rand_start, length, shuffle_k): if i + shuffle_k > length: indices_list_i = list(np.random.permutation(np.arange(i, length))) else: indices_list_i = list(np.random.permutation(np.arange(i, i + shuffle_k))) shuffle_k_indices.extend(indices_list_i) return np.array(shuffle_k_indices) batch_shuffle=[] shuffle_0_indices = get_shuffle_k_indices(batch[0].shape[1], config.train.shuffle_k) shuffle_1_indices = get_shuffle_k_indices(batch[1].shape[1], config.train.shuffle_k) #print(shuffle_0_indices) batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose()) batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose()) start_time = time.time() step = sess.run(model.global_step) step, lr, gnorm_aa, loss_aa, acc_aa, _ = sess.run( [model.global_step, model.learning_rate, model.grads_norm_aa, model.loss_aa, model.acc_aa, model.train_op_aa], feed_dict={model.src_a_pl: batch_shuffle[0], model.dst_a_pl: batch[0]}) step, lr, gnorm_bb, loss_bb, acc_bb, _ = sess.run( [model.global_step, model.learning_rate, model.grads_norm_bb, model.loss_bb, model.acc_bb, model.train_op_bb], feed_dict={model.src_b_pl: batch_shuffle[1], model.dst_b_pl: batch[1]}) # this step takes too much time generate_ab, generate_ba = sess.run( [model.generate_ab, model.generate_ba], feed_dict={model.src_a_pl: batch[0], model.src_b_pl: batch[1]}) generate_ab_dealed, _ = deal_generated_samples(generate_ab, du.dst2idx) generate_ba_dealed, _ = deal_generated_samples(generate_ba, du.src2idx) #for sent in du.indices_to_words(batch[0], o='src'): # print(sent, file=tmp_writer) #for sent in du.indices_to_words(generate_ab_dealed, o='dst'): # print(sent, file=tmp_writer) step, acc_ab, loss_ab, _ = sess.run( [model.global_step, model.acc_ab, model.loss_ab, model.train_op_ab], feed_dict={model.src_a_pl:generate_ba_dealed, model.dst_b_pl: batch[1]}) step, acc_ba, loss_ba, _ = sess.run( [model.global_step, model.acc_ba, model.loss_ba, model.train_op_ba], feed_dict={model.src_b_pl:generate_ab_dealed, model.dst_a_pl: batch[0]}) if step % config.train.disp_freq == 0: logger.info('epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}' '\tacc: {5:.4f}\tcross_loss: {6:.4f}\tcross_acc: {7:.4f}\ttime: {8:.4f}' .format(epoch, step, lr, gnorm_aa, loss_aa, acc_aa, loss_ab, acc_ab, time.time() - start_time)) # Save model if step % config.train.save_freq == 0: mp = config.train.logdir + '/model_epoch_%d_step_%d' % (epoch, step) saver.save(sess, mp) logger.info('Save model in %s.' % mp) logger.info("Finish training.")
def gan_train(config): sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True default_graph=tf.Graph() with default_graph.as_default(): sess = tf.Session(config=sess_config, graph=default_graph) logger = logging.getLogger('') du = DataUtil(config=config) du.load_vocab(src_vocab=config.generator.src_vocab, dst_vocab=config.generator.dst_vocab, src_vocab_size=config.src_vocab_size_a, dst_vocab_size=config.src_vocab_size_b) generator = Model(config=config, graph=default_graph, sess=sess) generator.build_variational_train_model() generator.init_and_restore(modelFile=config.generator.modelFile) dis_filter_sizes = [i for i in range(1, config.discriminator.dis_max_len, 4)] dis_num_filters = [(100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4)] discriminator = text_DisCNN( sess=sess, max_len=config.discriminator.dis_max_len, num_classes=3, vocab_size_s=config.dst_vocab_size_a, vocab_size_t=config.dst_vocab_size_b, batch_size=config.discriminator.dis_batch_size, dim_word=config.discriminator.dis_dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=config.discriminator.dis_src_vocab, target_dict=config.discriminator.dis_dst_vocab, gpu_device=config.discriminator.dis_gpu_devices, s_domain_data=config.discriminator.s_domain_data, t_domain_data=config.discriminator.t_domain_data, s_domain_generated_data=config.discriminator.s_domain_generated_data, t_domain_generated_data=config.discriminator.t_domain_generated_data, dev_s_domain_data=config.discriminator.dev_s_domain_data, dev_t_domain_data=config.discriminator.dev_t_domain_data, dev_s_domain_generated_data=config.discriminator.dev_s_domain_generated_data, dev_t_domain_generated_data=config.discriminator.dev_t_domain_generated_data, max_epoches=config.discriminator.dis_max_epoches, dispFreq=config.discriminator.dis_dispFreq, saveFreq=config.discriminator.dis_saveFreq, saveto=config.discriminator.dis_saveto, reload=config.discriminator.dis_reload, clip_c=config.discriminator.dis_clip_c, optimizer=config.discriminator.dis_optimizer, reshuffle=config.discriminator.dis_reshuffle, scope=config.discriminator.dis_scope ) batch_iter = du.get_training_batches( set_train_src_path=config.generator.src_path, set_train_dst_path=config.generator.dst_path, set_batch_size=config.generator.batch_size, set_max_length=config.generator.max_length ) for epoch in range(1, config.gan_iter_num + 1): for gen_iter in range(config.gan_gen_iter_num): batch = next(batch_iter) x, y = batch[0], batch[1] generate_ab, generate_ba = generator.generate_step(x, y) logging.info("generate the samples") generate_ab_dealed, generate_ab_mask = deal_generated_samples(generate_ab, du.dst2idx) generate_ba_dealed, generate_ba_mask = deal_generated_samples(generate_ba, du.src2idx) ## for debug #print('the generate_ba_dealed is ') #sample_str=du.indices_to_words(generate_ba_dealed, 'src') #print(sample_str) #print('the generate_ab_dealed is ') #sample_str=du.indices_to_words(generate_ab_dealed, 'dst') #print(sample_str) x_to_maxlen = extend_sentence_to_maxlen(x) y_to_maxlen = extend_sentence_to_maxlen(y) logging.info("calculate the reward") rewards_ab = generator.get_reward(x=x, x_to_maxlen=x_to_maxlen, y_sample=generate_ab_dealed, y_sample_mask=generate_ab_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du, direction='ab') rewards_ba = generator.get_reward(x=y, x_to_maxlen=y_to_maxlen, y_sample=generate_ba_dealed, y_sample_mask=generate_ba_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du, direction='ba') loss_ab = generator.generate_step_and_update(x, generate_ab_dealed, rewards_ab) loss_ba = generator.generate_step_and_update(y, generate_ba_dealed, rewards_ba) print("the reward for ab and ba is ", rewards_ab, rewards_ba) print("the loss is for ab and ba is", loss_ab, loss_ba) logging.info("save the model into %s" % config.generator.modelFile) generator.saver.save(generator.sess, config.generator.modelFile) #### modified to here, next starts from here logging.info("prepare the gan_dis_data begin") data_num = prepare_gan_dis_data( train_data_source=config.generator.src_path, train_data_target=config.generator.dst_path, gan_dis_source_data=config.discriminator.s_domain_data, gan_dis_positive_data=config.discriminator.t_domain_data, num=config.generate_num, reshuf=True ) s_domain_data_half = config.discriminator.s_domain_data+'.half' t_domain_data_half = config.discriminator.t_domain_data+'.half' os.popen('head -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.s_domain_data + ' > ' + s_domain_data_half) os.popen('tail -n ' + str(config.generate_num / 2) + ' ' + config.discriminator.t_domain_data + ' > ' + t_domain_data_half) logging.info("generate and the save t_domain_generated_data in to %s." %config.discriminator.s_domain_generated_data) generator.generate_and_save(data_util=du, infile=s_domain_data_half, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.t_domain_generated_data, direction='ab' ) logging.info("generate and the save s_domain_generated_data in to %s." %config.discriminator.t_domain_generated_data) generator.generate_and_save(data_util=du, infile=t_domain_data_half, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.s_domain_generated_data, direction='ba' ) logging.info("prepare %d gan_dis_data done!" %data_num) logging.info("finetuen the discriminator begin") discriminator.train(max_epoch=config.gan_dis_iter_num, s_domain_data=config.discriminator.s_domain_data, t_domain_data=config.discriminator.t_domain_data, s_domain_generated_data=config.discriminator.s_domain_generated_data, t_domain_generated_data=config.discriminator.t_domain_generated_data ) discriminator.saver.save(discriminator.sess, discriminator.saveto) logging.info("finetune the discrimiantor done!") logging.info('reinforcement training done!')
def gan_train(config): sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True default_graph = tf.Graph() with default_graph.as_default(): sess = tf.Session(config=sess_config, graph=default_graph) logger = logging.getLogger('') du = DataUtil(config=config) du.load_vocab(src_vocab=config.generator.src_vocab, dst_vocab=config.generator.dst_vocab, src_vocab_size=config.src_vocab_size, dst_vocab_size=config.dst_vocab_size) generator = Model(config=config, graph=default_graph, sess=sess) generator.build_train_model() generator.build_generate(max_len=config.generator.max_length, generate_devices=config.generator.devices, optimizer=config.generator.optimizer) generator.build_rollout_generate( max_len=config.generator.max_length, roll_generate_devices=config.generator.devices) generator.init_and_restore(modelFile=config.generator.modelFile) #这个变量是什么?!filters? dis_filter_sizes = [ i for i in range(1, config.discriminator.dis_max_len, 4) ] dis_num_filters = [ (100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4) ] discriminator = DisCNN( sess=sess, max_len=config.discriminator.dis_max_len, num_classes=2, vocab_size=config.dst_vocab_size, vocab_size_s=config.src_vocab_size, batch_size=config.discriminator.dis_batch_size, dim_word=config.discriminator.dis_dim_word, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, source_dict=config.discriminator.dis_src_vocab, target_dict=config.discriminator.dis_dst_vocab, gpu_device=config.discriminator.dis_gpu_devices, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data, dev_positive_data=config.discriminator.dis_dev_positive_data, dev_negative_data=config.discriminator.dis_dev_negative_data, dev_source_data=config.discriminator.dis_dev_source_data, max_epoches=config.discriminator.dis_max_epoches, dispFreq=config.discriminator.dis_dispFreq, saveFreq=config.discriminator.dis_saveFreq, saveto=config.discriminator.dis_saveto, reload=config.discriminator.dis_reload, clip_c=config.discriminator.dis_clip_c, optimizer=config.discriminator.dis_optimizer, reshuffle=config.discriminator.dis_reshuffle, scope=config.discriminator.dis_scope) batch_train_iter = du.get_training_batches( set_train_src_path=config.generator.src_path, set_train_dst_path=config.generator.dst_path, set_batch_size=config.generator.batch_size, set_max_length=config.generator.max_length) max_SARI_results = 0.32 #!! max_BLEU_results = 0.77 #!! def evaluation_result(generator, config): nonlocal max_SARI_results nonlocal max_BLEU_results # 在test dataset 上开始验证 logging.info("Max_SARI_results: {}".format(max_SARI_results)) logging.info("Max_BLEU_results: {}".format(max_BLEU_results)) output_t = "prepare_data/test.8turkers.clean.out.gan" # Beam Search 8.turkers dataset evaluator = Evaluator(config=config, out_file=output_t) #logging.info("Evaluate on BLEU and SARI") SARI_results, BLEU_results = evaluator.translate() logging.info(" Current_SARI is {} \n Current_BLEU is {}".format( SARI_results, BLEU_results)) if SARI_results >= max_SARI_results or BLEU_results >= max_BLEU_results: if SARI_results >= max_SARI_results: max_SARI_results = SARI_results logging.info("SARI Update Successfully !!!") if BLEU_results >= max_BLEU_results: logging.info("BLEU Update Successfully !!!") max_BLEU_results = BLEU_results return True else: return False for epoch in range(1, config.gan_iter_num + 1): #10000 for gen_iter in range(config.gan_gen_iter_num): #1 batch_train = next(batch_train_iter) x, y_ground = batch_train[0], batch_train[1] y_sample = generator.generate_step(x) logging.info("1. Policy Gradient Training !!!") y_sample_dealed, y_sample_mask = deal_generated_samples( y_sample, du.dst2idx) #将y_sample数字矩阵用0补齐长度 x_to_maxlen = extend_sentence_to_maxlen( x, config.generator.max_length) #将x数字矩阵用0补齐长度 x_str = du.indices_to_words(x, 'dst') ground_str = du.indices_to_words(y_ground, 'dst') sample_str = du.indices_to_words(y_sample, 'dst') # Rewards = D(Discriminator) + Q(BLEU socres) logging.info("2. Calculate the Reward !!!") rewards = generator.get_reward( x=x, x_to_maxlen=x_to_maxlen, y_sample=y_sample_dealed, y_sample_mask=y_sample_mask, rollnum=config.rollnum, disc=discriminator, max_len=config.discriminator.dis_max_len, bias_num=config.bias_num, data_util=du) # Police Gradient 更新Generator模型 logging.info("3. Update the Generator Model !!!") loss = generator.generate_step_and_update( x, y_sample_dealed, rewards) #logging.info("The reward is ",rewards) #logging.info("The loss is ",loss) #update_or_not_update=evaluation_result(generator,config) #if update_or_not_update: # 保存Generator模型 logging.info("4. Save the Generator model into %s" % config.generator.modelFile) generator.saver.save(generator.sess, config.generator.modelFile) if config.generator.teacher_forcing: logging.info("5. Doing the Teacher Forcing begin!") y_ground, y_ground_mask = deal_generated_samples_to_maxlen( y_sample=y_ground, dicts=du.dst2idx, maxlen=config.discriminator.dis_max_len) rewards_ground = np.ones_like(y_ground) rewards_ground = rewards_ground * y_ground_mask loss = generator.generate_step_and_update( x, y_ground, rewards_ground) #logging.info("The teacher forcing reward is ", rewards_ground) #logging.info("The teacher forcing loss is ", loss) logging.info("5. Evaluation SARI and BLEU") update_or_not_update = evaluation_result(generator, config) if update_or_not_update: #保存Generator模型 generator.saver.save(generator.sess, config.generator.modelFile) data_num = prepare_gan_dis_data( train_data_source=config.generator.src_path, train_data_target=config.generator.dst_path, gan_dis_source_data=config.discriminator.dis_source_data, gan_dis_positive_data=config.discriminator.dis_positive_data, num=config.generate_num, reshuf=True) logging.info("8.Generate Negative Dataset for Discriminator !!!") # 生成negative数据集 generator.generate_and_save( data_util=du, infile=config.discriminator.dis_source_data, generate_batch=config.discriminator.dis_batch_size, outfile=config.discriminator.dis_negative_data) logging.info("9.Negative Dataset was save in to %s." % config.discriminator.dis_negative_data) logging.info("10.Finetuen the discriminator begin !!!!!") discriminator.train( max_epoch=config.gan_dis_iter_num, positive_data=config.discriminator.dis_positive_data, negative_data=config.discriminator.dis_negative_data, source_data=config.discriminator.dis_source_data) discriminator.saver.save(discriminator.sess, discriminator.saveto) logging.info("11.Finetune the discrimiantor done !!!!") logging.info('Reinforcement training done!')