示例#1
0
 def test_any_batch_size(self):
   inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
   fdistance = util.mnist_frechet_distance(inputs, inputs)
   for batch_size in [4, 16, 30]:
     with self.test_session() as sess:
       sess.run(fdistance,
                feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
示例#2
0
 def test_any_batch_size(self):
   inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
   fdistance = util.mnist_frechet_distance(inputs, inputs)
   for batch_size in [4, 16, 30]:
     with self.test_session() as sess:
       sess.run(fdistance,
                feed_dict={inputs: np.zeros([batch_size, 28, 28, 1])})
示例#3
0
 def test_batch_splitting_doesnt_change_value(self):
   for num_batches in [1, 2, 4, 8]:
     fdistance = util.mnist_frechet_distance(
         tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0),
         tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0),
         num_batches=num_batches)
     with self.test_session():
       self.assertNear(97.8, fdistance.eval(), 2e-1)
示例#4
0
 def test_batch_splitting_doesnt_change_value(self):
   for num_batches in [1, 2, 4, 8]:
     fdistance = util.mnist_frechet_distance(
         tf.concat([real_digit()] * 6 + [fake_digit()] * 2, 0),
         tf.concat([real_digit()] * 2 + [fake_digit()] * 6, 0),
         num_batches=num_batches)
     with self.test_session():
       self.assertNear(97.8, fdistance.eval(), 2e-1)
示例#5
0
 def test_minibatch_correct(self):
     fdistance = util.mnist_frechet_distance(
         tf.concat([real_digit(), real_digit(),
                    fake_digit()], 0),
         tf.concat([real_digit(), fake_digit(),
                    fake_digit()], 0))
     with self.test_session():
         self.assertNear(43.5, fdistance.eval(), 2e-1)
示例#6
0
    def test_deterministic(self):
        fdistance = util.mnist_frechet_distance(
            tf.concat([real_digit()] * 2, 0), tf.concat([fake_digit()] * 2, 0))
        with self.test_session():
            fdistance1 = fdistance.eval()
            fdistance2 = fdistance.eval()
        self.assertNear(fdistance1, fdistance2, 2e-1)

        with self.test_session():
            fdistance3 = fdistance.eval()
        self.assertNear(fdistance1, fdistance3, 2e-1)
示例#7
0
  def test_deterministic(self):
    fdistance = util.mnist_frechet_distance(
        tf.concat([real_digit()] * 2, 0),
        tf.concat([fake_digit()] * 2, 0))
    with self.test_session():
      fdistance1 = fdistance.eval()
      fdistance2 = fdistance.eval()
    self.assertNear(fdistance1, fdistance2, 2e-1)

    with self.test_session():
      fdistance3 = fdistance.eval()
    self.assertNear(fdistance1, fdistance3, 2e-1)
示例#8
0
 def test_minibatch_correct(self):
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit(), real_digit(), fake_digit()], 0),
       tf.concat([real_digit(), fake_digit(), fake_digit()], 0))
   with self.test_session():
     self.assertNear(43.5, fdistance.eval(), 2e-1)
示例#9
0
 def test_single_example_correct(self):
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit()] * 2, 0),
       tf.concat([real_digit()] * 2, 0))
   with self.test_session():
     self.assertNear(0.0, fdistance.eval(), 2e-1)
示例#10
0
 def test_single_example_correct(self):
   fdistance = util.mnist_frechet_distance(
       tf.concat([real_digit()] * 2, 0),
       tf.concat([real_digit()] * 2, 0))
   with self.test_session():
     self.assertNear(0.0, fdistance.eval(), 2e-1)