def pretrain(sess, generator, target_lstm, train_discriminator): # samples = generate_samples(sess, target_lstm, BATCH_SIZE, generated_num) gen_data_loader = Gen_Data_loader(BATCH_SIZE) gen_data_loader.create_batches(positive_samples) results = OrderedDict({'exp_name': PREFIX}) # pre-train generator print('Start pre-training...') start = time.time() for epoch in tqdm(range(PRE_EPOCH_NUM)): print(' gen pre-train') loss = pre_train_epoch(sess, generator, gen_data_loader) if epoch == 10 or epoch % 40 == 0: samples = generate_samples(sess, generator, BATCH_SIZE, SAMPLE_NUM) likelihood_data_loader.create_batches(samples) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print('\t test_loss {}, train_loss {}'.format(test_loss, loss)) mm.compute_results(samples, train_samples, ord_dict, results) samples = generate_samples(sess, generator, BATCH_SIZE, SAMPLE_NUM) likelihood_data_loader.create_batches(samples) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) samples = generate_samples(sess, generator, BATCH_SIZE, SAMPLE_NUM) likelihood_data_loader.create_batches(samples) print('Start training discriminator...') for i in tqdm(range(dis_alter_epoch)): print(' discriminator pre-train') d_loss, acc, ypred_for_auc = train_discriminator() end = time.time() print('Total time was {:.4f}s'.format(end - start)) return
def main(): random.seed(SEED) np.random.seed(SEED) # assert START_TOKEN == 0 vocab_size = NUM_EMB dis_data_loader = Dis_dataloader() best_score = 1000 generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, MAX_LENGTH, START_TOKEN) target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, MAX_LENGTH, 0) with tf.variable_scope('discriminator'): cnn = TextCNN(sequence_length=MAX_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda) cnn_params = [ param for param in tf.trainable_variables() if 'discriminator' in param.name ] # Define Discriminator Training procedure dis_global_step = tf.Variable(0, name="global_step", trainable=False) dis_optimizer = tf.train.AdamOptimizer(1e-4) dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss, cnn_params, aggregation_method=2) dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars, global_step=dis_global_step) config = tf.ConfigProto() # config.gpu_options.per_process_gpu_memory_fraction = 0.5 config.gpu_options.allow_growth = True sess = tf.Session(config=config) def train_discriminator(): if D_WEIGHT == 0: return 0, 0 negative_samples = generate_samples(sess, generator, BATCH_SIZE, POSITIVE_NUM) # global positive_samples # pos_new=positive_samples # random 10% of positive samples are labeled negatively to weaken generator and avoid collapsing training # random.shuffle(pos_new) # length=len(pos_new) # fake_neg_number= int(0.05*length) # fake_neg= pos_new[:fake_neg_number] # pos_new=pos_new[fake_neg_number:] # negative_samples+=fake_neg # random.shuffle(negative_samples) # train discriminator dis_x_train, dis_y_train = dis_data_loader.load_train_data( positive_samples, negative_samples) dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs) ypred = 0 counter = 0 for batch in dis_batches: x_batch, y_batch = zip(*batch) feed = { cnn.input_x: x_batch, cnn.input_y: y_batch, cnn.dropout_keep_prob: dis_dropout_keep_prob } _, step, loss, accuracy, ypred_for_auc = sess.run([ dis_train_op, dis_global_step, cnn.loss, cnn.accuracy, cnn.ypred_for_auc ], feed) ypred_vect = np.array([item[1] for item in ypred_for_auc]) ypred += np.mean(ypred_vect) counter += 1 ypred = float(ypred) / counter print('\tD loss : {}'.format(loss)) print('\tAccuracy: {}'.format(accuracy)) print('\tMean ypred: {}'.format(ypred)) return loss, accuracy, ypred # Pretrain is checkpointed and only execcutes if we don't find a checkpoint saver = tf.train.Saver() # We check previous session and pretrain is checkpointed and only execcutes if we don't find a checkpoint saver = tf.train.Saver() #check previous session prev_sess = False ckpt_dir = 'checkpoints/mingan' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) # ckpt_file = os.path.join(ckpt_dir, ckpt_dir + '_model') #old checkpoint ckpt_file = os.path.join( ckpt_dir, 'drd2_new' + '_model_' ) #new checkpoint iterate over checkpoints to find largest total a nbatches_max = 0 for i in range(500): #maximal number of batches iterations is 500 if os.path.isfile(ckpt_file + str(i) + '.meta'): #and params["LOAD_PREV_SESS"] nbatches_max = i #end try find max checkpoint ckpt_file = ckpt_file + str(nbatches_max) + '.meta' if params["LOAD_PREV_SESS"]: # and os.path.isfile(ckpt_file): # saver_test = tf.train.import_meta_graph(ckpt_file) # sess.run(tf.global_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) # saver.restore(sess, ckpt_file) print('Previous session loaded from previous checkpoint {}'.format( ckpt_file)) prev_sess = True else: if params["LOAD_PREV_SESS"]: print('\t* No previous session data found as {:s}.'.format( ckpt_file)) else: print('\t* LOAD_PREV_SESS was set to false.') # sess.run(tf.global_variables_initializer()) # pretrain(sess, generator, target_lstm, train_discriminator) # path = saver.save(sess, ckpt_file) # print('Pretrain finished and saved at {}'.format(path)) if prev_sess == False: #check pretraining ckpt_dir = 'checkpoints/{}_pretrain'.format(PREFIX) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) ckpt_file = os.path.join(ckpt_dir, 'pretrain_ckpt') if os.path.isfile(ckpt_file + '.meta') and params["LOAD_PRETRAIN"]: saver.restore(sess, ckpt_file) print('Pretrain loaded from previous checkpoint {}'.format( ckpt_file)) else: if params["LOAD_PRETRAIN"]: print('\t* No pre-training data found as {:s}.'.format( ckpt_file)) else: print('\t* LOAD_PRETRAIN was set to false.') sess.run(tf.global_variables_initializer()) pretrain(sess, generator, target_lstm, train_discriminator) path = saver.save(sess, ckpt_file) print('Pretrain finished and saved at {}'.format(path)) #end loading previous session or pre-training # create reward function batch_reward = make_reward(train_samples) rollout = ROLLOUT(generator, 0.8) # nbatches_max= 30 print( '#########################################################################' ) print('Start Reinforcement Training Generator...') results_rows = [] if nbatches_max + 1 > TOTAL_BATCH: print( ' We already trained that many batches: Check the Checkpoints folder or take a larger TOTAL_BATCH' ) else: for nbatch in tqdm(range(nbatches_max + 1, TOTAL_BATCH)): #for nbatch in tqdm(range(TOTAL_BATCH)): results = OrderedDict({'exp_name': PREFIX}) if nbatch % 1 == 0 or nbatch == TOTAL_BATCH - 1: print('* Making samples') if nbatch % 10 == 0: gen_samples = generate_samples(sess, generator, BATCH_SIZE, BIG_SAMPLE_NUM) else: gen_samples = generate_samples(sess, generator, BATCH_SIZE, SAMPLE_NUM) likelihood_data_loader.create_batches(gen_samples) test_loss = target_loss(sess, target_lstm, likelihood_data_loader) print('batch_num: {}'.format(nbatch)) print('test_loss: {}'.format(test_loss)) results['Batch'] = nbatch results['test_loss'] = test_loss if test_loss < best_score: best_score = test_loss print('best score: %f' % test_loss) # results mm.compute_results(gen_samples, train_samples, ord_dict, results) print( '#########################################################################' ) print('-> Training generator with RL.') print('G Epoch {}'.format(nbatch)) for it in range(TRAIN_ITER): samples = generator.generate(sess) rewards = rollout.get_reward(sess, samples, 16, cnn, batch_reward, D_WEIGHT) nll = generator.generator_step(sess, samples, rewards) # results print_rewards(rewards) print('neg-loglike: {}'.format(nll)) results['neg-loglike'] = nll rollout.update_params() # generate for discriminator print('-> Training Discriminator') for i in range(D): print('D_Epoch {}'.format(i)) d_loss, accuracy, ypred = train_discriminator() results['D_loss_{}'.format(i)] = d_loss results['Accuracy_{}'.format(i)] = accuracy results['Mean_ypred_{}'.format(i)] = ypred print('results') results_rows.append(results) if nbatch % params["EPOCH_SAVES"] == 0: save_results(sess, PREFIX, PREFIX + '_model_' + str(nbatch), results_rows) # write results save_results(sess, PREFIX, PREFIX + '_model_' + str(nbatch), results_rows) print('\n:*** FINISHED ***') return