def discriminator(images_and_lbls, unused_conditioning, mode):
     """TF-GAN compatible discriminator."""
     del unused_conditioning, mode
     images, labels = images_and_lbls['images'], images_and_lbls['labels']
     if hparams.debug_params.fake_nets:
         # Need discriminator variables and to depend on the generator.
         logits = tf.zeros(
             [tf.shape(input=images)[0], 20]) * tf.compat.v1.get_variable(
                 'dummy_d',
                 initializer=2.0) * tf.reduce_mean(input_tensor=images)
         discriminator_vars = ()
     else:
         num_trainable_variables = len(tf.compat.v1.trainable_variables())
         logits, discriminator_vars = dis_module.discriminator(
             images, labels, hparams.df_dim, hparams.num_classes)
         if num_trainable_variables != len(
                 tf.compat.v1.trainable_variables()):
             # Log the generated variables only in the first time the function is
             # called and new variables are generated (it is called twice: once for
             # the generated data and once for the real data).
             eval_lib.log_and_summarize_variables(
                 discriminator_vars, 'dvars',
                 hparams.tpu_params.use_tpu_estimator)
     logits.shape.assert_is_compatible_with([None, None])
     return logits
Example #2
0
    def generator(noise, mode):
        """TF-GAN compatible generator function."""
        batch_size = tf.shape(input=noise)[0]
        is_train = (mode == tf.estimator.ModeKeys.TRAIN)

        # Some label trickery.
        gen_class_logits = tf.zeros((batch_size, hparams.num_classes))
        gen_class_ints = tf.random.categorical(logits=gen_class_logits,
                                               num_samples=1)
        gen_sparse_class = tf.squeeze(gen_class_ints, -1)
        gen_sparse_class.shape.assert_is_compatible_with([None])

        if hparams.debug_params.fake_nets:
            gen_imgs = tf.zeros([batch_size, 128, 128, 3]) * tf.get_variable(
                'dummy_g', initializer=2.0)
            generator_vars = ()
        else:
            gen_imgs, generator_vars = gen_module.generator(
                noise,
                gen_sparse_class,
                hparams.gf_dim,
                hparams.num_classes,
                training=is_train)
        # Print debug statistics and log the generated variables.
        gen_imgs, gen_sparse_class = eval_lib.print_debug_statistics(
            gen_imgs, gen_sparse_class, 'generator',
            hparams.tpu_params.use_tpu_estimator)
        eval_lib.log_and_summarize_variables(
            generator_vars, 'gvars', hparams.tpu_params.use_tpu_estimator)
        gen_imgs.shape.assert_is_compatible_with([None, 128, 128, 3])

        if mode == tf.estimator.ModeKeys.PREDICT:
            return gen_imgs
        else:
            return {'images': gen_imgs, 'labels': gen_sparse_class}
Example #3
0
    def generator(noise_and_lbls, mode):
        """TF-GAN compatible generator function."""
        noise, labs = noise_and_lbls['z'], noise_and_lbls['labels']
        batch_size = tf.shape(input=noise)[0]
        is_train = (mode == tf.estimator.ModeKeys.TRAIN)

        # labs.shape.assert_is_compatible_with([None]) # not correct for gen_images

        if hparams.debug_params.fake_nets:
            gen_imgs = tf.zeros([
                batch_size, flags.FLAGS.image_size, flags.FLAGS.image_size, 3
            ]) * tf.compat.v1.get_variable('dummy_g', initializer=2.0)
            generator_vars = ()
        else:
            gen_imgs, generator_vars = gen_module.generator(
                noise,
                labs,
                hparams.gf_dim,
                hparams.num_classes,
                training=is_train)
        # Print debug statistics and log the generated variables.
        gen_imgs, gen_sparse_class = eval_lib.print_debug_statistics(
            gen_imgs, labs, 'generator', hparams.tpu_params.use_tpu_estimator)
        eval_lib.log_and_summarize_variables(
            generator_vars, 'gvars', hparams.tpu_params.use_tpu_estimator)
        gen_imgs.shape.assert_is_compatible_with(
            [None, flags.FLAGS.image_size, flags.FLAGS.image_size, 3])

        if mode == tf.estimator.ModeKeys.PREDICT and not flags.FLAGS.gen_images_with_margins:
            return gen_imgs
        else:
            return {'images': gen_imgs, 'labels': labs}