def testMatrixAndVectorBatchShapesSameRankButPermuted(self): batch_shape = [6, 3, 2] for static_batch_shape in [ tf.TensorShape(batch_shape), tf.TensorShape(None)]: with self.test_session(): mat = self._rng.rand(2, 3, 4, 6) vec = operator_pd.flip_matrix_to_vector( mat, batch_shape, static_batch_shape) vec_v = vec.eval() self.assertAllEqual((6, 3, 2, 4), vec_v.shape)
def testVectorBatchShapeLongerThanMatrixBatchShape(self): batch_shape = [2, 3, 2, 3] for static_batch_shape in [ tf.TensorShape(batch_shape), tf.TensorShape(None)]: with self.test_session(): mat = self._rng.rand(2, 3, 4, 6) vec = operator_pd.flip_matrix_to_vector( mat, batch_shape, static_batch_shape) vec_v = vec.eval() self.assertAllEqual((2, 3, 2, 3, 4), vec_v.shape)
def testMatrixBatchShapeHasASingletonThatVecBatchShapeDoesnt(self): batch_shape = [6, 3] for static_batch_shape in [ tf.TensorShape(batch_shape), tf.TensorShape(None)]: with self.test_session(): mat = self._rng.rand(1, 3, 4, 6) vec = operator_pd.flip_matrix_to_vector( mat, batch_shape, static_batch_shape) vec_v = vec.eval() self.assertAllEqual((6, 3, 4), vec_v.shape) self.assertAllEqual(mat[0, 2, 3, 4], vec_v[4, 2, 3])
def test_matrix_and_vector_batch_shapes_the_same(self): batch_shape = [6, 2, 3] for static_batch_shape in [ tf.TensorShape(batch_shape), tf.TensorShape(None)]: with self.test_session(): mat = self._rng.rand(2, 3, 4, 6) vec = operator_pd.flip_matrix_to_vector( mat, batch_shape, static_batch_shape) vec_v = vec.eval() self.assertAllEqual((6, 2, 3, 4), vec_v.shape) self.assertAllEqual(mat[1, 2, 3, 4], vec_v[4, 1, 2, 3])