Exemplo n.º 1
0
 def test_when_x_has_two_larger_larger_batch_rank_than_batch_rank_arg(self):
   batch_shape = [4, 5]
   x = self._rng.rand(2, 3, 4, 5, 6)
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = operator_pd._flip_vector_to_matrix(
           x, batch_shape, static_batch_shape)
       mat_v = mat.eval()
       self.assertAllEqual((4, 5, 6, 2*3), mat_v.shape)
Exemplo n.º 2
0
 def test_when_batch_shape_requires_reshape_of_vector_batch_shape(self):
   batch_shape = [5, 4]
   x = self._rng.rand(3, 4, 5, 6)  # Note x has (4,5) and batch_shape is (5, 4)
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = operator_pd._flip_vector_to_matrix(
           x, batch_shape, static_batch_shape)
       mat_v = mat.eval()
       self.assertAllEqual((5, 4, 6, 3), mat_v.shape)
Exemplo n.º 3
0
 def test_when_x_batch_rank_is_same_as_batch_rank_arg(self):
   batch_shape = [4, 5]
   x = self._rng.rand(4, 5, 6)
   for static_batch_shape in [
       tf.TensorShape(batch_shape), tf.TensorShape(None)]:
     with self.test_session():
       mat = operator_pd._flip_vector_to_matrix(
           x, batch_shape, static_batch_shape)
       mat_v = mat.eval()
       expected_mat_v = x.reshape(x.shape + (1,))
       self.assertAllEqual(expected_mat_v, mat_v)
Exemplo n.º 4
0
 def test_when_x_has_two_larger_larger_batch_rank_than_batch_rank_arg(self):
     batch_shape = [4, 5]
     x = self._rng.rand(2, 3, 4, 5, 6)
     for static_batch_shape in [
             tf.TensorShape(batch_shape),
             tf.TensorShape(None)
     ]:
         with self.test_session():
             mat = operator_pd._flip_vector_to_matrix(
                 x, batch_shape, static_batch_shape)
             mat_v = mat.eval()
             self.assertAllEqual((4, 5, 6, 2 * 3), mat_v.shape)
Exemplo n.º 5
0
 def test_when_batch_shape_requires_reshape_of_vector_batch_shape(self):
     batch_shape = [5, 4]
     x = self._rng.rand(3, 4, 5,
                        6)  # Note x has (4,5) and batch_shape is (5, 4)
     for static_batch_shape in [
             tf.TensorShape(batch_shape),
             tf.TensorShape(None)
     ]:
         with self.test_session():
             mat = operator_pd._flip_vector_to_matrix(
                 x, batch_shape, static_batch_shape)
             mat_v = mat.eval()
             self.assertAllEqual((5, 4, 6, 3), mat_v.shape)
Exemplo n.º 6
0
 def test_when_x_batch_rank_is_same_as_batch_rank_arg(self):
     batch_shape = [4, 5]
     x = self._rng.rand(4, 5, 6)
     for static_batch_shape in [
             tf.TensorShape(batch_shape),
             tf.TensorShape(None)
     ]:
         with self.test_session():
             mat = operator_pd._flip_vector_to_matrix(
                 x, batch_shape, static_batch_shape)
             mat_v = mat.eval()
             expected_mat_v = x.reshape(x.shape + (1, ))
             self.assertAllEqual(expected_mat_v, mat_v)