예제 #1
0
 def testVQNearestNeighbors(self):
   x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
   means = tf.constant(
       [[1, 0, 0], [0, 1, 0], [0, 0, 1], [9, 9, 9]], dtype=tf.float32)
   x_means_hot, _, _ = discretization.vq_nearest_neighbor(x, means)
   x_means_hot_test = np.array([[0, 1, 0, 0], [1, 0, 0, 0]])
   x_means_hot_eval = self.evaluate(x_means_hot)
   self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
   self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))
예제 #2
0
 def testVQNearestNeighbors(self):
   x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
   means = tf.constant(
       [[1, 0, 0], [0, 1, 0], [0, 0, 1], [9, 9, 9]], dtype=tf.float32)
   x_means_hot, _, _ = discretization.vq_nearest_neighbor(x, means)
   x_means_hot_test = np.array([[0, 1, 0, 0], [1, 0, 0, 0]])
   x_means_hot_eval = self.evaluate(x_means_hot)
   self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
   self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))
 def testVQNearestNeighbors(self):
     x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
     means = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [9, 9, 9]],
                         dtype=tf.float32)
     x_means_hot, _ = discretization.vq_nearest_neighbor(x, means)
     x_means_hot_test = np.array([[0, 1, 0, 0], [1, 0, 0, 0]])
     with self.test_session() as sess:
         tf.global_variables_initializer().run()
         x_means_hot_eval = sess.run(x_means_hot)
         self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
         self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))
예제 #4
0
 def testVQNearestNeighbors(self):
   x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
   means = tf.constant(
       [[1, 0, 0], [0, 1, 0], [0, 0, 1], [9, 9, 9]], dtype=tf.float32)
   x_means_hot, _ = discretization.vq_nearest_neighbor(x, means)
   x_means_hot_test = np.array([[0, 1, 0, 0], [1, 0, 0, 0]])
   with self.test_session() as sess:
     tf.global_variables_initializer().run()
     x_means_hot_eval = sess.run(x_means_hot)
     self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
     self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))