def test_discriminator_invalid_input(self): wrong_dim_input = tf.zeros([5, 32, 32]) with self.assertRaisesRegexp(ValueError, 'Shape must be rank 4'): networks.discriminator(wrong_dim_input) not_fully_defined = tf.placeholder(tf.float32, [3, None, 32, 3]) with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'): networks.compression_model(not_fully_defined)
def test_generator_graph(self): for i, batch_size in zip(xrange(3, 7), xrange(3, 11, 2)): tf.reset_default_graph() patch_size = 2 ** i bits = 2 ** i img = tf.ones([batch_size, patch_size, patch_size, 3]) uncompressed, binary_codes, prebinary = networks.compression_model( img, bits) self.assertAllEqual([batch_size, patch_size, patch_size, 3], uncompressed.shape.as_list()) self.assertEqual([batch_size, bits], binary_codes.shape.as_list()) self.assertEqual([batch_size, bits], prebinary.shape.as_list())
def main(_, run_eval_loop=True): with tf.name_scope('inputs'): images = data_provider.provide_data('validation', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir, patch_size=FLAGS.patch_size) # In order for variables to load, use the same variable scope as in the # train job. with tf.variable_scope('generator'): reconstructions, _, prebinary = networks.compression_model( images, num_bits=FLAGS.bits_per_patch, depth=FLAGS.model_depth, is_training=False) summaries.add_reconstruction_summaries(images, reconstructions, prebinary) # Visualize losses. pixel_loss_per_example = tf.reduce_mean(tf.abs(images - reconstructions), axis=[1, 2, 3]) pixel_loss = tf.reduce_mean(pixel_loss_per_example) tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example) tf.summary.scalar('pixel_l1_loss', pixel_loss) # Create ops to write images to disk. uint8_images = data_provider.float_image_to_uint8(images) uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions) uint8_reshaped = summaries.stack_images(uint8_images, uint8_reconstructions) image_write_ops = tf.write_file( '%s/%s' % (FLAGS.eval_dir, 'compression.png'), tf.image.encode_png(uint8_reshaped[0])) # For unit testing, use `run_eval_loop=False`. if not run_eval_loop: return tf.contrib.training.evaluate_repeatedly( FLAGS.checkpoint_dir, master=FLAGS.master, hooks=[ tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir), tf.contrib.training.StopAfterNEvalsHook(1) ], eval_ops=image_write_ops, max_number_of_evaluations=FLAGS.max_number_of_evaluations)
def test_generator_run(self): img_batch = tf.zeros([3, 16, 16, 3]) model_output = networks.compression_model(img_batch) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(model_output)
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Put input pipeline on CPU to reserve GPU for training. with tf.name_scope('inputs'), tf.device('/cpu:0'): images = data_provider.provide_data('train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir, patch_size=FLAGS.patch_size) # Manually define a GANModel tuple. This is useful when we have custom # code to track variables. Note that we could replace all of this with a # call to `tfgan.gan_model`, but we don't in order to demonstrate some of # TFGAN's flexibility. with tf.variable_scope('generator') as gen_scope: reconstructions, _, prebinary = networks.compression_model( images, num_bits=FLAGS.bits_per_patch, depth=FLAGS.model_depth) gan_model = _get_gan_model(generator_inputs=images, generated_data=reconstructions, real_data=images, generator_scope=gen_scope) summaries.add_reconstruction_summaries(images, reconstructions, prebinary) tfgan.eval.add_gan_model_summaries(gan_model) # Define the GANLoss tuple using standard library functions. with tf.name_scope('loss'): gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.least_squares_generator_loss, discriminator_loss_fn=tfgan.losses. least_squares_discriminator_loss, add_summaries=FLAGS.weight_factor > 0) # Define the standard pixel loss. l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) # Modify the loss tuple to include the pixel loss. Add summaries as well. gan_loss = tfgan.losses.combine_adversarial_loss( gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor) # Get the GANTrain ops using the custom optimizers and optional # discriminator weight clipping. with tf.name_scope('train_ops'): gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr) gen_opt, dis_opt = _optimizer(gen_lr, dis_lr) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True, colocate_gradients_with_ops=True, aggregation_method=tf.AggregationMethod. EXPERIMENTAL_ACCUMULATE_N) tf.summary.scalar('generator_lr', gen_lr) tf.summary.scalar('discriminator_lr', dis_lr) # Determine the number of generator vs discriminator steps. train_steps = tfgan.GANTrainSteps( generator_train_steps=1, discriminator_train_steps=int(FLAGS.weight_factor > 0)) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if FLAGS.max_number_of_steps == 0: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)