예제 #1
0
 def testMatrixAndVectorBatchShapesSameRankButPermuted(self):
   batch_shape = [6, 3, 2]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(2, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((6, 3, 2, 4), vec_v.shape)
예제 #2
0
 def testVectorBatchShapeLongerThanMatrixBatchShape(self):
   batch_shape = [2, 3, 2, 3]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(2, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((2, 3, 2, 3, 4), vec_v.shape)
예제 #3
0
 def testVectorBatchShapeLongerThanMatrixBatchShape(self):
   batch_shape = [2, 3, 2, 3]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(2, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((2, 3, 2, 3, 4), vec_v.shape)
예제 #4
0
 def testMatrixAndVectorBatchShapesSameRankButPermuted(self):
   batch_shape = [6, 3, 2]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(2, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((6, 3, 2, 4), vec_v.shape)
예제 #5
0
 def testMatrixBatchShapeHasASingletonThatVecBatchShapeDoesnt(self):
   batch_shape = [6, 3]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(1, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((6, 3, 4), vec_v.shape)
       self.assertAllEqual(mat[0, 2, 3, 4], vec_v[4, 2, 3])
예제 #6
0
 def test_matrix_and_vector_batch_shapes_the_same(self):
   batch_shape = [6, 2, 3]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(2, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((6, 2, 3, 4), vec_v.shape)
       self.assertAllEqual(mat[1, 2, 3, 4], vec_v[4, 1, 2, 3])
예제 #7
0
 def testMatrixBatchShapeHasASingletonThatVecBatchShapeDoesnt(self):
   batch_shape = [6, 3]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(1, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((6, 3, 4), vec_v.shape)
       self.assertAllEqual(mat[0, 2, 3, 4], vec_v[4, 2, 3])
예제 #8
0
 def test_matrix_and_vector_batch_shapes_the_same(self):
   batch_shape = [6, 2, 3]
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = self._rng.rand(2, 3, 4, 6)
       vec = operator_pd.flip_matrix_to_vector(
           mat, batch_shape, static_batch_shape)
       vec_v = vec.eval()
       self.assertAllEqual((6, 2, 3, 4), vec_v.shape)
       self.assertAllEqual(mat[1, 2, 3, 4], vec_v[4, 1, 2, 3])