def testGumbelSoftmaxDiscreteBottleneck(self): x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32) tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, tf.constant(1)) x_means_hot, _ = discretization.gumbel_softmax_discrete_bottleneck( x, bottleneck_bits=2) self.evaluate(tf.global_variables_initializer()) x_means_hot_eval = self.evaluate(x_means_hot) self.assertEqual(np.shape(x_means_hot_eval), (2, 4))