def test_single_example_correct(self, mock_tfhub_load): mock_tfhub_load.return_value = fake_logit_fn fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([real_digit()] * 2, 0)) with self.cached_session() as sess: self.assertNear(0.0, sess.run(fdistance), 2e-1)
def calculate_fid(opts): model_path = opts['--model-path'] model = get_generator_model() model = load_model(model_path) (_, _), (x_test, _) = mnist.load_data() x_test = (x_test.astype(numpy.float32)) / 255.0 x_test = numpy.expand_dims(x_test, axis=3) random_code = numpy.random.normal(size=(NUM_PICS_TO_CALCULATE_FID, LATENT_DIM)) predicted_imgs = model.predict(random_code) numpy.random.shuffle(x_test) real_imgs = x_test[:NUM_PICS_TO_CALCULATE_FID] predicted_imgs = tf.convert_to_tensor(predicted_imgs, dtype=tf.float32) real_imgs = tf.convert_to_tensor(real_imgs, dtype=tf.float32) fid = util.mnist_frechet_distance(real_imgs, predicted_imgs) print("CALCULATED FID: %f" % fid) if opts.get('--output', None): path = opts['--output'] if os.path.exists(path): os.remove(path) with open(path, "w") as f: f.write("CALCULATED FID: %f" % fid)
def test_batch_splitting_doesnt_change_value(self): for num_batches in [1, 2, 4, 8]: fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0), tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0), num_batches=num_batches) with self.cached_session() as sess: self.assertNear(97.8, sess.run(fdistance), 2e-1)
def test_minibatch_correct(self): fdistance = util.mnist_frechet_distance( tf.concat([real_digit(), real_digit(), fake_digit()], 0), tf.concat([real_digit(), fake_digit(), fake_digit()], 0)) with self.cached_session() as sess: self.assertNear(43.5, sess.run(fdistance), 2e-1)
def _disabled_test_minibatch_correct(self): """Tests the correctness of the mnist_frechet_distance function.""" # Disabled since it requires loading the tfhub MNIST module. fdistance = util.mnist_frechet_distance( tf.concat([real_digit(), real_digit(), fake_digit()], 0), tf.concat([real_digit(), fake_digit(), fake_digit()], 0)) with self.cached_session() as sess: self.assertNear(43.5, sess.run(fdistance), 2e-1)
def test_single_example_correct(self): if tf.executing_eagerly(): # `run_image_classifier` doesn't work in eager. return fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([real_digit()] * 2, 0)) with self.cached_session() as sess: self.assertNear(0.0, sess.run(fdistance), 2e-1)
def test_minibatch_correct(self): if tf.executing_eagerly(): # `run_image_classifier` doesn't work in eager. return fdistance = util.mnist_frechet_distance( tf.concat([real_digit(), real_digit(), fake_digit()], 0), tf.concat([real_digit(), fake_digit(), fake_digit()], 0)) with self.cached_session() as sess: self.assertNear(43.5, sess.run(fdistance), 2e-1)
def test_any_batch_size(self): if tf.executing_eagerly(): # Placeholders don't work in eager execution mode. return inputs = tf.compat.v1.placeholder(tf.float32, shape=[None, 28, 28, 1]) fdistance = util.mnist_frechet_distance(inputs, inputs) for batch_size in [4, 16, 30]: with self.cached_session() as sess: sess.run(fdistance, feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
def test_any_batch_size(self, mock_tfhub_load): mock_tfhub_load.return_value = fake_logit_fn # Create a graph since placeholders don't work in eager execution mode. with tf.Graph().as_default(): inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1]) fdistance = util.mnist_frechet_distance(inputs, inputs) for batch_size in [4, 16, 30]: with self.cached_session() as sess: sess.run(fdistance, feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
def get_metrics(gan_model): """Return metrics for MNIST experiment.""" real_mnist_score = util.mnist_score(gan_model.real_data) generated_mnist_score = util.mnist_score(gan_model.generated_data) frechet_distance = util.mnist_frechet_distance(gan_model.real_data, gan_model.generated_data) return { 'real_mnist_score': tf.compat.v1.metrics.mean(real_mnist_score), 'mnist_score': tf.compat.v1.metrics.mean(generated_mnist_score), 'frechet_distance': tf.compat.v1.metrics.mean(frechet_distance), }
def test_deterministic(self): fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0)) with self.cached_session() as sess: fdistance1 = sess.run(fdistance) fdistance2 = sess.run(fdistance) self.assertNear(fdistance1, fdistance2, 2e-1) with self.cached_session() as sess: fdistance3 = sess.run(fdistance) self.assertNear(fdistance1, fdistance3, 2e-1)
def _disabled_test_batch_splitting_doesnt_change_value(self): """Tests correctness of mnist_frechet_distance function with batch sizes.""" # Disabled since it requires loading the tfhub MNIST module. with tf.Graph().as_default(): for num_batches in [1, 2, 4, 8]: fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0), tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0), num_batches=num_batches) with self.cached_session() as sess: self.assertNear(97.8, sess.run(fdistance), 2e-1)
def test_batch_splitting_doesnt_change_value(self): if tf.executing_eagerly(): # `run_image_classifier` doesn't work in eager. return for num_batches in [1, 2, 4, 8]: fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0), tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0), num_batches=num_batches) with self.cached_session() as sess: self.assertNear(97.8, sess.run(fdistance), 2e-1)
def test_deterministic(self, mock_tfhub_load): mock_tfhub_load.return_value = fake_logit_fn fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0)) with self.cached_session() as sess: fdistance1 = sess.run(fdistance) fdistance2 = sess.run(fdistance) self.assertNear(fdistance1, fdistance2, 2e-1) with self.cached_session() as sess: fdistance3 = sess.run(fdistance) self.assertNear(fdistance1, fdistance3, 2e-1)
def evaluate(hparams, run_eval_loop=True): """Runs an evaluation loop. Args: hparams: An HParams instance containing the eval hyperparameters. run_eval_loop: Whether to run the full eval loop. Set to False for testing. """ # Fetch real images. with tf.compat.v1.name_scope('inputs'): real_images, _ = data_provider.provide_data( 'train', hparams.num_images_generated, hparams.dataset_dir) image_write_ops = None if hparams.eval_real_images: tf.compat.v1.summary.scalar( 'MNIST_Classifier_score', util.mnist_score(real_images, hparams.classifier_filename)) else: # In order for variables to load, use the same variable scope as in the # train job. with tf.compat.v1.variable_scope('Generator'): images = networks.unconditional_generator(tf.random.normal( [hparams.num_images_generated, hparams.noise_dims]), is_training=False) tf.compat.v1.summary.scalar( 'MNIST_Frechet_distance', util.mnist_frechet_distance(real_images, images, hparams.classifier_filename)) tf.compat.v1.summary.scalar( 'MNIST_Classifier_score', util.mnist_score(images, hparams.classifier_filename)) if hparams.num_images_generated >= 100 and hparams.write_to_disk: reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10) uint8_images = data_provider.float_image_to_uint8(reshaped_images) image_write_ops = tf.io.write_file( '%s/%s' % (hparams.eval_dir, 'unconditional_gan.png'), tf.image.encode_png(uint8_images[0])) # For unit testing, use `run_eval_loop=False`. if not run_eval_loop: return evaluation.evaluate_repeatedly( hparams.checkpoint_dir, hooks=[ evaluation.SummaryAtEndHook(hparams.eval_dir), evaluation.StopAfterNEvalsHook(1) ], eval_ops=image_write_ops, max_number_of_evaluations=hparams.max_number_of_evaluations)
def test_deterministic(self): if tf.executing_eagerly(): # `run_image_classifier` doesn't work in eager. return fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0)) with self.cached_session() as sess: fdistance1 = sess.run(fdistance) fdistance2 = sess.run(fdistance) self.assertNear(fdistance1, fdistance2, 2e-1) with self.cached_session() as sess: fdistance3 = sess.run(fdistance) self.assertNear(fdistance1, fdistance3, 2e-1)
def get_mnist_eval_metrics(real, fake): frechet = tfgan_mnist.mnist_frechet_distance(real, fake, 1) score = tfgan_mnist.mnist_score(fake, 1) return tf.stack(list(map(tf.stop_gradient, (frechet, score))))
def test_single_example_correct(self): fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([real_digit()] * 2, 0)) with self.cached_session() as sess: self.assertNear(0.0, sess.run(fdistance), 2e-1)