def __init__(self, batch_size, max_epoch, model_path, data, latent_dim, sample_path, log_dir, learnrate_init, num_vae): assert num_vae in [2, 3, 4] # only implement 2, 3, 4 stakced vae self.batch_size = batch_size self.max_epoch = max_epoch self.saved_model_path = model_path self.ds_train = data self.latent_dim = latent_dim self.sample_path = sample_path self.log_dir = log_dir self.learn_rate_init = learnrate_init self.log_vars = [] self.channel = 3 self.output_size = CelebA().image_size self.images = tf.placeholder(tf.float32, [ self.batch_size, self.output_size, self.output_size, self.channel ]) self.z_p = tf.placeholder(tf.float32, [self.batch_size, self.latent_dim]) self.ep = tf.random_normal(shape=[self.batch_size, self.latent_dim]) self.num_vae = num_vae
def test(self): init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(init) self.saver.restore(sess, self.saved_model_path) max_iter = len(self.ds_train) / self.batch_size - 1 train_list = CelebA.getNextBatch(self.ds_train, max_iter, 0, self.batch_size) realbatch_array = CelebA.getShapeForData(train_list) real_images, sample_images = sess.run( [self.images, self.x_tilde], feed_dict={self.images: realbatch_array}) save_images( sample_images[0:self.batch_size], [int(np.sqrt(self.batch_size)), int(np.sqrt(self.batch_size))], '{}/train_{:02d}_{:04d}_con.png'.format( self.sample_path, 0, 0)) save_images( real_images[0:self.batch_size], [int(np.sqrt(self.batch_size)), int(np.sqrt(self.batch_size))], '{}/train_{:02d}_{:04d}_r.png'.format(self.sample_path, 0, 0)) ri = cv2.imread( '{}/train_{:02d}_{:04d}_r.png'.format(self.sample_path, 0, 0), 1) fi = cv2.imread( '{}/train_{:02d}_{:04d}_con.png'.format( self.sample_path, 0, 0), 1) cv2.imshow('real_image', ri) cv2.imshow('reconstruction', fi) cv2.waitKey(-1)
def train(self): global_step = tf.Variable(0, trainable=False) add_global = global_step.assign_add(1) new_learning_rate = tf.train.exponential_decay(self.learn_rate_init, global_step=global_step, decay_steps=20000, decay_rate=0.98) #for D trainer_D = tf.train.RMSPropOptimizer(learning_rate=new_learning_rate) gradients_D = trainer_D.compute_gradients(self.D_loss, var_list=self.d_vars) clipped_gradients_D = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gradients_D] opti_D = trainer_D.apply_gradients(clipped_gradients_D) #for G trainer_G = tf.train.RMSPropOptimizer(learning_rate=new_learning_rate) gradients_G = trainer_G.compute_gradients(self.G_loss, var_list=self.g_vars) clipped_gradients_G = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_G] opti_G = trainer_G.apply_gradients(clipped_gradients_G) #for E trainer_E = tf.train.RMSPropOptimizer(learning_rate=new_learning_rate) gradients_E = trainer_E.compute_gradients(self.encode_loss, var_list=self.e_vars) clipped_gradients_E = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_E] opti_E = trainer_E.apply_gradients(clipped_gradients_E) init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(init) summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph) #self.saver.restore(sess, self.saved_model_path) batch_num = 0 epoch = 0 step = 0 while epoch <= self.max_epoch: max_iter = len(self.ds_train) / self.batch_size - 1 while batch_num < len(self.ds_train) / self.batch_size: step = step + 1 train_list = CelebA.getNextBatch(self.ds_train, max_iter, batch_num, self.batch_size) realbatch_array = CelebA.getShapeForData(train_list) sample_z = np.random.normal( size=[self.batch_size, self.latent_dim]) sess.run(opti_E, feed_dict={self.images: realbatch_array}) #optimizaiton G sess.run(opti_G, feed_dict={ self.images: realbatch_array, self.z_p: sample_z }) # optimization D sess.run(opti_D, feed_dict={ self.images: realbatch_array, self.z_p: sample_z }) summary_str = sess.run(summary_op, feed_dict={ self.images: realbatch_array, self.z_p: sample_z }) summary_writer.add_summary(summary_str, step) batch_num += 1 new_learn_rate = sess.run(new_learning_rate) if new_learn_rate > 0.00005: sess.run(add_global) if step % 20 == 0: D_loss, fake_loss, encode_loss, LL_loss, kl_loss = sess.run( [ self.D_loss, self.G_loss, self.encode_loss, self.LL_loss, self.kl_loss / (128 * 64) ], feed_dict={ self.images: realbatch_array, self.z_p: sample_z }) print( "EPOCH %d step %d: D: loss = %.7f G: loss=%.7f Encode: loss=%.7f LL loss=%.7f KL=%.7f" % (epoch, step, D_loss, fake_loss, encode_loss, LL_loss, kl_loss)) if np.mod(step, 200) == 1: save_images( realbatch_array[0:self.batch_size], [ int(np.sqrt(self.batch_size)), int(np.sqrt(self.batch_size)) ], '{}/train_{:02d}_{:04d}_r.png'.format( self.sample_path, epoch, step)) sample_images = sess.run( self.x_tilde, feed_dict={self.images: realbatch_array}) save_images( sample_images[0:self.batch_size], [ int(np.sqrt(self.batch_size)), int(np.sqrt(self.batch_size)) ], '{}/train_{:02d}_{:04d}.png'.format( self.sample_path, epoch, step)) self.saver.save(sess, self.saved_model_path) epoch += 1 batch_num = 0 save_path = self.saver.save(sess, self.saved_model_path) print "Model saved in file: %s" % save_path
mkdir_p(root_log_dir) mkdir_p(vaegan_checkpoint_dir) mkdir_p(sample_path) model_path = vaegan_checkpoint_dir batch_size = FLAGS.batch_size max_epoch = FLAGS.max_epoch latent_dim = FLAGS.latent_dim learn_rate_init = FLAGS.learn_rate_init os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu data_list = CelebA().load_celebA(image_path=FLAGS.path) print "the num of dataset", len(data_list) vaeGan = stacked_vaegan_celeba(batch_size=batch_size, max_epoch=max_epoch, model_path=model_path, data=data_list, latent_dim=latent_dim, sample_path=sample_path, log_dir=root_log_dir, learnrate_init=learn_rate_init, num_vae=FLAGS.num_vae) vaeGan.build_model_vaegan() vaeGan.train()