Exemplo n.º 1
0
 def test_minibatch_correct(self):
   # Reorded minibatches should have the same value.
   xent1 = util.mnist_cross_entropy(
       tf.concat([real_digit(), real_digit(), fake_digit()], 0),
       tf.concat([one_hot_real(), one_hot1(), one_hot1()], 0))
   xent2 = util.mnist_cross_entropy(
       tf.concat([real_digit(), fake_digit(), real_digit()], 0),
       tf.concat([one_hot_real(), one_hot1(), one_hot1()], 0))
   with self.test_session():
     self.assertNear(6.972539, xent1.eval(), 1e-5)
     self.assertNear(xent1.eval(), xent2.eval(), 1e-5)
Exemplo n.º 2
0
 def test_minibatch_correct(self):
   # Reorded minibatches should have the same value.
   xent1 = util.mnist_cross_entropy(
       tf.concat([real_digit(), real_digit(), fake_digit()], 0),
       tf.concat([one_hot_real(), one_hot1(), one_hot1()], 0))
   xent2 = util.mnist_cross_entropy(
       tf.concat([real_digit(), fake_digit(), real_digit()], 0),
       tf.concat([one_hot_real(), one_hot1(), one_hot1()], 0))
   with self.test_session():
     self.assertNear(6.972539, xent1.eval(), 1e-5)
     self.assertNear(xent1.eval(), xent2.eval(), 1e-5)
Exemplo n.º 3
0
 def test_single_example_correct(self):
   # The correct label should have low cross entropy.
   correct_xent = util.mnist_cross_entropy(real_digit(), one_hot_real())
   # The incorrect label should have high cross entropy.
   wrong_xent = util.mnist_cross_entropy(real_digit(), one_hot1())
   # A random digit should have medium cross entropy for any label.
   fake_xent1 = util.mnist_cross_entropy(fake_digit(), one_hot_real())
   fake_xent6 = util.mnist_cross_entropy(fake_digit(), one_hot1())
   with self.test_session():
     self.assertNear(0.00996, correct_xent.eval(), 1e-5)
     self.assertNear(18.63073, wrong_xent.eval(), 1e-5)
     self.assertNear(2.2, fake_xent1.eval(), 1e-1)
     self.assertNear(2.2, fake_xent6.eval(), 1e-1)
Exemplo n.º 4
0
 def test_single_example_correct(self):
   # The correct label should have low cross entropy.
   correct_xent = util.mnist_cross_entropy(real_digit(), one_hot_real())
   # The incorrect label should have high cross entropy.
   wrong_xent = util.mnist_cross_entropy(real_digit(), one_hot1())
   # A random digit should have medium cross entropy for any label.
   fake_xent1 = util.mnist_cross_entropy(fake_digit(), one_hot_real())
   fake_xent6 = util.mnist_cross_entropy(fake_digit(), one_hot1())
   with self.test_session():
     self.assertNear(0.00996, correct_xent.eval(), 1e-5)
     self.assertNear(18.63073, wrong_xent.eval(), 1e-5)
     self.assertNear(2.2, fake_xent1.eval(), 1e-1)
     self.assertNear(2.2, fake_xent6.eval(), 1e-1)
Exemplo n.º 5
0
  def test_deterministic(self):
    xent = util.mnist_cross_entropy(real_digit(), one_hot_real())
    with self.test_session():
      ent1 = xent.eval()
      ent2 = xent.eval()
    self.assertEqual(ent1, ent2)

    with self.test_session():
      ent3 = xent.eval()
    self.assertEqual(ent1, ent3)
Exemplo n.º 6
0
  def test_deterministic(self):
    xent = util.mnist_cross_entropy(real_digit(), one_hot_real())
    with self.test_session():
      ent1 = xent.eval()
      ent2 = xent.eval()
    self.assertEqual(ent1, ent2)

    with self.test_session():
      ent3 = xent.eval()
    self.assertEqual(ent1, ent3)
Exemplo n.º 7
0
 def test_any_batch_size(self):
   num_classes = 10
   one_label = np.array([[1] + [0] * (num_classes - 1)])
   inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
   one_hot_label = tf.placeholder(tf.int32, shape=[None, num_classes])
   entropy = util.mnist_cross_entropy(inputs, one_hot_label)
   for batch_size in [4, 16, 30]:
     with self.test_session() as sess:
       sess.run(entropy, feed_dict={
           inputs: np.zeros([batch_size, 28, 28, 1]),
           one_hot_label: np.concatenate([one_label] * batch_size)})
Exemplo n.º 8
0
 def test_any_batch_size(self):
   num_classes = 10
   one_label = np.array([[1] + [0] * (num_classes - 1)])
   inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
   one_hot_label = tf.placeholder(tf.int32, shape=[None, num_classes])
   entropy = util.mnist_cross_entropy(inputs, one_hot_label)
   for batch_size in [4, 16, 30]:
     with self.test_session() as sess:
       sess.run(entropy, feed_dict={
           inputs: np.zeros([batch_size, 28, 28, 1]),
           one_hot_label: np.concatenate([one_label] * batch_size)})