def discriminator(images_and_lbls, unused_conditioning, mode): """TF-GAN compatible discriminator.""" del unused_conditioning, mode images, labels = images_and_lbls['images'], images_and_lbls['labels'] if hparams.debug_params.fake_nets: # Need discriminator variables and to depend on the generator. logits = tf.zeros( [tf.shape(input=images)[0], 20]) * tf.compat.v1.get_variable( 'dummy_d', initializer=2.0) * tf.reduce_mean(input_tensor=images) discriminator_vars = () else: num_trainable_variables = len(tf.compat.v1.trainable_variables()) logits, discriminator_vars = dis_module.discriminator( images, labels, hparams.df_dim, hparams.num_classes) if num_trainable_variables != len( tf.compat.v1.trainable_variables()): # Log the generated variables only in the first time the function is # called and new variables are generated (it is called twice: once for # the generated data and once for the real data). eval_lib.log_and_summarize_variables( discriminator_vars, 'dvars', hparams.tpu_params.use_tpu_estimator) logits.shape.assert_is_compatible_with([None, None]) return logits
def test_generator_shapes_and_ranges(self): """Tests the discriminator. Make sure the image shapes and output value ranges are as expected. """ if tf.executing_eagerly(): # `compute_spectral_norm` doesn't work when executing eagerly. return batch_size = 10 num_classes = 1000 gen_class_logits = tf.zeros((batch_size, num_classes)) gen_class_ints = tf.multinomial(gen_class_logits, 1) gen_sparse_class = tf.squeeze(gen_class_ints) images = tf.random.normal([10, 32, 32, 3]) d_out, var_list = discriminator.discriminator(images, gen_sparse_class, 16, 1000) sess = tf.train.MonitoredTrainingSession() images_np = sess.run(d_out) self.assertEqual((batch_size, 1), images_np.shape) self.assertAllInRange(images_np, -1.0, 1.0) self.assertIsInstance(var_list, list)