def test_unflatten_beam_dim(self): x = tf.ones([28, 2, 5]) x = beam_search._unflatten_beam_dim(x, 7, 4) with self.test_session() as sess: shape = sess.run(tf.shape(x)) self.assertAllEqual([7, 4, 2, 5], shape)