def __init__(self, z_dim, crop_image_size, resized_image_size, batch_size, data_dir): celebA_dataset = celebA.read_dataset(data_dir) self.z_dim = z_dim self.crop_image_size = crop_image_size self.resized_image_size = resized_image_size self.batch_size = batch_size filename_queue = tf.train.string_input_producer(celebA_dataset.train_images) self.images = self._read_input_queue(filename_queue)
def __init__(self, z_dim, num_cls, crop_image_size, resized_image_size, batch_size, data_dir): self.num_cls = num_cls celebA_dataset = celebA.read_dataset(data_dir) self.z_dim = z_dim self.crop_image_size = crop_image_size self.resized_image_size = resized_image_size self.batch_size = batch_size label_dict = celebA.create_label_dict(data_dir) imgfn_list = label_dict.keys() label_list = [label_dict[imgfn] for imgfn in imgfn_list] images = tf.convert_to_tensor(imgfn_list) labels = tf.convert_to_tensor(label_list, dtype=np.int32) input_queue = tf.train.slice_input_producer([images, labels], shuffle=True) self.images, self.labels = self._read_input_queue(input_queue)
def __init__(self, z_dim, crop_image_size, resized_image_size, batch_size, data_dir, critic_iterations=5, root_scope_name=''): celebA_dataset = celebA.read_dataset(data_dir) self.root_scope_name = root_scope_name self.summary_collections = None if not root_scope_name else [ root_scope_name ] self.z_dim = z_dim self.crop_image_size = crop_image_size self.resized_image_size = resized_image_size self.batch_size = batch_size self.critic_iterations = critic_iterations filename_queue = tf.train.string_input_producer( celebA_dataset.train_images) self.training_batch_images = self._read_input_queue(filename_queue)
def main(argv=None): print("Setting up image reader...") train_images, valid_images, test_images = celebA.read_dataset( FLAGS.data_dir) filename_queue = tf.train.string_input_producer(train_images) images = read_input_queue(filename_queue) train_phase = tf.placeholder(tf.bool) z_vec = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.z_dim], name="z") print("Setting up network model...") tf.histogram_summary("z", z_vec) tf.image_summary("image_real", images, max_images=2) gen_images = generator(z_vec, train_phase) tf.image_summary("image_generated", gen_images, max_images=2) with tf.variable_scope("discriminator") as scope: discriminator_real_prob, logits_real, feature_real = discriminator( images, train_phase) utils.add_activation_summary( tf.identity(discriminator_real_prob, name='disc_real_prob')) scope.reuse_variables() discriminator_fake_prob, logits_fake, feature_fake = discriminator( gen_images, train_phase) utils.add_activation_summary( tf.identity(discriminator_fake_prob, name='disc_fake_prob')) discriminator_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits_real, tf.ones_like(logits_real))) discrimintator_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits_fake, tf.zeros_like(logits_fake))) discriminator_loss = discrimintator_loss_fake + discriminator_loss_real gen_loss_1 = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits_fake, tf.ones_like(logits_fake))) gen_loss_2 = tf.reduce_mean( tf.nn.l2_loss(feature_real - feature_fake)) / (IMAGE_SIZE * IMAGE_SIZE) gen_loss = gen_loss_1 + 0.1 * gen_loss_2 tf.scalar_summary("Discriminator_loss_real", discriminator_loss_real) tf.scalar_summary("Discrimintator_loss_fake", discrimintator_loss_fake) tf.scalar_summary("Discriminator_loss", discriminator_loss) tf.scalar_summary("Generator_loss", gen_loss) train_variables = tf.trainable_variables() generator_variables = [ v for v in train_variables if v.name.startswith("generator") ] print(map(lambda x: x.op.name, generator_variables)) discriminator_variables = [ v for v in train_variables if v.name.startswith("discriminator") ] print(map(lambda x: x.op.name, discriminator_variables)) generator_train_op = train(gen_loss, generator_variables) discriminator_train_op = train(discriminator_loss, discriminator_variables) for v in train_variables: utils.add_to_regularization_and_summary(var=v) sess = tf.Session() summary_op = tf.merge_all_summaries() saver = tf.train.Saver() summary_writer = tf.train.SummaryWriter(FLAGS.logs_dir, sess.graph) sess.run(tf.initialize_all_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored...") coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) try: for itr in xrange(MAX_ITERATIONS): batch_z = np.random.uniform(-1.0, 1.0, size=[FLAGS.batch_size, FLAGS.z_dim]).astype(np.float32) feed_dict = {z_vec: batch_z, train_phase: True} sess.run(discriminator_train_op, feed_dict=feed_dict) sess.run(generator_train_op, feed_dict=feed_dict) if itr % 10 == 0: g_loss_val, d_loss_val, summary_str = sess.run( [gen_loss, discriminator_loss, summary_op], feed_dict=feed_dict) print("Step: %d, generator loss: %g, discriminator_loss: %g" % (itr, g_loss_val, d_loss_val)) summary_writer.add_summary(summary_str, itr) if itr % 500 == 0: saver.save(sess, FLAGS.logs_dir + "model.ckpt", global_step=itr) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') except KeyboardInterrupt: print("Ending Training...") finally: coord.request_stop() # Wait for threads to finish. coord.join(threads)