Beispiel #1
0
 def testDynamicSliceWithIncorrectSizeIndicesShape(self):
   with self.session() as session:
     with self.test_scope():
       output = xla.dynamic_slice(
           np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
           np.array([5, 7, 3]), np.array([2, 3]))
     with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
       session.run(output)
     self.assertRegexpMatches(
         invalid_arg_error.exception.message,
         (r'size_indices must be a vector with length equal to input rank, '
          r'but input rank is 3 and size_indices has shape \[2\].*'))
Beispiel #2
0
 def testDynamicSliceWithIncorrectSizeIndicesShape(self):
   with self.test_session() as session:
     with self.test_scope():
       output = xla.dynamic_slice(
           np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
           np.array([5, 7, 3]), np.array([2, 3]))
     with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
       session.run(output)
     self.assertRegexpMatches(
         invalid_arg_error.exception.message,
         (r'^size_indices must be a vector with length equal to input rank, '
          r'but input rank is 3 and size_indices has shape \[2\].*'))
 def testDynamicSliceWithIncorrectSizeIndicesShape(self):
     with self.session() as session:
         with self.test_scope():
             output = xla.dynamic_slice(
                 np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
                 np.array([5, 7, 3]), np.array([2, 3]))
         with self.assertRaises(
                 errors.InvalidArgumentError) as invalid_arg_error:
             session.run(output)
         self.assertRegex(invalid_arg_error.exception.message, (
             r'op has mismatched number of slice sizes \(2\) and number of start'
             r' indices \(3\)'))
    def testDynamicSlice(self):
        start = array_ops.placeholder(np.int32, shape=(2, 3, 4))

        # If slice_sizes are known, the operand shape does not matter.
        # The shape of the output is equal to slice_sizes.
        slice_sizes = np.array([1, 2, 4], dtype=np.int32)
        for a_shape in [(2, 3, 4), (None, 3, 4), None]:
            a = array_ops.placeholder(np.float32, shape=a_shape)
            res = xla.dynamic_slice(a, start, slice_sizes)
            self.assertEqual(res.shape.as_list(), [1, 2, 4])

        # The first two dimension slice sizes are known
        slice_sizes = array_ops.stack(
            [1, 2, array_ops.placeholder(np.int32, [])])
        for a_shape in [(2, 3, 4), (None, 3, 4), None]:
            a = array_ops.placeholder(np.float32, shape=a_shape)
            res = xla.dynamic_slice(a, start, slice_sizes)
            self.assertEqual(res.shape.as_list(), [1, 2, None])

        # If slice_sizes has known rank and dimension, but is not a constant
        # then output has the same rank, but with unknown dimensions.
        slice_sizes = array_ops.placeholder(np.int32, [3])
        for a_shape in [(2, 3, 4), (None, 3, 4), None]:
            a = array_ops.placeholder(np.float32, shape=a_shape)
            res = xla.dynamic_slice(a, start, slice_sizes)
            self.assertEqual(res.shape.as_list(), [None, None, None])

        # slice sizes has known rank, but unknown dimensions.
        # then the output has the same rank as the operand, but with unknown
        # dimensions.
        slice_sizes = array_ops.placeholder(np.int32, [None])
        for a_shape in [(2, 3, 4), (None, 3, 4)]:
            a = array_ops.placeholder(np.float32, shape=a_shape)
            res = xla.dynamic_slice(a, start, slice_sizes)
            self.assertEqual(res.shape.as_list(), [None, None, None])

        a = array_ops.placeholder(np.float32, shape=None)
        slice_sizes = array_ops.placeholder(np.int32, [None])
        res = xla.dynamic_slice(a, start, slice_sizes)
        self.assertIsNone(res.shape.rank)