def test_emnist_frechet_distance(self):
    distance = eeu.emnist_frechet_distance(
        self.real_images, self.fake_images,
        ecm.get_trained_emnist_classifier_model())
    self.assertAllClose(distance, 568.6883, rtol=0.0001, atol=0.0001)

    distance = eeu.emnist_frechet_distance(
        self.real_images, self.real_images,
        ecm.get_trained_emnist_classifier_model())
    self.assertAllClose(distance, 0.0)
Exemple #2
0
def _compute_eval_metrics(generator, discriminator, gen_inputs, real_images,
                          gan_loss_fns, emnist_classifier):
    """Computes eval metrics for the GAN."""
    gen_images = generator(gen_inputs, training=False)
    disc_on_real_images = discriminator(real_images, training=False)
    disc_on_gen_outputs = discriminator(gen_images, training=False)
    real_data_logits = tf.reduce_mean(disc_on_real_images)
    gen_data_logits = tf.reduce_mean(disc_on_gen_outputs)

    gen_loss = gan_loss_fns.generator_loss(generator, discriminator,
                                           gen_inputs)
    disc_loss = gan_loss_fns.discriminator_loss(generator, discriminator,
                                                gen_inputs, real_images)
    classifier_score = eeu.emnist_score(gen_images, emnist_classifier)
    frechet_classifier_distance = eeu.emnist_frechet_distance(
        real_images, gen_images, emnist_classifier)

    metrics = collections.OrderedDict([
        ('real_data_logits', real_data_logits),
        ('gen_data_logits', gen_data_logits),
        ('gen trainable norm',
         tf.linalg.global_norm(generator.trainable_variables)),
        ('disc trainable norm',
         tf.linalg.global_norm(discriminator.trainable_variables)),
        ('gen non-trainable norm',
         tf.linalg.global_norm(generator.non_trainable_variables)),
        ('disc non-trainable norm',
         tf.linalg.global_norm(discriminator.non_trainable_variables)),
        ('gen_loss', gen_loss),
        ('disc_loss', disc_loss),
        ('classifier_score', classifier_score),
        ('frechet_classifier_distance', frechet_classifier_distance),
    ])
    return metrics
Exemple #3
0
def _compute_eval_metrics(generator, discriminator, gen_inputs, real_images,
                          gan_loss_fns, emnist_classifier, server_state):
  """Computes eval metrics for the GAN."""
  with tf.device("/cpu:0"):
    gen_images = generator(gen_inputs, training=False)
  disc_on_real_images = discriminator(real_images, training=False)
  disc_on_gen_outputs = discriminator(gen_images, training=False)
  real_data_logits = tf.reduce_mean(disc_on_real_images)
  gen_data_logits = tf.reduce_mean(disc_on_gen_outputs)

  gen_loss = gan_loss_fns.generator_loss(generator, discriminator, gen_inputs)
  disc_loss = gan_loss_fns.discriminator_loss(generator, discriminator,
                                              gen_inputs, real_images)
  classifier_score = eeu.emnist_score(gen_images, emnist_classifier)
  
  frechet_classifier_distance = eeu.emnist_frechet_distance(
      real_images, gen_images, emnist_classifier)
  
  '''
  frechet_classifier_distance = eval_util.mnist_frechet_distance(
      real_images, gen_images)
  '''
  metrics = collections.OrderedDict([
      ('real_data_logits', real_data_logits),
      ('gen_data_logits', gen_data_logits),
      ('gen trainable norm',
       tf.linalg.global_norm(generator.trainable_variables)),
      ('disc trainable norm',
       tf.linalg.global_norm(discriminator.trainable_variables)),
      ('gen non-trainable norm',
       tf.linalg.global_norm(generator.non_trainable_variables)),
      ('disc non-trainable norm',
       tf.linalg.global_norm(discriminator.non_trainable_variables)),
      ('gen_loss', gen_loss),
      ('disc_loss', disc_loss),
      ('classifier_score', classifier_score),
      ('frechet_classifier_distance', frechet_classifier_distance),
      ('gen_opt_state', tf.linalg.global_norm([tf.cast(x, tf.float32) for x in server_state.state_gen_optimizer_weights])),
      ('disc_opt_state', tf.linalg.global_norm([tf.cast(x, tf.float32) for x in server_state.state_disc_optimizer_weights])),
      ('gen_diff', tf.linalg.global_norm([tf.cast(x, tf.float32) for x in server_state.generator_diff])),
      ('disc_diff', tf.linalg.global_norm([tf.cast(x, tf.float32) for x in server_state.discriminator_diff])),
      ])
  return metrics