def main(): random.seed(SEED) np.random.seed(SEED) assert START_TOKEN == 0 gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing vocab_size = 5000 generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) target_params = pickle.load(open('save/target_params_py3.pkl', 'rb')) target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model mediator = Generator(vocab_size, BATCH_SIZE * 2, EMB_DIM * 2, HIDDEN_DIM * 2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file) gen_data_loader.create_batches(positive_file) generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file) val_data_loader.create_batches(eval_file) log = open('save/experiment-log.txt', 'w') log_nll = open('save/experiment-log-nll.txt', 'w') log_jsd = open('save/experiment-log-jsd.txt', 'w') # pre-train generator (default 0 epochs)(not recommended) print('Start pre-training...') log.write('pre-training...\n') for epoch in range(PRE_EPOCH_NUM): loss = mle_epoch(sess, generator, gen_data_loader) if epoch % 1 == 0: generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) likelihood_data_loader.create_batches(negative_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print('pre-train epoch ', epoch, 'nll_oracle ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll_oracle:\t' + str( test_loss) + '\n' log_nll.write(buffer) if epoch % 1 == 0: test_loss = target_loss(sess, generator, val_data_loader) print('pre-train epoch ', epoch, 'nll_test ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str( test_loss) + '\n' log_nll.write(buffer) print( '#########################################################################' ) print('Start Cooperative Training...') for iter_idx in range(TOTAL_BATCH): # Train the generator for one step for it in range(2): samples = generator.generate(sess) rewards = mediator.get_reward( sess, np.concatenate([samples, samples], axis=0)) feed = { generator.x: samples, generator.rewards: rewards[0:BATCH_SIZE] } _ = sess.run(generator.g_updates, feed_dict=feed) # Test if iter_idx % 100 == 0 or iter_idx == TOTAL_BATCH - 1: generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) likelihood_data_loader.create_batches(negative_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'batch:\t' + str(iter_idx) + '\tnll_oracle:\t' + str( test_loss) + '\n' print('batch: ', iter_idx, 'nll_oracle: ', test_loss) log_nll.write(buffer) if iter_idx % 100 == 0: test_loss = target_loss(sess, generator, val_data_loader) print('batch:\t', iter_idx, 'nll_test ', test_loss) buffer = 'batch:\t' + str(iter_idx) + '\tnll_test:\t' + str( test_loss) + '\n' log_nll.write(buffer) # Train the mediator for _ in range(1): bnll_ = [] collected_x = [] ratio = 2 for it in range(ratio): if it % 2 == 0: x_batch = gen_data_loader.next_batch() else: x_batch = generator.generate(sess) collected_x.append(x_batch) collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH]) np.random.shuffle(collected_x) collected_x = np.reshape(collected_x, [-1, BATCH_SIZE * 2, SEQ_LENGTH]) for it in range(1): feed = { mediator.x: collected_x[it], } bnll = sess.run(mediator.likelihood_loss, feed) bnll_.append(bnll) # sess.run(mediator.dropout_on) _ = sess.run(mediator.likelihood_updates, feed) # sess.run(mediator.dropout_off) if (iter_idx * 4) % gen_data_loader.num_batch == 0: bnll = np.mean(bnll_) gnll = sess.run( mediator.likelihood_loss, feed_dict={ mediator.x: np.reshape( [generator.generate(sess), generator.generate(sess)], [BATCH_SIZE * 2, SEQ_LENGTH]) }) print("mediator cooptrain iter#%d, balanced_nll %f, g_nll %f" % (iter_idx, bnll, gnll)) log.write("%d\t%f\n" % (iter_idx, bnll)) if iter_idx % gen_data_loader.num_batch == 0: jsd = jsd_calculate(sess, generator, target_lstm) print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch, 'jsd ', jsd) log_jsd.write("%d\t%f\n" % (iter_idx // gen_data_loader.num_batch, jsd)) log.close() log_nll.close() log_jsd.close()
def main(unused_argv): config_train = training_config() config_gen = generator_config() config_dis = discriminator_config() np.random.seed(config_train.seed) assert config_train.start_token == 0 #Build dataloader for generaotr, testing and discriminator gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size) likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size) dis_data_loader = Dis_dataloader(config_dis.dis_batch_size) #Build generator and its rollout generator = Generator(config=config_gen) generator.build() rollout_gen = rollout(config=config_gen) #Build target LSTM target_params = cPickle.load(StrToBytes(open('save/target_params.pkl')), encoding='bytes') target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model #Build discriminator discriminator = Discriminator(config=config_dis) discriminator.build_discriminator() #Build optimizer op for pretraining pretrained_optimizer = tf.train.AdamOptimizer( config_train.gen_learning_rate) var_pretrained = [ v for v in tf.trainable_variables() if 'teller' in v.name ] #Using name 'teller' here to prevent name collision of target LSTM gradients, variables = zip(*pretrained_optimizer.compute_gradients( generator.pretrained_loss, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) gen_pre_upate = pretrained_optimizer.apply_gradients( zip(gradients, variables)) #Initialize all variables sess = tf.Session(config=config_hardware) sess.run(tf.global_variables_initializer()) #Initalize data loader of generator # generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file) gen_data_loader.create_batches(config_train.positive_file) #Start pretraining log = open('save/experiment-log.txt', 'w') print('Start pre-training generator...') log.write('pre-training...\n') for epoch in range(config_train.pretrained_epoch_num): gen_data_loader.reset_pointer() for it in range(gen_data_loader.num_batch): batch = gen_data_loader.next_batch() _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\ generator.input_seqs_mask:np.ones_like(batch)}) if epoch % config_train.test_per_epoch == 0: # generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print('pre-train epoch ', epoch, 'test_loss ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str( test_loss) + '\n' log.write(buffer) print('Start pre-training discriminator...') for t in range(config_train.dis_update_time_pre): print("Times: " + str(t)) generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) for _ in range(config_train.dis_update_epoch_pre): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) #Build optimizer op for adversarial training train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate) gradients, variables = zip(*train_adv_opt.compute_gradients( generator.gen_loss_adv, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables)) #Initialize global variables of optimizer for adversarial training uninitialized_var = [ e for e in tf.global_variables() if e not in tf.trainable_variables() ] init_vars_uninit_op = tf.variables_initializer(uninitialized_var) sess.run(init_vars_uninit_op) #Start adversarial training for total_batch in range(config_train.total_batch): for iter_gen in range(config_train.gen_update_time): samples = sess.run(generator.sample_word_list_reshape) feed = {"pred_seq_rollout:0": samples} reward_rollout = [] #calcuate the reward given in the specific stpe t by roll out for iter_roll in range(config_train.rollout_num): rollout_list = sess.run(rollout_gen.sample_rollout_step, feed_dict=feed) rollout_list_stack = np.vstack( rollout_list ) #shape: #batch_size * #rollout_step, #sequence length reward_rollout_seq = sess.run( discriminator.ypred_for_auc, feed_dict={ discriminator.input_x: rollout_list_stack, discriminator.dropout_keep_prob: 1.0 }) reward_last_tok = sess.run(discriminator.ypred_for_auc, feed_dict={ discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0 }) reward_allseq = np.concatenate( (reward_rollout_seq, reward_last_tok), axis=0)[:, 1] reward_tmp = [] for r in range(config_gen.gen_batch_size): reward_tmp.append(reward_allseq[range( r, config_gen.gen_batch_size * config_gen.sequence_length, config_gen.gen_batch_size)]) reward_rollout.append(np.array(reward_tmp)) rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\ generator.rewards:rewards}) if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1: generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str( test_loss) + '\n' print('total_batch: ', total_batch, 'test_loss: ', test_loss) log.write(buffer) for _ in range(config_train.dis_update_time_adv): generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) for _ in range(config_train.dis_update_epoch_adv): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) log.close()
def main(unused_argv): config_train = training_config() config_gen = generator_config() config_dis = discriminator_config() np.random.seed(config_train.seed) assert config_train.start_token == 0 #Build dataloader for generaotr, testing and discriminator gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size) likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size) dis_data_loader = Dis_dataloader(config_dis.dis_batch_size) #Build generator and its rollout generator = Generator(config=config_gen) # 生成 3个神经网络 generator.build() # 快速展开网络,序列未生成完就预测后边的序列,用于计算reward rollout_gen = rollout(config=config_gen) #Build target LSTM target_params = cPickle.load(open('save/target_params.pkl')) target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model #Build discriminator discriminator = Discriminator(config=config_dis) discriminator.build_discriminator() #Build optimizer op for pretraining pretrained_optimizer = tf.train.AdamOptimizer( config_train.gen_learning_rate) # 取出 teller 的所有变量, teller在 generator和rollout网络中 var_pretrained = [ v for v in tf.trainable_variables() if 'teller' in v.name ] #Using name 'teller' here to prevent name collision of target LSTM # zip函数将 2个迭代器 组成tuple gradients, variables = zip(*pretrained_optimizer.compute_gradients( generator.pretrained_loss, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) gen_pre_upate = pretrained_optimizer.apply_gradients( zip(gradients, variables)) #Initialize all variables sess = tf.Session(config=config_hardware) sess.run(tf.global_variables_initializer()) #Initalize data loader of generator utils.py文件中 # target_lstm 网络生成真实数据 写入config_train.positive_file 文件 generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file) gen_data_loader.create_batches(config_train.positive_file) #Start pretraining log = open('save/experiment-log.txt', 'w') print 'Start pre-training generator...' log.write('pre-training...\n') for epoch in xrange(config_train.pretrained_epoch_num): gen_data_loader.reset_pointer() for it in xrange(gen_data_loader.num_batch): #见第60行,加载target_lstm 神经网络的数据,用于预训练生成器====真实样本 batch = gen_data_loader.next_batch() #真实数据训练 generator;有监督学习 batch 最后第一个是label _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\ generator.input_seqs_mask:np.ones_like(batch)}) if epoch % config_train.test_per_epoch == 0: # generator 生成样本 与 真实数据的相似度 generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) #评估生成质量 test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print 'pre-train epoch ', epoch, 'test_loss ', test_loss buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str( test_loss) + '\n' log.write(buffer) print 'Start pre-training discriminator...' for t in range(config_train.dis_update_time_pre): print "Times: " + str(t) # generator生成假数据+ target_lstm的真实数据;; 用于训练 generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) # 混合真假数据 dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) for _ in range(config_train.dis_update_epoch_pre): dis_data_loader.reset_pointer() for it in xrange(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob } #交叉上最小; 主要是训练评分网络 用于给generator提供reward _ = sess.run(discriminator.train_op, feed) #Build optimizer op for adversarial training train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate) gradients, variables = zip(*train_adv_opt.compute_gradients( generator.gen_loss_adv, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables)) #Initialize global variables of optimizer for adversarial training uninitialized_var = [ e for e in tf.global_variables() if e not in tf.trainable_variables() ] init_vars_uninit_op = tf.variables_initializer(uninitialized_var) sess.run(init_vars_uninit_op) #Start adversarial training 开始对抗训练 for total_batch in xrange(config_train.total_batch): for iter_gen in xrange(config_train.gen_update_time): # 用generator进行抽样; LSTM 生成序列 samples = sess.run(generator.sample_word_list_reshape) feed = {"pred_seq_rollout:0": samples} reward_rollout = [] #calcuate the reward given in the specific stpe t by roll out # 用rollout网络计算指定动作的回报 for iter_roll in xrange(config_train.rollout_num): # 生成器采样的获得的单词传给 rollout ??有一个疑问?samples看代码是完整序列(与论文不符),为什么还要rollout rollout_list = sess.run(rollout_gen.sample_rollout_step, feed_dict=feed) rollout_list_stack = np.vstack( rollout_list ) #shape: #batch_size * #rollout_step, #sequence length # 蒙特卡洛 展开成序列,贝尔曼方程计算 reward reward_rollout_seq = sess.run( discriminator.ypred_for_auc, feed_dict={ discriminator.input_x: rollout_list_stack, discriminator.dropout_keep_prob: 1.0 }) reward_last_tok = sess.run(discriminator.ypred_for_auc, feed_dict={ discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0 }) reward_allseq = np.concatenate( (reward_rollout_seq, reward_last_tok), axis=0)[:, 1] reward_tmp = [] for r in xrange(config_gen.gen_batch_size): reward_tmp.append(reward_allseq[range( r, config_gen.gen_batch_size * config_gen.sequence_length, config_gen.gen_batch_size)]) reward_rollout.append(np.array(reward_tmp)) #计算reward rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num # 用reward 指导 generator 更新梯度 _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\ generator.rewards:rewards}) if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1: #对抗训练后 用generator再次生成样本与模拟器(target_lstm,真实数据)进行比对 generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) #util.py中定义 test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str( test_loss) + '\n' print 'total_batch: ', total_batch, 'test_loss: ', test_loss log.write(buffer) for _ in range(config_train.dis_update_time_adv): generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) for _ in range(config_train.dis_update_epoch_adv): dis_data_loader.reset_pointer() for it in xrange(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob } #训练这个评分网络, score _ = sess.run(discriminator.train_op, feed) log.close()
def main(): # load embedding info vocab_dict, vocab_size, vocab_list = load_emb_data(emb_dict_file) # prepare data pre_train_data_loader = Gen_Data_loader(BATCH_SIZE, vocab_dict) pre_train_data_loader.create_batches( [imdb_file_id, sst_pos_file_id, sst_neg_file_id]) gen_data_loader = Gen_Data_loader(BATCH_SIZE, vocab_dict) gen_data_loader.create_batches([sst_pos_file_id, sst_neg_file_id]) dis_data_loader = Dis_Data_loader(BATCH_SIZE, vocab_dict, MAX_SEQ_LENGTH) # build model # num_emb, vocab_dict, batch_size, emb_dim, num_units, sequence_length generator = Generator(vocab_size, vocab_dict, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, MAX_SEQ_LENGTH) discriminator = Discriminator(sequence_length=MAX_SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) log = open('save/experiment-log.txt', 'w') buffer = 'Start pre-training generator...' print(buffer) log.write(buffer + '\n') for epoch in range(150): #120 train_loss = pre_train_epoch(sess, generator, pre_train_data_loader) if epoch % 5 == 0: generate_samples(sess, generator, 1, eval_file, vocab_list, if_log=True, epoch=epoch) print(' pre-train epoch ', epoch, 'train_loss ', train_loss) buffer = ' epoch:\t' + str(epoch) + '\tnll:\t' + str( train_loss) + '\n' log.write(buffer) buffer = 'Start pre-training discriminator...' print(buffer) log.write(buffer) for _ in range(10): # 10 generate_samples(sess, generator, 70, negative_file, vocab_list) dis_data_loader.load_train_data([sst_pos_file_id, sst_neg_file_id], [negative_file]) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob, } d_loss, d_acc, _ = sess.run([ discriminator.loss, discriminator.accuracy, discriminator.train_op ], feed) buffer = "discriminator loss %f acc %f" % (d_loss, d_acc) print(buffer) log.write(buffer + '\n') print("Start Adversarial Training...") log.write('adversarial training...') for total_batch in range(TOTAL_BATCH): # Train the generator for it in range(2): # print("1") samples = generator.generate(sess) samples = produce_samples(samples) # print("2") rewards = generator.get_reward(sess, samples, 16, discriminator) # print("3") a = str(samples[0]) b = str(rewards[0]) # rewards = change_rewards(rewards) # c = str(rewards[0]) d = build_from_ids(samples[0], vocab_list) buffer = "%s\n%s\n%s\n\n" % (d, a, b) print(buffer) log.write(buffer) # print("4") rewards_loss = generator.update_with_rewards( sess, samples, rewards) # print("5") # good rewards # good_samples = gen_data_loader.next_batch() # rewards = np.array([[0.0001] * SEQ_LENGTH] * BATCH_SIZE) # a = str(good_samples[0]) # b = str(rewards[0]) # buffer = "%s\n%s\n\n" % (a, b) # print(buffer) # log.write(buffer) # rewards_loss = generator.update_with_rewards(sess, good_samples, rewards, START_TOKEN) # little1 good reward little1_samples = gen_data_loader.next_batch() rewards = generator.get_reward(sess, little1_samples, 16, discriminator) a = str(little1_samples[0]) b = str(rewards[0]) buffer = "%s\n%s\n\n" % (a, b) # print(buffer) log.write(buffer) rewards_loss = generator.update_with_rewards( sess, little1_samples, rewards) # generate_infer(sess, generator, epoch, vocab_list) # Test if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: generate_samples(sess, generator, 120, eval_file, vocab_list, if_log=True) generate_infer(sess, generator, total_batch, vocab_list) # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) # likelihood_data_loader.create_batches(eval_file) # test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'reward-train epoch %s train loss %s' % ( str(total_batch), str(rewards_loss)) print(buffer) log.write(buffer + '\n') generator.save_model(sess) # Train the discriminator begin = True for _ in range(1): generate_samples(sess, generator, 70, negative_file, vocab_list) dis_data_loader.load_train_data([sst_pos_file_id, sst_neg_file_id], [negative_file]) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob, } d_loss, d_acc, _ = sess.run([ discriminator.loss, discriminator.accuracy, discriminator.train_op ], feed) if (total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1) and begin: buffer = "discriminator loss %f acc %f\n" % (d_loss, d_acc) print(buffer) log.write(buffer) begin = False # pretrain for _ in range(10): train_loss = pre_train_epoch(sess, generator, pre_train_data_loader)
def main(): random.seed(SEED) np.random.seed(SEED) if os.path.exists(DICO_PKL): with open(DICO_PKL, 'rb') as f: word_to_id, id_to_word = pickle.load(f) else: word_to_id, id_to_word = create_dico(DICO) with open(DICO_PKL, 'wb') as f: pickle.dump([word_to_id, id_to_word], f) gen_data_loader = Gen_Data_loader(BATCH_SIZE, word_to_id) dis_data_loader = Dis_Data_loader(BATCH_SIZE, word_to_id) vocab_size = len(word_to_id) assert START_TOKEN == word_to_id['sos'] generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) discriminator = BLEUCNN(SEQ_LENGTH, 2, EMB_DIM, generator) mobilenet = MobileNet(BATCH_SIZE) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) mobilenet.load_pretrained_weights(sess) sess.run(tf.global_variables_initializer()) log = open('experiment-log.txt', 'w', encoding='utf-8') # pre-train generator and discriminator log.write('pre-training...\n') print('Start pre-training discriminator...') datas = create_data(DICO, word_to_id) gen_data_loader.create_batches(CORPUS, IMAGE) samples = [] for it in range(gen_data_loader.num_batch): inp_batch, image_batch = gen_data_loader.next_batch() feed_dict = {mobilenet.X: image_batch, mobilenet.is_training: False} hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict) samples.extend(generator.generate(sess, hidden_batch).tolist()) dis_data_loader.create_batches(random.sample(datas, 3000), samples) for _ in range(PRE_EPOCH_NUM): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, labels = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.labels: labels, discriminator.dropout_keep_prob: 0.75 } _ = sess.run(discriminator.train_op, feed) print('Start pre-training generator...') for epoch in range(PRE_EPOCH_NUM): supervised_g_losses = [] gen_data_loader.reset_pointer() for it in range(gen_data_loader.num_batch): inp_batch, image_batch = gen_data_loader.next_batch() feed_dict = { mobilenet.X: image_batch, mobilenet.is_training: False } hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict) _, g_loss = generator.pretrain_step(sess, inp_batch, hidden_batch) supervised_g_losses.append(g_loss) loss = np.mean(supervised_g_losses) if epoch % 5 == 0: print('pre-train epoch ', epoch, 'train_loss ', loss) buffer = 'epoch:\t' + str(epoch) + '\ttrain_loss:\t' + str( loss) + '\n' log.write(buffer) rollout = ROLLOUT(generator, 0.8) print( '#########################################################################' ) print('Start REINFORCE Training...') log.write('REINFORCE training...\n') for total_batch in range(RL_EPOCH_NUM): gen_data_loader.reset_pointer() for it in range(gen_data_loader.num_batch): ra = random.randint(0, 1) inp_batch, image_batch = gen_data_loader.next_batch(shuffle=ra) feed_dict = { mobilenet.X: image_batch, mobilenet.is_training: False } hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict) samples = generator.generate(sess, hidden_batch) rewards = rollout.get_reward(sess, samples, hidden_batch, 16, discriminator) feed = { generator.x: inp_batch, generator.rewards: rewards, generator.hiddens: hidden_batch } _ = sess.run(generator.g_updates, feed_dict=feed) # Test if total_batch % 5 == 0 or total_batch == RL_EPOCH_NUM - 1: mean_rewards = [] gen_data_loader.reset_pointer() for it in range(gen_data_loader.num_batch): inp_batch, image_batch = gen_data_loader.next_batch() feed_dict = { mobilenet.X: image_batch, mobilenet.is_training: False } hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict) samples = generator.generate(sess, hidden_batch) rewards = rollout.get_reward(sess, samples, hidden_batch, 16, discriminator) mean_rewards.append(np.mean(rewards[:, -1])) reward = np.mean(mean_rewards) buffer = 'epoch:\t' + str(total_batch) + '\treward:\t' + str( reward) + '\n' print('total_batch: ', total_batch, 'reward: ', reward) log.write(buffer) generator.save_weight(sess) # Update roll-out parameters rollout.update_params() discriminator.update_embedding() # Train the discriminator samples = [] for it in range(gen_data_loader.num_batch): inp_batch, image_batch = gen_data_loader.next_batch() feed_dict = { mobilenet.X: image_batch, mobilenet.is_training: False } hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict) samples.extend(generator.generate(sess, hidden_batch).tolist()) dis_data_loader.create_batches(random.sample(datas, 3000), samples) dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, labels = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.labels: labels, discriminator.dropout_keep_prob: 0.75 } _ = sess.run(discriminator.train_op, feed) # final test gen_data_loader.reset_pointer() _, image_batch = gen_data_loader.next_batch() feed_dict = {mobilenet.X: image_batch, mobilenet.is_training: False} hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict) samples = generator.generate(sess, hidden_batch) y = samples.tolist() sams = [] for k, sam in enumerate(y): sa = [id_to_word[i] for i in sam] sa = ''.join(sa) sams.append(sa) for sam in sams: log.write(sam + '\n') log.close()
TEST BEGIN @3.29 TEST 1 @4.18 ''' print( '#########################################################################' ) print('Start Adversarial Training...') log.write('adversarial training...\n') sampel_log = open('save/sample-log.txt', 'w') gen_data_loader.reset_pointer() for total_batch in range(TOTAL_BATCH): # Train the generator for one step samples = None for it in range(5): batch, ques_len = gen_data_loader.next_batch() samples = generator.generate(sess, batch, ques_len) rewards = get_reward(sess, samples, 16, generator, discriminator) # print("rewards sample: ", rewards[0]) feed = { generator.x: samples, generator.rewards: rewards, generator.target_sequence_length: ques_len, generator.max_sequence_length_per_batch: max(ques_len) } _, g_loss = sess.run([generator.g_updates, generator.g_loss], feed_dict=feed) buffer = 'epoch:\t' + str(total_batch) + '\tg_loss:\t' + str(g_loss) + '\n' print('total_batch: ', total_batch, 'g_loss: ', g_loss) log.write(buffer)
def main(): random.seed(SEED) np.random.seed(SEED) # assert START_TOKEN == 0 gen_data_loader = Gen_Data_loader(BATCH_SIZE) dis_data_loader = Dis_dataloader(BATCH_SIZE) generator = Generator(vocab_size, condition_size, FEATURE_NUM, BATCH_SIZE, EMB_DIM, COND_DIM, HIDDEN_DIM, Z_DIM, SEQ_LENGTH, START_TOKEN, vocab_file, condition_file, word_vec=word_vec) discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) # Checkpoint saver = tf.train.Saver() ckpt = get_ckpt(ckpt_dir) if ckpt is not None: print("Load checkpoints from: ", ckpt) saver.restore(sess, ckpt) # Load true data gen_data_loader.create_batches(positive_file) log = open('save/experiment-log.txt', 'w') # pre-train generator print('Start pre-training...') log.write('pre-training...\n') for epoch in range(PRE_EPOCH_NUM): g_loss, lstm_loss, recon_loss, kl_loss = pre_train_epoch( sess, generator, gen_data_loader) if epoch % 10 == 0: log.write( 'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f\n' % (epoch, g_loss, lstm_loss, recon_loss, kl_loss)) print( 'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f' % (epoch, g_loss, lstm_loss, recon_loss, kl_loss)) generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, eval_file) if epoch % 20 == 0: saver.save(sess, os.path.join(ckpt_dir, 'checkpoint_' + str(epoch))) print('Start pre-training discriminator...') # Train 3 epoch on the generated data and do this for 50 times for _ in range(50): generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) rollout = ROLLOUT(generator, 0.8) print( '#########################################################################' ) print('Start Adversarial Training...') log.write('adversarial training...\n') for total_batch in range(TOTAL_BATCH): # Train the generator for one step for it in range(1): samples = generator.generate(sess, gen_data_loader.next_batch()) rewards = rollout.get_reward(sess, samples, 16, discriminator) feed = {generator.x: samples, generator.rewards: rewards} _ = sess.run(generator.g_updates, feed_dict=feed) # # Test # if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: # generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, eval_file) # Update roll-out parameters rollout.update_params() # Train the discriminator for _ in range(5): generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) log.close()
def main(): print('program start') from utils.text_process import text_precess, text_to_code # TODO: move? from utils.text_process import get_tokenlized, get_word_list, get_dict random.seed(SEED) np.random.seed(SEED) assert START_TOKEN == 0 # JJ added SEQ_LENGTH, vocab_size = text_precess(true_file, val_file) gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) gan_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing #vocab_size = 5000 # JJ added # Create training file and dicts tokens = get_tokenlized(true_file) val_tokens = get_tokenlized(val_file) word_set = get_word_list(tokens + val_tokens) [word_index_dict, index_word_dict] = get_dict(word_set) with open(oracle_file, 'w') as outfile: outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH)) with open(val_oracle_file, 'w') as outfile: outfile.write(text_to_code(val_tokens, word_index_dict, SEQ_LENGTH)) generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) #target_params = pickle.load(open('save/target_params_py3.pkl', 'rb')) #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model mediator = Mediator(vocab_size, BATCH_SIZE, EMB_DIM * 2, HIDDEN_DIM * 2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE, learning_rate=3e-3, with_professor_forcing=False) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file) gen_data_loader.create_batches(oracle_file) #positive_file) gan_data_loader.create_batches(oracle_file) #positive_file) #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file) val_data_loader.create_batches(val_oracle_file) #eval_file) log = open('save/experiment-log.txt', 'w') log_nll = open('save/experiment-log-nll.txt', 'w') #log_jsd = open('save/experiment-log-jsd.txt', 'w') # pre-train generator (default 0 epochs)(not recommended) print('Start pre-training...') log.write('pre-training...\n') saver = tf.train.Saver(tf.global_variables()) if RESTORE: saver.restore(sess, "saved_model/CoT") for epoch in range(PRE_EPOCH_NUM): loss = mle_epoch(sess, generator, gen_data_loader) if epoch % 1 == 0: generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) likelihood_data_loader.create_batches(negative_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print('pre-train epoch ', epoch, 'nll_oracle ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll_oracle:\t' + str( test_loss) + '\n' log_nll.write(buffer) if epoch % 1 == 0: test_loss = target_loss(sess, generator, val_data_loader) print('pre-train epoch ', epoch, 'nll_test ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str( test_loss) + '\n' log_nll.write(buffer) print( '#########################################################################' ) toc = time.time() # JJ print('Start Cooperative Training...') for iter_idx in range(TOTAL_BATCH): print('iteration: ' + str(iter_idx) + '\ntime: ' + str(time.time() - toc)) toc = time.time() # Train the generator for one step for it in range(1): samples = generator.generate(sess) rewards = mediator.get_reward(sess, samples) feed = {generator.x: samples, generator.rewards: rewards} _ = sess.run( generator.g_updates, feed_dict=feed ) # JJ -> loss, _ = sess.run([generator.g_loss, generator.g_updates], feed_dict=feed) # Test # JJ delete ''' if iter_idx % 100 == 0 or iter_idx == TOTAL_BATCH - 1: generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) likelihood_data_loader.create_batches(negative_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'batch:\t' + str(iter_idx) + '\tnll_oracle:\t' + str(test_loss) + '\n' print('batch: ', iter_idx, 'nll_oracle: ', test_loss) log_nll.write(buffer) ''' if iter_idx % gen_data_loader.num_batch == 0: # epochs instead of batches #if iter_idx % 100 == 0: test_loss = target_loss(sess, generator, val_data_loader) print('epoch:\t', iter_idx // gen_data_loader.num_batch, 'nll_test ', test_loss) buffer = 'epoch:\t' + str( iter_idx // gen_data_loader.num_batch) + '\tnll_test:\t' + str( test_loss) + '\n' #print('batch:\t', iter_idx, 'nll_test ', test_loss) #buffer = 'batch:\t'+ str(iter_idx) + '\tnll_test:\t' + str(test_loss) + '\n' log_nll.write(buffer) saver.save(sess, "saved_model/CoT") # Train the mediator for _ in range(1): bnll_ = [] """ d_loss_ = [] for it in range(3): feed = { mediator.x0: gan_data_loader.next_batch(), mediator.x1: generator.generate(sess) } d_loss, _ = sess.run([mediator.d_loss, mediator.d_update], feed) d_loss_.append(d_loss) """ for it in range(1): feed = { mediator.x0: gen_data_loader.next_batch(), mediator.x1: generator.generate(sess) } bnll = sess.run(mediator.likelihood_loss, feed) bnll_.append(bnll) sess.run(mediator.dropout_on) _ = sess.run(mediator.likelihood_updates, feed) sess.run(mediator.dropout_off) if iter_idx % 10 == 0: bnll = np.mean(bnll_) print("mediator cooptrain iter#%d, balanced_nll %f" % (iter_idx, bnll)) log.write("%d\t%f\n" % (iter_idx, bnll)) #if iter_idx % gen_data_loader.num_batch == 0: #jsd = jsd_calculate(sess, generator, target_lstm) #print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch, 'jsd ', jsd) #log_jsd.write("%d\t%f\n" % (iter_idx // gen_data_loader.num_batch, jsd)) #saver.save(sess, "saved_model/CoT") log.close() log_nll.close()
def main(unused_argv): config_train = training_config() config_gen = generator_config() config_dis = discriminator_config() np.random.seed(config_train.seed) assert config_train.start_token == 0 gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size) likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size) dis_data_loader = Dis_dataloader(config_dis.dis_batch_size) generator = Generator(config=config_gen) generator.build() rollout_gen = rollout(config=config_gen) #Build target LSTM target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='iso-8859-1') target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model # Build discriminator discriminator = Discriminator(config=config_dis) discriminator.build_discriminator() # Build optimizer op for pretraining pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate) var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name] gradients, variables = zip( *pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) gen_pre_update = pretrained_optimizer.apply_gradients(zip(gradients, variables)) sess = tf.Session() sess.run(tf.global_variables_initializer()) generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file) gen_data_loader.create_batches(config_train.positive_file) log = open('save/experiment-log.txt','w') print('Start pre-training generator....') log.write('pre-training...\n') for epoch in range(config_train.pretrained_epoch_num): gen_data_loader.reset_pointer() for it in range(gen_data_loader.num_batch): batch = gen_data_loader.next_batch() _,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch, generator.input_seqs_mask:np.ones_like(batch)}) if epoch % config_train.test_per_epoch == 0: #进行测试,通过Generator产生一批序列, generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.eval_file) # 创建这批序列的data-loader likelihood_data_loader.create_batches(config_train.eval_file) # 使用oracle 计算 交叉熵损失nll test_loss = target_loss(sess,target_lstm,likelihood_data_loader) # 打印并写入日志 print('pre-train ',epoch, ' test_loss ',test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n' log.write(buffer) print('Start pre-training discriminator...') for t in range(config_train.dis_update_time_pre): print("Times: " + str(t)) generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file) for _ in range(config_train.dis_update_time_pre): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch,y_batch = dis_data_loader.next_batch() feed_dict = { discriminator.input_x : x_batch, discriminator.input_y : y_batch, discriminator.dropout_keep_prob : config_dis.dis_dropout_keep_prob } _ = sess.run(discriminator.train_op,feed_dict) # Build optimizer op for adversarial training train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate) gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables)) # Initialize global variables of optimizer for adversarial training uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()] init_vars_uninit_op = tf.variables_initializer(uninitialized_var) sess.run(init_vars_uninit_op) # Start adversarial training for total_batch in range(config_train.total_batch): for iter_gen in range(config_train.gen_update_time): samples = sess.run(generator.sample_word_list_reshpae) feed = {'pred_seq_rollout:0':samples} reward_rollout = [] for iter_roll in range(config_train.rollout_num): rollout_list = sess.run(rollout_gen.sample_rollout_step,feed_dict=feed) # np.vstack 它是垂直(按照行顺序)的把数组给堆叠起来。 rollout_list_stack = np.vstack(rollout_list) reward_rollout_seq = sess.run(discriminator.ypred_for_auc,feed_dict={ discriminator.input_x:rollout_list_stack,discriminator.dropout_keep_prob:1.0 }) reward_last_tok = sess.run(discriminator.ypred_for_auc,feed_dict={ discriminator.input_x:samples,discriminator.dropout_keep_prob:1.0 }) reward_allseq = np.concatenate((reward_rollout_seq,reward_last_tok),axis=0)[:,1] reward_tmp = [] for r in range(config_gen.gen_batch_size): reward_tmp.append(reward_allseq[range(r,config_gen.gen_batch_size * config_gen.sequence_length,config_gen.gen_batch_size)]) reward_rollout.append(np.array(reward_tmp)) rewards = np.sum(reward_rollout,axis = 0) / config_train.rollout_num _,gen_loss = sess.run([train_adv_update,generator.gen_loss_adv],feed_dict={generator.input_seqs_adv:samples, generator.rewards:rewards}) if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1: generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n' print ('total_batch: ', total_batch, 'test_loss: ', test_loss) log.write(buffer) for _ in range(config_train.dis_update_time_adv): generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file) for _ in range(config_train.dis_update_time_adv): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch,y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x:x_batch, discriminator.input_y:y_batch, discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob } _ = sess.run(discriminator.train_op,feed) log.close()
def main(): print('program start') from utils.text_process import text_precess, text_to_code # TODO: move? from utils.text_process import get_tokenlized, get_word_list, get_dict random.seed(SEED) np.random.seed(SEED) assert START_TOKEN == 0 SEQ_LENGTH, vocab_size = text_precess(true_file, val_file) gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # Create training file and dicts tokens = get_tokenlized(true_file) val_tokens = get_tokenlized(val_file) word_set = get_word_list(tokens + val_tokens) [word_index_dict, index_word_dict] = get_dict(word_set) with open(oracle_file, 'w') as outfile: outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH)) with open(val_oracle_file, 'w') as outfile: outfile.write(text_to_code(val_tokens, word_index_dict, SEQ_LENGTH)) generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) #target_params = pickle.load(open('save/target_params_py3.pkl', 'rb')) #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model # replace target lstm with true data mediator = Generator(vocab_size, BATCH_SIZE * 2, EMB_DIM * 2, HIDDEN_DIM * 2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) gen_data_loader.create_batches(oracle_file) val_data_loader.create_batches(val_oracle_file) log = open('save/experiment-log.txt', 'w') log_nll = open('save/experiment-log-nll.txt', 'w') # pre-train generator (default 0 epochs)(not recommended) print('Start pre-training...') log.write('pre-training...\n') for epoch in range(PRE_EPOCH_NUM): loss = mle_epoch(sess, generator, gen_data_loader) if epoch % 5 == 0: generate_samples(sess, generator, BATCH_SIZE, generated_num, generator_file) #get_real_test_file(index_word_dict, generator_file, test_file) # only needed in debugging test_loss = target_loss(sess, generator, val_data_loader) print('pre-train epoch ', epoch, 'nll_test ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str( test_loss) + '\n' log_nll.write(buffer) print( '#########################################################################' ) toc = time.time() print('Start Cooperative Training...') for iter_idx in range(TOTAL_BATCH): print('iteration: ' + str(iter_idx) + '\ntime: ' + str(time.time() - toc)) toc = time.time() # Train the generator for one step for it in range(1): samples = generator.generate(sess) rewards = mediator.get_reward( sess, np.concatenate([samples, samples], axis=0)) feed = { generator.x: samples, generator.rewards: rewards[0:BATCH_SIZE] } loss, _ = sess.run([generator.g_loss, generator.g_updates], feed_dict=feed) # Test, removed oracle test if iter_idx % gen_data_loader.num_batch == 0: # epochs instead of batches test_loss = target_loss(sess, generator, val_data_loader) print('epoch:\t', iter_idx // gen_data_loader.num_batch, 'nll_test ', test_loss) buffer = 'epoch:\t' + str( iter_idx // gen_data_loader.num_batch) + '\tnll_test:\t' + str( test_loss) + '\n' log_nll.write(buffer) if iter_idx == TOTAL_BATCH - 1: print('generating samples') generate_samples(sess, generator, BATCH_SIZE, generated_num, generator_file) get_real_test_file(index_word_dict, generator_file, test_file) # Train the mediator for _ in range(1): print('training mediator...') bnll_ = [] collected_x = [] ratio = 2 for it in range(ratio): if it % 2 == 0: x_batch = gen_data_loader.next_batch() else: x_batch = generator.generate(sess) collected_x.append(x_batch) collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH]) np.random.shuffle(collected_x) collected_x = np.reshape(collected_x, [-1, BATCH_SIZE * 2, SEQ_LENGTH]) for it in range(1): feed = { mediator.x: collected_x[it], } print('running bnll sess') bnll = sess.run(mediator.likelihood_loss, feed) bnll_.append(bnll) print('running mediator and updating') sess.run(mediator.dropout_on) _ = sess.run(mediator.likelihood_updates, feed) sess.run(mediator.dropout_off) if iter_idx % 50 == 0: bnll = np.mean(bnll_) print("mediator cooptrain iter#%d, balanced_nll %f" % (iter_idx, bnll)) log.write("%d\t%f\n" % (iter_idx, bnll)) log.close() log_nll.close()
def main(): random.seed(SEED) np.random.seed(SEED) assert START_TOKEN == 0 SEQ_LENGTH, vocab_size = text_precess(train_set) gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) target_params = pickle.load(open('save/target_params_py3.pkl', 'rb')) #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model mediator = Generator(vocab_size, BATCH_SIZE*2, EMB_DIM*2, HIDDEN_DIM*2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) # create training set indices tokens = get_tokenlized(train_set) word_set = get_word_list(tokens) [word_index_dict, index_word_dict] = get_dict(word_set) with open(positive_file, 'w') as outfile: outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH)) # create and load batches from index training set gen_data_loader.create_batches(positive_file) # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file) #gen_data_loader.create_batches(positive_file) # use training file #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file) #val_data_loader.create_batches(eval_file) log = open('save/experiment-log' + str(time()) + '.txt', 'w') log_nll = open('save/experiment-log-nll' + str(time()) + '.txt', 'w') print('#########################################################################') print('Start Cooperative Training...') print('Num batches: ' + str(gen_data_loader.num_batch)) for iter_idx in range(TOTAL_BATCH): # Train the generator for one step for it in range(1): print('Training G') samples = generator.generate(sess) rewards = mediator.get_reward(sess, np.concatenate([samples, samples], axis=0)) feed = {generator.x: samples, generator.rewards: rewards[0:BATCH_SIZE]} loss, _ = sess.run([generator.g_loss, generator.g_updates], feed_dict=feed) #print(loss) # remove, to often? #_ = sess.run(generator.g_updates, feed_dict=feed) if iter_idx % gen_data_loader.num_batch == 0: print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch) print('loss: ' + str(loss)) # Test, removed oracle if iter_idx % 100 == 0 or iter_idx == TOTAL_BATCH - 1: print('Generating fake samples') generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) likelihood_data_loader.create_batches(negative_file) print('Calculating NLL') test_loss = target_loss(sess, generator, gen_data_loader) # use validation generator? Texygen uses same print('batch:\t', iter_idx, 'nll_test ', test_loss) buffer = 'batch:\t'+ str(iter_idx) + '\tnll_test:\t' + str(test_loss) + '\n' log_nll.write(buffer) # Train the mediator for _ in range(1): print('Training M') bnll_ = [] collected_x = [] ratio = 2 for it in range(ratio): if it % 2 == 0: x_batch = gen_data_loader.next_batch() else: x_batch = generator.generate(sess) collected_x.append(x_batch) collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH]) np.random.shuffle(collected_x) collected_x = np.reshape(collected_x, [-1, BATCH_SIZE*2, SEQ_LENGTH]) for it in range(1): print('Calculating BNLL') feed = { mediator.x: collected_x[it], } bnll = sess.run(mediator.likelihood_loss, feed) bnll_.append(bnll) # sess.run(mediator.dropout_on) _ = sess.run(mediator.likelihood_updates, feed) # sess.run(mediator.dropout_off) if (iter_idx * 4) % gen_data_loader.num_batch == 0: print('Calculating likelihood loss for M') bnll = np.mean(bnll_) gnll = sess.run(mediator.likelihood_loss, feed_dict={mediator.x: np.reshape([generator.generate(sess), generator.generate(sess)], [BATCH_SIZE*2, SEQ_LENGTH])}) print("mediator cooptrain iter#%d, balanced_nll %f, g_nll %f" % (iter_idx, bnll, gnll)) log.write("%d\t%f\n" % (iter_idx, bnll)) log.close() log_nll.close()