def main(): random.seed(SEED) np.random.seed(SEED) # prepare data gen_data_loader = Gen_Data_loader(BATCH_SIZE) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing dis_data_loader = Dis_Data_loader(BATCH_SIZE) generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) # target_params's size: [15 * 5000 * 32] target_params = pickle.load(open('./save/target_params_py3.pkl', 'rb')) # The oracle model target_lstm = TARGET_LSTM(5000, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, 20, 0, target_params) discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) generate_samples_from_target(sess, target_lstm, BATCH_SIZE, generated_num, positive_file) gen_data_loader.create_batches(positive_file) # print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") # # likelihood_data_loader.create_batches(positive_file) # for i in range(100): # test_loss = target_loss(sess, target_lstm, likelihood_data_loader) # print('my step ', i, 'test_loss ', test_loss) # input("next:") # input("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") log = open('save/experiment-log.txt', 'w') # pre-train generator print('Start pre-training...') log.write('pre-training...\n') ans_file = open("learning_cure.txt", 'w') for epoch in range(120): # 120 loss = pre_train_epoch(sess, generator, gen_data_loader) if epoch % 1 == 0: generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print('pre-train epoch ', epoch, 'test_loss ', test_loss) buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n' log.write(buffer) ans_file.write("%s\n" % str(test_loss)) buffer = 'Start pre-training discriminator...' print(buffer) log.write(buffer) for _ in range(10): # 10 generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob, } d_loss, d_acc, _ = sess.run([discriminator.loss, discriminator.accuracy, discriminator.train_op], feed) buffer = "discriminator loss %f acc %f\n" % (d_loss, d_acc) print(buffer) log.write(buffer) ans_file.write("==========\n") print("Start Adversarial Training...") log.write('adversarial training...') for total_batch in range(TOTAL_BATCH): # Train the generator for it in range(1): samples = generator.generate(sess) rewards = generator.get_reward(sess, samples, 16, discriminator, START_TOKEN) a = str(samples[0]) b = str(rewards[0]) buffer = "%s\n%s\n\n" % (a, b) # print(buffer) log.write(buffer) rewards_loss = generator.update_with_rewards(sess, samples, rewards, START_TOKEN) # good rewards # good_samples = gen_data_loader.next_batch() # rewards = np.array([[1.0] * SEQ_LENGTH] * BATCH_SIZE) # a = str(good_samples[0]) # b = str(rewards[0]) # buffer = "%s\n%s\n\n" % (a, b) # print(buffer) # log.write(buffer) # rewards_loss = generator.update_with_rewards(sess, good_samples, rewards, START_TOKEN) # little1 good reward # litter1_samples = gen_data_loader.next_batch() # rewards = generator.get_reward(sess, litter1_samples, 16, discriminator, START_TOKEN) # a = str(little1 good reward[0]) # b = str(rewards[0]) # buffer = "%s\n%s\n\n" % (a, b) # print(buffer) # log.write(buffer) # rewards_loss = generator.update_with_rewards(sess, litter1_samples, rewards, START_TOKEN) # Test if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1: generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'reward-train epoch %s train loss %s test_loss %s\n' % (str(total_batch), str(rewards_loss), str(test_loss)) print(buffer) log.write(buffer) ans_file.write("%s\n" % str(test_loss)) # Train the discriminator for _ in range(1): generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob, } d_loss, d_acc, _ = sess.run([discriminator.loss, discriminator.accuracy, discriminator.train_op], feed) if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: buffer = "discriminator loss %f acc %f\n" % (d_loss, d_acc) print(buffer) log.write(buffer)
def main(): # load embedding info vocab_dict, vocab_size, vocab_list = load_emb_data(emb_dict_file) # prepare data pre_train_data_loader = Gen_Data_loader(BATCH_SIZE, vocab_dict) pre_train_data_loader.create_batches( [imdb_file_id, sst_pos_file_id, sst_neg_file_id]) gen_data_loader = Gen_Data_loader(BATCH_SIZE, vocab_dict) gen_data_loader.create_batches([sst_pos_file_id, sst_neg_file_id]) dis_data_loader = Dis_Data_loader(BATCH_SIZE, vocab_dict, MAX_SEQ_LENGTH) # build model # num_emb, vocab_dict, batch_size, emb_dim, num_units, sequence_length generator = Generator(vocab_size, vocab_dict, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, MAX_SEQ_LENGTH) discriminator = Discriminator(sequence_length=MAX_SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) log = open('save/experiment-log.txt', 'w') buffer = 'Start pre-training generator...' print(buffer) log.write(buffer + '\n') for epoch in range(150): #120 train_loss = pre_train_epoch(sess, generator, pre_train_data_loader) if epoch % 5 == 0: generate_samples(sess, generator, 1, eval_file, vocab_list, if_log=True, epoch=epoch) print(' pre-train epoch ', epoch, 'train_loss ', train_loss) buffer = ' epoch:\t' + str(epoch) + '\tnll:\t' + str( train_loss) + '\n' log.write(buffer) buffer = 'Start pre-training discriminator...' print(buffer) log.write(buffer) for _ in range(10): # 10 generate_samples(sess, generator, 70, negative_file, vocab_list) dis_data_loader.load_train_data([sst_pos_file_id, sst_neg_file_id], [negative_file]) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob, } d_loss, d_acc, _ = sess.run([ discriminator.loss, discriminator.accuracy, discriminator.train_op ], feed) buffer = "discriminator loss %f acc %f" % (d_loss, d_acc) print(buffer) log.write(buffer + '\n') print("Start Adversarial Training...") log.write('adversarial training...') for total_batch in range(TOTAL_BATCH): # Train the generator for it in range(2): # print("1") samples = generator.generate(sess) samples = produce_samples(samples) # print("2") rewards = generator.get_reward(sess, samples, 16, discriminator) # print("3") a = str(samples[0]) b = str(rewards[0]) # rewards = change_rewards(rewards) # c = str(rewards[0]) d = build_from_ids(samples[0], vocab_list) buffer = "%s\n%s\n%s\n\n" % (d, a, b) print(buffer) log.write(buffer) # print("4") rewards_loss = generator.update_with_rewards( sess, samples, rewards) # print("5") # good rewards # good_samples = gen_data_loader.next_batch() # rewards = np.array([[0.0001] * SEQ_LENGTH] * BATCH_SIZE) # a = str(good_samples[0]) # b = str(rewards[0]) # buffer = "%s\n%s\n\n" % (a, b) # print(buffer) # log.write(buffer) # rewards_loss = generator.update_with_rewards(sess, good_samples, rewards, START_TOKEN) # little1 good reward little1_samples = gen_data_loader.next_batch() rewards = generator.get_reward(sess, little1_samples, 16, discriminator) a = str(little1_samples[0]) b = str(rewards[0]) buffer = "%s\n%s\n\n" % (a, b) # print(buffer) log.write(buffer) rewards_loss = generator.update_with_rewards( sess, little1_samples, rewards) # generate_infer(sess, generator, epoch, vocab_list) # Test if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: generate_samples(sess, generator, 120, eval_file, vocab_list, if_log=True) generate_infer(sess, generator, total_batch, vocab_list) # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) # likelihood_data_loader.create_batches(eval_file) # test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'reward-train epoch %s train loss %s' % ( str(total_batch), str(rewards_loss)) print(buffer) log.write(buffer + '\n') generator.save_model(sess) # Train the discriminator begin = True for _ in range(1): generate_samples(sess, generator, 70, negative_file, vocab_list) dis_data_loader.load_train_data([sst_pos_file_id, sst_neg_file_id], [negative_file]) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob, } d_loss, d_acc, _ = sess.run([ discriminator.loss, discriminator.accuracy, discriminator.train_op ], feed) if (total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1) and begin: buffer = "discriminator loss %f acc %f\n" % (d_loss, d_acc) print(buffer) log.write(buffer) begin = False # pretrain for _ in range(10): train_loss = pre_train_epoch(sess, generator, pre_train_data_loader)
def main(unused_argv): config_train = training_config() config_gen = generator_config() config_dis = discriminator_config() np.random.seed(config_train.seed) assert config_train.start_token == 0 # Build dataloader for generaotr, testing and discriminator gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size) likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size) dis_data_loader = Dis_Data_loader(config_dis.dis_batch_size) # Build generator and its rollout generator = Generator(config=config_gen) generator.build() rollout_gen = Rollout(config=config_gen) # Build target LSTM target_params = pickle.load(open('save/target_params_py3.pkl', 'rb')) target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model # Build discriminator discriminator = Discriminator(config=config_dis) discriminator.build_discriminator() # Build optimizer op for pretraining pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate) var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name] # Using name 'teller' here to prevent name collision of target LSTM gradients, variables = zip( *pretrained_optimizer.compute_gradients(generator.pretrain_loss, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) gen_pre_upate = pretrained_optimizer.apply_gradients(zip(gradients, variables)) # Initialize all variables sess = tf.Session(config=config_hardware) sess.run(tf.global_variables_initializer()) # Initalize data loader of generator generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file) gen_data_loader.create_batches(config_train.positive_file) # Start pretraining log = open('save/experiment-log.txt', 'w') print('Start pre-training generator...') log.write('pre-training...\n') for epoch in tqdm(range(config_train.pretrained_epoch_num), desc='Pre-Training(Generator)'): gen_data_loader.reset_pointer() for it in range(gen_data_loader.num_batch): batch = gen_data_loader.next_batch() _, g_loss = sess.run([gen_pre_upate, generator.pretrain_loss], feed_dict={generator.input_seqs_pre: batch, generator.input_seqs_mask: np.ones_like( batch)}) if epoch % config_train.test_per_epoch == 0: generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) # print('pre-train epoch ', epoch, 'test_loss ', test_loss) loss_info = 'pre-train epoch ' + str(epoch) + ' test_loss ' + str(test_loss) tqdm.write(loss_info) buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n' log.write(buffer) print('Start pre-training discriminator...') for _ in tqdm(range(config_train.dis_update_time_pre), desc='Pre-Training(Discriminator)'): generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) for _ in range(config_train.dis_update_epoch_pre): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) # Build optimizer op for adversarial training train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate) gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained)) gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables)) # Initialize global variables of optimizer for adversarial training uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()] init_vars_uninit_op = tf.variables_initializer(uninitialized_var) sess.run(init_vars_uninit_op) # Start adversarial training for total_batch in tqdm(range(config_train.total_batch), desc='Adversarial-Training'): for _ in tqdm(range(config_train.gen_update_time), desc='Adversarial Generate Update'): samples = sess.run(generator.sample_word_list_reshape) feed = {"pred_seq_rollout:0": samples} reward_rollout = [] # calcuate the reward given in the specific step t by roll out for iter_roll in range(config_train.rollout_num): rollout_list = sess.run(rollout_gen.sample_rollout_step, feed_dict=feed) rollout_list_stack = np.vstack(rollout_list) # shape: #batch_size * #rollout_step, #sequence length reward_rollout_seq = sess.run(discriminator.ypred_for_auc, feed_dict={discriminator.input_x: rollout_list_stack, discriminator.dropout_keep_prob: 1.0}) reward_last_tok = sess.run(discriminator.ypred_for_auc, feed_dict={discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0}) reward_allseq = np.concatenate((reward_rollout_seq, reward_last_tok), axis=0)[:, 1] reward_tmp = [] for r in range(config_gen.gen_batch_size): reward_tmp.append(reward_allseq[range(r, config_gen.gen_batch_size * config_gen.sequence_length, config_gen.gen_batch_size)]) reward_rollout.append(np.array(reward_tmp)) rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv: samples, generator.rewards: rewards}) if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1: generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) likelihood_data_loader.create_batches(config_train.eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n' loss_info = 'total_batch: ' + str(total_batch) + 'test_loss: ' + str(test_loss) tqdm.write(loss_info) log.write(buffer) for _ in tqdm(range(config_train.dis_update_time_adv), desc='Adversarial Discriminator Update'): generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) for _ in range(config_train.dis_update_epoch_adv): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) log.close()
def main(): random.seed(SEED) np.random.seed(SEED) assert START_TOKEN == 0 gen_data_loader = Gen_Data_loader(BATCH_SIZE) likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing vocab_size = 5000 dis_data_loader = Dis_Data_loader(BATCH_SIZE) generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) target_params = pickle.load(open('save/target_params.pkl')) target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file) gen_data_loader.create_batches(positive_file) log = open('save/experiment-log.txt', 'w') # pre-train generator print('Start pre-training...') log.write('pre-training...\n') for epoch in range(PRE_EPOCH_NUM): loss = pre_train_epoch(sess, generator, gen_data_loader) if epoch % 5 == 0: generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print 'pre-train epoch ', epoch, 'test_loss ', test_loss buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n' log.write(buffer) print 'Start pre-training discriminator...' # Train 3 epoch on the generated data and do this for 50 times for _ in range(50): generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) rollout = ROLLOUT(generator, 0.8) print '#########################################################################' print 'Start Adversarial Training...' log.write('adversarial training...\n') for total_batch in range(TOTAL_BATCH): # Train the generator for one step for it in range(1): samples = generator.generate(sess) rewards = rollout.get_reward(sess, samples, 16, discriminator) feed = {generator.x: samples, generator.rewards: rewards} _ = sess.run(generator.g_updates, feed_dict=feed) # Test if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) likelihood_data_loader.create_batches(eval_file) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n' print 'total_batch: ', total_batch, 'test_loss: ', test_loss log.write(buffer) # Update roll-out parameters rollout.update_params() # Train the discriminator for _ in range(5): generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) dis_data_loader.load_train_data(positive_file, negative_file) for _ in range(3): dis_data_loader.reset_pointer() for it in range(dis_data_loader.num_batch): x_batch, y_batch = dis_data_loader.next_batch() feed = { discriminator.input_x: x_batch, discriminator.input_y: y_batch, discriminator.dropout_keep_prob: dis_dropout_keep_prob } _ = sess.run(discriminator.train_op, feed) log.close()