def testGetVQBottleneck(self):
   bottleneck_bits = 2
   bottleneck_size = 2**bottleneck_bits
   hidden_size = 3
   means, _, ema_count = discretization.get_vq_codebook(
       bottleneck_size, hidden_size)
   assign_op = means.assign(tf.zeros(shape=[bottleneck_size, hidden_size]))
   means_new, _, _ = discretization.get_vq_codebook(bottleneck_size,
                                                    hidden_size)
   with self.test_session() as sess:
     tf.global_variables_initializer().run()
     sess.run(assign_op)
     self.assertTrue(np.all(sess.run(means_new) == 0))
     self.assertTrue(np.all(sess.run(ema_count) == 0))
Beispiel #2
0
 def testGetVQBottleneck(self):
   bottleneck_bits = 2
   bottleneck_size = 2**bottleneck_bits
   hidden_size = 3
   means, _, ema_count = discretization.get_vq_codebook(
       bottleneck_size, hidden_size)
   assign_op = means.assign(tf.zeros(shape=[bottleneck_size, hidden_size]))
   means_new, _, _ = discretization.get_vq_codebook(bottleneck_size,
                                                    hidden_size)
   with self.test_session() as sess:
     tf.global_variables_initializer().run()
     sess.run(assign_op)
     self.assertTrue(np.all(sess.run(means_new) == 0))
     self.assertTrue(np.all(sess.run(ema_count) == 0))