コード例 #1
0
 def testSplitWithPaddedOutput(self, texts, expected, ragged_rank=None):
     input_tensor = ragged.constant_value(_nested_encode(texts, "UTF-8"),
                                          ragged_rank=ragged_rank,
                                          dtype=bytes)
     result = ragged.unicode_split(input_tensor,
                                   "UTF-8").to_tensor(default_value="")
     self.assertAllEqual(np.array(expected, dtype=bytes), result)
コード例 #2
0
 def testBasicSplit(self, texts, ragged_rank=None):
     input_tensor = ragged.constant_value(_nested_encode(texts, "UTF-8"),
                                          ragged_rank=ragged_rank,
                                          dtype=bytes)
     result = ragged.unicode_split(input_tensor, "UTF-8")
     expected = _nested_splitchars(texts, "UTF-8")
     self.assertRaggedEqual(expected, result)
コード例 #3
0
 def testBasicSplitWithOffsets(self, texts, ragged_rank=None):
     input_tensor = ragged.constant_value(_nested_encode(texts, "UTF-8"),
                                          ragged_rank=ragged_rank,
                                          dtype=bytes)
     result = ragged.unicode_split_with_offsets(input_tensor, "UTF-8")
     expected_codepoints = _nested_splitchars(texts, "UTF-8")
     expected_offsets = _nested_offsets(texts, "UTF-8")
     self.assertRaggedEqual(expected_codepoints, result[0])
     self.assertRaggedEqual(expected_offsets, result[1])
コード例 #4
0
    def testRaggedValues(self,
                         pylist,
                         dtype=None,
                         ragged_rank=None,
                         inner_shape=None,
                         expected_shape=None,
                         expected_dtype=None):
        """Tests that `ragged_value(pylist).to_list() == pylist`."""
        rt = ragged.constant_value(pylist,
                                   dtype=dtype,
                                   ragged_rank=ragged_rank,
                                   inner_shape=inner_shape)

        # If dtype was explicitly specified, check it.
        if dtype is not None:
            self.assertEqual(rt.dtype, dtype)
        if expected_dtype is not None:
            self.assertEqual(rt.dtype, expected_dtype)

        # If ragged_rank was explicitly specified, check it.
        if ragged_rank is not None:
            if isinstance(rt, ragged.RaggedTensorValue):
                self.assertEqual(rt.ragged_rank, ragged_rank)
            else:
                self.assertEqual(0, ragged_rank)

        # If inner_shape was explicitly specified, check it.
        if inner_shape is not None:
            if isinstance(rt, ragged.RaggedTensorValue):
                self.assertEqual(rt.flat_values.shape[1:], inner_shape)
            else:
                self.assertEqual(rt.shape, inner_shape)

        if expected_shape is not None:
            self.assertEqual(tuple(rt.shape), expected_shape)

        if rt.shape:
            if isinstance(rt, ragged.RaggedTensorValue):
                self.assertEqual(rt.to_list(), pylist)
            else:
                self.assertEqual(rt.tolist(), pylist)
            if expected_shape is not None:
                self.assertEqual(rt.shape, expected_shape)
        else:
            self.assertEqual(rt, pylist)
            if expected_shape is not None:
                self.assertEqual((), expected_shape)
コード例 #5
0
  def testRaggedValues(self,
                       pylist,
                       dtype=None,
                       ragged_rank=None,
                       inner_shape=None,
                       expected_shape=None,
                       expected_dtype=None):
    """Tests that `ragged_value(pylist).tolist() == pylist`."""
    rt = ragged.constant_value(
        pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)

    # If dtype was explicitly specified, check it.
    if dtype is not None:
      self.assertEqual(rt.dtype, dtype)
    if expected_dtype is not None:
      self.assertEqual(rt.dtype, expected_dtype)

    # If ragged_rank was explicitly specified, check it.
    if ragged_rank is not None:
      if isinstance(rt, ragged.RaggedTensorValue):
        self.assertEqual(rt.ragged_rank, ragged_rank)
      else:
        self.assertEqual(0, ragged_rank)

    # If inner_shape was explicitly specified, check it.
    if inner_shape is not None:
      if isinstance(rt, ragged.RaggedTensorValue):
        self.assertEqual(rt.inner_values.shape[1:], inner_shape)
      else:
        self.assertEqual(rt.shape, inner_shape)

    if expected_shape is not None:
      self.assertEqual(tuple(rt.shape), expected_shape)

    if rt.shape:
      self.assertEqual(rt.tolist(), pylist)
      if expected_shape is not None:
        self.assertEqual(rt.shape, expected_shape)
    else:
      self.assertEqual(rt, pylist)
      if expected_shape is not None:
        self.assertEqual((), expected_shape)
コード例 #6
0
class RaggedWhereOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
    @parameterized.parameters([
        #=========================================================================
        # Docstring Examples
        #=========================================================================
        dict(  # shape=[D1, (D2)]
            condition=ragged.constant_value([[True, False, True],
                                             [False, True]]),
            expected=[[0, 0], [0, 2], [1, 1]]),
        dict(  # shape=[D1, (D2)]
            condition=ragged.constant_value([[True, False, True],
                                             [False, True]]),
            x=ragged.constant_value([['A', 'B', 'C'], ['D', 'E']]),
            y=ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]),
            expected=ragged.constant_value([[b'A', b'b', b'C'], [b'd',
                                                                 b'E']])),
        dict(  # shape=[D1, (D2)]
            condition=ragged.constant_value([True, False]),
            x=ragged.constant_value([['A', 'B', 'C'], ['D', 'E']]),
            y=ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]),
            expected=ragged.constant_value([[b'A', b'B', b'C'], [b'd',
                                                                 b'e']])),
        #=========================================================================
        # Coordinate-retrieval mode
        #=========================================================================
        dict(  # shape=[D1]
            condition=[True, False, True, False, True],
            expected=[[0], [2], [4]]),
        dict(  # shape=[D1, D2]
            condition=[[True, False], [False, True]],
            expected=[[0, 0], [1, 1]]),
        dict(  # shape=[D1, (D2)]
            condition=ragged.constant_value([[True, False, True],
                                             [False, True]]),
            expected=[[0, 0], [0, 2], [1, 1]]),
        dict(  # shape=[D1, (D2), (D3)]
            condition=ragged.constant_value([[[True, False, True],
                                              [False, True]],
                                             [[True], [], [False],
                                              [False, True, False]]]),
            expected=[[0, 0, 0], [0, 0, 2], [0, 1, 1], [1, 0, 0], [1, 3, 1]]),
        dict(  # shape=[D1, (D2), D3]
            condition=ragged.constant_value([[[True, False], [False, True]],
                                             [[True, False], [False, False],
                                              [True, False], [False, True]]],
                                            ragged_rank=1),
            expected=[[0, 0, 0], [0, 1, 1], [1, 0, 0], [1, 2, 0], [1, 3, 1]]),
        dict(  # shape=[D1, (D2), (D3), (D4)]
            condition=ragged.constant_value([[[[], [True]]],
                                             [[[True, False, True],
                                               [False, True]],
                                              [[True], [], [False],
                                               [False, True, False]]]]),
            expected=[[0, 0, 1, 0], [1, 0, 0, 0], [1, 0, 0, 2], [1, 0, 1, 1],
                      [1, 1, 0, 0], [1, 1, 3, 1]]),

        #=========================================================================
        # Elementwise value-selection mode
        #=========================================================================
        dict(  # shape=[]
            condition=True, x='A', y='a', expected=b'A'),
        dict(  # shape=[]
            condition=False, x='A', y='a', expected=b'a'),
        dict(  # shape=[D1]
            condition=[True, False, True],
            x=['A', 'B', 'C'],
            y=['a', 'b', 'c'],
            expected=[b'A', b'b', b'C']),
        dict(  # shape=[D1, D2]
            condition=[[True, False], [False, True]],
            x=[['A', 'B'], ['D', 'E']],
            y=[['a', 'b'], ['d', 'e']],
            expected=[[b'A', b'b'], [b'd', b'E']]),
        dict(  # shape=[D1, (D2)]
            condition=ragged.constant_value([[True, False, True],
                                             [False, True]]),
            x=ragged.constant_value([['A', 'B', 'C'], ['D', 'E']]),
            y=ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]),
            expected=ragged.constant_value([[b'A', b'b', b'C'], [b'd',
                                                                 b'E']])),
        dict(  # shape=[D1, (D2), D3]
            condition=ragged.constant_value([[[True, False], [False, True]],
                                             [[True, False], [False, False],
                                              [True, False], [False, True]]],
                                            ragged_rank=1),
            x=ragged.constant_value(
                [[['A', 'B'], ['C', 'D']],
                 [['E', 'F'], ['G', 'H'], ['I', 'J'], ['K', 'L']]],
                ragged_rank=1),
            y=ragged.constant_value(
                [[['a', 'b'], ['c', 'd']],
                 [['e', 'f'], ['g', 'h'], ['i', 'j'], ['k', 'l']]],
                ragged_rank=1),
            expected=ragged.constant_value(
                [[[b'A', b'b'], [b'c', b'D']],
                 [[b'E', b'f'], [b'g', b'h'], [b'I', b'j'], [b'k', b'L']]],
                ragged_rank=1)),
        dict(  # shape=[D1, (D2), (D3), (D4)]
            condition=ragged.constant_value([[[[], [True]]],
                                             [[[True, False, True],
                                               [False, True]],
                                              [[True], [], [False],
                                               [False, True, False]]]]),
            x=ragged.constant_value([[[[], ['A']]],
                                     [[['B', 'C', 'D'], ['E', 'F']],
                                      [['G'], [], ['H'], ['I', 'J', 'K']]]]),
            y=ragged.constant_value([[[[], ['a']]],
                                     [[['b', 'c', 'd'], ['e', 'f']],
                                      [['g'], [], ['h'], ['i', 'j', 'k']]]]),
            expected=ragged.constant_value([[[[], [b'A']]],
                                            [[[b'B', b'c', b'D'], [b'e',
                                                                   b'F']],
                                             [[b'G'], [], [b'h'],
                                              [b'i', b'J', b'k']]]])),

        #=========================================================================
        # Elementwise row-selection mode
        #=========================================================================
        dict(  # shape=[D1, D2]
            condition=[True, False, True],
            x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
            y=[['a', 'b'], ['c', 'd'], ['e', 'f']],
            expected=[[b'A', b'B'], [b'c', b'd'], [b'E', b'F']]),
        dict(  # shape=[D1, (D2)]
            condition=[True, False, True],
            x=ragged.constant_value([['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]),
            y=ragged.constant_value([['a', 'b'], ['c'], ['d', 'e']]),
            expected=ragged.constant_value([[b'A', b'B', b'C'], [b'c'],
                                            [b'F', b'G']])),
        dict(  # shape=[D1, (D2), (D3), (D4)]
            condition=ragged.constant_value([True, False]),
            x=ragged.constant_value([[[[], ['A']]],
                                     [[['B', 'C', 'D'], ['E', 'F']],
                                      [['G'], [], ['H'], ['I', 'J', 'K']]]]),
            y=ragged.constant_value([[[['a']]], [[['b']]]]),
            expected=ragged.constant_value([[[[], [b'A']]], [[[b'b']]]])),
    ])  # pyformat: disable
    def testRaggedWhere(self, condition, expected, x=None, y=None):
        result = ragged.where(condition, x, y)
        self.assertEqual(getattr(result, 'ragged_rank', 0),
                         getattr(expected, 'ragged_rank', 0))
        with self.test_session():
            result_value = self.evaluate(result)
            if hasattr(result_value, 'tolist'):
                result_value = result_value.tolist()
            if hasattr(expected, 'tolist'):
                expected = expected.tolist()
            self.assertEqual(result_value, expected)

    @parameterized.parameters([
        dict(condition=[True, False],
             x=[1, 2],
             error=ValueError,
             message='x and y must be either both None or both non-None'),
        dict(condition=ragged.constant_value([[True, False, True],
                                              [False, True]]),
             x=ragged.constant_value([['A', 'B', 'C'], ['D', 'E']]),
             y=[['a', 'b'], ['d', 'e']],
             error=ValueError,
             message='Input shapes do not match.'),
    ])
    def testRaggedWhereErrors(self, condition, error, message, x=None, y=None):
        with self.assertRaisesRegexp(error, message):
            ragged.where(condition, x, y)
コード例 #7
0
class RaggedGatherNdOpTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):

    DOCSTRING_PARAMS = [[['000', '001'], ['010']],
                        [['100'], ['110', '111', '112'], ['120']],
                        [[], ['210']]]  # pyformat: disable

    @parameterized.parameters([
        #=========================================================================
        # Docstring Examples
        #=========================================================================
        dict(descr='Docstring example 1',
             params=ragged.constant_value(DOCSTRING_PARAMS),
             indices=[[2], [0]],
             expected=ragged.constant_value([[[], [b'210']],
                                             [[b'000', b'001'], [b'010']]])),
        dict(descr='Docstring example 2',
             params=ragged.constant_value(DOCSTRING_PARAMS),
             indices=[[2, 1], [0, 0]],
             expected=ragged.constant_value([[b'210'], [b'000', b'001']])),
        dict(descr='Docstring example 3',
             params=ragged.constant_value(DOCSTRING_PARAMS),
             indices=[[0, 0, 1], [1, 1, 2]],
             expected=[b'001', b'112']),
        #=========================================================================
        # Indices with 0 values (selects the entire params)
        #=========================================================================
        dict(descr='params: [B1, (B2)], indices: [0], result: [B1, (B2)]',
             params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
             indices=np.zeros([0], dtype=np.int32),
             expected=ragged.constant_value([[b'a', b'b', b'c'], [b'd']])),
        dict(descr=
             'params: [B1, (B2)], indices: [A1, 0], result: [A1, B1, (B2)]',
             params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
             indices=np.zeros([3, 0], dtype=np.int32),
             expected=ragged.constant_value([[[b'a', b'b', b'c'], [b'd']],
                                             [[b'a', b'b', b'c'], [b'd']],
                                             [[b'a', b'b', b'c'], [b'd']]])),
        dict(descr=('params: [B1, (B2)], indices: [A1, A2, 0], '
                    'result: [A1, A2, B1, (B2)]'),
             params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
             indices=np.zeros([1, 3, 0], dtype=np.int32),
             expected=ragged.constant_value([[[[b'a', b'b', b'c'], [b'd']],
                                              [[b'a', b'b', b'c'], [b'd']],
                                              [[b'a', b'b', b'c'], [b'd']]]])),
        dict(descr=
             'params: [B1], indices: [A1, (A2), 0], result: [A1, (A2), B1]',
             params=['a'],
             indices=ragged.constant_value([[[], []], [[]]],
                                           ragged_rank=1,
                                           dtype=np.int32),
             expected=ragged.constant_value([[[b'a'], [b'a']], [[b'a']]],
                                            ragged_rank=1)),
        #=========================================================================
        # Indices with 1 value (selects row from params)
        #=========================================================================
        dict(descr='params: [B1, (B2)], indices: [A1, 1], result: [A1, (B2)]',
             params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
             indices=[[1], [0]],
             expected=ragged.constant_value([[b'd'], [b'a', b'b', b'c']])),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, 1], '
                    'result: [A1, (B2), (B3)]'),
             params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                           [['e', 'f']]]),
             indices=[[1], [1]],
             expected=ragged.constant_value([[[b'e', b'f']], [[b'e', b'f']]])),
        dict(descr=('params: [B1, B2, B3], indices: [A1, (A2), 1], '
                    'result: [A1, (A2), B2, B3]'),
             params=[[['a']], [['b']]],
             indices=ragged.constant_value([[[0]]], ragged_rank=1),
             expected=ragged.constant_value([[[[b'a']]]], ragged_rank=1)),
        #=========================================================================
        # Indices with 2 values (selects row & col from params)
        #=========================================================================
        dict(descr='params: [B1, (B2)], indices: [A1, 2], result: [A1]',
             params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
             indices=[[1, 0], [0, 0], [0, 2]],
             expected=ragged.constant_value([b'd', b'a', b'c'])),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, 2], '
                    'result: [A1, (B3)]'),
             params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                           [['e', 'f']]]),
             indices=[[1, 0], [0, 1], [0, 0]],
             expected=ragged.constant_value([[b'e', b'f'], [b'd'],
                                             [b'a', b'b', b'c']])),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, A2, 2], '
                    'result: [A1, (A2), (B3)]'),
             params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                           [['e', 'f']]]),
             indices=[[[1, 0], [0, 1], [0, 0]]],
             expected=ragged.constant_value([[[b'e', b'f'], [b'd'],
                                              [b'a', b'b', b'c']]])),
        dict(descr=('params: [B1, (B2), B3], indices: [A1, A2, 2], '
                    'result: [A1, A2, B3]'),
             params=ragged.constant_value(
                 [[['a', 'b'], ['c', 'd']], [['e', 'f']]], ragged_rank=1),
             indices=[[[1, 0], [0, 1], [0, 0]]],
             expected=[[[b'e', b'f'], [b'c', b'd'], [b'a', b'b']]]),
        dict(descr=('params: [B1, (B2), B3], indices: [A1, A2, A3, 2], '
                    'result: [A1, A2, A3, B3]'),
             params=ragged.constant_value(
                 [[['a', 'b'], ['c', 'd']], [['e', 'f']]], ragged_rank=1),
             indices=[[[[1, 0], [0, 1], [0, 0]]]],
             expected=[[[[b'e', b'f'], [b'c', b'd'], [b'a', b'b']]]]),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, (A2), 2], '
                    'result: [A1, (A2), (B3)]'),
             params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                           [['e', 'f']]]),
             indices=ragged.constant_value([[[1, 0], [0, 1]], [[0, 0]]],
                                           ragged_rank=1),
             expected=ragged.constant_value([[[b'e', b'f'], [b'd']],
                                             [[b'a', b'b', b'c']]])),
        #=========================================================================
        # Indices with 3 values
        #=========================================================================
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, 3], '
                    'result: [A1]'),
             params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                           [['e', 'f']]]),
             indices=[[1, 0, 1], [0, 0, 0], [0, 1, 0]],
             expected=[b'f', b'a', b'd']),
        dict(descr=('params: [B1, (B2), B3], indices: [A1, 3], '
                    'result: [A1]'),
             params=ragged.constant_value(
                 [[['a', 'b'], ['c', 'd']], [['e', 'f']]], ragged_rank=1),
             indices=[[1, 0, 1], [0, 0, 0], [0, 1, 1]],
             expected=[b'f', b'a', b'd']),
        dict(descr=('params: [B1, (B2), (B3), B4], indices: [A1, 3], '
                    'result: [A1, B4]'),
             params=ragged.constant_value(
                 [[[['a', 'b'], ['c', 'd']], [['e', 'f']]]], ragged_rank=2),
             indices=[[0, 0, 1], [0, 0, 0], [0, 1, 0]],
             expected=[[b'c', b'd'], [b'a', b'b'], [b'e', b'f']]),
    ])  # pyformat: disable
    @test_util.run_deprecated_v1
    def testRaggedGatherNd(self, descr, params, indices, expected):
        result = ragged.gather_nd(params, indices)
        self.assertEqual(getattr(result, 'ragged_rank', 0),
                         getattr(expected, 'ragged_rank', 0))
        with self.test_session() as sess:
            if hasattr(expected, 'tolist'):
                expected = expected.tolist()
            self.assertEqual(self.evaluate(result).tolist(), expected)

    @test_util.run_deprecated_v1
    def testRaggedGatherNdUnknownRankError(self):
        params = ragged.constant([['a', 'b'], ['c', 'd']])
        indices1 = array_ops.placeholder(dtypes.int32, shape=None)
        indices2 = array_ops.placeholder(dtypes.int32, shape=[None])

        with self.assertRaisesRegexp(ValueError,
                                     'indices.rank be statically known.'):
            ragged.gather_nd(params, indices1)
        with self.assertRaisesRegexp(
                ValueError, r'indices.shape\[-1\] must be statically known.'):
            ragged.gather_nd(params, indices2)

    @parameterized.parameters([
        dict(params=['a'],
             indices=0,
             message='Shape must be at least rank 1 but is rank 0'
             " for 'GatherNd'"),
        dict(params=ragged.constant_value([['a']]),
             indices=0,
             message='indices.rank must be at least 1.'),
        dict(params=['a', 'b', 'c'],
             indices=ragged.constant([[0]]),
             message='The innermost dimension of indices may not be ragged'),
    ])
    @test_util.run_deprecated_v1
    def testRaggedGatherNdStaticError(self,
                                      params,
                                      indices,
                                      message,
                                      error=ValueError):
        with self.assertRaisesRegexp(error, message):
            ragged.gather_nd(params, indices)
コード例 #8
0
class RaggedConvertToTensorOrRaggedTensorTest(
        ragged_test_util.RaggedTensorTestCase, parameterized.TestCase):

    #=============================================================================
    # Tests where the 'value' param is a RaggedTensor
    #=============================================================================
    @parameterized.parameters([
        dict(pylist=[[1, 2], [3]]),
        dict(pylist=[[1, 2], [3]], preferred_dtype=dtypes.float32),
        dict(pylist=[[1, 2], [3]], preferred_dtype=dtypes.string),
    ])
    def testConvertRaggedTensor(self,
                                pylist,
                                dtype=None,
                                preferred_dtype=None):
        rt = ragged.constant(pylist)
        converted = ragged.convert_to_tensor_or_ragged_tensor(
            rt, dtype, preferred_dtype)
        self.assertIs(converted, rt)

    @parameterized.parameters([
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.float32,
             message=('Tensor conversion requested dtype float32 for '
                      'RaggedTensor with dtype int32')),
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.string,
             message=('Tensor conversion requested dtype string for '
                      'RaggedTensor with dtype .*')),
    ])
    def testConvertRaggedTensorError(self,
                                     pylist,
                                     message,
                                     dtype=None,
                                     preferred_dtype=None):
        rt = ragged.constant(pylist)

        with self.assertRaisesRegexp(ValueError, message):
            ragged.convert_to_tensor_or_ragged_tensor(rt, dtype,
                                                      preferred_dtype)

    #=============================================================================
    # Tests where the 'value' param is a RaggedTensorValue
    #=============================================================================
    @parameterized.parameters([
        dict(value=ragged.constant_value([[1, 2], [3]], dtype=np.int32),
             expected_dtype=dtypes.int32),
        dict(value=ragged.constant_value([[b'a', b'b'], [b'c']]),
             expected_dtype=dtypes.string),
        dict(value=ragged.constant_value([[1, 2], [3]], dtype=np.int32),
             dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=ragged.constant_value([[1, 2], [3]], dtype=np.int32),
             preferred_dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=ragged.constant_value([[1, 2], [3]], dtype=np.int32),
             preferred_dtype=dtypes.string,
             expected_dtype=dtypes.int32),
    ])
    def testConvertRaggedTensorValue(self,
                                     value,
                                     dtype=None,
                                     preferred_dtype=None,
                                     expected_dtype=None):
        if expected_dtype is None:
            expected_dtype = value.dtype if dtype is None else dtype
        converted = ragged.convert_to_tensor_or_ragged_tensor(
            value, dtype, preferred_dtype)
        self.assertEqual(value.ragged_rank, converted.ragged_rank)
        self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
        self.assertEqual(value.to_list(), self.eval_to_list(converted))

    @parameterized.parameters([
        dict(value=ragged.constant_value([['a', 'b'], ['c']], dtype=str),
             dtype=dtypes.int32,
             message=r"invalid literal for int\(\) with base 10: 'a'"),
    ])
    def testConvertRaggedTensorValueError(self,
                                          value,
                                          message,
                                          dtype=None,
                                          preferred_dtype=None):
        with self.assertRaisesRegexp(ValueError, message):
            ragged.convert_to_tensor_or_ragged_tensor(value, dtype,
                                                      preferred_dtype)

    #=============================================================================
    # Tests where the 'value' param is a Tensor
    #=============================================================================
    @parameterized.parameters([
        dict(pylist=[[1, 2], [3, 4]]),
        dict(pylist=[[1, 2], [3, 4]], preferred_dtype=dtypes.float32),
        dict(pylist=[[1, 2], [3, 4]], preferred_dtype=dtypes.string),
    ])
    def testConvertTensor(self, pylist, dtype=None, preferred_dtype=None):
        tensor = constant_op.constant(pylist)
        converted = ragged.convert_to_tensor_or_ragged_tensor(
            tensor, dtype, preferred_dtype)
        self.assertIs(tensor, converted)

    @parameterized.parameters([
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.float32,
             message=('Tensor conversion requested dtype float32 for '
                      'Tensor with dtype int32')),
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.string,
             message=('Tensor conversion requested dtype string for '
                      'Tensor with dtype int32')),
    ])
    def testConvertTensorError(self,
                               pylist,
                               message,
                               dtype=None,
                               preferred_dtype=None):
        tensor = constant_op.constant(pylist)
        with self.assertRaisesRegexp(ValueError, message):
            ragged.convert_to_tensor_or_ragged_tensor(tensor, dtype,
                                                      preferred_dtype)

    #=============================================================================
    # Tests where the 'value' param is a np.array
    #=============================================================================
    @parameterized.parameters([
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             expected_dtype=dtypes.int32),
        dict(value=np.array([[b'a', b'b'], [b'c', b'd']]),
             expected_dtype=dtypes.string),
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             preferred_dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             preferred_dtype=dtypes.string,
             expected_dtype=dtypes.int32),
    ])
    def testConvertNumpyArray(self,
                              value,
                              dtype=None,
                              preferred_dtype=None,
                              expected_dtype=None):
        if expected_dtype is None:
            expected_dtype = value.dtype if dtype is None else dtype
        converted = ragged.convert_to_tensor_or_ragged_tensor(
            value, dtype, preferred_dtype)
        self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
        self.assertAllEqual(value, converted)

    @parameterized.parameters([
        dict(value=np.array([['a', 'b'], ['c', 'd']], dtype=str),
             dtype=dtypes.int32,
             message=r"invalid literal for int\(\) with base 10: 'a'"),
    ])
    def testConvertNumpyArrayError(self,
                                   value,
                                   message,
                                   dtype=None,
                                   preferred_dtype=None):
        with self.assertRaisesRegexp(ValueError, message):
            ragged.convert_to_tensor_or_ragged_tensor(value, dtype,
                                                      preferred_dtype)
コード例 #9
0
class RaggedBooleanMaskOpTest(test_util.TensorFlowTestCase,
                              parameterized.TestCase):
    # Define short constants for true & false, so the data & mask can be lined
    # up in the examples below.  This makes it easier to read the examples, to
    # see which values should be kept vs. masked.
    T = True
    F = False

    @parameterized.parameters([
        #=========================================================================
        # Docstring examples
        #=========================================================================
        dict(descr='Docstring example 1',
             data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
             mask=[[T, F, T], [F, F, F], [T, F, F]],
             keepdims=False,
             expected=[1, 3, 7]),
        dict(descr='Docstring example 2',
             data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
             mask=[[T, F, T], [F, F, F], [T, F, F]],
             keepdims=True,
             expected=ragged.constant_value([[1, 3], [], [7]])),
        dict(descr='Docstring example 3',
             data=ragged.constant_value([[1, 2, 3], [4], [5, 6]]),
             mask=ragged.constant_value([[F, F, T], [F], [T, T]]),
             keepdims=False,
             expected=[3, 5, 6]),
        dict(descr='Docstring example 4',
             data=ragged.constant_value([[1, 2, 3], [4], [5, 6]]),
             mask=ragged.constant_value([[F, F, T], [F], [T, T]]),
             keepdims=True,
             expected=ragged.constant_value([[3], [], [5, 6]])),
        dict(descr='Docstring example 5',
             data=ragged.constant_value([[1, 2, 3], [4], [5, 6]]),
             mask=[True, False, True],
             keepdims=False,
             expected=ragged.constant_value([[1, 2, 3], [5, 6]])),
        #=========================================================================
        # Uniform data and uniform mask.
        #=========================================================================
        dict(descr='data.shape=[7]; mask.shape=[7]; keepdims=True',
             data=[1, 2, 3, 4, 5, 6, 7],
             mask=[T, F, T, T, F, F, F],
             keepdims=True,
             expected=[1, 3, 4]),
        dict(descr='data.shape=[5, 3]; mask.shape=[5]; keepdims=True',
             data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14,
                                                                   15]],
             mask=[True, False, True, True, False],
             keepdims=True,
             expected=[[1, 2, 3], [7, 8, 9], [10, 11, 12]]),
        dict(descr='data.shape=[5, 3]; mask.shape=[5, 3]; keepdims=True',
             data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 1, 2], [3, 4, 5]],
             mask=[[F, F, F], [T, F, T], [T, T, T], [F, F, F], [T, T, F]],
             keepdims=True,
             expected=ragged.constant_value([[], [4, 6], [7, 8, 9], [], [3,
                                                                         4]])),
        dict(descr='data.shape=[3, 2, 2]; mask.shape=[3]; keepdims=True',
             data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
             mask=[F, F, T],
             keepdims=True,
             expected=[[[2, 4], [6, 8]]]),
        dict(descr='data.shape=[3, 2, 2]; mask.shape=[3]; keepdims=False',
             data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
             mask=[F, F, T],
             keepdims=False,
             expected=[[[2, 4], [6, 8]]]),
        dict(descr='data.shape=[3, 2, 2]; mask.shape=[3, 2]; keepdims=True',
             data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
             mask=[[T, F], [T, T], [F, F]],
             keepdims=True,
             expected=ragged.constant_value([[[1, 2]], [[5, 6], [7, 8]], []],
                                            ragged_rank=1)),
        dict(descr='data.shape=[3, 2, 2]; mask.shape=[3, 2]; keepdims=False',
             data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
             mask=[[T, F], [T, T], [F, F]],
             keepdims=False,
             expected=[[1, 2], [5, 6], [7, 8]]),
        dict(descr='data.shape=[3, 2, 2]; mask.shape=[3, 2, 2]; keepdims=True',
             data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
             mask=[[[T, T], [F, T]], [[F, F], [F, F]], [[T, F], [T, T]]],
             keepdims=True,
             expected=ragged.constant_value([[[1, 2], [4]], [[], []],
                                             [[2], [6, 8]]])),
        dict(descr='data.shape=mask.shape=[2, 2, 2, 2]; keepdims=True',
             data=[[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
                   [[[2, 4], [6, 8]], [[1, 3], [5, 7]]]],
             mask=[[[[T, T], [F, F]], [[T, F], [F, F]]],
                   [[[F, F], [F, F]], [[T, T], [T, F]]]],
             keepdims=True,
             expected=ragged.constant_value([[[[1, 2], []], [[5], []]],
                                             [[[], []], [[1, 3], [5]]]])),
        dict(descr='data.shape=mask.shape=[2, 2, 2, 2]; keepdims=False',
             data=[[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
                   [[[2, 4], [6, 8]], [[1, 3], [5, 7]]]],
             mask=[[[[T, T], [F, F]], [[T, F], [F, F]]],
                   [[[F, F], [F, F]], [[T, T], [T, F]]]],
             keepdims=False,
             expected=[1, 2, 5, 1, 3, 5]),

        #=========================================================================
        # Ragged data and ragged mask.
        #=========================================================================
        dict(descr='data.shape=[5, (D2)]; mask.shape=[5, (D2)]',
             data=ragged.constant_value([[1, 2], [3, 4, 5, 6], [7, 8, 9], [],
                                         [1, 2, 3]]),
             mask=ragged.constant_value([[F, F], [F, T, F, T], [F, F, F], [],
                                         [T, F, T]]),
             keepdims=True,
             expected=ragged.constant_value([[], [4, 6], [], [], [1, 3]])),
        dict(descr='data.shape=[3, (D2), (D3)]; mask.shape=[3, (D2)]',
             data=ragged.constant_value([[[1, 2], [3, 4]], [[5, 6], [7, 8]],
                                         [[2, 4], [6, 8]]]),
             mask=ragged.constant_value([[T, F], [T, T], [F, F]]),
             keepdims=True,
             expected=ragged.constant_value([[[1, 2]], [[5, 6], [7, 8]], []])),
        dict(descr='data.shape=[3, (D2), (D3)]; mask.shape=[3, (D2)]',
             data=ragged.constant_value([[[1, 2], [3, 4]], [[5, 6], [7, 8]],
                                         [[2, 4], [6, 8]]]),
             mask=ragged.constant_value([[T, F], [T, T], [F, F]]),
             keepdims=False,
             expected=ragged.constant_value([[1, 2], [5, 6], [7, 8]])),
        dict(descr='data.shape=[3, (D2), D3]; mask.shape=[3, (D2)]',
             data=ragged.constant_value(
                 [[[1, 2], [3, 4]], [[5, 6], [7, 8], [2, 4]], [[6, 8]]],
                 ragged_rank=1),
             mask=ragged.constant_value([[T, F], [T, T, F], [F]]),
             keepdims=True,
             expected=ragged.constant_value([[[1, 2]], [[5, 6], [7, 8]], []],
                                            ragged_rank=1)),
        dict(descr='data.shape=[3, (D2), D3]; mask.shape=[3, (D2)]',
             data=ragged.constant_value(
                 [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
                 ragged_rank=1),
             mask=ragged.constant_value([[T, F], [T, T], [F, F]]),
             keepdims=False,
             expected=[[1, 2], [5, 6], [7, 8]]),
        dict(descr='data.shape=[3, (D2), (D3)]; mask.shape=[3, (D2), (D3)]',
             data=ragged.constant_value([[[1, 2], [3, 4]], [[5, 6], [7, 8]],
                                         [[2, 4]]]),
             mask=ragged.constant_value([[[T, T], [F, T]], [[F, F], [F, F]],
                                         [[T, F]]]),
             keepdims=True,
             expected=ragged.constant_value([[[1, 2], [4]], [[], []], [[2]]])),
        dict(descr=('data.shape=[3, (D2), (D3), (D4)]; '
                    'mask.shape=[3, (D2), (D3), (D4)]'),
             data=ragged.constant_value([[[[1, 2], [3, 4]], [[5, 6]]],
                                         [[[2, 4], [6, 8]]]]),
             mask=ragged.constant_value([[[[T, T], [F, F]], [[T, F]]],
                                         [[[F, F], [T, T]]]]),
             keepdims=True,
             expected=ragged.constant_value([[[[1, 2], []], [[5]]],
                                             [[[], [6, 8]]]])),

        #=========================================================================
        # Ragged mask and uniform data
        #=========================================================================
        dict(descr='data.shape=[2, 3]; mask.shape=[2, (3)]',
             data=[[1, 2, 3], [4, 5, 6]],
             mask=ragged.constant_value([[T, F, F], [F, T, T]]),
             keepdims=True,
             expected=ragged.constant_value([[1], [5, 6]])),
        dict(descr='data.shape=[2, 3, 2]; mask.shape=[2, (3)]',
             data=[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 0], [2, 4]]],
             mask=ragged.constant_value([[T, F, F], [F, T, T]]),
             keepdims=True,
             expected=ragged.constant_value([[[1, 2]], [[9, 0], [2, 4]]],
                                            ragged_rank=1)),
        dict(descr='data.shape=[2, 3, 2]; mask.shape=[2, (3), 2]',
             data=[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 0], [2, 4]]],
             mask=ragged.constant_value(
                 [[[T, F], [F, F], [T, T]], [[T, F], [F, T], [F, F]]],
                 ragged_rank=1),
             keepdims=True,
             expected=ragged.constant_value([[[1], [], [5, 6]], [[7], [0],
                                                                 []]])),

        #=========================================================================
        # Ragged data and uniform mask.
        #=========================================================================
        dict(descr='data.shape=[4, (D2)]; mask.shape=[4]',
             data=ragged.constant_value([[1, 2, 3], [4], [], [5, 6]]),
             mask=[T, F, T, F],
             keepdims=False,
             expected=ragged.constant_value([[1, 2, 3], []])),
        dict(descr='data.shape=[4, (D2), (D3)]; mask.shape=[4]',
             data=ragged.constant_value([[[1, 2, 3]], [[4], []], [[5, 6]],
                                         []]),
             mask=[T, F, T, T],
             keepdims=False,
             expected=ragged.constant_value([[[1, 2, 3]], [[5, 6]], []])),
        dict(descr='data.shape=[4, (D2), 2]; mask.shape=[4]',
             data=ragged.constant_value(
                 [[[1, 2], [3, 4]], [], [[5, 6]], [[7, 8], [9, 0], [1, 2]]],
                 ragged_rank=1),
             mask=[T, F, F, T],
             keepdims=False,
             expected=ragged.constant_value(
                 [[[1, 2], [3, 4]], [[7, 8], [9, 0], [1, 2]]], ragged_rank=1)),
        dict(descr='data.shape=[4, (D2), 2]; mask.shape=[4]',
             data=ragged.constant_value(
                 [[[1, 2], [3, 4]], [], [[5, 6]], [[7, 8], [9, 0], [1, 2]]],
                 ragged_rank=1),
             mask=[T, F, F, T],
             keepdims=True,
             expected=ragged.constant_value(
                 [[[1, 2], [3, 4]], [[7, 8], [9, 0], [1, 2]]], ragged_rank=1)),
        dict(descr='data.shape=[1, (2)]; mask.shape=[1, 2]',
             data=ragged.constant_value([[1, 2]]),
             mask=[[T, F]],
             keepdims=True,
             expected=ragged.constant_value([[1]])),
        dict(descr='data.shape=[2, (2), (D3)]; mask.shape=[2, 2]',
             data=ragged.constant_value([[[1], [2, 3]], [[], [4, 5, 6]]]),
             mask=[[T, F], [T, T]],
             keepdims=True,
             expected=ragged.constant_value([[[1]], [[], [4, 5, 6]]])),
        dict(descr='data.shape=[2, (2), 3]; mask.shape=[2, 2]',
             data=ragged.constant_value(
                 [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [2, 4, 6]]],
                 ragged_rank=1),
             mask=[[T, F], [T, T]],
             keepdims=True,
             expected=ragged.constant_value(
                 [[[1, 2, 3]], [[7, 8, 9], [2, 4, 6]]], ragged_rank=1)),
        dict(descr='data.shape=[2, (2), 3]; mask.shape=[2, 2, 3]',
             data=ragged.constant_value(
                 [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [2, 4, 6]]],
                 ragged_rank=1),
             mask=[[[T, F, F], [T, F, T]], [[T, F, T], [F, F, F]]],
             keepdims=True,
             expected=ragged.constant_value([[[1], [4, 6]], [[7, 9], []]])),
    ])  # pyformat: disable
    @test_util.run_deprecated_v1
    def testBooleanMask(self, descr, data, mask, keepdims, expected):
        actual = ragged.boolean_mask(data, mask, keepdims=keepdims)
        self.assertEqual(getattr(actual, 'ragged_rank', 0),
                         getattr(expected, 'ragged_rank', 0))
        with self.test_session():
            if isinstance(expected, ragged.RaggedTensorValue):
                expected = expected.tolist()
            self.assertEqual(actual.eval().tolist(), expected)

    @test_util.run_deprecated_v1
    def testErrors(self):
        self.assertRaisesRegexp(ValueError,
                                r'mask\.shape\.ndims must be kown statically',
                                ragged.boolean_mask, [[1, 2]],
                                array_ops.placeholder(dtypes.bool))

        self.assertRaisesRegexp(TypeError,
                                "Expected bool, got 0 of type 'int' instead.",
                                ragged.boolean_mask, [[1, 2]], [[0, 1]])
        self.assertRaisesRegexp(
            ValueError, 'Tensor conversion requested dtype bool for '
            'RaggedTensor with dtype int32', ragged.boolean_mask,
            ragged.constant([[1, 2]]), ragged.constant([[0, 0]]))

        self.assertRaisesRegexp(
            ValueError, r'Shapes \(1, 2\) and \(1, 3\) are incompatible',
            ragged.boolean_mask, [[1, 2]], [[True, False, True]])

        # self.assertRaisesRegexp(ValueError,
        #                         r'data=.* is non-ragged but mask=.* is ragged',
        #                         ragged.boolean_mask, [[1, 2]],
        #                         ragged.constant([[True, False]]))

        # self.assertRaisesRegexp(
        #     ValueError, r'data=.* is ragged but mask=.* is non-ragged',
        #     ragged.boolean_mask, ragged.constant([[1, 2]]), [[True, False]])

        self.assertRaisesRegexp(errors.InvalidArgumentError,
                                r'Inputs must have identical ragged splits',
                                ragged.boolean_mask, ragged.constant([[1, 2]]),
                                ragged.constant([[True, False, True]]))

        self.assertRaisesRegexp(ValueError, 'mask cannot be scalar',
                                ragged.boolean_mask, [[1, 2]], True)

        self.assertRaisesRegexp(ValueError,
                                'mask cannot be scalar', ragged.boolean_mask,
                                ragged.constant([[1, 2]]), True)
コード例 #10
0
class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
                              parameterized.TestCase):
    @parameterized.parameters([
        #=========================================================================
        # Docstring Example
        #=========================================================================
        dict(descr='Docstring example',
             params=ragged.constant_value([['a', 'b', 'c'], ['d'], [], ['e']]),
             indices=ragged.constant_value([[1, 2, 0], [], [], [0, 0]]),
             expected=ragged.constant_value([[b'b', b'c', b'a'], [], [],
                                             [b'e', b'e']])),
        #=========================================================================
        # 0 Batch Dimensions
        #=========================================================================
        dict(descr='params: [P1], indices: [I], result: [I]',
             params=['a', 'b', 'c', 'd'],
             indices=[3, 2],
             expected=[b'd', b'c']),
        dict(descr='params: [P1, (P2)], indices: [I], result: [I, (P2)]',
             params=ragged.constant_value([['a', 'b'], [], ['c'], ['d', 'e']]),
             indices=[3, 2],
             expected=ragged.constant_value([[b'd', b'e'], [b'c']])),
        #=========================================================================
        # 1 Batch Dimension
        #=========================================================================
        dict(descr='params: [B1, P1], indices: [B1, I], result: [B1, I]',
             params=[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']],
             indices=[[2, 0], [0, 1], [1, 0]],
             expected=[[b'c', b'a'], [b'd', b'e'], [b'h', b'g']]),
        dict(descr='params: [B1, (P1)], indices: [B1, I], result: [B1, I]',
             params=ragged.constant_value([['a', 'b', 'c'], ['d', 'e'],
                                           ['g']]),
             indices=[[2, 0], [0, 1], [0, 0]],
             expected=[[b'c', b'a'], [b'd', b'e'], [b'g', b'g']]),
        dict(descr='params: [B1, P1], indices: [B1, (I)], result: [B1, (I)]',
             params=[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']],
             indices=ragged.constant_value([[2, 0, 2], [0], [1]]),
             expected=ragged.constant_value([[b'c', b'a', b'c'], [b'd'],
                                             [b'h']])),
        dict(descr=('params: [B1, (P1), (P2), P3], indices: [B1, I], '
                    'result: [B1, I, (P2), P3]'),
             params=ragged.constant_value(
                 [[[['a']], [['b'], ['c']]], [[['d'], ['e']], [['f']]],
                  [[['g']]]],
                 ragged_rank=2),
             indices=[[1, 0], [0, 1], [0, 0]],
             expected=ragged.constant_value(
                 [[[[b'b'], [b'c']], [[b'a']]], [[[b'd'], [b'e']], [[b'f']]],
                  [[[b'g']], [[b'g']]]],
                 ragged_rank=2)),
        #=========================================================================
        # 2 Batch Dimensions
        #=========================================================================
        dict(descr=('params: [B1, B2, P1], indices: [B1, B2, I], '
                    'result: [B1, B2, I]'),
             params=[[['a', 'b', 'c']], [['d', 'e', 'f']], [['g', 'h', 'i']]],
             indices=[[[2, 0]], [[0, 1]], [[1, 0]]],
             expected=[[[b'c', b'a']], [[b'd', b'e']], [[b'h', b'g']]]),
        dict(descr=('params: [B1, (B2), P1], indices: [B1, (B2), I], '
                    'result: [B1, (B2), I]'),
             params=ragged.constant_value(
                 [[['a', 'b', 'c'], ['d', 'e', 'f']], [['g', 'h', 'i']]],
                 ragged_rank=1),
             indices=ragged.constant_value([[[2, 0], [0, 1]], [[1, 0]]],
                                           ragged_rank=1),
             expected=ragged.constant_value(
                 [[[b'c', b'a'], [b'd', b'e']], [[b'h', b'g']]],
                 ragged_rank=1)),
        dict(descr=('params: [B1, (B2), (P1)], indices: [B1, (B2), I], '
                    'result: [B1, (B2), I]'),
             params=ragged.constant_value(
                 [[['a', 'b', 'c'], ['d']], [['e', 'f']]], ragged_rank=2),
             indices=ragged.constant_value([[[2, 0], [0, 0]], [[1, 0]]],
                                           ragged_rank=1),
             expected=ragged.constant_value(
                 [[[b'c', b'a'], [b'd', b'd']], [[b'f', b'e']]],
                 ragged_rank=1)),
        dict(descr=('params: [B1, (B2), P1], indices: [B1, (B2), (I)], '
                    'result: [B1, (B2), (I)]'),
             params=ragged.constant_value(
                 [[['a', 'b', 'c'], ['d', 'e', 'f']], [['g', 'h', 'i']]],
                 ragged_rank=1),
             indices=ragged.constant_value([[[2, 1, 0], [0]], [[1, 1]]],
                                           ragged_rank=2),
             expected=ragged.constant_value(
                 [[[b'c', b'b', b'a'], [b'd']], [[b'h', b'h']]],
                 ragged_rank=2)),
        #=========================================================================
        # 3 Batch Dimensions
        #=========================================================================
        dict(descr=(
            'params: [B1, (B2), (B3), (P1)], indices: [B1, (B2), (B3), I], '
            'result: [B1, (B2), (B3), I]'),
             params=ragged.constant_value(
                 [[[['a', 'b', 'c'], ['d']], [['e', 'f']]]], ragged_rank=3),
             indices=ragged.constant_value([[[[2, 0], [0, 0]], [[1, 0]]]],
                                           ragged_rank=2),
             expected=ragged.constant_value(
                 [[[[b'c', b'a'], [b'd', b'd']], [[b'f', b'e']]]],
                 ragged_rank=2)),
    ])
    def testRaggedBatchGather(self, descr, params, indices, expected):
        result = ragged.batch_gather(params, indices)
        self.assertRaggedEqual(result, expected)

    def testRaggedBatchGatherUnknownRankError(self):
        if context.executing_eagerly():
            return
        params = [['a', 'b'], ['c', 'd']]
        indices = array_ops.placeholder(dtypes.int32, shape=None)
        ragged_indices = ragged.RaggedTensor.from_row_splits(
            indices, [0, 2, 4])

        with self.assertRaisesRegexp(
                ValueError,
                'batch_gather does not allow indices with unknown shape.'):
            ragged.batch_gather(params, indices)

        with self.assertRaisesRegexp(
                ValueError,
                'batch_gather does not allow indices with unknown shape.'):
            ragged.batch_gather(params, ragged_indices)

    @parameterized.parameters([
        dict(params=ragged.constant_value([['a'], ['b'], ['c']]),
             indices=ragged.constant_value([[0], [0]]),
             message='Dimensions 3 and 2 are not compatible'),
        dict(params=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
             indices=ragged.constant_value([[[0, 0], [0, 0, 0]], [[0]]]),
             message='batch shape from indices does not match params shape'),
        dict(  # rank mismatch
            params=ragged.constant_value([[[0, 0], [0, 0, 0]], [[0]]]),
            indices=ragged.constant_value([[[0, 0]], [[0, 0, 0]], [[0]]]),
            error=(ValueError, errors.InvalidArgumentError)),
        dict(params=ragged.constant_value([[[0, 0], [0, 0, 0]], [[0]], [[0]]]),
             indices=ragged.constant_value([[[0, 0]], [[0, 0, 0]], [[0]]]),
             error=errors.InvalidArgumentError,
             message='.*Condition x == y did not hold.*'),
        dict(params=ragged.constant_value(['a', 'b', 'c']),
             indices=ragged.constant_value([[0], [0]]),
             message='batch shape from indices does not match params shape'),
        dict(params=ragged.constant_value([['a']]),
             indices=0,
             message='indices.rank must be at least 1.'),
        dict(params=ragged.constant_value([['a']]),
             indices=[[[0]]],
             message='batch shape from indices does not match params shape'),
    ])
    def testRaggedBatchGatherStaticError(self,
                                         params,
                                         indices,
                                         message=None,
                                         error=ValueError):
        with self.assertRaisesRegexp(error, message):
            ragged.batch_gather(params, indices)
コード例 #11
0
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
                               parameterized.TestCase):

  def assertSameShape(self, x, y):
    """Checks that x and y have the same shape (including ragged shapes)."""
    if isinstance(x, ragged.RaggedTensor):
      self.assertIsInstance(y, ragged.RaggedTensor)
      self.assertEqual(x.ragged_rank, y.ragged_rank)
      for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
        self.assertAllEqual(x_splits, y_splits)
      self.assertAllEqual(
          array_ops.shape(x.inner_values), array_ops.shape(y.inner_values))
    else:
      self.assertIsInstance(y, ops.Tensor)
      self.assertAllEqual(array_ops.shape(x), array_ops.shape(y))

  @parameterized.parameters(
      #=========================================================================
      # Test different input shapes.
      #=========================================================================
      [
          # 0-dimensional input
          {'x': 12},
          # 1-dimensional input
          {'x': [1, -2, 3]},
          # 2-dimensional input
          {'x': [[-2, 3], [-3, 4]]},
          {'x': ragged.constant_value([[-2, 3], [-3]], ragged_rank=1)},
          # 3-dimensional inputs
          {'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]},
          {'x': ragged.constant_value([[[-2, 3], [3, 4]], [[7, 6]]],
                                      ragged_rank=1)},
          {'x': ragged.constant_value([[[-2, 3, 4], []], [[7, 6]], []],
                                      ragged_rank=2)},
          ] +
      #=========================================================================
      # Test each unary op.
      #=========================================================================
      [{'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'op': op}
       for op in UNARY_FLOAT_OPS] +
      [{'x': ragged.constant_value([[True, False], [True]]), 'op': op}
       for op in UNARY_BOOL_OPS] +
      [{'x': ragged.constant_value([[18, 512], [12412]], np.int32), 'op': op}
       for op in UNARY_INT_OPS] +
      [{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op}
       for op in UNARY_STRING_OPS] +
      [
          {'op': ragged.clip_by_value,
           'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
           'clip_value_min': 0.1, 'clip_value_max': 4.0},
          {'op': ragged.cast,
           'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
           'dtype': dtypes.int32},
          {'op': ragged.saturate_cast,
           'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
           'dtype': dtypes.int32},
          {'op': ragged.string_to_hash_bucket,
           'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
           'num_buckets': 1000},
          {'op': ragged.string_to_hash_bucket_fast,
           'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
           'num_buckets': 1000},
          {'op': ragged.string_to_hash_bucket_strong,
           'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
           'num_buckets': 1000,
           'key': [1231, 12512]},
          {'op': ragged.string_to_number,
           'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])},
          {'op': ragged.regex_full_match,
           'x': ragged.constant_value([['hello', '123'], ['1+1']]),
           'pattern': r'\w+'},
          {'op': ragged.regex_replace,
           'x': ragged.constant_value([['hello', '123'], ['1+1']]),
           'pattern': r'\d',
           'rewrite': '#'},
          {'op': ragged.substr,
           'x': ragged.constant_value([['hello', '123'], ['1+1']]),
           'pos': 2, 'len': 3},
          {'op': ragged.check_numerics,
           'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
           'message': 'check-numerics'},
      ]
      )  # pyformat: disable
  def testUnaryOp(self, x, op=ragged.abs, **extra_args):
    x = ragged.convert_to_tensor_or_ragged_tensor(x)
    result = op(x, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
    expected_flat_values = array_ops.reshape(
        op.__wrapped__(dense_x, **extra_args), [-1])

    with self.test_session():
      # Check that the result has the expected shape.
      self.assertSameShape(x, result)

      # Check that the result has the expected (flattened) values.
      if isinstance(result, ragged.RaggedTensor):
        result_flat_values = array_ops.reshape(result.inner_values, [-1])
      else:
        result_flat_values = array_ops.reshape(result, [-1])
      self.assertAllEqual(expected_flat_values, result_flat_values)

  @parameterized.parameters(
      [
          #=====================================================================
          # Without broadcasting -- i.e., shapes match exactly.
          #=====================================================================
          # Shapes: x:(), y:()
          {'x': 12,
           'y': 8},
          # Shapes: x:(3,), y:(3,)
          {'x': [7, 8, 9],
           'y': [1, -2, 3]},
          # Shapes: x:(2, 2), y:(2, 2)
          {'x': [[-2, 3], [-3, -4]],
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(2, None), y:(2, None)
          {'x': ragged.constant_value([[-2, 3], [-3]]),
           'y': ragged.constant_value([[5, 6], [7]])},
          # Shapes: x:(2, 2, 2), y:(2, 2, 2)
          {'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
           'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]},
          # Shapes: x:(2, None, None), y: (2, None, None)
          {'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
           'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]])},
          # Shapes: x:(2, None, 2), y: (2, None, 2)
          {'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
                                      ragged_rank=1),
           'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
                                      ragged_rank=1)},

          #=====================================================================
          # With broadcasting
          #=====================================================================
          # Shapes: x:(), y:(3,)
          {'x': 12,                                 # Broadcast () -> (3,)
           'y': [1, -2, 3]},
          # Shapes: x:(1,), y:(3,)
          {'x': [12],                               # Broadcast (1,) -> (3,)
           'y': [1, -2, 3]},
          # Shapes: x:(), y:(2, 2)
          {'x': 12,                                 # Broadcast () -> (2, 2)
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(1,), y:(2, 2)
          {'x': 12,                                 # Broadcast (1,) -> (2, 2)
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(2, 1), y:(2, 2)
          {'x': [[10], [20]],                       # Broadcast (2, 1) -> (2, 2)
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(), y:(2, None)
          {'x': 10,                                 # Broadcast () -> (2, None)
           'y': ragged.constant_value([[1, 2], [3]], dtype=np.int32)},
          # TODO(edloper): Add tests for more advanced broadcasting, once we add
          # support for it.

          #=====================================================================
          # Keyword Args
          #=====================================================================
          {'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
           'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
           'use_kwargs': True},
          {'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
                                      ragged_rank=1),
           'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
                                      ragged_rank=1),
           'use_kwargs': True},
      ] +
      #=========================================================================
      # Test each unary op.
      #=========================================================================
      [{'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
        'y': ragged.constant_value([[5.0, 1.0], [12.0]]),
        'op': op}
       for op in BINARY_FLOAT_OPS] +
      [{'x': ragged.constant_value([[-2, 3], [-3]]),
        'y': ragged.constant_value([[5, 1], [12]]),
        'op': op}
       for op in BINARY_INT_OPS] +
      [{'x': ragged.constant_value([[True, True], [False]]),
        'y': ragged.constant_value([[False, True], [False]]),
        'op': op}
       for op in BINARY_BOOL_OPS] +
      [
      ]
      )  # pyformat: disable
  def testBinaryOp(self, x, y, op=ragged.add, **extra_args):
    use_kwargs = extra_args.pop('use_kwargs', False)
    x = ragged.convert_to_tensor_or_ragged_tensor(x)
    y = ragged.convert_to_tensor_or_ragged_tensor(y)
    if use_kwargs:
      result = op(x=x, y=y, **extra_args)
    else:
      result = op(x, y, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
    dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y
    expected_flat_values = array_ops.reshape(
        op.__wrapped__(dense_x, dense_y, **extra_args), [-1])

    with self.test_session():
      # Check that the result has the expected shape.
      self.assertSameShape(y, result)

      # Check that the result has the expected (flattened) values.
      if isinstance(result, ragged.RaggedTensor):
        result_flat_values = array_ops.reshape(result.inner_values, [-1])
      else:
        result_flat_values = array_ops.reshape(result, [-1])
      self.assertAllEqual(expected_flat_values, result_flat_values)

  @parameterized.parameters(
      [
          {'inputs': (12, 8, 3)},
          {'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])},
          {'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])},
          {'inputs': (ragged.constant_value([[1, 3], [-3]]),
                      ragged.constant_value([[4, 7], [88]]),
                      ragged.constant_value([[2, 9], [12]]))},
          {'inputs': (ragged.constant_value([[[1, 3], [-3]], [[1]]]),
                      ragged.constant_value([[[4, 7], [88]], [[2]]]),
                      ragged.constant_value([[[2, 9], [12]], [[8]]]))},
          {'inputs': (ragged.constant_value([[[1, 3], [3, 4]], [[1, 5]]],
                                            ragged_rank=1),
                      ragged.constant_value([[[4, 7], [1, 2]], [[2, 2]]],
                                            ragged_rank=1),
                      ragged.constant_value([[[2, 9], [5, 2]], [[8, 0]]],
                                            ragged_rank=1))},
          {'inputs': (ragged.constant_value([[[1, 3], [-3]], [[1]]]),
                      ragged.constant_value([[[4, 7], [88]], [[2]]]),
                      ragged.constant_value([[[2, 9], [12]], [[8]]])),
           'use_kwargs': True},
      ] + [
          {'op': ragged.add_n,
           'inputs': (ragged.constant_value([[1, 3], [-3]]),
                      ragged.constant_value([[4, 7], [88]]),
                      ragged.constant_value([[2, 9], [12]]))},
          {'op': ragged.string_join,
           'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
                      ragged.constant_value([['foo', 'bar'], ['baz']]),
                      ragged.constant_value([['2', '9'], ['12']]))},
      ])  # pyformat: disable
  def testListValuedOp(self, inputs, op=ragged.add_n, **extra_args):
    use_kwargs = extra_args.pop('use_kwargs', False)
    inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
    if use_kwargs:
      result = op(inputs=inputs, **extra_args)
    else:
      result = op(inputs, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_inputs = [
        x.inner_values if isinstance(x, ragged.RaggedTensor) else x
        for x in inputs
    ]
    expected_flat_values = array_ops.reshape(
        op.__wrapped__(dense_inputs, **extra_args), [-1])

    with self.test_session():
      # Check that the result has the expected shape.
      self.assertSameShape(inputs[0], result)

      # Check that the result has the expected (flattened) values.
      if isinstance(result, ragged.RaggedTensor):
        result_flat_values = array_ops.reshape(result.inner_values, [-1])
      else:
        result_flat_values = array_ops.reshape(result, [-1])
      self.assertAllEqual(expected_flat_values, result_flat_values)

  def testUnknownRankError(self):
    x = ragged.constant([[1, 2], [3]])
    y = ragged.from_row_splits(
        array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
    with self.assertRaisesRegexp(
        ValueError, r'Ragged elementwise ops require that rank \(number '
        r'of dimensions\) be statically known.'):
      ragged.add(x, y)

  def testBroadcastError1(self):
    x = ragged.constant([[1, 2], [3]])
    y = [[12]]
    with self.assertRaisesRegexp(
        ValueError, 'Ragged elementwise ops do not support broadcasting yet'):
      ragged.add(x, y)

  def testBroadcastError2(self):
    x = ragged.constant([[[1, 2], [3, 4]], [[5]]], ragged_rank=2)
    y = ragged.constant([[[8], [3]], [[2]]], ragged_rank=1)
    with self.assertRaisesRegexp(ValueError,
                                 'Inputs must have identical ragged splits'):
      ragged.add(x, y)

  def testBroadcastError3(self):
    x = ragged.constant([[[1, 2], [3]], [[4, 5], [6]]], ragged_rank=2)
    y = ragged.constant([[7, 8], [9]], ragged_rank=1)
    with self.assertRaisesRegexp(
        ValueError, 'Ragged elementwise ops do not support broadcasting yet'):
      ragged.add(x, y)

  def testBroadcastError4(self):
    x = ragged.constant([[[1]]])
    y = ragged.constant([[1]])
    with self.assertRaisesRegexp(
        ValueError, 'Ragged elementwise ops do not support broadcasting yet'):
      ragged.add(x, y)

  def testShapeMismatch(self):
    x = ragged.constant([[1, 2, 3], [4, 5]])
    y = ragged.constant([[1, 2, 3], [4, 5, 6]])
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 'Inputs must have identical ragged splits'):
      ragged.add(x, y)

  def testDocstring(self):
    self.assertRegexpMatches(
        ragged.add.__doc__,
        'Ragged version of the elementwise operation `tf.math.add`')
    self.assertEqual(ragged.add.__name__, 'add')
コード例 #12
0
class RaggedTensorBoundingShapeOp(ragged_test_util.RaggedTensorTestCase,
                                  parameterized.TestCase):

  def assertShapeEq(self, x, y):
    assert isinstance(x, ragged.RaggedTensorDynamicShape)
    assert isinstance(y, ragged.RaggedTensorDynamicShape)
    x_partitioned_dim_sizes = [
        self.eval_to_list(splits)  #
        for splits in x.partitioned_dim_sizes
    ]
    y_partitioned_dim_sizes = [
        self.eval_to_list(splits)  #
        for splits in y.partitioned_dim_sizes
    ]
    self.assertEqual(x_partitioned_dim_sizes, y_partitioned_dim_sizes)
    self.assertAllEqual(x.inner_dim_sizes, y.inner_dim_sizes)

  @parameterized.parameters([
      dict(value='x', expected_dim_sizes=[]),
      dict(value=['a', 'b', 'c'], expected_dim_sizes=[3]),
      dict(value=[['a', 'b', 'c'], ['d', 'e', 'f']], expected_dim_sizes=[2, 3]),
      dict(
          value=[[['a', 'b', 'c'], ['d', 'e', 'f']]],
          expected_dim_sizes=[1, 2, 3]),
      dict(
          value=ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]),
          expected_dim_sizes=[2, [3, 2]]),
      dict(
          value=ragged.constant_value([[['a', 'b', 'c'], ['d', 'e']]]),
          expected_dim_sizes=[1, [2], [3, 2]]),
      dict(
          value=ragged.constant_value([[['a', 'b', 'c'], ['d', 'e', 'f']]],
                                      ragged_rank=1),
          expected_dim_sizes=[1, [2], 3]),
      dict(
          value=ragged.constant_value([[[[1], [2]], [[3], [4]]],
                                       [[[5], [6]]]], ragged_rank=1),
          expected_dim_sizes=[2, [2, 1], 2, 1]),
      dict(
          value=ragged.constant_value([[10, 20], [30]]),
          expected_dim_sizes=[2, [2, 1]]),
      # Docstring examples:
      dict(value=[[1, 2, 3], [4, 5, 6]], expected_dim_sizes=[2, 3]),
      dict(
          value=ragged.constant_value([[1, 2], [], [3, 4, 5]]),
          expected_dim_sizes=[3, [2, 0, 3]]),
      dict(
          value=ragged.constant_value([[[1, 2], [3, 4]], [[5, 6]]],
                                      ragged_rank=1),
          expected_dim_sizes=[2, [2, 1], 2]),
      dict(
          value=ragged.constant_value([[[1, 2], [3]], [[4, 5]]]),
          expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
  ])
  def testFromTensor(self, value, expected_dim_sizes):
    shape = ragged.RaggedTensorDynamicShape.from_tensor(value)
    expected = ragged.RaggedTensorDynamicShape.from_dim_sizes(
        expected_dim_sizes)
    self.assertShapeEq(shape, expected)

  @parameterized.parameters([
      dict(dim_sizes=[], rank=0, expected_dim_sizes=[]),
      dict(dim_sizes=[], rank=3, expected_dim_sizes=[1, 1, 1]),
      dict(dim_sizes=[3], rank=1, expected_dim_sizes=[3]),
      dict(dim_sizes=[3], rank=3, expected_dim_sizes=[1, 1, 3]),
      dict(dim_sizes=[2, 3], rank=3, expected_dim_sizes=[1, 2, 3]),
      dict(dim_sizes=[3, [3, 2, 4]], rank=2, expected_dim_sizes=[3, [3, 2, 4]]),
      dict(
          dim_sizes=[3, [3, 2, 4]],
          rank=4,
          expected_dim_sizes=[1, 1, 3, [3, 2, 4]]),
      dict(
          dim_sizes=[3, [3, 2, 4], 2, 3],
          rank=5,
          expected_dim_sizes=[1, 3, [3, 2, 4], 2, 3]),
  ])
  def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes):
    shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
    expected = ragged.RaggedTensorDynamicShape.from_dim_sizes(
        expected_dim_sizes)
    broadcasted_shape = shape.broadcast_to_rank(rank)
    self.assertShapeEq(broadcasted_shape, expected)
    self.assertEqual(broadcasted_shape.rank, rank)

  @parameterized.parameters([
      #=========================================================================
      # dimension[axis] is uniform inner; and row_lengths is a scalar
      #=========================================================================
      # shape: [BROADCAST(UNIFORM), UNIFORM, UNIFORM]
      dict(axis=0,
           row_length=3,
           original_dim_sizes=[1, 4, 5],
           broadcast_dim_sizes=[3, 4, 5]),

      # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
      dict(axis=2,
           row_length=5,
           original_dim_sizes=[3, 4, 1],
           broadcast_dim_sizes=[3, 4, 5]),

      # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
      dict(axis=2,
           row_length=5,
           original_dim_sizes=[3, [3, 2, 8], 1],
           broadcast_dim_sizes=[3, [3, 2, 8], 5]),

      # shape: [UNIFORM, RAGGED, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
      dict(axis=5,
           row_length=5,
           original_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 1],
           broadcast_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 5]),

      #=========================================================================
      # dimension[axis] is uniform inner; and row_lengths is a vector
      #=========================================================================
      # shape: [UNIFORM, BROADCAST(UNIFORM)]
      dict(axis=1,
           row_length=[2, 0, 1],
           original_dim_sizes=[3, 1],
           broadcast_dim_sizes=[3, [2, 0, 1]]),
      # shape: [UNIFORM, BROADCAST(UNIFORM), UNIFORM]
      dict(axis=1,
           row_length=[2, 0, 1],
           original_dim_sizes=[3, 1, 5],
           broadcast_dim_sizes=[3, [2, 0, 1], 5]),

      # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
      dict(axis=2,
           row_length=[2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0],
           original_dim_sizes=[4, 3, 1],
           broadcast_dim_sizes=[4, 3, [2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0]]),

      # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
      dict(axis=2,
           row_length=[2, 5, 3],
           original_dim_sizes=[2, [2, 1], 1],
           broadcast_dim_sizes=[2, [2, 1], [2, 5, 3]]),

      # shape: [UNIFORM, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM), UNIFORM]
      dict(axis=4,
           row_length=list(range(18)),
           original_dim_sizes=[2, [2, 1], 3, 2, 1, 8],
           broadcast_dim_sizes=[2, [2, 1], 3, 2, list(range(18)), 8]),

      #=========================================================================
      # dimension[axis] is uniform partitioned; and row_lengths is a scalar
      #=========================================================================
      # shape: [BROADCAST(UNIFORM), RAGGED]
      dict(axis=0,
           row_length=3,
           original_dim_sizes=[1, [5]],
           broadcast_dim_sizes=[3, [5, 5, 5]]),

      # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED]
      dict(axis=0,
           row_length=2,
           original_dim_sizes=[1, 3, [3, 0, 2]],
           broadcast_dim_sizes=[2, 3, [3, 0, 2, 3, 0, 2]]),

      # shape: [BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM, UNIFORM]
      dict(axis=0,
           row_length=3,
           original_dim_sizes=[1, [3], [3, 5, 2], 9, 4, 5],
           broadcast_dim_sizes=[3, [3, 3, 3], [3, 5, 2, 3, 5, 2, 3, 5, 2],
                                9, 4, 5]),

      # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED, UNIFORM]
      dict(axis=0,
           row_length=2,
           original_dim_sizes=[1, 2, [2, 1], [3, 5, 2], 2],
           broadcast_dim_sizes=[2, 2, [2, 1, 2, 1], [3, 5, 2, 3, 5, 2], 2]),

      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
      dict(axis=1,
           row_length=2,
           original_dim_sizes=[3, 1, [4, 0, 2], 5],
           broadcast_dim_sizes=[3, 2, [4, 0, 2, 4, 0, 2], 5]),

      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED]
      dict(axis=1,
           row_length=1,
           original_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)],
           broadcast_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)]),

      #=========================================================================
      # dimension[axis] is uniform partitioned; and row_lengths is a vector
      #=========================================================================
      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
      dict(axis=1,
           row_length=[4, 1, 2],
           original_dim_sizes=[
               3,                          # axis=0
               1,                          # axis=1 (broadcast)
               [3, 1, 2],                  # axis=2
               5],                         # axis=3
           broadcast_dim_sizes=[
               3,                          # axis=0
               [4, 1, 2],                  # axis=1 (broadcast)
               [3, 3, 3, 3, 1, 2, 2],      # axis=2
               5]),                        # axis=3

      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, RAGGED]
      dict(axis=1,
           row_length=[2, 0, 3],
           original_dim_sizes=[
               3,                                         # axis=0
               1,                                         # axis=1 (broadcast)
               [3, 1, 2],                                 # axis=2
               [3, 1, 4, 1, 5, 9]],                       # axis=3
           broadcast_dim_sizes=[
               3,                                         # axis=0
               [2, 0, 3],                                 # axis=1 (broadcast)
               [3, 3, 2, 2, 2],                           # axis=2
               [3, 1, 4, 3, 1, 4, 5, 9, 5, 9, 5, 9]]),    # axis=3

      # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM]
      dict(axis=2,
           row_length=[4, 1, 2],
           original_dim_sizes=[
               3,                                         # axis=0
               [2, 0, 1],                                 # axis=1
               1,                                         # axis=2 (broadcast)
               [3, 2, 1],                                 # axis=3
               [1, 0, 1, 0, 2, 3],                        # axis=4
               5],                                        # axis=5
           broadcast_dim_sizes=[
               3,                                         # axis=0
               [2, 0, 1],                                 # axis=2
               [4, 1, 2],                                 # axis=2 (broadcast)
               [3, 3, 3, 3, 2, 1, 1],                     # axis=3
               [1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0,    # axis=4
                2, 3, 3],
               5]),                                       # axis=5

      dict(axis=0,
           row_length=2,
           original_dim_sizes=[1, 1, 2, (2, 1)],
           broadcast_dim_sizes=[2, 1, 2, (2, 1, 2, 1)]),
      dict(axis=1,
           row_length=(2, 1),
           original_dim_sizes=[2, 1, 2, (2, 1, 2, 1)],
           broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
      dict(axis=2,
           row_length=2,
           original_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)],
           broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
      dict(axis=3,
           row_length=(2, 1, 2, 1, 2, 1),
           original_dim_sizes=[2, (2, 1), 2, 1],
           broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
  ])  # pyformat: disable
  def testBroadcastDimension(self, axis, row_length, original_dim_sizes,
                             broadcast_dim_sizes):
    """Tests for the broadcast_dimension method.

    Verifies that:

    * `original.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, 1) == broadcast`

    Args:
      axis: The axis to broadcast
      row_length: The slice lengths to broadcast to.
      original_dim_sizes: The dimension sizes before broadcasting.
        original_dim_sizes[axis] should be equal to `1` or `row_length`.
      broadcast_dim_sizes: THe dimension sizes after broadcasting.
    """
    original_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(
        original_dim_sizes)
    broadcast_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(
        broadcast_dim_sizes)
    self.assertEqual(original_shape.rank, broadcast_shape.rank)
    # shape[axis].value == 1 and row_length > 1:
    bcast1 = original_shape.broadcast_dimension(axis, row_length)
    # shape[axis].value > 1 and row_length == shape[axis].value:
    bcast2 = broadcast_shape.broadcast_dimension(axis, row_length)
    # shape[axis].value > 1 and row_length == 1:
    bcast3 = broadcast_shape.broadcast_dimension(axis, 1)

    self.assertShapeEq(bcast1, broadcast_shape)
    self.assertShapeEq(bcast2, broadcast_shape)
    self.assertShapeEq(bcast3, broadcast_shape)

  @parameterized.parameters(
      [
          # Broadcast scalar
          dict(x_dims=[], y_dims=[], expected_dims=[]),
          dict(x_dims=[], y_dims=[2], expected_dims=[2]),
          dict(x_dims=[], y_dims=[2, 3], expected_dims=[2, 3]),
          dict(
              x_dims=[],
              y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
              expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
          # Broadcast vector
          dict(x_dims=[3], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
          dict(x_dims=[1], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
          dict(x_dims=[3], y_dims=[4, 2, 1], expected_dims=[4, 2, 3]),
          dict(
              x_dims=[3],
              y_dims=[3, (2, 3, 1), 1],
              expected_dims=[3, (2, 3, 1), 3]),
          dict(x_dims=[1], y_dims=[3, (2, 1, 3)], expected_dims=[3, (2, 1, 3)]),
          dict(
              x_dims=[1],
              y_dims=[3, (2, 1, 3), 8],
              expected_dims=[3, (2, 1, 3), 8]),
          dict(
              x_dims=[1],
              y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
              expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
          # Mixed broadcasting
          dict(
              x_dims=[
                  1,  # axis=0
                  3,  # axis=1
                  (3, 0, 2),  # axis=2
                  1,  # axis=3
                  2,  # axis=4
              ],
              y_dims=[
                  2,  # axis=0
                  1,  # axis=1
                  1,  # axis=2
                  (7, 2),  # axis=3
                  1,  # axis=4
              ],
              expected_dims=[
                  2,  # axis=0
                  3,  # axis=1
                  (3, 0, 2, 3, 0, 2),  # axis=2
                  (7, 7, 7, 7, 7, 2, 2, 2, 2, 2),  # axis=3
                  2,  # axis=4
              ]),
          dict(
              x_dims=[2, (2, 1), 2, 1],
              y_dims=[1, 1, 2, (2, 1)],
              expected_dims=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
      ])
  def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
    x_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(x_dims)
    y_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(y_dims)
    expected = ragged.RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
    result1 = ragged.broadcast_dynamic_shape(x_shape, y_shape)
    result2 = ragged.broadcast_dynamic_shape(y_shape, x_shape)
    self.assertShapeEq(expected, result1)
    self.assertShapeEq(expected, result2)

  def testRepr(self):
    shape = ragged.RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1])
    self.assertRegexpMatches(
        repr(shape),
        r'RaggedTensorDynamicShape\('
        r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
        r'inner_dim_sizes=<[^>]+>\)')

  @parameterized.parameters([
      dict(
          x=[[10], [20], [30]],  # shape=[3, 1]
          dim_sizes=[3, 2],
          expected=[[10, 10], [20, 20], [30, 30]]),
      dict(
          x=[[10], [20], [30]],  # shape=[3, 1]
          dim_sizes=[3, [3, 0, 2]],
          expected=ragged.constant_value([[10, 10, 10], [], [30, 30]],
                                         dtype=np.int32)),
      dict(
          x=[[[1, 2, 3]], [[4, 5, 6]]],  # shape = [2, 1, 3]
          dim_sizes=[2, [2, 3], 3],
          expected=ragged.constant_value(
              [[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
              dtype=np.int32,
              ragged_rank=1)),
      dict(
          x=[[[1]], [[2]]],  # shape = [2, 1, 1]
          dim_sizes=[2, [2, 3], [0, 2, 1, 2, 0]],
          expected=ragged.constant_value([[[], [1, 1]], [[2], [2, 2], []]],
                                         dtype=np.int32,
                                         ragged_rank=2)),
      dict(
          x=10,
          dim_sizes=[3, [3, 0, 2]],
          expected=ragged.constant_value([[10, 10, 10], [], [10, 10]])),
  ])
  def testRaggedBroadcastTo(self, x, dim_sizes, expected):
    shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
    result = ragged.broadcast_to(x, shape)
    self.assertEqual(
        getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
    self.assertRaggedEqual(result, expected)

  @parameterized.parameters([
      dict(
          doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
          x=ragged.constant_value([[1, 2, 3], [], [4, 5]], dtype=np.int32),
          y=[[10], [20], [30]],
          expected=ragged.constant_value([[11, 12, 13], [], [34, 35]])),
      dict(
          doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]',
          x=ragged.constant_value([[1, 2, 3], [], [4, 5]], dtype=np.int32),
          y=10,
          expected=ragged.constant_value([[11, 12, 13], [], [14, 15]])),
      dict(
          doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
          x=ragged.constant_value([[1, 2, 3]], dtype=np.int32),
          y=[[10], [20], [30]],
          expected=ragged.constant_value(
              [[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)),
      dict(
          doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; '
               'bcast.shape=[2, (D1), (D2)]'),
          x=ragged.constant_value([[[1], [2], [3]], [[4]]], ragged_rank=1),
          y=ragged.constant_value([[10, 20, 30]]),
          expected=ragged.constant_value([[[11, 21, 31], [12, 22, 32],
                                           [13, 23, 33]], [[14, 24, 34]]])),
      dict(
          doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; '
               'bcast.shape=[2, (D1), 4]'),
          x=ragged.constant_value([[[10], [20]], [[30]]], ragged_rank=1),
          y=[[[1, 2, 3, 4]]],
          expected=ragged.constant_value(
              [[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]],
              ragged_rank=1)),
      dict(
          doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; '
               'bcast.shape=[2, (D1), (2), (D2)'),
          x=ragged.constant_value([[[[1], [2]], [[3], [4]]],
                                   [[[5], [6]]]],
                                  ragged_rank=1),
          y=ragged.constant_value([[10, 20], [30]]),
          expected=ragged.constant_value(
              [[[[11, 21], [32]], [[13, 23], [34]]],
               [[[15, 25], [36]]]])),
  ])
  def testRaggedAddWithBroadcasting(self, x, y, expected, doc):
    expected_rrank = getattr(expected, 'ragged_rank', 0)
    x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
    y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
    result = x + y
    result_rrank = getattr(result, 'ragged_rank', 0)
    self.assertEqual(expected_rrank, result_rrank)
    if hasattr(expected, 'tolist'):
      expected = expected.tolist()
    self.assertRaggedEqual(result, expected)
コード例 #13
0
class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
                              parameterized.TestCase):

  @parameterized.parameters([
      #=========================================================================
      # Docstring Example
      #=========================================================================
      dict(
          descr='Docstring example',
          params=ragged.constant_value([['a', 'b', 'c'], ['d'], [], ['e']]),
          indices=ragged.constant_value([[1, 2, 0], [], [], [0, 0]]),
          expected=ragged.constant_value([[b'b', b'c', b'a'], [], [],
                                          [b'e', b'e']])),
      #=========================================================================
      # 0 Batch Dimensions
      #=========================================================================
      dict(
          descr='params: [P1], indices: [I], result: [I]',
          params=['a', 'b', 'c', 'd'],
          indices=[3, 2],
          expected=[b'd', b'c']),
      dict(
          descr='params: [P1, (P2)], indices: [I], result: [I, (P2)]',
          params=ragged.constant_value([['a', 'b'], [], ['c'], ['d', 'e']]),
          indices=[3, 2],
          expected=ragged.constant_value([[b'd', b'e'], [b'c']])),
      #=========================================================================
      # 1 Batch Dimension
      #=========================================================================
      dict(
          descr='params: [B1, P1], indices: [B1, I], result: [B1, I]',
          params=[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']],
          indices=[[2, 0], [0, 1], [1, 0]],
          expected=[[b'c', b'a'], [b'd', b'e'], [b'h', b'g']]),
      dict(
          descr='params: [B1, (P1)], indices: [B1, I], result: [B1, I]',
          params=ragged.constant_value([['a', 'b', 'c'], ['d', 'e'], ['g']]),
          indices=[[2, 0], [0, 1], [0, 0]],
          expected=[[b'c', b'a'], [b'd', b'e'], [b'g', b'g']]),
      dict(
          descr='params: [B1, P1], indices: [B1, (I)], result: [B1, (I)]',
          params=[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']],
          indices=ragged.constant_value([[2, 0, 2], [0], [1]]),
          expected=ragged.constant_value([[b'c', b'a', b'c'], [b'd'], [b'h']])),
      dict(
          descr=('params: [B1, (P1), (P2), P3], indices: [B1, I], '
                 'result: [B1, I, (P2), P3]'),
          params=ragged.constant_value(
              [[[['a']], [['b'], ['c']]], [[['d'], ['e']], [['f']]], [[['g']]]],
              ragged_rank=2),
          indices=[[1, 0], [0, 1], [0, 0]],
          expected=ragged.constant_value(
              [[[[b'b'], [b'c']], [[b'a']]], [[[b'd'], [b'e']], [[b'f']]],
               [[[b'g']], [[b'g']]]],
              ragged_rank=2)),
      #=========================================================================
      # 2 Batch Dimensions
      #=========================================================================
      dict(
          descr=('params: [B1, B2, P1], indices: [B1, B2, I], '
                 'result: [B1, B2, I]'),
          params=[[['a', 'b', 'c']], [['d', 'e', 'f']], [['g', 'h', 'i']]],
          indices=[[[2, 0]], [[0, 1]], [[1, 0]]],
          expected=[[[b'c', b'a']], [[b'd', b'e']], [[b'h', b'g']]]),
      dict(
          descr=('params: [B1, (B2), P1], indices: [B1, (B2), I], '
                 'result: [B1, (B2), I]'),
          params=ragged.constant_value(
              [[['a', 'b', 'c'], ['d', 'e', 'f']], [['g', 'h', 'i']]],
              ragged_rank=1),
          indices=ragged.constant_value([[[2, 0], [0, 1]], [[1, 0]]],
                                        ragged_rank=1),
          expected=ragged.constant_value(
              [[[b'c', b'a'], [b'd', b'e']], [[b'h', b'g']]], ragged_rank=1)),
      dict(
          descr=('params: [B1, (B2), (P1)], indices: [B1, (B2), I], '
                 'result: [B1, (B2), I]'),
          params=ragged.constant_value([[['a', 'b', 'c'], ['d']], [['e', 'f']]],
                                       ragged_rank=2),
          indices=ragged.constant_value([[[2, 0], [0, 0]], [[1, 0]]],
                                        ragged_rank=1),
          expected=ragged.constant_value(
              [[[b'c', b'a'], [b'd', b'd']], [[b'f', b'e']]], ragged_rank=1)),
      dict(
          descr=('params: [B1, (B2), P1], indices: [B1, (B2), (I)], '
                 'result: [B1, (B2), (I)]'),
          params=ragged.constant_value(
              [[['a', 'b', 'c'], ['d', 'e', 'f']], [['g', 'h', 'i']]],
              ragged_rank=1),
          indices=ragged.constant_value([[[2, 1, 0], [0]], [[1, 1]]],
                                        ragged_rank=2),
          expected=ragged.constant_value(
              [[[b'c', b'b', b'a'], [b'd']], [[b'h', b'h']]], ragged_rank=2)),
      #=========================================================================
      # 3 Batch Dimensions
      #=========================================================================
      dict(
          descr=(
              'params: [B1, (B2), (B3), (P1)], indices: [B1, (B2), (B3), I], '
              'result: [B1, (B2), (B3), I]'),
          params=ragged.constant_value(
              [[[['a', 'b', 'c'], ['d']], [['e', 'f']]]], ragged_rank=3),
          indices=ragged.constant_value([[[[2, 0], [0, 0]], [[1, 0]]]],
                                        ragged_rank=2),
          expected=ragged.constant_value(
              [[[[b'c', b'a'], [b'd', b'd']], [[b'f', b'e']]]], ragged_rank=2)),
  ])
  @test_util.run_deprecated_v1
  def testRaggedBatchGather(self, descr, params, indices, expected):
    result = ragged.batch_gather(params, indices)
    self.assertEqual(
        getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
    with self.test_session():
      if hasattr(expected, 'tolist'):
        expected = expected.tolist()
      self.assertEqual(result.eval().tolist(), expected)

  @test_util.run_deprecated_v1
  def testRaggedBatchGatherUnknownRankError(self):
    params = [['a', 'b'], ['c', 'd']]
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    ragged_indices = ragged.from_row_splits(indices, [0, 2, 4])

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, indices)

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, ragged_indices)

  @parameterized.parameters([
      dict(
          params=ragged.constant([['a'], ['b'], ['c']]),
          indices=ragged.constant([[0], [0]]),
          message='Dimensions 3 and 2 are not compatible'),
      dict(
          params=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
          indices=ragged.constant([[[0, 0], [0, 0, 0]], [[0]]]),
          message='batch shape from indices does not match params shape'),
      dict(
          params=ragged.constant([[[0, 0], [0, 0, 0]], [[0]]]),
          indices=ragged.constant([[[0, 0]], [[0, 0, 0]], [[0]]]),
          message='Dimensions must be equal, but are 3 and 4'),
      dict(
          params=ragged.constant([[[0, 0], [0, 0, 0]], [[0]], [[0]]]),
          indices=ragged.constant([[[0, 0]], [[0, 0, 0]], [[0]]]),
          error=errors.InvalidArgumentError,
          message='Condition x == y did not hold element-wise'),
      dict(
          params=ragged.constant(['a', 'b', 'c']),
          indices=ragged.constant([[0], [0]]),
          message='batch shape from indices does not match params shape'),
      dict(params=ragged.constant_value([['a']]),
           indices=0,
           message='indices.rank must be at least 1.'),
      dict(params=ragged.constant_value([['a']]),
           indices=[[[0]]],
           message='batch shape from indices does not match params shape'),
  ])
  @test_util.run_deprecated_v1
  def testRaggedBatchGatherStaticError(self,
                                       params,
                                       indices,
                                       message,
                                       error=ValueError):
    with self.assertRaisesRegexp(error, message):
      ragged.batch_gather(params, indices)
コード例 #14
0
class RaggedGatherNdOpTest(ragged_test_util.RaggedTensorTestCase,
                           parameterized.TestCase):

  DOCSTRING_PARAMS = [[['000', '001'], ['010']],
                      [['100'], ['110', '111', '112'], ['120']],
                      [[], ['210']]]  # pyformat: disable

  @parameterized.parameters([
      #=========================================================================
      # Docstring Examples
      #=========================================================================
      dict(
          descr='Docstring example 1',
          params=ragged.constant_value(DOCSTRING_PARAMS),
          indices=[[2], [0]],
          expected=ragged.constant_value([[[], [b'210']],
                                          [[b'000', b'001'], [b'010']]])),
      dict(
          descr='Docstring example 2',
          params=ragged.constant_value(DOCSTRING_PARAMS),
          indices=[[2, 1], [0, 0]],
          expected=ragged.constant_value([[b'210'], [b'000', b'001']])),
      dict(
          descr='Docstring example 3',
          params=ragged.constant_value(DOCSTRING_PARAMS),
          indices=[[0, 0, 1], [1, 1, 2]],
          expected=[b'001', b'112']),
      #=========================================================================
      # Indices with 0 values (selects the entire params)
      #=========================================================================
      dict(
          descr='params: [B1, (B2)], indices: [0], result: [B1, (B2)]',
          params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
          indices=np.zeros([0], dtype=np.int32),
          expected=ragged.constant_value([[b'a', b'b', b'c'], [b'd']])),
      dict(
          descr='params: [B1, (B2)], indices: [A1, 0], result: [A1, B1, (B2)]',
          params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
          indices=np.zeros([3, 0], dtype=np.int32),
          expected=ragged.constant_value([[[b'a', b'b', b'c'], [b'd']],
                                          [[b'a', b'b', b'c'], [b'd']],
                                          [[b'a', b'b', b'c'], [b'd']]])),
      dict(
          descr=('params: [B1, (B2)], indices: [A1, A2, 0], '
                 'result: [A1, A2, B1, (B2)]'),
          params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
          indices=np.zeros([1, 3, 0], dtype=np.int32),
          expected=ragged.constant_value([[[[b'a', b'b', b'c'], [b'd']],
                                           [[b'a', b'b', b'c'], [b'd']],
                                           [[b'a', b'b', b'c'], [b'd']]]])),
      dict(
          descr='params: [B1], indices: [A1, (A2), 0], result: [A1, (A2), B1]',
          params=['a'],
          indices=ragged.constant_value([[[], []], [[]]],
                                        ragged_rank=1,
                                        dtype=np.int32),
          expected=ragged.constant_value([[[b'a'], [b'a']], [[b'a']]],
                                         ragged_rank=1)),
      #=========================================================================
      # Indices with 1 value (selects row from params)
      #=========================================================================
      dict(
          descr='params: [B1, (B2)], indices: [A1, 1], result: [A1, (B2)]',
          params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
          indices=[[1], [0]],
          expected=ragged.constant_value([[b'd'], [b'a', b'b', b'c']])),
      dict(
          descr=('params: [B1, (B2), (B3)], indices: [A1, 1], '
                 'result: [A1, (B2), (B3)]'),
          params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                        [['e', 'f']]]),
          indices=[[1], [1]],
          expected=ragged.constant_value([[[b'e', b'f']], [[b'e', b'f']]])),
      dict(
          descr=('params: [B1, B2, B3], indices: [A1, (A2), 1], '
                 'result: [A1, (A2), B2, B3]'),
          params=[[['a']], [['b']]],
          indices=ragged.constant_value([[[0]]], ragged_rank=1),
          expected=ragged.constant_value([[[[b'a']]]], ragged_rank=1)),
      #=========================================================================
      # Indices with 2 values (selects row & col from params)
      #=========================================================================
      dict(
          descr='params: [B1, (B2)], indices: [A1, 2], result: [A1]',
          params=ragged.constant_value([['a', 'b', 'c'], ['d']]),
          indices=[[1, 0], [0, 0], [0, 2]],
          expected=ragged.constant_value([b'd', b'a', b'c'])),
      dict(
          descr=('params: [B1, (B2), (B3)], indices: [A1, 2], '
                 'result: [A1, (B3)]'),
          params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                        [['e', 'f']]]),
          indices=[[1, 0], [0, 1], [0, 0]],
          expected=ragged.constant_value([[b'e', b'f'], [b'd'],
                                          [b'a', b'b', b'c']])),
      dict(
          descr=('params: [B1, (B2), (B3)], indices: [A1, A2, 2], '
                 'result: [A1, (A2), (B3)]'),
          params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                        [['e', 'f']]]),
          indices=[[[1, 0], [0, 1], [0, 0]]],
          expected=ragged.constant_value([[[b'e', b'f'], [b'd'],
                                           [b'a', b'b', b'c']]])),
      dict(
          descr=('params: [B1, (B2), B3], indices: [A1, A2, 2], '
                 'result: [A1, A2, B3]'),
          params=ragged.constant_value([[['a', 'b'], ['c', 'd']],
                                        [['e', 'f']]],
                                       ragged_rank=1),
          indices=[[[1, 0], [0, 1], [0, 0]]],
          expected=[[[b'e', b'f'], [b'c', b'd'], [b'a', b'b']]]),
      dict(
          descr=('params: [B1, (B2), B3], indices: [A1, A2, A3, 2], '
                 'result: [A1, A2, A3, B3]'),
          params=ragged.constant_value([[['a', 'b'], ['c', 'd']],
                                        [['e', 'f']]],
                                       ragged_rank=1),
          indices=[[[[1, 0], [0, 1], [0, 0]]]],
          expected=[[[[b'e', b'f'], [b'c', b'd'], [b'a', b'b']]]]),
      dict(
          descr=('params: [B1, (B2), (B3)], indices: [A1, (A2), 2], '
                 'result: [A1, (A2), (B3)]'),
          params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                        [['e', 'f']]]),
          indices=ragged.constant_value([[[1, 0], [0, 1]], [[0, 0]]],
                                        ragged_rank=1),
          expected=ragged.constant_value([[[b'e', b'f'], [b'd']],
                                          [[b'a', b'b', b'c']]])),
      #=========================================================================
      # Indices with 3 values
      #=========================================================================
      dict(
          descr=('params: [B1, (B2), (B3)], indices: [A1, 3], '
                 'result: [A1]'),
          params=ragged.constant_value([[['a', 'b', 'c'], ['d']],
                                        [['e', 'f']]]),
          indices=[[1, 0, 1], [0, 0, 0], [0, 1, 0]],
          expected=[b'f', b'a', b'd']),
      dict(
          descr=('params: [B1, (B2), B3], indices: [A1, 3], '
                 'result: [A1]'),
          params=ragged.constant_value([[['a', 'b'], ['c', 'd']],
                                        [['e', 'f']]],
                                       ragged_rank=1),
          indices=[[1, 0, 1], [0, 0, 0], [0, 1, 1]],
          expected=[b'f', b'a', b'd']),
      dict(
          descr=('params: [B1, (B2), (B3), B4], indices: [A1, 3], '
                 'result: [A1, B4]'),
          params=ragged.constant_value([[[['a', 'b'], ['c', 'd']],
                                         [['e', 'f']]]],
                                       ragged_rank=2),
          indices=[[0, 0, 1], [0, 0, 0], [0, 1, 0]],
          expected=[[b'c', b'd'], [b'a', b'b'], [b'e', b'f']]),
  ])  # pyformat: disable
  def testRaggedGatherNd(self, descr, params, indices, expected):
    result = ragged.gather_nd(params, indices)
    self.assertRaggedEqual(result, expected)

  def testRaggedGatherNdUnknownRankError(self):
    if context.executing_eagerly():
      return
    params = ragged.constant([['a', 'b'], ['c', 'd']])
    indices1 = array_ops.placeholder(dtypes.int32, shape=None)
    indices2 = array_ops.placeholder(dtypes.int32, shape=[None])

    with self.assertRaisesRegexp(ValueError,
                                 'indices.rank be statically known.'):
      ragged.gather_nd(params, indices1)
    with self.assertRaisesRegexp(
        ValueError, r'indices.shape\[-1\] must be statically known.'):
      ragged.gather_nd(params, indices2)

  @parameterized.parameters([
      dict(
          params=['a'],
          indices=0,
          error=(ValueError, errors.InvalidArgumentError)),
      dict(
          params=ragged.constant_value([['a']]),
          indices=0,
          message='indices.rank must be at least 1.'),
      dict(
          params=['a', 'b', 'c'],
          indices=ragged.constant_value([[0]]),
          message='The innermost dimension of indices may not be ragged'),
  ])
  def testRaggedGatherNdStaticError(self,
                                    params,
                                    indices,
                                    message=None,
                                    error=ValueError):
    with self.assertRaisesRegexp(error, message):
      ragged.gather_nd(params, indices)
コード例 #15
0
class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
                               parameterized.TestCase):
    def assertSameShape(self, x, y):
        """Checks that x and y have the same shape (including ragged shapes)."""
        if isinstance(x, ragged.RaggedTensor):
            self.assertIsInstance(y, ragged.RaggedTensor)
            self.assertEqual(x.ragged_rank, y.ragged_rank)
            for (x_splits, y_splits) in zip(x.nested_row_splits,
                                            y.nested_row_splits):
                self.assertAllEqual(x_splits, y_splits)
            self.assertAllEqual(array_ops.shape(x.flat_values),
                                array_ops.shape(y.flat_values))
        else:
            self.assertIsInstance(y, ops.Tensor)
            self.assertAllEqual(array_ops.shape(x), array_ops.shape(y))

    @parameterized.parameters(
        #=========================================================================
        # Test different input shapes.
        #=========================================================================
        [
            # 0-dimensional input
            {
                'x': 12
            },
            # 1-dimensional input
            {
                'x': [1, -2, 3]
            },
            # 2-dimensional input
            {
                'x': [[-2, 3], [-3, 4]]
            },
            {
                'x': ragged.constant_value([[-2, 3], [-3]], ragged_rank=1)
            },
            # 3-dimensional inputs
            {
                'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]
            },
            {
                'x':
                ragged.constant_value([[[-2, 3], [3, 4]], [[7, 6]]],
                                      ragged_rank=1)
            },
            {
                'x':
                ragged.constant_value([[[-2, 3, 4], []], [[7, 6]], []],
                                      ragged_rank=2)
            },
        ] +
        #=========================================================================
        # Test each unary op.
        #=========================================================================
        [{
            'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
            'op': op
        } for op in UNARY_FLOAT_OPS] + [{
            'x':
            ragged.constant_value([[True, False], [True]]),
            'op':
            op
        } for op in UNARY_BOOL_OPS] + [{
            'x':
            ragged.constant_value([[18, 512], [12412]], np.int32),
            'op':
            op
        } for op in UNARY_INT_OPS] + [{
            'x':
            ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
            'op':
            op
        } for op in UNARY_STRING_OPS] + [
            {
                'op': clip_ops.clip_by_value,
                'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
                'clip_value_min': 0.1,
                'clip_value_max': 4.0
            },
            {
                'op': math_ops.cast,
                'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
                'dtype': dtypes.int32
            },
            {
                'op': math_ops.saturate_cast,
                'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
                'dtype': dtypes.int32
            },
            {
                'op': string_ops.string_to_hash_bucket,
                'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
                'num_buckets': 1000
            },
            {
                'op': string_ops.string_to_hash_bucket_fast,
                'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
                'num_buckets': 1000
            },
            {
                'op': string_ops.string_to_hash_bucket_strong,
                'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
                'num_buckets': 1000,
                'key': [1231, 12512]
            },
            {
                'op': string_ops.string_to_number,
                'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])
            },
            {
                'op': string_ops.regex_full_match,
                'x': ragged.constant_value([['hello', '123'], ['1+1']]),
                'pattern': r'\w+'
            },
            {
                'op': string_ops.regex_replace,
                'x': ragged.constant_value([['hello', '123'], ['1+1']]),
                'pattern': r'\d',
                'rewrite': '#'
            },
            {
                'op': string_ops.substr,
                'x': ragged.constant_value([['hello', '123'], ['1+1']]),
                'pos': 2,
                'len': 3
            },
            {
                'op': array_ops.check_numerics,
                'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
                'message': 'check-numerics'
            },
        ])  # pyformat: disable
    def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
        x = ragged.convert_to_tensor_or_ragged_tensor(x)
        result = op(x, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_x = x.flat_values if isinstance(x, ragged.RaggedTensor) else x
        expected_flat_values = array_ops.reshape(op(dense_x, **extra_args),
                                                 [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(x, result)

        # Check that the result has the expected (flattened) values.
        if isinstance(result, ragged.RaggedTensor):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)

    @parameterized.parameters(
        [
            #=====================================================================
            # Without broadcasting -- i.e., shapes match exactly.
            #=====================================================================
            # Shapes: x:(), y:()
            {
                'x': 12,
                'y': 8
            },
            # Shapes: x:(3,), y:(3,)
            {
                'x': [7, 8, 9],
                'y': [1, -2, 3]
            },
            # Shapes: x:(2, 2), y:(2, 2)
            {
                'x': [[-2, 3], [-3, -4]],
                'y': [[1, 2], [3, 4]]
            },
            # Shapes: x:(2, None), y:(2, None)
            {
                'x': ragged.constant_value([[-2, 3], [-3]]),
                'y': ragged.constant_value([[5, 6], [7]])
            },
            # Shapes: x:(2, 2, 2), y:(2, 2, 2)
            {
                'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
                'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]
            },
            # Shapes: x:(2, None, None), y: (2, None, None)
            {
                'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5,
                                                                      7, 8]]]),
                'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1,
                                                                      9, 8]]])
            },
            # Shapes: x:(2, None, 2), y: (2, None, 2)
            {
                'x':
                ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
                                      ragged_rank=1),
                'y':
                ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
                                      ragged_rank=1)
            },

            #=====================================================================
            # With broadcasting
            #=====================================================================
            # Shapes: x:(), y:(3,)
            {
                'x': 12,  # Broadcast () -> (3,)
                'y': [1, -2, 3]
            },
            # Shapes: x:(1,), y:(3,)
            {
                'x': [12],  # Broadcast (1,) -> (3,)
                'y': [1, -2, 3]
            },
            # Shapes: x:(), y:(2, 2)
            {
                'x': 12,  # Broadcast () -> (2, 2)
                'y': [[1, 2], [3, 4]]
            },
            # Shapes: x:(1,), y:(2, 2)
            {
                'x': 12,  # Broadcast (1,) -> (2, 2)
                'y': [[1, 2], [3, 4]]
            },
            # Shapes: x:(2, 1), y:(2, 2)
            {
                'x': [[10], [20]],  # Broadcast (2, 1) -> (2, 2)
                'y': [[1, 2], [3, 4]]
            },
            # Shapes: x:(), y:(2, None)
            {
                'x': 10,  # Broadcast () -> (2, None)
                'y': ragged.constant_value([[1, 2], [3]], dtype=np.int32)
            },
            # TODO(edloper): Add tests for more advanced broadcasting, once we add
            # support for it.

            #=====================================================================
            # Keyword Args
            #=====================================================================
            {
                'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5,
                                                                      7, 8]]]),
                'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1,
                                                                      9, 8]]]),
                'use_kwargs': ('x', 'y')
            },
            {
                'x':
                ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
                                      ragged_rank=1),
                'y':
                ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
                                      ragged_rank=1),
                'use_kwargs': ('x', 'y')
            },
            {
                'x':
                ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
                                      ragged_rank=1),
                'y':
                ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
                                      ragged_rank=1),
                'use_kwargs': ('x', )
            },
        ] +
        #=========================================================================
        # Test each unary op.
        #=========================================================================
        [{
            'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
            'y': ragged.constant_value([[5.0, 1.0], [12.0]]),
            'op': op
        } for op in BINARY_FLOAT_OPS] + [
            {
                'x': ragged.constant_value([[-2, 3], [-3]]),
                'y': ragged.constant_value([[5, 1], [12]]),
                'op': op
            } for op in BINARY_INT_OPS
        ] + [{
            'x': ragged.constant_value([[True, True], [False]]),
            'y': ragged.constant_value([[False, True], [False]]),
            'op': op
        } for op in BINARY_BOOL_OPS])  # pyformat: disable
    def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
        use_kwargs = extra_args.pop('use_kwargs', ())
        x = ragged.convert_to_tensor_or_ragged_tensor(x)
        y = ragged.convert_to_tensor_or_ragged_tensor(y)
        if 'x' in use_kwargs and 'y' in use_kwargs:
            result = op(x=x, y=y, **extra_args)
        elif 'y' in use_kwargs:
            result = op(x, y=y, **extra_args)
        else:
            result = op(x, y, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_x = x.flat_values if isinstance(x, ragged.RaggedTensor) else x
        dense_y = y.flat_values if isinstance(y, ragged.RaggedTensor) else y
        expected_flat_values = array_ops.reshape(
            op(dense_x, dense_y, **extra_args), [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(y, result)

        # Check that the result has the expected (flattened) values.
        if isinstance(result, ragged.RaggedTensor):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)

    @parameterized.parameters([
        {
            'inputs': (12, 8, 3)
        },
        {
            'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])
        },
        {
            'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])
        },
        {
            'inputs': (ragged.constant_value(
                [[1, 3], [-3]]), ragged.constant_value(
                    [[4, 7], [88]]), ragged.constant_value([[2, 9], [12]]))
        },
        {
            'inputs': (ragged.constant_value([[[1, 3], [-3]], [[1]]]),
                       ragged.constant_value([[[4, 7], [88]], [[2]]]),
                       ragged.constant_value([[[2, 9], [12]], [[8]]]))
        },
        {
            'inputs': (ragged.constant_value([[[1, 3], [3, 4]], [[1, 5]]],
                                             ragged_rank=1),
                       ragged.constant_value([[[4, 7], [1, 2]], [[2, 2]]],
                                             ragged_rank=1),
                       ragged.constant_value([[[2, 9], [5, 2]], [[8, 0]]],
                                             ragged_rank=1))
        },
        {
            'inputs': (ragged.constant_value([[[1, 3], [-3]], [[1]]]),
                       ragged.constant_value([[[4, 7], [88]], [[2]]]),
                       ragged.constant_value([[[2, 9], [12]], [[8]]])),
            'use_kwargs':
            True
        },
    ] + [
        {
            'op':
            math_ops.add_n,
            'inputs': (ragged.constant_value(
                [[1, 3], [-3]]), ragged.constant_value(
                    [[4, 7], [88]]), ragged.constant_value([[2, 9], [12]]))
        },
        {
            'op':
            string_ops.string_join,
            'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
                       ragged.constant_value([['foo', 'bar'], ['baz']]),
                       ragged.constant_value([['2', '9'], ['12']]))
        },
    ])  # pyformat: disable
    def testListValuedElementwiseOp(self,
                                    inputs,
                                    op=math_ops.add_n,
                                    **extra_args):
        use_kwargs = extra_args.pop('use_kwargs', False)
        inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
        if use_kwargs:
            result = op(inputs=inputs, **extra_args)
        else:
            result = op(inputs, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_inputs = [
            x.flat_values if isinstance(x, ragged.RaggedTensor) else x
            for x in inputs
        ]
        expected_flat_values = array_ops.reshape(
            op(dense_inputs, **extra_args), [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(inputs[0], result)

        # Check that the result has the expected (flattened) values.
        if isinstance(result, ragged.RaggedTensor):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)

    def testElementwiseOpUnknownRankError(self):
        if context.executing_eagerly():
            return
        x = ragged.constant([[1, 2], [3]])
        y = ragged.RaggedTensor.from_row_splits(
            array_ops.placeholder_with_default([1, 2, 3], shape=None),
            x.row_splits)
        with self.assertRaisesRegexp(ValueError,
                                     r'Unable to broadcast: unknown rank'):
            math_ops.add(x, y)

    @parameterized.parameters([
        dict(x=ragged.constant_value([[1, 2], [3]]),
             y=[[10]],
             expected=[[11, 12], [13]]),
        dict(x=ragged.constant_value([[[1, 2], [3, 4]], [[5]]], ragged_rank=2),
             y=ragged.constant_value([[[10], [20]], [[30]]], ragged_rank=1),
             expected=[[[11, 12], [23, 24]], [[35]]]),
        dict(x=ragged.constant_value([[[1]]]),
             y=ragged.constant_value([[1]]),
             expected=[[[2]]]),
    ])
    def testElementwiseOpBroadcast(self, x, y, expected):
        x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
        y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
        result = x + y
        self.assertRaggedEqual(result, expected)

    def testElementwiseOpShapeMismatch(self):
        x = ragged.constant([[1, 2, 3], [4, 5]])
        y = ragged.constant([[1, 2, 3], [4, 5, 6]])
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(math_ops.add(x, y))

    def testBinaryOpSparseAndRagged(self):
        x = ragged.constant([[1, 2, 3], [4, 5]])
        y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3],
                                       [3, 2])
        with self.assertRaises((TypeError, ValueError)):
            self.evaluate(math_ops.add(x, y))

        with self.assertRaises((TypeError, ValueError)):
            self.evaluate(math_ops.add_n([x, y]))