def testDiscreteBottleneckVQCond(self): hidden_size = 60 z_size = 4 x = tf.zeros(shape=[100, 1, hidden_size], dtype=tf.float32) with tf.variable_scope("test2", reuse=tf.AUTO_REUSE): means = tf.get_variable("means", shape=[1, 1, 2**z_size, hidden_size], initializer=tf.constant_initializer(0.), dtype=tf.float32) ema_count = [] ema_count_i = tf.get_variable( "ema_count", [1, 2**z_size], initializer=tf.constant_initializer(0), trainable=False) ema_count.append(ema_count_i) ema_means = [] with tf.colocate_with(means): ema_means_i = tf.get_variable("ema_means", initializer=means.initialized_value()[0], trainable=False) ema_means.append(ema_means_i) cond = tf.cast(0.0, tf.bool) x_means_dense, x_means_hot, _, _, _ = discretization.discrete_bottleneck( x, hidden_size, z_size, 32, means=means, num_blocks=1, cond=cond, ema_means=ema_means, ema_count=ema_count, name="test2") with self.test_session() as sess: sess.run(tf.global_variables_initializer()) x_means_dense_eval, x_means_hot_eval = sess.run( [x_means_dense, x_means_hot]) means_eval = sess.run(means) self.assertEqual(x_means_dense_eval.shape, (100, 1, hidden_size)) self.assertEqual(x_means_hot_eval.shape, (100, 1)) self.assertAllClose(means_eval, np.zeros((1, 1, 2**z_size, hidden_size)))