Ejemplo n.º 1
0
    def test_discriminator_invalid_input(self):
        wrong_dim_input = tf.zeros([5, 32, 32])
        with self.assertRaisesRegexp(ValueError, 'Shape must be rank 4'):
            networks.discriminator(wrong_dim_input)

        not_fully_defined = tf.placeholder(tf.float32, [3, None, 32, 3])
        with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
            networks.compression_model(not_fully_defined)
Ejemplo n.º 2
0
    def test_generator_graph(self):
        for i, batch_size in zip(xrange(3, 7), xrange(3, 11, 2)):
            tf.reset_default_graph()
            patch_size = 2 ** i
            bits = 2 ** i
            img = tf.ones([batch_size, patch_size, patch_size, 3])
            uncompressed, binary_codes, prebinary = networks.compression_model(
                img, bits)

            self.assertAllEqual([batch_size, patch_size, patch_size, 3],
                                uncompressed.shape.as_list())
            self.assertEqual([batch_size, bits], binary_codes.shape.as_list())
            self.assertEqual([batch_size, bits], prebinary.shape.as_list())
Ejemplo n.º 3
0
def main(_, run_eval_loop=True):
    with tf.name_scope('inputs'):
        images = data_provider.provide_data('validation',
                                            FLAGS.batch_size,
                                            dataset_dir=FLAGS.dataset_dir,
                                            patch_size=FLAGS.patch_size)

    # In order for variables to load, use the same variable scope as in the
    # train job.
    with tf.variable_scope('generator'):
        reconstructions, _, prebinary = networks.compression_model(
            images,
            num_bits=FLAGS.bits_per_patch,
            depth=FLAGS.model_depth,
            is_training=False)
    summaries.add_reconstruction_summaries(images, reconstructions, prebinary)

    # Visualize losses.
    pixel_loss_per_example = tf.reduce_mean(tf.abs(images - reconstructions),
                                            axis=[1, 2, 3])
    pixel_loss = tf.reduce_mean(pixel_loss_per_example)
    tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example)
    tf.summary.scalar('pixel_l1_loss', pixel_loss)

    # Create ops to write images to disk.
    uint8_images = data_provider.float_image_to_uint8(images)
    uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions)
    uint8_reshaped = summaries.stack_images(uint8_images,
                                            uint8_reconstructions)
    image_write_ops = tf.write_file(
        '%s/%s' % (FLAGS.eval_dir, 'compression.png'),
        tf.image.encode_png(uint8_reshaped[0]))

    # For unit testing, use `run_eval_loop=False`.
    if not run_eval_loop: return
    tf.contrib.training.evaluate_repeatedly(
        FLAGS.checkpoint_dir,
        master=FLAGS.master,
        hooks=[
            tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
            tf.contrib.training.StopAfterNEvalsHook(1)
        ],
        eval_ops=image_write_ops,
        max_number_of_evaluations=FLAGS.max_number_of_evaluations)
Ejemplo n.º 4
0
 def test_generator_run(self):
     img_batch = tf.zeros([3, 16, 16, 3])
     model_output = networks.compression_model(img_batch)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(model_output)
Ejemplo n.º 5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        # Put input pipeline on CPU to reserve GPU for training.
        with tf.name_scope('inputs'), tf.device('/cpu:0'):
            images = data_provider.provide_data('train',
                                                FLAGS.batch_size,
                                                dataset_dir=FLAGS.dataset_dir,
                                                patch_size=FLAGS.patch_size)

        # Manually define a GANModel tuple. This is useful when we have custom
        # code to track variables. Note that we could replace all of this with a
        # call to `tfgan.gan_model`, but we don't in order to demonstrate some of
        # TFGAN's flexibility.
        with tf.variable_scope('generator') as gen_scope:
            reconstructions, _, prebinary = networks.compression_model(
                images, num_bits=FLAGS.bits_per_patch, depth=FLAGS.model_depth)
        gan_model = _get_gan_model(generator_inputs=images,
                                   generated_data=reconstructions,
                                   real_data=images,
                                   generator_scope=gen_scope)
        summaries.add_reconstruction_summaries(images, reconstructions,
                                               prebinary)
        tfgan.eval.add_gan_model_summaries(gan_model)

        # Define the GANLoss tuple using standard library functions.
        with tf.name_scope('loss'):
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.least_squares_generator_loss,
                discriminator_loss_fn=tfgan.losses.
                least_squares_discriminator_loss,
                add_summaries=FLAGS.weight_factor > 0)

            # Define the standard pixel loss.
            l1_pixel_loss = tf.norm(gan_model.real_data -
                                    gan_model.generated_data,
                                    ord=1)

            # Modify the loss tuple to include the pixel loss. Add summaries as well.
            gan_loss = tfgan.losses.combine_adversarial_loss(
                gan_loss,
                gan_model,
                l1_pixel_loss,
                weight_factor=FLAGS.weight_factor)

        # Get the GANTrain ops using the custom optimizers and optional
        # discriminator weight clipping.
        with tf.name_scope('train_ops'):
            gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr)
            gen_opt, dis_opt = _optimizer(gen_lr, dis_lr)
            train_ops = tfgan.gan_train_ops(
                gan_model,
                gan_loss,
                generator_optimizer=gen_opt,
                discriminator_optimizer=dis_opt,
                summarize_gradients=True,
                colocate_gradients_with_ops=True,
                aggregation_method=tf.AggregationMethod.
                EXPERIMENTAL_ACCUMULATE_N)
            tf.summary.scalar('generator_lr', gen_lr)
            tf.summary.scalar('discriminator_lr', dis_lr)

        # Determine the number of generator vs discriminator steps.
        train_steps = tfgan.GANTrainSteps(
            generator_train_steps=1,
            discriminator_train_steps=int(FLAGS.weight_factor > 0))

        # Run the alternating training loop. Skip it if no steps should be taken
        # (used for graph construction tests).
        status_message = tf.string_join([
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
                                        name='status_message')
        if FLAGS.max_number_of_steps == 0: return
        tfgan.gan_train(
            train_ops,
            FLAGS.train_log_dir,
            tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=10)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)