def discriminator(images_and_lbls, unused_conditioning, mode):
     """TF-GAN compatible discriminator."""
     del unused_conditioning, mode
     images, labels = images_and_lbls['images'], images_and_lbls['labels']
     if hparams.debug_params.fake_nets:
         # Need discriminator variables and to depend on the generator.
         logits = tf.zeros(
             [tf.shape(input=images)[0], 20]) * tf.compat.v1.get_variable(
                 'dummy_d',
                 initializer=2.0) * tf.reduce_mean(input_tensor=images)
         discriminator_vars = ()
     else:
         num_trainable_variables = len(tf.compat.v1.trainable_variables())
         logits, discriminator_vars = dis_module.discriminator(
             images, labels, hparams.df_dim, hparams.num_classes)
         if num_trainable_variables != len(
                 tf.compat.v1.trainable_variables()):
             # Log the generated variables only in the first time the function is
             # called and new variables are generated (it is called twice: once for
             # the generated data and once for the real data).
             eval_lib.log_and_summarize_variables(
                 discriminator_vars, 'dvars',
                 hparams.tpu_params.use_tpu_estimator)
     logits.shape.assert_is_compatible_with([None, None])
     return logits
Beispiel #2
0
    def test_generator_shapes_and_ranges(self):
        """Tests the discriminator.

    Make sure the image shapes and output value ranges are as expected.
    """
        if tf.executing_eagerly():
            # `compute_spectral_norm` doesn't work when executing eagerly.
            return
        batch_size = 10
        num_classes = 1000
        gen_class_logits = tf.zeros((batch_size, num_classes))
        gen_class_ints = tf.multinomial(gen_class_logits, 1)
        gen_sparse_class = tf.squeeze(gen_class_ints)
        images = tf.random.normal([10, 32, 32, 3])
        d_out, var_list = discriminator.discriminator(images, gen_sparse_class,
                                                      16, 1000)
        sess = tf.train.MonitoredTrainingSession()
        images_np = sess.run(d_out)
        self.assertEqual((batch_size, 1), images_np.shape)
        self.assertAllInRange(images_np, -1.0, 1.0)
        self.assertIsInstance(var_list, list)