Ejemplo n.º 1
0
def prepare_metric_arguments(generator_inputs, generated_data, real_data,
                             discriminator_real_outputs,
                             discriminator_gen_outputs):
  """Prepares the arguments needed for get_metrics.

  When training on TPUs, this function should be executed on TPU.

  Args:
    generator_inputs: Inputs to the generator fn.
    generated_data: Output from the generator.
    real_data: A sample of real data.
    discriminator_real_outputs: Discriminator output on real data.
    discriminator_gen_outputs: Discriminator output on generated data.

  Returns:
    A metric dictionary.
  """
  del generator_inputs, discriminator_real_outputs, discriminator_gen_outputs

  real_images = (real_data['images'] if isinstance(real_data, dict) else
                 real_data)
  gen_images = (generated_data['images'] if isinstance(generated_data, dict)
                else generated_data)
  # Get logits and pools for real and generated images.
  real_logits, real_pools = eval_lib.get_activations(
      lambda: real_images, num_batches=1, get_logits=True)
  fake_logits, fake_pools = eval_lib.get_activations(
      lambda: gen_images, num_batches=1, get_logits=True)

  return {
      'real_logits': real_logits,
      'real_pools': real_pools,
      'fake_logits': fake_logits,
      'fake_pools': fake_pools
  }
Ejemplo n.º 2
0
def get_metrics(generator_inputs, generated_data, real_data,
                discriminator_real_outputs, discriminator_gen_outputs,
                hparams):
    """Return metrics for SAGAN experiment on TPU, CPU, or GPU.

  Args:
    generator_inputs: Inputs to the generator fn.
    generated_data: Output from the generator.
    real_data: A sample of real data.
    discriminator_real_outputs: Discriminator output on real data.
    discriminator_gen_outputs: Discriminator output on generated data.
    hparams: An hparams object.

  Returns:
    A metric dictionary.
  """
    del generator_inputs, discriminator_real_outputs, discriminator_gen_outputs

    real_images = real_data['images']
    gen_images = generated_data['images']

    # Get logits and pools for real and generated images.
    real_logits, real_pools = eval_lib.get_activations(lambda: real_images,
                                                       num_batches=1,
                                                       get_logits=True)
    fake_logits, fake_pools = eval_lib.get_activations(lambda: gen_images,
                                                       num_batches=1,
                                                       get_logits=True)

    if hparams.eval_on_tpu:
        # TODO(dyoel): Rewrite once b/135664219 is resolved.
        real_iscore = tfgan.eval.classifier_score_from_logits(real_logits)
        generated_iscore = tfgan.eval.classifier_score_from_logits(fake_logits)
        fid = tfgan.eval.frechet_classifier_distance_from_activations(
            real_pools, fake_pools)
        # Tile metrics because TPU requires metric outputs to be [batch_size, ...].
        batch_size = tf.shape(input=gen_images)[0]
        real_iscore_tiled = tf.tile([real_iscore], [batch_size])
        generated_iscore_tiled = tf.tile([generated_iscore], [batch_size])
        frechet_distance_tiled = tf.tile([fid], [batch_size])
        return {
            'eval/real_incscore': real_iscore_tiled,
            'eval/incscore': generated_iscore_tiled,
            'eval/fid': frechet_distance_tiled,
        }
    else:
        metric_dict = {
            'eval/real_incscore':
            tfgan.eval.classifier_score_from_logits_streaming(real_logits),
            'eval/incscore':
            tfgan.eval.classifier_score_from_logits_streaming(fake_logits),
            'eval/fid':
            tfgan.eval.frechet_classifier_distance_from_activations_streaming(
                real_pools, fake_pools),
        }
        metric_dict.update(_generator_summary_ops(gen_images, real_images))
        return metric_dict