Example #1
0
 def test_single_example_correct(self, mock_tfhub_load):
   mock_tfhub_load.return_value = fake_logit_fn
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit()] * 2, 0),
       tf.concat([real_digit()] * 2, 0))
   with self.cached_session() as sess:
     self.assertNear(0.0, sess.run(fdistance), 2e-1)
Example #2
0
def calculate_fid(opts):
    model_path = opts['--model-path']

    model = get_generator_model()
    model = load_model(model_path)

    (_, _), (x_test, _) = mnist.load_data()

    x_test = (x_test.astype(numpy.float32)) / 255.0
    x_test = numpy.expand_dims(x_test, axis=3)

    random_code = numpy.random.normal(size=(NUM_PICS_TO_CALCULATE_FID,
                                            LATENT_DIM))

    predicted_imgs = model.predict(random_code)
    numpy.random.shuffle(x_test)
    real_imgs = x_test[:NUM_PICS_TO_CALCULATE_FID]

    predicted_imgs = tf.convert_to_tensor(predicted_imgs, dtype=tf.float32)
    real_imgs = tf.convert_to_tensor(real_imgs, dtype=tf.float32)

    fid = util.mnist_frechet_distance(real_imgs, predicted_imgs)
    print("CALCULATED FID: %f" % fid)

    if opts.get('--output', None):
        path = opts['--output']
        if os.path.exists(path):
            os.remove(path)

        with open(path, "w") as f:
            f.write("CALCULATED FID: %f" % fid)
Example #3
0
 def test_batch_splitting_doesnt_change_value(self):
     for num_batches in [1, 2, 4, 8]:
         fdistance = util.mnist_frechet_distance(
             tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0),
             tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0),
             num_batches=num_batches)
         with self.cached_session() as sess:
             self.assertNear(97.8, sess.run(fdistance), 2e-1)
Example #4
0
 def test_minibatch_correct(self):
     fdistance = util.mnist_frechet_distance(
         tf.concat([real_digit(), real_digit(),
                    fake_digit()], 0),
         tf.concat([real_digit(), fake_digit(),
                    fake_digit()], 0))
     with self.cached_session() as sess:
         self.assertNear(43.5, sess.run(fdistance), 2e-1)
Example #5
0
 def _disabled_test_minibatch_correct(self):
   """Tests the correctness of the mnist_frechet_distance function."""
   # Disabled since it requires loading the tfhub MNIST module.
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit(), real_digit(), fake_digit()], 0),
       tf.concat([real_digit(), fake_digit(), fake_digit()], 0))
   with self.cached_session() as sess:
     self.assertNear(43.5, sess.run(fdistance), 2e-1)
Example #6
0
 def test_single_example_correct(self):
   if tf.executing_eagerly():
     # `run_image_classifier` doesn't work in eager.
     return
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit()] * 2, 0),
       tf.concat([real_digit()] * 2, 0))
   with self.cached_session() as sess:
     self.assertNear(0.0, sess.run(fdistance), 2e-1)
Example #7
0
 def test_minibatch_correct(self):
   if tf.executing_eagerly():
     # `run_image_classifier` doesn't work in eager.
     return
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit(), real_digit(), fake_digit()], 0),
       tf.concat([real_digit(), fake_digit(), fake_digit()], 0))
   with self.cached_session() as sess:
     self.assertNear(43.5, sess.run(fdistance), 2e-1)
Example #8
0
 def test_any_batch_size(self):
   if tf.executing_eagerly():
     # Placeholders don't work in eager execution mode.
     return
   inputs = tf.compat.v1.placeholder(tf.float32, shape=[None, 28, 28, 1])
   fdistance = util.mnist_frechet_distance(inputs, inputs)
   for batch_size in [4, 16, 30]:
     with self.cached_session() as sess:
       sess.run(fdistance,
                feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
Example #9
0
 def test_any_batch_size(self, mock_tfhub_load):
   mock_tfhub_load.return_value = fake_logit_fn
   # Create a graph since placeholders don't work in eager execution mode.
   with tf.Graph().as_default():
     inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
     fdistance = util.mnist_frechet_distance(inputs, inputs)
     for batch_size in [4, 16, 30]:
       with self.cached_session() as sess:
         sess.run(fdistance,
                  feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
Example #10
0
def get_metrics(gan_model):
    """Return metrics for MNIST experiment."""
    real_mnist_score = util.mnist_score(gan_model.real_data)
    generated_mnist_score = util.mnist_score(gan_model.generated_data)
    frechet_distance = util.mnist_frechet_distance(gan_model.real_data,
                                                   gan_model.generated_data)
    return {
        'real_mnist_score': tf.compat.v1.metrics.mean(real_mnist_score),
        'mnist_score': tf.compat.v1.metrics.mean(generated_mnist_score),
        'frechet_distance': tf.compat.v1.metrics.mean(frechet_distance),
    }
Example #11
0
    def test_deterministic(self):
        fdistance = util.mnist_frechet_distance(
            tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0))
        with self.cached_session() as sess:
            fdistance1 = sess.run(fdistance)
            fdistance2 = sess.run(fdistance)
        self.assertNear(fdistance1, fdistance2, 2e-1)

        with self.cached_session() as sess:
            fdistance3 = sess.run(fdistance)
        self.assertNear(fdistance1, fdistance3, 2e-1)
Example #12
0
 def _disabled_test_batch_splitting_doesnt_change_value(self):
     """Tests correctness of mnist_frechet_distance function with batch sizes."""
     # Disabled since it requires loading the tfhub MNIST module.
     with tf.Graph().as_default():
         for num_batches in [1, 2, 4, 8]:
             fdistance = util.mnist_frechet_distance(
                 tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0),
                 tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0),
                 num_batches=num_batches)
             with self.cached_session() as sess:
                 self.assertNear(97.8, sess.run(fdistance), 2e-1)
Example #13
0
 def test_batch_splitting_doesnt_change_value(self):
   if tf.executing_eagerly():
     # `run_image_classifier` doesn't work in eager.
     return
   for num_batches in [1, 2, 4, 8]:
     fdistance = util.mnist_frechet_distance(
         tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0),
         tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0),
         num_batches=num_batches)
     with self.cached_session() as sess:
       self.assertNear(97.8, sess.run(fdistance), 2e-1)
Example #14
0
    def test_deterministic(self, mock_tfhub_load):
        mock_tfhub_load.return_value = fake_logit_fn
        fdistance = util.mnist_frechet_distance(
            tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0))
        with self.cached_session() as sess:
            fdistance1 = sess.run(fdistance)
            fdistance2 = sess.run(fdistance)
        self.assertNear(fdistance1, fdistance2, 2e-1)

        with self.cached_session() as sess:
            fdistance3 = sess.run(fdistance)
        self.assertNear(fdistance1, fdistance3, 2e-1)
Example #15
0
def evaluate(hparams, run_eval_loop=True):
    """Runs an evaluation loop.

  Args:
    hparams: An HParams instance containing the eval hyperparameters.
    run_eval_loop: Whether to run the full eval loop. Set to False for testing.
  """
    # Fetch real images.
    with tf.compat.v1.name_scope('inputs'):
        real_images, _ = data_provider.provide_data(
            'train', hparams.num_images_generated, hparams.dataset_dir)

    image_write_ops = None
    if hparams.eval_real_images:
        tf.compat.v1.summary.scalar(
            'MNIST_Classifier_score',
            util.mnist_score(real_images, hparams.classifier_filename))
    else:
        # In order for variables to load, use the same variable scope as in the
        # train job.
        with tf.compat.v1.variable_scope('Generator'):
            images = networks.unconditional_generator(tf.random.normal(
                [hparams.num_images_generated, hparams.noise_dims]),
                                                      is_training=False)
        tf.compat.v1.summary.scalar(
            'MNIST_Frechet_distance',
            util.mnist_frechet_distance(real_images, images,
                                        hparams.classifier_filename))
        tf.compat.v1.summary.scalar(
            'MNIST_Classifier_score',
            util.mnist_score(images, hparams.classifier_filename))
        if hparams.num_images_generated >= 100 and hparams.write_to_disk:
            reshaped_images = tfgan.eval.image_reshaper(images[:100, ...],
                                                        num_cols=10)
            uint8_images = data_provider.float_image_to_uint8(reshaped_images)
            image_write_ops = tf.io.write_file(
                '%s/%s' % (hparams.eval_dir, 'unconditional_gan.png'),
                tf.image.encode_png(uint8_images[0]))

    # For unit testing, use `run_eval_loop=False`.
    if not run_eval_loop:
        return
    evaluation.evaluate_repeatedly(
        hparams.checkpoint_dir,
        hooks=[
            evaluation.SummaryAtEndHook(hparams.eval_dir),
            evaluation.StopAfterNEvalsHook(1)
        ],
        eval_ops=image_write_ops,
        max_number_of_evaluations=hparams.max_number_of_evaluations)
Example #16
0
  def test_deterministic(self):
    if tf.executing_eagerly():
      # `run_image_classifier` doesn't work in eager.
      return
    fdistance = util.mnist_frechet_distance(
        tf.concat([real_digit()] * 2, 0),
        tf.concat([fake_digit()] * 2, 0))
    with self.cached_session() as sess:
      fdistance1 = sess.run(fdistance)
      fdistance2 = sess.run(fdistance)
    self.assertNear(fdistance1, fdistance2, 2e-1)

    with self.cached_session() as sess:
      fdistance3 = sess.run(fdistance)
    self.assertNear(fdistance1, fdistance3, 2e-1)
Example #17
0
 def get_mnist_eval_metrics(real, fake):
     frechet = tfgan_mnist.mnist_frechet_distance(real, fake, 1)
     score = tfgan_mnist.mnist_score(fake, 1)
     return tf.stack(list(map(tf.stop_gradient, (frechet, score))))
Example #18
0
 def test_single_example_correct(self):
     fdistance = util.mnist_frechet_distance(
         tf.concat([real_digit()] * 2, 0), tf.concat([real_digit()] * 2, 0))
     with self.cached_session() as sess:
         self.assertNear(0.0, sess.run(fdistance), 2e-1)