def test_any_batch_size(self): inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1]) fdistance = util.mnist_frechet_distance(inputs, inputs) for batch_size in [4, 16, 30]: with self.test_session() as sess: sess.run(fdistance, feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
def test_batch_splitting_doesnt_change_value(self): for num_batches in [1, 2, 4, 8]: fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0), tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0), num_batches=num_batches) with self.test_session(): self.assertNear(97.8, fdistance.eval(), 2e-1)
def test_minibatch_correct(self): fdistance = util.mnist_frechet_distance( tf.concat([real_digit(), real_digit(), fake_digit()], 0), tf.concat([real_digit(), fake_digit(), fake_digit()], 0)) with self.test_session(): self.assertNear(43.5, fdistance.eval(), 2e-1)
def test_deterministic(self): fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0)) with self.test_session(): fdistance1 = fdistance.eval() fdistance2 = fdistance.eval() self.assertNear(fdistance1, fdistance2, 2e-1) with self.test_session(): fdistance3 = fdistance.eval() self.assertNear(fdistance1, fdistance3, 2e-1)
def test_single_example_correct(self): fdistance = util.mnist_frechet_distance( tf.concat([real_digit()] * 2, 0), tf.concat([real_digit()] * 2, 0)) with self.test_session(): self.assertNear(0.0, fdistance.eval(), 2e-1)