def testGetVQBottleneck(self): bottleneck_bits = 2 bottleneck_size = 2**bottleneck_bits hidden_size = 3 means, _, ema_count = discretization.get_vq_codebook( bottleneck_size, hidden_size) assign_op = means.assign(tf.zeros(shape=[bottleneck_size, hidden_size])) means_new, _, _ = discretization.get_vq_codebook(bottleneck_size, hidden_size) with self.test_session() as sess: tf.global_variables_initializer().run() sess.run(assign_op) self.assertTrue(np.all(sess.run(means_new) == 0)) self.assertTrue(np.all(sess.run(ema_count) == 0))