def z_real_(bz): if FLAGS.z_dist == 'u': return uniform(bz, FLAGS.z_dim, 0, 2) elif FLAGS.z_dist == 'g': return gaussian(bz, FLAGS.z_dim, 0, 2) elif FLAGS.z_dist == 'mg': return gaussian_mixture(bz, FLAGS.z_dim, 10) else: return swiss_roll(bz, FLAGS.z_dim, 10)
def main(run_load_from_file=False): # load MNIST images images, labels = dataset.load_test_images() # config opt = Operation() opt.check_dir(config.ckpt_dir, is_restart=False) opt.check_dir(config.log_dir, is_restart=True) max_epoch = 510 num_trains_per_epoch = 500 batch_size_u = 100 # training with tf.device(config.device): h = build_graph() sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) sess_config.gpu_options.allow_growth = True sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9 saver = tf.train.Saver(max_to_keep=2) with tf.Session(config=sess_config) as sess: ''' Load from checkpoint or start a new session ''' if run_load_from_file: saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_dir)) training_epoch_loss, _ = pickle.load( open(config.ckpt_dir + '/pickle.pkl', 'rb')) else: sess.run(tf.global_variables_initializer()) training_epoch_loss = [] # Recording loss per epoch process = Process() for epoch in range(max_epoch): process.start_epoch(epoch, max_epoch) ''' Learning rate generator ''' learning_rate = 0.0001 # Recording loss per iteration sum_loss_reconstruction = 0 sum_loss_discrminator_z = 0 sum_loss_discrminator_img = 0 sum_loss_generator_z = 0 sum_loss_generator_img = 0 process_iteration = Process() for i in range(num_trains_per_epoch): process_iteration.start_epoch(i, num_trains_per_epoch) # Inputs ''' _l -> labeled _u -> unlabeled ''' images_u = dataset.sample_unlabeled_data(images, batch_size_u) if config.distribution_sampler == 'swiss_roll': z_true_u = sampler.swiss_roll(batch_size_u, config.ndim_z, config.num_types_of_label) elif config.distribution_sampler == 'gaussian_mixture': z_true_u = sampler.gaussian_mixture( batch_size_u, config.ndim_z, config.num_types_of_label) elif config.distribution_sampler == 'uniform_desk': z_true_u = sampler.uniform_desk(batch_size_u, config.ndim_z, radius=2) elif config.distribution_sampler == 'gaussian': z_true_u = sampler.gaussian(batch_size_u, config.ndim_z, var=1) elif config.distribution_sampler == 'uniform': z_true_u = sampler.uniform(batch_size_u, config.ndim_z, minv=-1, maxv=1) # reconstruction_phase _, loss_reconstruction = sess.run([h.opt_r, h.loss_r], feed_dict={ h.x: images_u, h.lr: learning_rate }) # adversarial phase for discriminator_z images_u_s = dataset.sample_unlabeled_data( images, batch_size_u) _, loss_discriminator_z = sess.run([h.opt_dz, h.loss_dz], feed_dict={ h.x: images_u, h.z: z_true_u, h.lr: learning_rate }) _, loss_discriminator_img = sess.run([h.opt_dimg, h.loss_dimg], feed_dict={ h.x: images_u, h.x_s: images_u_s, h.lr: learning_rate }) # adversarial phase for generator _, loss_generator_z = sess.run([h.opt_e, h.loss_e], feed_dict={ h.x: images_u, h.lr: learning_rate }) _, loss_generator_img = sess.run([h.opt_d, h.loss_d], feed_dict={ h.x: images_u, h.lr: learning_rate }) sum_loss_reconstruction += loss_reconstruction sum_loss_discrminator_z += loss_discriminator_z sum_loss_discrminator_img += loss_discriminator_img sum_loss_generator_z += loss_generator_z sum_loss_generator_img += loss_generator_img if i % 1000 == 0: process_iteration.show_table_2d( i, num_trains_per_epoch, { 'reconstruction': sum_loss_reconstruction / (i + 1), 'discriminator_z': sum_loss_discrminator_z / (i + 1), 'discriminator_img': sum_loss_discrminator_img / (i + 1), 'generator_z': sum_loss_generator_z / (i + 1), 'generator_img': sum_loss_generator_img / (i + 1), }) average_loss_per_epoch = [ sum_loss_reconstruction / num_trains_per_epoch, sum_loss_discrminator_z / num_trains_per_epoch, sum_loss_discrminator_img / num_trains_per_epoch, sum_loss_generator_z / num_trains_per_epoch, sum_loss_generator_img / num_trains_per_epoch, (sum_loss_discrminator_z + sum_loss_discrminator_img) / num_trains_per_epoch, (sum_loss_generator_z + sum_loss_generator_img) / num_trains_per_epoch ] training_epoch_loss.append(average_loss_per_epoch) training_loss_name = [ 'reconstruction', 'discriminator_z', 'discriminator_img', 'generator_z', 'generator_img', 'discriminator', 'generator' ] if epoch % 1 == 0: process.show_bar( epoch, max_epoch, { 'loss_r': average_loss_per_epoch[0], 'loss_d': average_loss_per_epoch[5], 'loss_g': average_loss_per_epoch[6] }) plt.scatter_labeled_z( sess.run(h.z_r, feed_dict={h.x: images[:1000]}), [int(var) for var in labels[:1000]], dir=config.log_dir, filename='z_representation-{}'.format(epoch)) if epoch % 10 == 0: saver.save(sess, os.path.join(config.ckpt_dir, 'model_ckptpoint'), global_step=epoch) pickle.dump((training_epoch_loss, training_loss_name), open(config.ckpt_dir + '/pickle.pkl', 'wb'))
def train(): data_set = 'cifar10' prior = 'uniform' x_dim = 32 x_chl = 3 y_dim = 10 z_dim = 64 batch_size = 100 num_epochs = 500 * 50 step_epochs = 100 #int(num_epochs/100) learn_rate = 0.0005 root_path = 'save/{}/{}'.format(data_set, prior) save_path = tools.make_save_directory(root_path) z = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='z') y = tf.placeholder(dtype=tf.float32, shape=[batch_size, y_dim], name='y') x_real = tf.placeholder(dtype=tf.float32, shape=[batch_size, x_dim, x_dim, x_chl], name='x') generator = Generator(batch_size=batch_size, z_dim=z_dim, dataset=data_set) discriminator = Discriminator(batch_size=batch_size, dataset=data_set) x_fake = generator.generate_on_cifar10(z, y, train=True) d_out_real = discriminator.discriminator_cifar10(x_real) d_out_fake = discriminator.discriminator_cifar10(x_fake) # discriminator loss D_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(d_out_real), logits=d_out_real)) D_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(d_out_fake), logits=d_out_fake)) with tf.control_dependencies([D_loss_fake, D_loss_real]): D_loss = D_loss_fake + D_loss_real # generator loss G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(d_out_fake), logits=d_out_fake)) # optimizers all_variables = tf.trainable_variables() g_var = [var for var in all_variables if 'generator' in var.name] d_var = [var for var in all_variables if 'discriminator' in var.name] optimizer = tf.train.AdamOptimizer(learn_rate) G_solver = optimizer.minimize(G_loss, var_list=g_var) D_solver = optimizer.minimize(D_loss, var_list=d_var) # read data train_data = Cifar10(train=True) file = open('{}/train.txt'.format(save_path), 'w') sess = tf.Session() # train the model sess.run(tf.global_variables_initializer()) ave_loss_list = [0, 0, 0] cur_time = datetime.now() # for save sample images save_step = int(num_epochs / 100) z_sample = spl.uniform(100, z_dim) y_sample = spl.onehot_categorical(100, y_dim) # training process for epochs in range(1, num_epochs + 1): batch_x, batch_y = train_data.next_batch(batch_size) s_z_real = spl.uniform(batch_size, z_dim) for _ in range(1): sess.run(D_solver, feed_dict={ z: s_z_real, x_real: batch_x, y: batch_y }) for _ in range(2): sess.run(G_solver, feed_dict={z: s_z_real, y: batch_y}) loss_list = sess.run([D_loss_fake, D_loss_real, G_loss], feed_dict={ z: s_z_real, x_real: batch_x, y: batch_y }) ave_loss(ave_loss_list, loss_list, step_epochs) if epochs % save_step == 0: iter_counter = int(epochs / save_step) x_sample = sess.run(x_fake, feed_dict={z: z_sample, y: y_sample}) tools.save_grid_images(x_sample, '{}/images/{}.png'.format( save_path, iter_counter), size=x_dim, chl=x_chl) # record information if epochs % step_epochs == 0: time_use = (datetime.now() - cur_time).seconds iter_counter = int(epochs / step_epochs) liner = "Epoch {:d}/{:d}, loss_dis_faker {:9f}, loss_dis_real {:9f}, loss_encoder {:9f} time_use {:f}" \ .format(epochs, num_epochs, ave_loss_list[0], ave_loss_list[1], ave_loss_list[2], time_use) print(liner), file.writelines(liner + '\n') ave_loss_list = [0, 0, 0] # reset to 0 cur_time = datetime.now() # save model vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) saver = tf.train.Saver(var_list=vars) saver.save(sess, save_path='{}/model'.format(save_path)) # close all file.close() sess.close()