示例#1
0
class HubModuleTokenizerTest(parameterized.TestCase, test.TestCase):
    @parameterized.parameters([
        # Test scalar input.
        dict(text_input=_Utf8(u"新华社北京"),
             expected_tokens=[_Utf8(u"新华社"), _Utf8(u"北京")],
             expected_starts=[0, 9],
             expected_ends=[9, 15]),
        # Test rank 1 input.
        dict(text_input=[_Utf8(u"新华社北京"), _Utf8(u"中文测试")],
             expected_tokens=[[_Utf8(u"新华社"), _Utf8(u"北京")],
                              [_Utf8(u"中文"), _Utf8(u"测试")]],
             expected_starts=[[0, 9], [0, 6]],
             expected_ends=[[9, 15], [6, 12]]),
        # Test rank 2 ragged input.
        dict(text_input=ragged_factory_ops.constant_value(
            [[_Utf8(u"新华社北京"), _Utf8(u"中文测试")], [_Utf8(u"新华社上海")]]),
             expected_tokens=[[[_Utf8(u"新华社"), _Utf8(u"北京")],
                               [_Utf8(u"中文"), _Utf8(u"测试")]],
                              [[_Utf8(u"新华社"), _Utf8(u"上海")]]],
             expected_starts=[[[0, 9], [0, 6]], [[0, 9]]],
             expected_ends=[[[9, 15], [6, 12]], [[9, 15]]]),
        # Test rank 2 dense input.
        dict(text_input=ragged_factory_ops.constant_value(
            [[_Utf8(u"新华社北京"), _Utf8(u"中文测试")],
             [_Utf8(u"新华社上海"), _Utf8(u"英国交通")]]),
             expected_tokens=[[[_Utf8(u"新华社"), _Utf8(u"北京")],
                               [_Utf8(u"中文"), _Utf8(u"测试")]],
                              [[_Utf8(u"新华社"), _Utf8(u"上海")],
                               [_Utf8(u"英国"), _Utf8(u"交通")]]],
             expected_starts=[[[0, 9], [0, 6]], [[0, 9], [0, 6]]],
             expected_ends=[[[9, 15], [6, 12]], [[9, 15], [6, 12]]]),
        # Test ragged input with rank higher than 2.
        dict(text_input=ragged_factory_ops.constant_value([[[_Utf8(u"新华社北京")],
                                                            [_Utf8(u"中文测试")]],
                                                           [[_Utf8(u"新华社上海")]]
                                                           ]),
             expected_tokens=[[[[_Utf8(u"新华社"), _Utf8(u"北京")]],
                               [[_Utf8(u"中文"), _Utf8(u"测试")]]],
                              [[[_Utf8(u"新华社"), _Utf8(u"上海")]]]],
             expected_starts=[[[[0, 9]], [[0, 6]]], [[[0, 9]]]],
             expected_ends=[[[[9, 15]], [[6, 12]]], [[[9, 15]]]])
    ])
    def testTokenize(self, text_input, expected_tokens, expected_starts,
                     expected_ends):
        hub_module_handle = os.path.join(
            TF_TEXT_PACKAGE, "python/ops/test_data/segmenter_hub_module")
        segmenter = hub_module_tokenizer.HubModuleTokenizer(hub_module_handle)
        tokens, starts, ends = segmenter.tokenize_with_offsets(text_input)
        tokens_no_offset = segmenter.tokenize(text_input)
        self.evaluate(lookup_ops.tables_initializer())
        self.evaluate(variables_lib.global_variables_initializer())

        # TODO(salcianu): here and elsewhere in this package: use
        # assertAllEqual(expected, actual) (instead of actual, expected) as that
        # generates more readable error messages:
        # http://google3/third_party/tensorflow/python/framework/test_util.py?l=2668&rcl=311760220
        self.assertAllEqual(tokens, expected_tokens)
        self.assertAllEqual(starts, expected_starts)
        self.assertAllEqual(ends, expected_ends)
        self.assertAllEqual(tokens_no_offset, expected_tokens)
class StringsToBytesOpTest(ragged_test_util.RaggedTensorTestCase,
                           parameterized.TestCase):
    @parameterized.parameters(
        # Scalar input -> vector output
        (b'hello', [b'h', b'e', b'l', b'l', b'o']),
        # Vector input -> 2D ragged output
        ([b'hello', b'123'], [[b'h', b'e', b'l', b'l', b'o'],
                              [b'1', b'2', b'3']]),
        # 2D tensor input -> 3D ragged output
        ([[b'abc', b'de'], [b'fgh', b'']], [[[b'a', b'b', b'c'], [b'd', b'e']],
                                            [[b'f', b'g', b'h'], []]]),
        # 2D ragged input -> 3D ragged output
        (ragged_factory_ops.constant_value([[b'abc', b'de'], [b'f']]),
         [[[b'a', b'b', b'c'], [b'd', b'e']], [[b'f']]]),
        # 3D input -> 4D ragged output
        (ragged_factory_ops.constant_value(
            [[[b'big', b'small'], [b'red']], [[b'cat', b'dog'], [b'ox']]
             ]), [[[[b'b', b'i', b'g'], [b's', b'm', b'a', b'l', b'l']],
                   [[b'r', b'e', b'd']]],
                  [[[b'c', b'a', b't'], [b'd', b'o', b'g']], [[b'o', b'x']]]]),
        # Empty string
        (b'', []),
        # Null byte
        (b'\x00', [b'\x00']),
        # Unicode
        (u'仅今年前'.encode('utf-8'), [
            b'\xe4', b'\xbb', b'\x85', b'\xe4', b'\xbb', b'\x8a', b'\xe5',
            b'\xb9', b'\xb4', b'\xe5', b'\x89', b'\x8d'
        ]),
    )
    def testStringToBytes(self, source, expected):
        expected = ragged_factory_ops.constant_value(expected, dtype=object)
        result = ragged_string_ops.string_bytes_split(source)
        self.assertRaggedEqual(expected, result)
  def testFromTensorSlicesMixedRagged(self):
    components = (np.tile(np.array([[1], [2], [3]]),
                          20), np.tile(np.array([[12], [13], [14]]),
                                       22), np.array([37.0, 38.0, 39.0]),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 0], [2, 0]]),
                      values=np.array([0, 0, 0]),
                      dense_shape=np.array([3, 1])),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 1], [2, 2]]),
                      values=np.array([1, 2, 3]),
                      dense_shape=np.array([3, 3])),
                  ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    get_next = self.getNext(dataset)

    expected = [
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[0]]),
             values=np.array([1]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[0]
                                                                           ])),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[1]]),
             values=np.array([2]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[1]
                                                                           ])),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[2]]),
             values=np.array([3]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[2]
                                                                           ])),
    ]
    for i in range(3):
      results = self.evaluate(get_next())
      for component, result_component in zip(
          (list(zip(*components[:3]))[i] + expected[i]), results):
        if sparse_tensor.is_sparse(component):
          self.assertSparseValuesEqual(component, result_component)
        elif ragged_tensor.is_ragged(component):
          self.assertRaggedEqual(component, result_component)
        else:
          self.assertAllEqual(component, result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
  def testFromTensorSlicesMixedRagged(self):
    components = (np.tile(np.array([[1], [2], [3]]),
                          20), np.tile(np.array([[12], [13], [14]]),
                                       22), np.array([37.0, 38.0, 39.0]),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 0], [2, 0]]),
                      values=np.array([0, 0, 0]),
                      dense_shape=np.array([3, 1])),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 1], [2, 2]]),
                      values=np.array([1, 2, 3]),
                      dense_shape=np.array([3, 3])),
                  ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    get_next = self.getNext(dataset)

    expected = [
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[0]]),
             values=np.array([1]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[0]
                                                                           ])),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[1]]),
             values=np.array([2]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[1]
                                                                           ])),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[2]]),
             values=np.array([3]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[2]
                                                                           ])),
    ]
    for i in range(3):
      results = self.evaluate(get_next())
      for component, result_component in zip(
          (list(zip(*components[:3]))[i] + expected[i]), results):
        if sparse_tensor.is_sparse(component):
          self.assertSparseValuesEqual(component, result_component)
        elif ragged_tensor.is_ragged(component):
          self.assertRaggedEqual(component, result_component)
        else:
          self.assertAllEqual(component, result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#5
0
    def testFromTensorsRagged(self):
        components = (
            ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]),
            ragged_factory_ops.constant_value([[[3]], [[4]], [[5]]]),
        )

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertDatasetProduces(dataset, expected_output=[components])
  def testFromTensorsRagged(self):
    components = (
        ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]),
        ragged_factory_ops.constant_value([[[3]], [[4]], [[5]]]),
    )

    dataset = dataset_ops.Dataset.from_tensors(components)

    self.assertDatasetProduces(dataset, expected_output=[components])
  def testSerializedContainingRaggedFeatureWithNoPartitions(self):
    original = [
        example(
            features=features({
                "rt_c": float_feature([3, 4, 5, 6, 7, 8]),
                "rt_f_values": float_feature([0, 1, 2, 3, 4]),
            })),
        example(
            features=features({
                "rt_c": float_feature([]),  # empty float list
            })),
        example(
            features=features({
                "rt_d": feature(),  # feature with nothing in it
            })),
        example(
            features=features({
                "rt_c": float_feature([1, 2, -1]),
                "rt_d": bytes_feature([b"hi"]),
                "rt_f_values": float_feature([0, 1, 2]),
            }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_rt_c = ragged_factory_ops.constant_value(
        [[3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [], [], [1.0, 2.0, -1.0]],
        row_splits_dtype=dtypes.int32)
    expected_rt_d = ragged_factory_ops.constant_value(
        [[], [], [], [b"hi"]], row_splits_dtype=dtypes.int64)
    expected_rt_f = ragged_factory_ops.constant_value(
        [[0.0, 1.0, 2.0, 3.0, 4.0], [], [], [0.0, 1.0, 2.0]],
        row_splits_dtype=dtypes.int32)

    expected_output = {
        "rt_c": expected_rt_c,
        "rt_d": expected_rt_d,
        "rt_f": expected_rt_f,
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "rt_c":
                parsing_ops.RaggedFeature(dtypes.float32),
            "rt_d":
                parsing_ops.RaggedFeature(
                    dtypes.string, row_splits_dtype=dtypes.int64),
            "rt_f":
                parsing_ops.RaggedFeature(
                    dtypes.float32, value_key="rt_f_values"),
        },
        expected_values=expected_output,
        create_iterator_twice=True)
示例#8
0
 def testUnbatchDatasetWithRaggedTensor(self):
   rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
                                           [[5]], [[6]], [[7]], [[8]], [[9]]])
   data = dataset_ops.Dataset.from_tensors(rt)
   data = data.unbatch()
   data = data.batch(5)
   data = data.batch(2)
   data = data.unbatch()
   expected_output = [
       ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]),
       ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]),
   ]
   self.assertDatasetProduces(
       data, expected_output=expected_output)
示例#9
0
 def testUnbatchDatasetWithRaggedTensor(self):
   rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
                                           [[5]], [[6]], [[7]], [[8]], [[9]]])
   data = dataset_ops.Dataset.from_tensors(rt)
   data = data.apply(batching.unbatch())
   data = data.batch(5)
   data = data.batch(2)
   data = data.apply(batching.unbatch())
   expected_output = [
       ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]),
       ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]),
   ]
   self.assertDatasetProduces(
       data, expected_output=expected_output)
示例#10
0
 def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self):
   st = sparse_tensor.SparseTensorValue(
       indices=[[i, i] for i in range(10)],
       values=list(range(10)),
       dense_shape=[10, 10])
   rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
                                           [[5]], [[6]], [[7]], [[8]], [[9]]])
   data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
   data = data.unbatch()
   data = data.batch(5)
   data = data.unbatch()
   expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]),
                       ragged_factory_ops.constant_value([[i]]))
                      for i in range(10)]
   self.assertDatasetProduces(
       data, expected_output=expected_output)
示例#11
0
 def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self):
   st = sparse_tensor.SparseTensorValue(
       indices=[[i, i] for i in range(10)],
       values=list(range(10)),
       dense_shape=[10, 10])
   rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
                                           [[5]], [[6]], [[7]], [[8]], [[9]]])
   data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
   data = data.apply(batching.unbatch())
   data = data.batch(5)
   data = data.apply(batching.unbatch())
   expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]),
                       ragged_factory_ops.constant_value([[i]]))
                      for i in range(10)]
   self.assertDatasetProduces(
       data, expected_output=expected_output)
示例#12
0
 def testSplitWithPaddedOutput(self, texts, expected, ragged_rank=None):
     input_tensor = ragged_factory_ops.constant_value(
         _nested_encode(texts, "UTF-8"),
         ragged_rank=ragged_rank,
         dtype=bytes)
     result = ragged_string_ops.unicode_split(
         input_tensor, "UTF-8").to_tensor(default_value="")
     self.assertAllEqual(np.array(expected, dtype=bytes), result)
示例#13
0
 def testBasicSplit(self, texts, ragged_rank=None):
     input_tensor = ragged_factory_ops.constant_value(
         _nested_encode(texts, "UTF-8"),
         ragged_rank=ragged_rank,
         dtype=bytes)
     result = ragged_string_ops.unicode_split(input_tensor, "UTF-8")
     expected = _nested_splitchars(texts, "UTF-8")
     self.assertAllEqual(expected, result)
示例#14
0
 def testBasicSplitWithOffsets(self, texts, ragged_rank=None):
   input_tensor = ragged_factory_ops.constant_value(
       _nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
   result = ragged_string_ops.unicode_split_with_offsets(input_tensor, "UTF-8")
   expected_codepoints = _nested_splitchars(texts, "UTF-8")
   expected_offsets = _nested_offsets(texts, "UTF-8")
   self.assertAllEqual(expected_codepoints, result[0])
   self.assertAllEqual(expected_offsets, result[1])
 def testBasicSplitWithOffsets(self, texts, ragged_rank=None):
   input_tensor = ragged_factory_ops.constant_value(
       _nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
   result = ragged_string_ops.unicode_split_with_offsets(input_tensor, "UTF-8")
   expected_codepoints = _nested_splitchars(texts, "UTF-8")
   expected_offsets = _nested_offsets(texts, "UTF-8")
   self.assertRaggedEqual(expected_codepoints, result[0])
   self.assertRaggedEqual(expected_offsets, result[1])
class StringsToBytesOpTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):

  @parameterized.parameters(
      # Scalar input -> vector output
      (b'hello', [b'h', b'e', b'l', b'l', b'o']),
      # Vector input -> 2D ragged output
      ([b'hello', b'123'],
       [[b'h', b'e', b'l', b'l', b'o'], [b'1', b'2', b'3']]),
      # 2D tensor input -> 3D ragged output
      ([[b'abc', b'de'], [b'fgh', b'']],
       [[[b'a', b'b', b'c'], [b'd', b'e']], [[b'f', b'g', b'h'], []]]),
      # 2D ragged input -> 3D ragged output
      (ragged_factory_ops.constant_value([[b'abc', b'de'], [b'f']]),
       [[[b'a', b'b', b'c'], [b'd', b'e']], [[b'f']]]),
      # 3D input -> 4D ragged output
      (ragged_factory_ops.constant_value(
          [[[b'big', b'small'], [b'red']], [[b'cat', b'dog'], [b'ox']]]),
       [[[[b'b', b'i', b'g'], [b's', b'm', b'a', b'l', b'l']],
         [[b'r', b'e', b'd']]],
        [[[b'c', b'a', b't'], [b'd', b'o', b'g']],
         [[b'o', b'x']]]]),
      # Empty string
      (b'', []),
      # Null byte
      (b'\x00', [b'\x00']),
      # Unicode
      (u'仅今年前'.encode('utf-8'),
       [b'\xe4', b'\xbb', b'\x85', b'\xe4', b'\xbb', b'\x8a', b'\xe5',
        b'\xb9', b'\xb4', b'\xe5', b'\x89', b'\x8d']),
  )
  def testStringToBytes(self, source, expected):
    expected = ragged_factory_ops.constant_value(expected, dtype=object)
    result = ragged_string_ops.string_bytes_split(source)
    self.assertAllEqual(expected, result)

  def testUnknownInputRankError(self):
    # Use a tf.function that erases shape information.
    @def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
    def f(v):
      return ragged_string_ops.string_bytes_split(v)

    with self.assertRaisesRegexp(ValueError,
                                 'input must have a statically-known rank'):
      f(['foo'])
示例#17
0
 def testFlattenRaggedValue(self):
   rt = ragged_factory_ops.constant_value([[[0]], [[1]]])
   single_value = rt
   list_of_values = [rt, rt, rt]
   nest_of_values = ((rt), ((rt), (rt)))
   dict_of_values = {"foo": rt, "bar": rt, "baz": rt}
   self.assertEqual([rt], nest.flatten(single_value))
   self.assertEqual([[rt, rt, rt]], nest.flatten(list_of_values))
   self.assertEqual([rt, rt, rt], nest.flatten(nest_of_values))
   self.assertEqual([rt, rt, rt], nest.flatten(dict_of_values))
示例#18
0
文件: nest_test.py 项目: MFChunga/poo
 def testFlattenRaggedValue(self):
     rt = ragged_factory_ops.constant_value([[[0]], [[1]]])
     single_value = rt
     list_of_values = [rt, rt, rt]
     nest_of_values = ((rt), ((rt), (rt)))
     dict_of_values = {"foo": rt, "bar": rt, "baz": rt}
     self.assertEqual([rt], nest.flatten(single_value))
     self.assertEqual([[rt, rt, rt]], nest.flatten(list_of_values))
     self.assertEqual([rt, rt, rt], nest.flatten(nest_of_values))
     self.assertEqual([rt, rt, rt], nest.flatten(dict_of_values))
    def testRaggedValues(self,
                         pylist,
                         dtype=None,
                         ragged_rank=None,
                         inner_shape=None,
                         expected_shape=None,
                         expected_dtype=None):
        """Tests that `ragged_value(pylist).to_list() == pylist`."""
        rt = ragged_factory_ops.constant_value(pylist,
                                               dtype=dtype,
                                               ragged_rank=ragged_rank,
                                               inner_shape=inner_shape)
        # Normalize the pylist, i.e., convert all np.arrays to list.
        # E.g., [np.array((1,2))] --> [[1,2]]
        pylist = self._normalize_pylist(pylist)
        # If dtype was explicitly specified, check it.
        if dtype is not None:
            self.assertEqual(rt.dtype, dtype)
        if expected_dtype is not None:
            self.assertEqual(rt.dtype, expected_dtype)

        # If ragged_rank was explicitly specified, check it.
        if ragged_rank is not None:
            if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
                self.assertEqual(rt.ragged_rank, ragged_rank)
            else:
                self.assertEqual(0, ragged_rank)

        # If inner_shape was explicitly specified, check it.
        if inner_shape is not None:
            if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
                self.assertEqual(rt.flat_values.shape[1:], inner_shape)
            else:
                self.assertEqual(rt.shape, inner_shape)

        if expected_shape is not None:
            self.assertEqual(tuple(rt.shape), expected_shape)

        if rt.shape:
            if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
                self.assertEqual(rt.to_list(), pylist)
            else:
                self.assertEqual(rt.tolist(), pylist)
            if expected_shape is not None:
                self.assertEqual(rt.shape, expected_shape)
        else:
            self.assertEqual(rt, pylist)
            if expected_shape is not None:
                self.assertEqual((), expected_shape)
示例#20
0
 def testIsNested(self):
   self.assertFalse(nest.is_nested("1234"))
   self.assertFalse(nest.is_nested([1, 3, [4, 5]]))
   self.assertTrue(nest.is_nested(((7, 8), (5, 6))))
   self.assertFalse(nest.is_nested([]))
   self.assertFalse(nest.is_nested(set([1, 2])))
   ones = array_ops.ones([2, 3])
   self.assertFalse(nest.is_nested(ones))
   self.assertFalse(nest.is_nested(math_ops.tanh(ones)))
   self.assertFalse(nest.is_nested(np.ones((4, 5))))
   self.assertTrue(nest.is_nested({"foo": 1, "bar": 2}))
   self.assertFalse(
       nest.is_nested(sparse_tensor.SparseTensorValue([[0]], [0], [1])))
   self.assertFalse(
       nest.is_nested(ragged_factory_ops.constant_value([[[0]], [[1]]])))
示例#21
0
    def testFromTensorsMixedRagged(self):
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
                      sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                                      values=np.array([0]),
                                                      dense_shape=np.array(
                                                          [1])),
                      sparse_tensor.SparseTensorValue(
                          indices=np.array([[0, 0], [1, 1]]),
                          values=np.array([-1, 1]),
                          dense_shape=np.array([2, 2])),
                      ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertDatasetProduces(dataset, expected_output=[components])
示例#22
0
 def testIsSequence(self):
   self.assertFalse(nest.is_sequence("1234"))
   self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
   self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
   self.assertFalse(nest.is_sequence([]))
   self.assertFalse(nest.is_sequence(set([1, 2])))
   ones = array_ops.ones([2, 3])
   self.assertFalse(nest.is_sequence(ones))
   self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
   self.assertFalse(nest.is_sequence(np.ones((4, 5))))
   self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2}))
   self.assertFalse(
       nest.is_sequence(sparse_tensor.SparseTensorValue([[0]], [0], [1])))
   self.assertFalse(
       nest.is_sequence(ragged_factory_ops.constant_value([[[0]], [[1]]])))
示例#23
0
  def testFromTensorsMixedRagged(self):
    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0]]),
                      values=np.array([0]),
                      dense_shape=np.array([1])),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 1]]),
                      values=np.array([-1, 1]),
                      dense_shape=np.array([2, 2])),
                  ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))

    dataset = dataset_ops.Dataset.from_tensors(components)

    self.assertDatasetProduces(dataset, expected_output=[components])
  def testRaggedValues(self,
                       pylist,
                       dtype=None,
                       ragged_rank=None,
                       inner_shape=None,
                       expected_shape=None,
                       expected_dtype=None):
    """Tests that `ragged_value(pylist).to_list() == pylist`."""
    rt = ragged_factory_ops.constant_value(
        pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)
    # Normalize the pylist, i.e., convert all np.arrays to list.
    # E.g., [np.array((1,2))] --> [[1,2]]
    pylist = self._normalize_pylist(pylist)
    # If dtype was explicitly specified, check it.
    if dtype is not None:
      self.assertEqual(rt.dtype, dtype)
    if expected_dtype is not None:
      self.assertEqual(rt.dtype, expected_dtype)

    # If ragged_rank was explicitly specified, check it.
    if ragged_rank is not None:
      if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
        self.assertEqual(rt.ragged_rank, ragged_rank)
      else:
        self.assertEqual(0, ragged_rank)

    # If inner_shape was explicitly specified, check it.
    if inner_shape is not None:
      if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
        self.assertEqual(rt.flat_values.shape[1:], inner_shape)
      else:
        self.assertEqual(rt.shape, inner_shape)

    if expected_shape is not None:
      self.assertEqual(tuple(rt.shape), expected_shape)

    if rt.shape:
      if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
        self.assertEqual(rt.to_list(), pylist)
      else:
        self.assertEqual(rt.tolist(), pylist)
      if expected_shape is not None:
        self.assertEqual(rt.shape, expected_shape)
    else:
      self.assertEqual(rt, pylist)
      if expected_shape is not None:
        self.assertEqual((), expected_shape)
    def testRaggedValues(self,
                         pylist,
                         dtype=None,
                         ragged_rank=None,
                         inner_shape=None,
                         expected_shape=None,
                         expected_dtype=None):
        """Tests that `ragged_value(pylist).to_list() == pylist`."""
        rt = ragged_factory_ops.constant_value(pylist,
                                               dtype=dtype,
                                               ragged_rank=ragged_rank,
                                               inner_shape=inner_shape)

        # If dtype was explicitly specified, check it.
        if dtype is not None:
            self.assertEqual(rt.dtype, dtype)
        if expected_dtype is not None:
            self.assertEqual(rt.dtype, expected_dtype)

        # If ragged_rank was explicitly specified, check it.
        if ragged_rank is not None:
            if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
                self.assertEqual(rt.ragged_rank, ragged_rank)
            else:
                self.assertEqual(0, ragged_rank)

        # If inner_shape was explicitly specified, check it.
        if inner_shape is not None:
            if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
                self.assertEqual(rt.flat_values.shape[1:], inner_shape)
            else:
                self.assertEqual(rt.shape, inner_shape)

        if expected_shape is not None:
            self.assertEqual(tuple(rt.shape), expected_shape)

        if rt.shape:
            if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
                self.assertEqual(rt.to_list(), pylist)
            else:
                self.assertEqual(rt.tolist(), pylist)
            if expected_shape is not None:
                self.assertEqual(rt.shape, expected_shape)
        else:
            self.assertEqual(rt, pylist)
            if expected_shape is not None:
                self.assertEqual((), expected_shape)
  def testRaggedValues(self,
                       pylist,
                       dtype=None,
                       ragged_rank=None,
                       inner_shape=None,
                       expected_shape=None,
                       expected_dtype=None):
    """Tests that `ragged_value(pylist).to_list() == pylist`."""
    rt = ragged_factory_ops.constant_value(
        pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)

    # If dtype was explicitly specified, check it.
    if dtype is not None:
      self.assertEqual(rt.dtype, dtype)
    if expected_dtype is not None:
      self.assertEqual(rt.dtype, expected_dtype)

    # If ragged_rank was explicitly specified, check it.
    if ragged_rank is not None:
      if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
        self.assertEqual(rt.ragged_rank, ragged_rank)
      else:
        self.assertEqual(0, ragged_rank)

    # If inner_shape was explicitly specified, check it.
    if inner_shape is not None:
      if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
        self.assertEqual(rt.flat_values.shape[1:], inner_shape)
      else:
        self.assertEqual(rt.shape, inner_shape)

    if expected_shape is not None:
      self.assertEqual(tuple(rt.shape), expected_shape)

    if rt.shape:
      if isinstance(rt, ragged_tensor_value.RaggedTensorValue):
        self.assertEqual(rt.to_list(), pylist)
      else:
        self.assertEqual(rt.tolist(), pylist)
      if expected_shape is not None:
        self.assertEqual(rt.shape, expected_shape)
    else:
      self.assertEqual(rt, pylist)
      if expected_shape is not None:
        self.assertEqual((), expected_shape)
示例#27
0
    def testPruneRagged(self):

        x_in = []
        x_out = []

        def f(x, y):
            x_in.append(x)
            xx = x * x
            x_out.append(xx)
            return xx, y * y

        x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32)
        y_spec = tensor_spec.TensorSpec((), dtypes.float32)

        f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec])

        f_pruned = f_wrapped.prune(x_in[0], x_out[0])
        rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
        expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]])

        # Note: when we call f_pruned, we must pass the RaggedTensor in using
        # its components, since that's the current convention for how concrete
        # functions handle structured inputs.
        self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
 def testBasicSplit(self, texts, ragged_rank=None):
   input_tensor = ragged_factory_ops.constant_value(
       _nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
   result = ragged_string_ops.unicode_split(input_tensor, "UTF-8")
   expected = _nested_splitchars(texts, "UTF-8")
   self.assertRaggedEqual(expected, result)
class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase):

  def assertSameShape(self, x, y):
    """Checks that x and y have the same shape (including ragged shapes)."""
    if ragged_tensor.is_ragged(x):
      self.assertTrue(ragged_tensor.is_ragged(y))
      self.assertEqual(x.ragged_rank, y.ragged_rank)
      for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
        self.assertAllEqual(x_splits, y_splits)
      self.assertAllEqual(
          array_ops.shape(x.flat_values), array_ops.shape(y.flat_values))
    else:
      self.assertIsInstance(y, ops.Tensor)
      self.assertAllEqual(array_ops.shape(x), array_ops.shape(y))

  @parameterized.parameters(
      #=========================================================================
      # Test different input shapes.
      #=========================================================================
      [
          # 0-dimensional input
          {'x': 12},
          # 1-dimensional input
          {'x': [1, -2, 3]},
          # 2-dimensional input
          {'x': [[-2, 3], [-3, 4]]},
          {'x': ragged_factory_ops.constant_value(
              [[-2, 3], [-3]], ragged_rank=1)},
          # 3-dimensional inputs
          {'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]},
          {'x': ragged_factory_ops.constant_value(
              [[[-2, 3], [3, 4]], [[7, 6]]],
              ragged_rank=1)},
          {'x': ragged_factory_ops.constant_value(
              [[[-2, 3, 4], []], [[7, 6]], []],
              ragged_rank=2)},
          ] +
      #=========================================================================
      # Test each unary op.
      #=========================================================================
      [{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 'op': op}
       for op in test_ops.UNARY_FLOAT_OPS] +
      [{'x': ragged_factory_ops.constant_value([[True, False], [True]]),
        'op': op}
       for op in test_ops.UNARY_BOOL_OPS] +
      [{'x': ragged_factory_ops.constant_value([[18, 512], [12412]], np.int32),
        'op': op}
       for op in test_ops.UNARY_INT_OPS] +
      [{'x': ragged_factory_ops.constant_value([['abcd', 'efgh'],
                                                ['aabbccdd']]),
        'op': op}
       for op in test_ops.UNARY_STRING_OPS] +
      [
          {'op': clip_ops.clip_by_value,
           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
           'clip_value_min': 0.1, 'clip_value_max': 4.0},
          {'op': math_ops.cast,
           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
           'dtype': dtypes.int32},
          {'op': math_ops.saturate_cast,
           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
           'dtype': dtypes.int32},
          {'op': string_ops.string_to_hash_bucket,
           'x': ragged_factory_ops.constant_value(
               [['abcd', 'efgh'], ['aabbccdd']]),
           'num_buckets': 1000},
          {'op': string_ops.string_to_hash_bucket_fast,
           'x': ragged_factory_ops.constant_value(
               [['abcd', 'efgh'], ['aabbccdd']]),
           'num_buckets': 1000},
          {'op': string_ops.string_to_hash_bucket_strong,
           'x': ragged_factory_ops.constant_value(
               [['abcd', 'efgh'], ['aabbccdd']]),
           'num_buckets': 1000,
           'key': [1231, 12512]},
          {'op': string_ops.string_to_number,
           'x': ragged_factory_ops.constant_value([['-2.0', '3.0'], ['-3.0']])},
          {'op': string_ops.regex_full_match,
           'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
           'pattern': r'\w+'},
          {'op': string_ops.regex_replace,
           'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
           'pattern': r'\d',
           'rewrite': '#'},
          {'op': string_ops.substr,
           'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
           'pos': 2, 'len': 3},
          {'op': array_ops.check_numerics,
           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
           'message': 'check-numerics'},
          {'op': nn_ops.dropout,
           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
           'rate': 0.5,
           'seed': 1},
      ]
      )  # pyformat: disable
  def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
    result = op(x, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
    expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])

    # Check that the result has the expected shape.
    self.assertSameShape(x, result)

    # Check that the result has the expected (flattened) values.
    if ragged_tensor.is_ragged(result):
      result_flat_values = array_ops.reshape(result.flat_values, [-1])
    else:
      result_flat_values = array_ops.reshape(result, [-1])
    self.assertAllEqual(expected_flat_values, result_flat_values)

  @parameterized.parameters(
      [
          #=====================================================================
          # Without broadcasting -- i.e., shapes match exactly.
          #=====================================================================
          # Shapes: x:(), y:()
          {'x': 12,
           'y': 8},
          # Shapes: x:(3,), y:(3,)
          {'x': [7, 8, 9],
           'y': [1, -2, 3]},
          # Shapes: x:(2, 2), y:(2, 2)
          {'x': [[-2, 3], [-3, -4]],
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(2, None), y:(2, None)
          {'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
           'y': ragged_factory_ops.constant_value([[5, 6], [7]])},
          # Shapes: x:(2, 2, 2), y:(2, 2, 2)
          {'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
           'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]},
          # Shapes: x:(2, None, None), y: (2, None, None)
          {'x': ragged_factory_ops.constant_value(
              [[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
           'y': ragged_factory_ops.constant_value(
               [[[3, 8], [2], [5]], [[], [1, 9, 8]]])},
          # Shapes: x:(2, None, 2), y: (2, None, 2)
          {'x': ragged_factory_ops.constant_value(
              [[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
              ragged_rank=1),
           'y': ragged_factory_ops.constant_value(
               [[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
               ragged_rank=1)},

          #=====================================================================
          # With broadcasting
          #=====================================================================
          # Shapes: x:(), y:(3,)
          {'x': 12,                                 # Broadcast () -> (3,)
           'y': [1, -2, 3]},
          # Shapes: x:(1,), y:(3,)
          {'x': [12],                               # Broadcast (1,) -> (3,)
           'y': [1, -2, 3]},
          # Shapes: x:(), y:(2, 2)
          {'x': 12,                                 # Broadcast () -> (2, 2)
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(1,), y:(2, 2)
          {'x': 12,                                 # Broadcast (1,) -> (2, 2)
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(2, 1), y:(2, 2)
          {'x': [[10], [20]],                       # Broadcast (2, 1) -> (2, 2)
           'y': [[1, 2], [3, 4]]},
          # Shapes: x:(), y:(2, None)
          {'x': 10,                                 # Broadcast () -> (2, None)
           'y': ragged_factory_ops.constant_value(
               [[1, 2], [3]], dtype=np.int32)},
          # TODO(edloper): Add tests for more advanced broadcasting, once we add
          # support for it.

          #=====================================================================
          # Keyword Args
          #=====================================================================
          {'x': ragged_factory_ops.constant_value(
              [[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
           'y': ragged_factory_ops.constant_value(
               [[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
           'use_kwargs': ('x', 'y')},
          {'x': ragged_factory_ops.constant_value(
              [[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
              ragged_rank=1),
           'y': ragged_factory_ops.constant_value(
               [[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
               ragged_rank=1),
           'use_kwargs': ('x', 'y')},
          {'x': ragged_factory_ops.constant_value(
              [[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
              ragged_rank=1),
           'y': ragged_factory_ops.constant_value(
               [[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
               ragged_rank=1),
           'use_kwargs': ('x',)},
      ] +
      #=========================================================================
      # Test each binary op.
      #=========================================================================
      [{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
        'y': ragged_factory_ops.constant_value([[5.0, 1.0], [12.0]]),
        'op': op}
       for op in test_ops.BINARY_FLOAT_OPS] +
      [{'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
        'y': ragged_factory_ops.constant_value([[5, 1], [12]]),
        'op': op}
       for op in test_ops.BINARY_INT_OPS] +
      [{'x': ragged_factory_ops.constant_value([[True, True], [False]]),
        'y': ragged_factory_ops.constant_value([[False, True], [False]]),
        'op': op}
       for op in test_ops.BINARY_BOOL_OPS]
      )  # pyformat: disable
  def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
    use_kwargs = extra_args.pop('use_kwargs', ())
    if 'x' in use_kwargs and 'y' in use_kwargs:
      result = op(x=x, y=y, **extra_args)
    elif 'y' in use_kwargs:
      result = op(x, y=y, **extra_args)
    else:
      result = op(x, y, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
    dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
    expected_flat_values = array_ops.reshape(
        op(dense_x, dense_y, **extra_args), [-1])

    # Check that the result has the expected shape.
    self.assertSameShape(y, result)

    # Check that the result has the expected (flattened) values.
    if ragged_tensor.is_ragged(result):
      result_flat_values = array_ops.reshape(result.flat_values, [-1])
    else:
      result_flat_values = array_ops.reshape(result, [-1])
    self.assertAllEqual(expected_flat_values, result_flat_values)

  @parameterized.parameters(
      [
          {'inputs': (12, 8, 3)},
          {'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])},
          {'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])},
          {'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
                      ragged_factory_ops.constant_value([[4, 7], [88]]),
                      ragged_factory_ops.constant_value([[2, 9], [12]]))},
          {'inputs': (ragged_factory_ops.constant_value(
              [[[1, 3], [-3]], [[1]]]),
                      ragged_factory_ops.constant_value(
                          [[[4, 7], [88]], [[2]]]),
                      ragged_factory_ops.constant_value(
                          [[[2, 9], [12]], [[8]]]))},
          {'inputs': (
              ragged_factory_ops.constant_value([[[1, 3], [3, 4]], [[1, 5]]],
                                                ragged_rank=1),
              ragged_factory_ops.constant_value([[[4, 7], [1, 2]], [[2, 2]]],
                                                ragged_rank=1),
              ragged_factory_ops.constant_value([[[2, 9], [5, 2]], [[8, 0]]],
                                                ragged_rank=1))},
          {'inputs': (
              ragged_factory_ops.constant_value([[[1, 3], [-3]], [[1]]]),
              ragged_factory_ops.constant_value([[[4, 7], [88]], [[2]]]),
              ragged_factory_ops.constant_value([[[2, 9], [12]], [[8]]])),
           'use_kwargs': True},
      ] + [
          {'op': math_ops.add_n,
           'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
                      ragged_factory_ops.constant_value([[4, 7], [88]]),
                      ragged_factory_ops.constant_value([[2, 9], [12]]))},
          {'op': string_ops.string_join,
           'inputs': (
               ragged_factory_ops.constant_value([['a', 'b'], ['c']]),
               ragged_factory_ops.constant_value([['foo', 'bar'], ['baz']]),
               ragged_factory_ops.constant_value([['2', '9'], ['12']]))},
      ])  # pyformat: disable
  def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
                                  **extra_args):
    use_kwargs = extra_args.pop('use_kwargs', False)
    if use_kwargs:
      result = op(inputs=inputs, **extra_args)
    else:
      result = op(inputs, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_inputs = [
        x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
    ]
    expected_flat_values = array_ops.reshape(
        op(dense_inputs, **extra_args), [-1])

    # Check that the result has the expected shape.
    self.assertSameShape(inputs[0], result)

    # Check that the result has the expected (flattened) values.
    if ragged_tensor.is_ragged(result):
      result_flat_values = array_ops.reshape(result.flat_values, [-1])
    else:
      result_flat_values = array_ops.reshape(result, [-1])
    self.assertAllEqual(expected_flat_values, result_flat_values)

  def testElementwiseOpUnknownRankError(self):
    if context.executing_eagerly():
      return
    x = ragged_factory_ops.constant([[1, 2], [3]])
    y = ragged_tensor.RaggedTensor.from_row_splits(
        array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
    with self.assertRaisesRegex(ValueError,
                                r'Unable to broadcast: unknown rank'):
      math_ops.add(x, y)

  @parameterized.parameters([
      dict(
          x=ragged_factory_ops.constant_value([[1, 2], [3]]),
          y=[[10]],
          expected=[[11, 12], [13]]),
      dict(
          x=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5]]],
                                              ragged_rank=2),
          y=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
                                              ragged_rank=1),
          expected=[[[11, 12], [23, 24]], [[35]]]),
      dict(
          x=ragged_factory_ops.constant_value([[[1]]]),
          y=ragged_factory_ops.constant_value([[1]]),
          expected=[[[2]]]),
  ])
  def testElementwiseOpBroadcast(self, x, y, expected):
    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
    result = x + y
    self.assertAllEqual(result, expected)

  def testElementwiseOpShapeMismatch(self):
    x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
    y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
    with self.assertRaises((ValueError, errors.InvalidArgumentError)):
      self.evaluate(math_ops.add(x, y))

  def testBinaryOpSparseAndRagged(self):
    x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
    y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
    with self.assertRaises((TypeError, ValueError)):
      self.evaluate(math_ops.add(x, y))

    with self.assertRaises((TypeError, ValueError)):
      self.evaluate(math_ops.add_n([x, y]))

  @parameterized.parameters([
      dict(
          op=array_ops.batch_gather,
          args=(ragged_factory_ops.constant_value([[5, 6, 7], [8, 9]]),
                ragged_factory_ops.constant_value([[2, 1, 0], [1]])),
          expected=ragged_factory_ops.constant_value([[7, 6, 5], [9]])),
      dict(
          op=array_ops.concat,
          args=([
              ragged_factory_ops.constant_value([[1, 2, 3], [4]],
                                                dtype=np.int32),
              np.array([[5, 6]], dtype=np.int32)
          ],),
          kwargs={'axis': 0},
          expected=ragged_factory_ops.constant_value([[1, 2, 3], [4], [5, 6]])),
      dict(
          op=array_ops.expand_dims,
          kwargs={
              'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'axis': 0
          },
          expected=ragged_factory_ops.constant_value([[[1, 2], [3]]])),
      dict(
          op=array_ops.expand_dims_v2,
          kwargs={
              'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'axis': -1
          },
          expected=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
                                                     ragged_rank=1),
      ),
      dict(
          op=array_ops.gather,
          kwargs={
              'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'indices': [1, 0, 1]
          },
          expected=ragged_factory_ops.constant_value([[3], [1, 2], [3]])),
      dict(
          op=array_ops.gather_v2,
          kwargs={
              'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'indices': ragged_factory_ops.constant_value([[1, 0], [1]])
          },
          expected=ragged_factory_ops.constant_value([[[3], [1, 2]], [[3]]])),
      dict(
          op=array_ops.gather_nd,
          kwargs={
              'params': ragged_factory_ops.constant_value([[7, 8], [9]]),
              'indices': [[0, 1], [1, 0], [0, 0]]
          },
          expected=ragged_factory_ops.constant_value([8, 9, 7])),
      dict(
          op=array_ops.one_hot,
          kwargs={
              'indices':
                  ragged_factory_ops.constant_value([[1, 2, 3], [0]],
                                                    dtype=np.int32),
              'depth':
                  4,
              'axis':
                  -1
          },
          expected=ragged_factory_ops.constant_value(
              [[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], [[1, 0, 0, 0]]],
              ragged_rank=1)),
      dict(
          op=array_ops.stack,
          args=([
              ragged_factory_ops.constant_value([[1, 2, 3], [4]],
                                                dtype=np.int32),
              np.array([[5, 6]], dtype=np.int32)
          ],),
          expected=ragged_factory_ops.constant_value([[[1, 2, 3], [4]],
                                                      [[5, 6]]])),
      dict(
          op=array_ops.tile,
          args=([
              ragged_factory_ops.constant_value([[1, 2], [3]], dtype=np.int32),
              [2, 3]
          ]),
          expected=ragged_factory_ops.constant_value([[1, 2, 1, 2, 1, 2],
                                                      [3, 3, 3],
                                                      [1, 2, 1, 2, 1, 2],
                                                      [3, 3, 3]])),
      dict(
          op=array_ops.where,
          args=(ragged_factory_ops.constant_value([[True, False], [True]]),
                ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]),
                ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])),
          expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])),
      dict(
          op=array_ops.where,
          args=(ragged_factory_ops.constant_value([[True, False], [True]]),),
          expected=[[0, 0], [1, 0]]),
      dict(
          op=math_ops.unsorted_segment_sum,
          kwargs={
              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]),
              'num_segments': 3
          },
          expected=[4, 0, 2]),
      dict(
          op=math_ops.unsorted_segment_prod,
          kwargs={
              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]),
              'num_segments': 3
          },
          expected=[3, 1, 2]),
      dict(
          op=math_ops.unsorted_segment_min,
          kwargs={
              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
              'num_segments': 2
          },
          expected=[1, 2]),
      dict(
          op=math_ops.unsorted_segment_max,
          kwargs={
              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
              'num_segments': 2
          },
          expected=[3, 2]),
      dict(
          op=math_ops.unsorted_segment_mean,
          kwargs={
              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
              'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
              'num_segments': 2
          },
          expected=[2, 2]),
      dict(
          op=math_ops.unsorted_segment_sqrt_n,
          kwargs={
              'data':
                  ragged_factory_ops.constant_value([[1.0, 2.0],
                                                     [3.0, 4.0, 6.0]]),
              'segment_ids':
                  ragged_factory_ops.constant_value([[0, 1], [0, 0, 0]]),
              'num_segments':
                  2
          },
          expected=[7.0, 2.0]),
      dict(
          op=math_ops.reduce_sum,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
              'axis':
                  1
          },
          expected=[3, 12]),
      dict(
          op=math_ops.reduce_prod,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
              'axis':
                  1
          },
          expected=[2, 60]),
      dict(
          op=math_ops.reduce_min,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
              'axis':
                  1
          },
          expected=[1, 3]),
      dict(
          op=math_ops.reduce_max,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
              'axis':
                  1
          },
          expected=[2, 5]),
      dict(
          op=math_ops.reduce_mean,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[1, 3], [3, 4, 5]]),
              'axis':
                  1
          },
          expected=[2, 4]),
      dict(
          op=math_ops.reduce_any,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[True, False],
                                                     [True, True, True]]),
              'axis':
                  1
          },
          expected=[True, True]),
      dict(
          op=string_ops.reduce_join,
          kwargs={
              'inputs':
                  ragged_factory_ops.constant_value([[
                      b'this', b'is', b'a', b'test', b'for', b'ragged',
                      b'tensors'
                  ], [b'please', b'do', b'not', b'panic', b'!']]),
              'axis':
                  0,
              'keepdims':
                  False,
              'separator':
                  ''
          },
          expected=[
              b'thisplease', b'isdo', b'anot', b'testpanic', b'for!', b'ragged',
              b'tensors'
          ]),
      dict(
          op=math_ops.reduce_all,
          kwargs={
              'input_tensor':
                  ragged_factory_ops.constant_value([[True, False],
                                                     [True, True, True]]),
              'axis':
                  1
          },
          expected=[False, True]),
      dict(
          op=array_ops.rank,
          kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
          expected=2),
      dict(
          op=array_ops.size,
          kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
          expected=3),
      dict(
          op=array_ops.size_v2,
          kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
          expected=3),
      dict(
          op=array_ops.squeeze,
          kwargs={
              'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
              'axis': [0]
          },
          expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
      dict(
          op=array_ops.squeeze_v2,
          kwargs={
              'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
              'axis': [0]
          },
          expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
      dict(
          op=data_flow_ops.dynamic_partition,
          kwargs={
              'data': ragged_factory_ops.constant_value([[1], [2, 3, 4], [5]]),
              'partitions': [2, 1, 1],
              'num_partitions': 3
          },
          expected=[
              ragged_factory_ops.constant_value([], ragged_rank=1),
              ragged_factory_ops.constant_value([[2, 3, 4], [5]]),
              ragged_factory_ops.constant_value([[1]])
          ],
          result_is_list=True),
      dict(
          op=array_ops.reverse,
          kwargs={
              'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
              'axis': [0, -1]
          },
          expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])),
      dict(
          op=string_ops.string_format,
          kwargs={'template': 'Hi {}',
                  'inputs': [ragged_factory_ops.constant_value([[1, 2], [3]])]},
          expected='Hi [[1, 2], [3]]'),
  ])
  def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
                         kwargs=None):
    if kwargs is None: kwargs = {}
    result = op(*args, **kwargs)
    if result_is_list:
      self.assertLen(result, len(expected))
      for (r, e) in zip(result, expected):
        self.assertAllEqual(r, e)
    else:
      self.assertAllEqual(result, expected)

  def testUnaryElementwiseOpsPreserveUniformRowLength(self):
    # Unary elementwise op
    rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
        ragged_factory_ops.constant([[1, 2], [3]]),
        uniform_row_length=2)
    self.assertAllEqual(rt.uniform_row_length,
                        array_ops.zeros_like(rt).uniform_row_length)

    # Unary-list elementwise op
    rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
        ragged_factory_ops.constant([[1, 2], [3]]),
        uniform_row_length=2)
    self.assertAllEqual(rt.uniform_row_length,
                        math_ops.add_n([rt, rt]).uniform_row_length)

  def test_ragged_op_list(self):
    # Ops that should be listed as supported in both v1 and v2.
    supported_ops = [
        'bitwise.bitwise_and', 'bitwise.bitwise_or', 'bitwise.bitwise_xor',
        'bitwise.invert', 'bitwise.left_shift', 'bitwise.right_shift',
        'clip_by_value', 'concat', 'debugging.check_numerics', 'cast',
        'dtypes.complex', 'dtypes.saturate_cast', 'expand_dims', 'gather_nd',
        'gather', 'identity', 'io.decode_base64', 'io.decode_compressed',
        'io.encode_base64', 'math.abs', 'math.acos', 'math.acosh', 'math.add_n',
        'math.add', 'math.angle', 'math.asin', 'math.asinh', 'math.atan2',
        'math.atan', 'math.atanh', 'math.ceil', 'math.conj', 'math.cos',
        'math.cosh', 'math.digamma', 'math.divide_no_nan', 'math.divide',
        'math.equal', 'math.erf', 'math.erfc', 'math.exp', 'math.expm1',
        'math.floor', 'math.floordiv', 'math.floormod', 'math.greater_equal',
        'math.greater', 'math.imag', 'math.is_finite', 'math.is_inf',
        'math.is_nan', 'math.less_equal', 'math.less', 'math.lgamma',
        'math.log1p', 'math.log_sigmoid', 'math.log', 'math.logical_and',
        'math.logical_not', 'math.logical_or', 'math.logical_xor',
        'math.maximum', 'math.minimum', 'math.multiply', 'math.negative',
        'math.not_equal', 'math.pow', 'math.real', 'math.reciprocal',
        'math.reduce_any', 'math.reduce_max', 'math.reduce_mean',
        'math.reduce_min', 'math.reduce_prod', 'math.reduce_sum', 'math.rint',
        'math.round', 'math.rsqrt', 'math.sign', 'math.sin', 'math.sinh',
        'math.sqrt', 'math.square', 'math.squared_difference', 'math.subtract',
        'math.tan', 'math.truediv', 'math.unsorted_segment_max',
        'math.unsorted_segment_mean', 'math.unsorted_segment_min',
        'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n',
        'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv',
        'math.reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
        'strings.join', 'strings.length', 'strings.reduce_join',
        'strings.regex_full_match', 'strings.regex_replace', 'strings.strip',
        'strings.substr', 'strings.to_hash_bucket_fast',
        'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
        'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
        'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse',
        'nn.dropout', 'strings.format', 'print'
    ]

    # Ops that should be listed as supported in v1 only.
    # TODO(edloper): Add a dispatch for where_v2.
    supported_ops_v1 = ['batch_gather', 'where']

    # Ops that should be listed as supported in v2 only.
    supported_ops_v2 = []

    v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1)
    for element in supported_ops + supported_ops_v1:
      self.assertIn('`tf.' + element + '`', v1_ragged_ops)
    for element in supported_ops_v2:
      self.assertNotIn('`tf.' + element + '`', v1_ragged_ops)

    v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2)
    for element in supported_ops + supported_ops_v2:
      self.assertIn('`tf.' + element + '`', v2_ragged_ops)
    for element in supported_ops_v1:
      self.assertNotIn('`tf.' + element + '`', v2_ragged_ops)
示例#30
0
class RaggedWhereV1OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):

  @parameterized.parameters([
      #=========================================================================
      # Docstring Examples
      #=========================================================================
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True], [False, True]]),
          expected=[[0, 0], [0, 2], [1, 1]]),
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True], [False, True]]),
          x=ragged_factory_ops.constant_value(
              [['A', 'B', 'C'], ['D', 'E']]),
          y=ragged_factory_ops.constant_value(
              [['a', 'b', 'c'], ['d', 'e']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'b', b'C'], [b'd', b'E']])),
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value([True, False]),
          x=ragged_factory_ops.constant_value([['A', 'B', 'C'], ['D', 'E']]),
          y=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d', 'e']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'B', b'C'], [b'd', b'e']])),
      #=========================================================================
      # Coordinate-retrieval mode
      #=========================================================================
      dict(  # shape=[D1]
          condition=[True, False, True, False, True],
          expected=[[0], [2], [4]]),
      dict(  # shape=[D1, D2]
          condition=[[True, False], [False, True]],
          expected=[[0, 0], [1, 1]]),
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True], [False, True]]),
          expected=[[0, 0], [0, 2], [1, 1]]),
      dict(  # shape=[D1, (D2), (D3)]
          condition=ragged_factory_ops.constant_value([
              [[True, False, True], [False, True]],
              [[True], [], [False], [False, True, False]]
          ]),
          expected=[[0, 0, 0], [0, 0, 2], [0, 1, 1],
                    [1, 0, 0], [1, 3, 1]]),
      dict(  # shape=[D1, (D2), D3]
          condition=ragged_factory_ops.constant_value([
              [[True, False], [False, True]],
              [[True, False], [False, False], [True, False], [False, True]]
          ], ragged_rank=1),
          expected=[[0, 0, 0], [0, 1, 1],
                    [1, 0, 0], [1, 2, 0], [1, 3, 1]]),
      dict(  # shape=[D1, (D2), (D3), (D4)]
          condition=ragged_factory_ops.constant_value([
              [[[], [True]]],
              [[[True, False, True], [False, True]],
               [[True], [], [False], [False, True, False]]]
          ]),
          expected=[[0, 0, 1, 0],
                    [1, 0, 0, 0], [1, 0, 0, 2], [1, 0, 1, 1],
                    [1, 1, 0, 0], [1, 1, 3, 1]]),

      #=========================================================================
      # Elementwise value-selection mode
      #=========================================================================
      dict(  # shape=[]
          condition=True, x='A', y='a', expected=b'A'),
      dict(  # shape=[]
          condition=False, x='A', y='a', expected=b'a'),
      dict(  # shape=[D1]
          condition=[True, False, True],
          x=['A', 'B', 'C'],
          y=['a', 'b', 'c'],
          expected=[b'A', b'b', b'C']),
      dict(  # shape=[D1, D2]
          condition=[[True, False], [False, True]],
          x=[['A', 'B'], ['D', 'E']],
          y=[['a', 'b'], ['d', 'e']],
          expected=[[b'A', b'b'], [b'd', b'E']]),
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True], [False, True]]),
          x=ragged_factory_ops.constant_value([['A', 'B', 'C'], ['D', 'E']]),
          y=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d', 'e']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'b', b'C'], [b'd', b'E']])),
      dict(  # shape=[D1, (D2), D3]
          condition=ragged_factory_ops.constant_value([
              [[True, False], [False, True]],
              [[True, False], [False, False], [True, False], [False, True]]
          ], ragged_rank=1),
          x=ragged_factory_ops.constant_value([
              [['A', 'B'], ['C', 'D']],
              [['E', 'F'], ['G', 'H'], ['I', 'J'], ['K', 'L']]
          ], ragged_rank=1),
          y=ragged_factory_ops.constant_value([
              [['a', 'b'], ['c', 'd']],
              [['e', 'f'], ['g', 'h'], ['i', 'j'], ['k', 'l']]
          ], ragged_rank=1),
          expected=ragged_factory_ops.constant_value([
              [[b'A', b'b'], [b'c', b'D']],
              [[b'E', b'f'], [b'g', b'h'], [b'I', b'j'], [b'k', b'L']]
          ], ragged_rank=1)),
      dict(  # shape=[D1, (D2), (D3), (D4)]
          condition=ragged_factory_ops.constant_value([
              [[[], [True]]],
              [[[True, False, True], [False, True]],
               [[True], [], [False], [False, True, False]]]
          ]),
          x=ragged_factory_ops.constant_value([
              [[[], ['A']]],
              [[['B', 'C', 'D'], ['E', 'F']],
               [['G'], [], ['H'], ['I', 'J', 'K']]]
          ]),
          y=ragged_factory_ops.constant_value([
              [[[], ['a']]],
              [[['b', 'c', 'd'], ['e', 'f']],
               [['g'], [], ['h'], ['i', 'j', 'k']]]
          ]),
          expected=ragged_factory_ops.constant_value([
              [[[], [b'A']]],
              [[[b'B', b'c', b'D'], [b'e', b'F']],
               [[b'G'], [], [b'h'], [b'i', b'J', b'k']]]
          ])),

      #=========================================================================
      # Elementwise row-selection mode
      #=========================================================================
      dict(  # x.shape=[D1, D2], y.shape=[D1, D2]
          condition=[True, False, True],
          x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
          y=[['a', 'b'], ['c', 'd'], ['e', 'f']],
          expected=[[b'A', b'B'], [b'c', b'd'], [b'E', b'F']]),
      dict(  # x.shape=[D1, D2], y.shape=[D1, (D2)]
          condition=[True, False, True],
          x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
          y=ragged_factory_ops.constant_value(
              [['a', 'b'], ['c'], ['d', 'e']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'B'], [b'c'], [b'E', b'F']])),
      dict(  # x.shape=[D1, (D2)], y.shape=[D1, (D2)]
          condition=[True, False, True],
          x=ragged_factory_ops.constant_value(
              [['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]),
          y=ragged_factory_ops.constant_value(
              [['a', 'b'], ['c'], ['d', 'e']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'B', b'C'], [b'c'], [b'F', b'G']])),
      dict(  # shape=[D1, (D2), (D3), (D4)]
          condition=ragged_factory_ops.constant_value([True, False]),
          x=ragged_factory_ops.constant_value([
              [[[], ['A']]],
              [[['B', 'C', 'D'], ['E', 'F']],
               [['G'], [], ['H'], ['I', 'J', 'K']]]
          ]),
          y=ragged_factory_ops.constant_value([[[['a']]], [[['b']]]]),
          expected=ragged_factory_ops.constant_value(
              [[[[], [b'A']]], [[[b'b']]]])),
  ])   # pyformat: disable
  def testRaggedWhere(self, condition, expected, x=None, y=None):
    result = ragged_where_op.where(condition, x, y)
    self.assertAllEqual(result, expected)

  @parameterized.parameters([
      dict(
          condition=[True, False],
          x=[1, 2],
          error=ValueError,
          message='x and y must be either both None or both non-None'),
      dict(
          condition=ragged_factory_ops.constant_value([[True, False, True],
                                                       [False, True]]),
          x=ragged_factory_ops.constant_value([['A', 'B', 'C'], ['D', 'E']]),
          y=[['a', 'b'], ['d', 'e']],
          error=ValueError,
          message='Input shapes do not match.'),
  ])
  def testRaggedWhereErrors(self, condition, error, message, x=None, y=None):
    with self.assertRaisesRegex(error, message):
      ragged_where_op.where(condition, x, y)
示例#31
0
class RaggedWhereV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):

  @parameterized.parameters([
      #=========================================================================
      # Coordinate-retrieval mode
      #=========================================================================
      dict(  # shape=[D1]
          condition=[True, False, True, False, True],
          expected=[[0], [2], [4]]),
      dict(  # shape=[D1, D2]
          condition=[[True, False], [False, True]],
          expected=[[0, 0], [1, 1]]),
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True], [False, True]]),
          expected=[[0, 0], [0, 2], [1, 1]]),
      dict(  # shape=[D1, (D2), (D3)]
          condition=ragged_factory_ops.constant_value([
              [[True, False, True], [False, True]],
              [[True], [], [False], [False, True, False]]
          ]),
          expected=[[0, 0, 0], [0, 0, 2], [0, 1, 1],
                    [1, 0, 0], [1, 3, 1]]),
      dict(  # shape=[D1, (D2), D3]
          condition=ragged_factory_ops.constant_value([
              [[True, False], [False, True]],
              [[True, False], [False, False], [True, False], [False, True]]
          ], ragged_rank=1),
          expected=[[0, 0, 0], [0, 1, 1],
                    [1, 0, 0], [1, 2, 0], [1, 3, 1]]),
      dict(  # shape=[D1, (D2), (D3), (D4)]
          condition=ragged_factory_ops.constant_value([
              [[[], [True]]],
              [[[True, False, True], [False, True]],
               [[True], [], [False], [False, True, False]]]
          ]),
          expected=[[0, 0, 1, 0],
                    [1, 0, 0, 0], [1, 0, 0, 2], [1, 0, 1, 1],
                    [1, 1, 0, 0], [1, 1, 3, 1]]),

      #=========================================================================
      # Elementwise multiplexing
      #=========================================================================
      dict(  # shape=[]
          condition=True, x='A', y='a', expected=b'A'),
      dict(  # shape=[]
          condition=False, x='A', y='a', expected=b'a'),
      dict(  # shape=[D1]
          condition=[True, False, True],
          x=['A', 'B', 'C'],
          y=['a', 'b', 'c'],
          expected=[b'A', b'b', b'C']),
      dict(  # shape=[D1, D2]
          condition=[[True, False], [False, True]],
          x=[['A', 'B'], ['D', 'E']],
          y=[['a', 'b'], ['d', 'e']],
          expected=[[b'A', b'b'], [b'd', b'E']]),
      dict(  # shape=[D1, (D2)]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True], [False, True]]),
          x=ragged_factory_ops.constant_value([['A', 'B', 'C'], ['D', 'E']]),
          y=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d', 'e']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'b', b'C'], [b'd', b'E']])),
      dict(  # shape=[D1, (D2), D3]
          condition=ragged_factory_ops.constant_value([
              [[True, False], [False, True]],
              [[True, False], [False, False], [True, False], [False, True]]
          ], ragged_rank=1),
          x=ragged_factory_ops.constant_value([
              [['A', 'B'], ['C', 'D']],
              [['E', 'F'], ['G', 'H'], ['I', 'J'], ['K', 'L']]
          ], ragged_rank=1),
          y=ragged_factory_ops.constant_value([
              [['a', 'b'], ['c', 'd']],
              [['e', 'f'], ['g', 'h'], ['i', 'j'], ['k', 'l']]
          ], ragged_rank=1),
          expected=ragged_factory_ops.constant_value([
              [[b'A', b'b'], [b'c', b'D']],
              [[b'E', b'f'], [b'g', b'h'], [b'I', b'j'], [b'k', b'L']]
          ], ragged_rank=1)),
      dict(  # shape=[D1, (D2), (D3), (D4)]
          condition=ragged_factory_ops.constant_value([
              [[[], [True]]],
              [[[True, False, True], [False, True]],
               [[True], [], [False], [False, True, False]]]
          ]),
          x=ragged_factory_ops.constant_value([
              [[[], ['A']]],
              [[['B', 'C', 'D'], ['E', 'F']],
               [['G'], [], ['H'], ['I', 'J', 'K']]]
          ]),
          y=ragged_factory_ops.constant_value([
              [[[], ['a']]],
              [[['b', 'c', 'd'], ['e', 'f']],
               [['g'], [], ['h'], ['i', 'j', 'k']]]
          ]),
          expected=ragged_factory_ops.constant_value([
              [[[], [b'A']]],
              [[[b'B', b'c', b'D'], [b'e', b'F']],
               [[b'G'], [], [b'h'], [b'i', b'J', b'k']]]
          ])),

      #=========================================================================
      # Broadcasting
      #=========================================================================
      dict(  # c.shape=[D1], x.shape=[D1, D2], y.shape=[D1, D2]
          condition=[[True], [False], [True]],
          x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
          y=[['a', 'b'], ['c', 'd'], ['e', 'f']],
          expected=[[b'A', b'B'], [b'c', b'd'], [b'E', b'F']]),
      dict(  # c.shape=[D1], x.shape=[D1, (D2)], y.shape=[D1, (D2)]
          condition=[[True], [False], [True]],
          x=ragged_factory_ops.constant_value(
              [['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]),
          y=ragged_factory_ops.constant_value(
              [['a', 'b', 'c'], ['d', 'e'], ['f', 'g']]),
          expected=ragged_factory_ops.constant_value(
              [[b'A', b'B', b'C'], [b'd', b'e'], [b'F', b'G']])),
      dict(  # c.shape=[D1, None], x.shape=[], y.shape=[]
          condition=ragged_factory_ops.constant_value(
              [[True, False, True, True], [True, False]]),
          x=10,
          y=20,
          expected=ragged_factory_ops.constant_value(
              [[10, 20, 10, 10], [10, 20]])),
      dict(  # c.shape=[D1, D2], x.shape=[D1, 1], y.shape=[1, D2]
          condition=[[True, False], [True, False], [False, True]],
          x=[[10], [20], [30]],
          y=[[40, 50]],
          expected=[[10, 50], [20, 50], [40, 30]]),
      dict(  # c.shape=[D1, (D2), D3], x.shape=[D1, (D2), 1], y.shape=[D3]
          condition=ragged_factory_ops.constant_value(
              [[[True, False], [False, True]], [[True, True]]],
              ragged_rank=1),
          x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
                                              ragged_rank=1),
          y=np.array([[[40, 50]]]),
          expected=[[[10, 50], [40, 20]], [[30, 30]]]),
  ])   # pyformat: disable
  def testRaggedWhere(self, condition, expected, x=None, y=None):
    result = ragged_where_op.where_v2(condition, x, y)
    self.assertAllEqual(result, expected)

  @parameterized.parameters([
      dict(
          condition=[True, False],
          x=[1, 2],
          error=ValueError,
          message='x and y must be either both None or both non-None'),
      dict(
          condition=ragged_factory_ops.constant_value([[True, False, True],
                                                       [False, True]]),
          x=ragged_factory_ops.constant_value([['A', 'B', 'C'], ['D', 'E']]),
          y=[['a', 'b'], ['d', 'e']],
          error=errors.InvalidArgumentError,
          message=r'must be broadcastable|Unable to broadcast'),
  ])
  def testRaggedWhereErrors(self, condition, error, message, x=None, y=None):
    with self.assertRaisesRegex(error, message):
      self.evaluate(ragged_where_op.where_v2(condition, x, y))
示例#32
0
class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
                               parameterized.TestCase):

    # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
    # assertAllEqual etc to work with StructuredTensors.
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorSpecTest,
                         self).assertAllEqual(a, b, msg)
        if not (isinstance(a, structured_tensor.StructuredTensor)
                and isinstance(b, structured_tensor.StructuredTensor)):
            # TODO(edloper) Add support for this once structured_factory_ops is added.
            raise ValueError('Not supported yet')

        self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            self.assertAllEqual(a.field_value(field), b.field_value(field))

    def assertAllTensorsEqual(self, x, y):
        assert isinstance(x, dict) and isinstance(y, dict)
        self.assertEqual(set(x), set(y))
        for key in x:
            self.assertAllEqual(x[key], y[key])

    def testConstruction(self):
        spec1_fields = dict(a=T_1_2_3_4)
        spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields)
        self.assertEqual(spec1._shape, (1, 2, 3))
        self.assertEqual(spec1._field_specs, spec1_fields)

        spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
        spec2 = StructuredTensorSpec([1, 2], spec2_fields)
        self.assertEqual(spec2._shape, (1, 2))
        self.assertEqual(spec2._field_specs, spec2_fields)

    @parameterized.parameters([
        (None, {}, r"StructuredTensor's shape must have known rank\."),
        ([], None, r'field_specs must be a dictionary\.'),
        ([], {
            1: tensor_spec.TensorSpec(None)
        }, r'field_specs must be a dictionary with string keys\.'),
        ([], {
            'x': 0
        }, r'field_specs must be a dictionary with TypeSpec values\.'),
    ])
    def testConstructionErrors(self, shape, field_specs, error):
        with self.assertRaisesRegex(TypeError, error):
            structured_tensor.StructuredTensorSpec(shape, field_specs)

    def testValueType(self):
        spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2))
        self.assertEqual(spec1.value_type, StructuredTensor)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3],
                              {}), (tensor_shape.TensorShape([1, 2, 3]), {})),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), (tensor_shape.TensorShape([1, 2]), {
            'a': T_1_2,
            'b': R_1_N
        })),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
    ])  # pyformat: disable
    def testSerialize(self, spec, expected):
        serialization = spec._serialize()
        # Note that we can only use assertEqual because none of our cases include
        # a None dimension. A TensorShape with a None dimension is never equal
        # to another TensorShape.
        self.assertEqual(serialization, expected)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3], {}),
         ({}, NROWS_SPEC, (PARTITION_SPEC, PARTITION_SPEC))),
        (StructuredTensorSpec([], {'a': T_1_2}), ({
            'a': T_1_2
        }, (), ())),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), ({
            'a': T_1_2,
            'b': R_1_N
        }, NROWS_SPEC, (PARTITION_SPEC, ))),
        (StructuredTensorSpec([], {'a': T_1_2}), ({
            'a': T_1_2
        }, (), ())),
    ])  # pyformat: disable
    def testComponentSpecs(self, spec, expected):
        self.assertEqual(spec._component_specs, expected)

    @parameterized.parameters([
        {
            'shape': [],
            'fields': dict(x=[[1.0, 2.0]]),
            'field_specs': dict(x=T_1_2),
        },
        {
            'shape': [2],
            'fields':
            dict(a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
                 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
            'field_specs':
            dict(a=R_1_N, b=T_2_3),
        },
    ])  # pyformat: disable
    def testToFromComponents(self, shape, fields, field_specs):
        struct = StructuredTensor.from_fields(fields, shape)
        spec = StructuredTensorSpec(shape, field_specs)
        actual_components = spec._to_components(struct)
        self.assertLen(actual_components, 3)
        self.assertAllTensorsEqual(actual_components[0], fields)
        rt_reconstructed = spec._from_components(actual_components)
        self.assertAllEqual(struct, rt_reconstructed)

    def testToFromComponentsEmptyScalar(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)
        self.assertEqual(components, ({}, (), ()))

    def testToFromComponentsEmptyTensor(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)
        self.assertLen(components, 3)
        fields, nrows, row_partitions = components
        self.assertEmpty(fields)
        self.assertAllEqual(nrows, 1)
        self.assertLen(row_partitions, 2)
        self.assertIsInstance(row_partitions[0], row_partition.RowPartition)
        self.assertIsInstance(row_partitions[1], row_partition.RowPartition)
        self.assertAllEqual(row_partitions[0].row_splits(), [0, 2])
        self.assertAllEqual(row_partitions[1].row_splits(), [0, 3, 6])

    @parameterized.parameters([{
        'unbatched': StructuredTensorSpec([], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5], {}),
    }, {
        'unbatched': StructuredTensorSpec([1, 2], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5, 1, 2], {}),
    }, {
        'unbatched':
        StructuredTensorSpec([], dict(a=T_3, b=R_1_N)),
        'batch_size':
        2,
        'batched':
        StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)),
    }])  # pyformat: disable
    def testBatchUnbatch(self, unbatched, batch_size, batched):
        self.assertEqual(unbatched._batch(batch_size), batched)
        self.assertEqual(batched._unbatch(), unbatched)

    @parameterized.parameters([
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields({
                    'a': 1,
                    'b': [5, 6]
                }),
                StructuredTensor.from_fields({
                    'a': 2,
                    'b': [7, 8]
                })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={
                                                     'a': [1, 2],
                                                     'b': [[5, 6], [7, 8]]
                                                 }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [1, 2, 3],
                                                 'b': [[5, 6], [6, 7], [7, 8]]
                                             }),
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [2, 3, 4],
                                                 'b': [[2, 2], [3, 3], [4, 4]]
                                             })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2, 3],
                fields={
                    'a': [[1, 2, 3], [2, 3, 4]],
                    'b': [[[5, 6], [6, 7], [7, 8]], [[2, 2], [3, 3], [4, 4]]]
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 1,
                        'b': StructuredTensor.from_fields({'x': [5]})
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 2,
                        'b': StructuredTensor.from_fields({'x': [6]})
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'a': [1, 2],
                    'b':
                    StructuredTensor.from_fields(shape=[2],
                                                 fields={'x': [[5], [6]]})
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d':
                        ragged_factory_ops.constant_value([[1, 2], [3]]),
                        'Ragged2d':
                        ragged_factory_ops.constant_value([1]),
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d': ragged_factory_ops.constant_value([[1]]),
                        'Ragged2d': ragged_factory_ops.constant_value([2, 3]),
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'Ragged3d':
                    ragged_factory_ops.constant_value([[[1, 2], [3]], [[1]]]),
                    'Ragged2d':
                    ragged_factory_ops.constant_value([[1], [2, 3]]),
                }),
            'use_only_batched_spec':
            True,
        },
    ])  # pyformat: disable
    def testBatchUnbatchValues(self,
                               unbatched,
                               batch_size,
                               batched,
                               use_only_batched_spec=False):
        batched = batched()  # Deferred init because it creates tensors.
        unbatched = unbatched()  # Deferred init because it creates tensors.

        # Test batching.
        if use_only_batched_spec:
            unbatched_spec = type_spec.type_spec_from_value(batched)._unbatch()
        else:
            unbatched_spec = type_spec.type_spec_from_value(unbatched[0])
        unbatched_tensor_lists = [
            unbatched_spec._to_tensor_list(st) for st in unbatched
        ]
        batched_tensor_list = [
            array_ops.stack(tensors)
            for tensors in zip(*unbatched_tensor_lists)
        ]
        actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list(
            batched_tensor_list)
        self.assertTrue(
            unbatched_spec._batch(batch_size).is_compatible_with(
                actual_batched))
        self.assertAllEqual(actual_batched, batched)

        # Test unbatching
        batched_spec = type_spec.type_spec_from_value(batched)
        batched_tensor_list = batched_spec._to_batched_tensor_list(batched)
        unbatched_tensor_lists = zip(
            *[array_ops.unstack(tensor) for tensor in batched_tensor_list])
        actual_unbatched = [
            batched_spec._unbatch()._from_tensor_list(tensor_list)
            for tensor_list in unbatched_tensor_lists
        ]
        self.assertLen(actual_unbatched, len(unbatched))
        for st in actual_unbatched:
            self.assertTrue(batched_spec._unbatch().is_compatible_with(st))
        for (actual, expected) in zip(actual_unbatched, unbatched):
            self.assertAllEqual(actual, expected)
示例#33
0
class StructuredTensorTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):

  def assertAllEqual(self, a, b, msg=None):
    if not (isinstance(a, structured_tensor.StructuredTensor) or
            isinstance(b, structured_tensor.StructuredTensor)):
      return super(StructuredTensorTest, self).assertAllEqual(a, b, msg)
    if not (isinstance(a, structured_tensor.StructuredTensor) and
            isinstance(b, structured_tensor.StructuredTensor)):
      # TODO(edloper) Add support for this once structured_factory_ops is added.
      raise ValueError("Not supported yet")

    self.assertEqual(repr(a.shape), repr(b.shape))
    self.assertEqual(set(a.field_names()), set(b.field_names()))
    for field in a.field_names():
      self.assertAllEqual(a.field_value(field), b.field_value(field))

  @parameterized.parameters([
      {
          "shape": [],
          "fields": {},
      },
      {
          "shape": [None],
          "fields": {},
      },
      {
          "shape": [1, 5, 3],
          "fields": {},
      },
      {
          "shape": [],
          "fields": {"Foo": 5, "Bar": [1, 2, 3]},
      },
      {
          "shape": [2],
          "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]},
      },
      {
          "shape": [None],
          "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]},
          "expected_shape": [2],  # inferred from field values.
      },
      {
          "shape": [],
          "fields": {
              "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
          },
      },
      {
          "shape": [2],
          "fields": {
              "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
          },
      },
  ])  # pyformat: disable
  def testConstruction(self, shape, fields, expected_shape=None):
    struct = structured_tensor.StructuredTensor.from_fields(shape, fields)
    if expected_shape is None:
      expected_shape = shape
    self.assertEqual(struct.shape.as_list(), expected_shape)
    self.assertLen(expected_shape, struct.rank)
    self.assertEqual(struct.field_names(), tuple(fields.keys()))
    for field, value in fields.items():
      self.assertIsInstance(
          struct.field_value(field),
          (ops.Tensor, structured_tensor.StructuredTensor,
           ragged_tensor.RaggedTensor))
      self.assertAllEqual(struct.field_value(field), value)

  def testNestedStructConstruction(self):
    rt = ragged_factory_ops.constant([[1, 2], [3]])
    struct1 = structured_tensor.StructuredTensor.from_fields([], {"x": [1, 2]})
    struct2 = structured_tensor.StructuredTensor.from_fields([2], {"x": [1, 2]})
    struct3 = structured_tensor.StructuredTensor.from_fields([], {
        "r": rt,
        "s": struct1
    })
    struct4 = structured_tensor.StructuredTensor.from_fields([2], {
        "r": rt,
        "s": struct2
    })

    self.assertEqual(struct3.shape.as_list(), [])
    self.assertEqual(struct3.rank, 0)
    self.assertEqual(set(struct3.field_names()), set(["r", "s"]))
    self.assertAllEqual(struct3.field_value("r"), rt)
    self.assertAllEqual(struct3.field_value("s"), struct1)

    self.assertEqual(struct4.shape.as_list(), [2])
    self.assertEqual(struct4.rank, 1)
    self.assertEqual(set(struct4.field_names()), set(["r", "s"]))
    self.assertAllEqual(struct4.field_value("r"), rt)
    self.assertAllEqual(struct4.field_value("s"), struct2)

  @parameterized.parameters([
      (object(), {}, TypeError),
      ([], object(), TypeError, "fields must be a dictionary"),
      ([], {1: 2}, TypeError, "Unexpected type for key"),
      ([], {"x": object()}, TypeError, "Unexpected type for value"),
      (None, {}, ValueError, "StructuredTensor's shape must have known rank"),
      ([5], {"f": 5}, ValueError, r"Shapes \(5,\) and \(\) are not compatible"),
      ([None], {"x": [1], "y": []}, ValueError,
       r"Shapes \([01],\) and \([01],\) are not compatible"),
      ([], {"": 5}, ValueError, "Field name '' is not currently allowed."),
      ([], {"_": 5}, ValueError, "Field name '_' is not currently allowed."),
  ])  # pyformat: disable
  def testConstructionErrors(self, shape, fields, err, msg=None):
    with self.assertRaisesRegexp(err, msg):
      structured_tensor.StructuredTensor.from_fields(shape, fields)
示例#34
0
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
                            parameterized.TestCase):
    def assertShapeEq(self, x, y):
        assert isinstance(x, RaggedTensorDynamicShape)
        assert isinstance(y, RaggedTensorDynamicShape)
        self.assertLen(x.partitioned_dim_sizes, len(y.partitioned_dim_sizes))
        for x_dims, y_dims in zip(x.partitioned_dim_sizes,
                                  y.partitioned_dim_sizes):
            self.assertAllEqual(x_dims, y_dims)
        self.assertAllEqual(x.inner_dim_sizes, y.inner_dim_sizes)

    @parameterized.parameters([
        dict(value='x', expected_dim_sizes=[]),
        dict(value=['a', 'b', 'c'], expected_dim_sizes=[3]),
        dict(value=[['a', 'b', 'c'], ['d', 'e', 'f']],
             expected_dim_sizes=[2, 3]),
        dict(value=[[['a', 'b', 'c'], ['d', 'e', 'f']]],
             expected_dim_sizes=[1, 2, 3]),
        dict(value=ragged_factory_ops.constant_value([['a', 'b', 'c'],
                                                      ['d', 'e']]),
             expected_dim_sizes=[2, [3, 2]]),
        dict(value=ragged_factory_ops.constant_value([[['a', 'b', 'c'],
                                                       ['d', 'e']]]),
             expected_dim_sizes=[1, [2], [3, 2]]),
        dict(value=ragged_factory_ops.constant_value(
            [[['a', 'b', 'c'], ['d', 'e', 'f']]], ragged_rank=1),
             expected_dim_sizes=[1, [2], 3]),
        dict(value=ragged_factory_ops.constant_value(
            [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
             expected_dim_sizes=[2, [2, 1], 2, 1]),
        dict(value=ragged_factory_ops.constant_value([[10, 20], [30]]),
             expected_dim_sizes=[2, [2, 1]]),
        # Docstring examples:
        dict(value=[[1, 2, 3], [4, 5, 6]], expected_dim_sizes=[2, 3]),
        dict(value=ragged_factory_ops.constant_value([[1, 2], [], [3, 4, 5]]),
             expected_dim_sizes=[3, [2, 0, 3]]),
        dict(value=ragged_factory_ops.constant_value(
            [[[1, 2], [3, 4]], [[5, 6]]], ragged_rank=1),
             expected_dim_sizes=[2, [2, 1], 2]),
        dict(value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4,
                                                                       5]]]),
             expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
    ])
    def testFromTensor(self, value, expected_dim_sizes):
        shape = RaggedTensorDynamicShape.from_tensor(value)
        expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
        self.assertShapeEq(shape, expected)

    @parameterized.parameters([
        dict(dim_sizes=[], rank=0, expected_dim_sizes=[]),
        dict(dim_sizes=[], rank=3, expected_dim_sizes=[1, 1, 1]),
        dict(dim_sizes=[3], rank=1, expected_dim_sizes=[3]),
        dict(dim_sizes=[3], rank=3, expected_dim_sizes=[1, 1, 3]),
        dict(dim_sizes=[2, 3], rank=3, expected_dim_sizes=[1, 2, 3]),
        dict(dim_sizes=[3, [3, 2, 4]],
             rank=2,
             expected_dim_sizes=[3, [3, 2, 4]]),
        dict(dim_sizes=[3, [3, 2, 4]],
             rank=4,
             expected_dim_sizes=[1, 1, 3, [3, 2, 4]]),
        dict(dim_sizes=[3, [3, 2, 4], 2, 3],
             rank=5,
             expected_dim_sizes=[1, 3, [3, 2, 4], 2, 3]),
    ])
    def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes):
        shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
        expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
        broadcasted_shape = shape.broadcast_to_rank(rank)
        self.assertShapeEq(broadcasted_shape, expected)
        self.assertEqual(broadcasted_shape.rank, rank)

    @parameterized.parameters([
        #=========================================================================
        # dimension[axis] is uniform inner; and row_lengths is a scalar
        #=========================================================================
        # shape: [BROADCAST(UNIFORM), UNIFORM, UNIFORM]
        dict(axis=0,
             row_length=3,
             original_dim_sizes=[1, 4, 5],
             broadcast_dim_sizes=[3, 4, 5]),

        # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
        dict(axis=2,
             row_length=5,
             original_dim_sizes=[3, 4, 1],
             broadcast_dim_sizes=[3, 4, 5]),

        # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
        dict(axis=2,
             row_length=5,
             original_dim_sizes=[3, [3, 2, 8], 1],
             broadcast_dim_sizes=[3, [3, 2, 8], 5]),

        # shape: [UNIFORM, RAGGED, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
        dict(axis=5,
             row_length=5,
             original_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 1],
             broadcast_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 5]),

        #=========================================================================
        # dimension[axis] is uniform inner; and row_lengths is a vector
        #=========================================================================
        # shape: [UNIFORM, BROADCAST(UNIFORM)]
        dict(axis=1,
             row_length=[2, 0, 1],
             original_dim_sizes=[3, 1],
             broadcast_dim_sizes=[3, [2, 0, 1]]),
        # shape: [UNIFORM, BROADCAST(UNIFORM), UNIFORM]
        dict(axis=1,
             row_length=[2, 0, 1],
             original_dim_sizes=[3, 1, 5],
             broadcast_dim_sizes=[3, [2, 0, 1], 5]),

        # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
        dict(axis=2,
             row_length=[2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0],
             original_dim_sizes=[4, 3, 1],
             broadcast_dim_sizes=[4, 3, [2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0]]),

        # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
        dict(axis=2,
             row_length=[2, 5, 3],
             original_dim_sizes=[2, [2, 1], 1],
             broadcast_dim_sizes=[2, [2, 1], [2, 5, 3]]),

        # shape: [UNIFORM, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM), UNIFORM]
        dict(axis=4,
             row_length=list(range(18)),
             original_dim_sizes=[2, [2, 1], 3, 2, 1, 8],
             broadcast_dim_sizes=[2, [2, 1], 3, 2,
                                  list(range(18)), 8]),

        #=========================================================================
        # dimension[axis] is uniform partitioned; and row_lengths is a scalar
        #=========================================================================
        # shape: [BROADCAST(UNIFORM), RAGGED]
        dict(axis=0,
             row_length=3,
             original_dim_sizes=[1, [5]],
             broadcast_dim_sizes=[3, [5, 5, 5]]),

        # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED]
        dict(axis=0,
             row_length=2,
             original_dim_sizes=[1, 3, [3, 0, 2]],
             broadcast_dim_sizes=[2, 3, [3, 0, 2, 3, 0, 2]]),

        # shape: [BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM, UNIFORM]
        dict(axis=0,
             row_length=3,
             original_dim_sizes=[1, [3], [3, 5, 2], 9, 4, 5],
             broadcast_dim_sizes=[
                 3, [3, 3, 3], [3, 5, 2, 3, 5, 2, 3, 5, 2], 9, 4, 5
             ]),

        # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED, UNIFORM]
        dict(axis=0,
             row_length=2,
             original_dim_sizes=[1, 2, [2, 1], [3, 5, 2], 2],
             broadcast_dim_sizes=[2, 2, [2, 1, 2, 1], [3, 5, 2, 3, 5, 2], 2]),

        # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
        dict(axis=1,
             row_length=2,
             original_dim_sizes=[3, 1, [4, 0, 2], 5],
             broadcast_dim_sizes=[3, 2, [4, 0, 2, 4, 0, 2], 5]),

        # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED]
        dict(axis=1,
             row_length=1,
             original_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)],
             broadcast_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)]),

        #=========================================================================
        # dimension[axis] is uniform partitioned; and row_lengths is a vector
        #=========================================================================
        # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
        dict(
            axis=1,
            row_length=[4, 1, 2],
            original_dim_sizes=[
                3,  # axis=0
                1,  # axis=1 (broadcast)
                [3, 1, 2],  # axis=2
                5
            ],  # axis=3
            broadcast_dim_sizes=[
                3,  # axis=0
                [4, 1, 2],  # axis=1 (broadcast)
                [3, 3, 3, 3, 1, 2, 2],  # axis=2
                5
            ]),  # axis=3

        # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, RAGGED]
        dict(
            axis=1,
            row_length=[2, 0, 3],
            original_dim_sizes=[
                3,  # axis=0
                1,  # axis=1 (broadcast)
                [3, 1, 2],  # axis=2
                [3, 1, 4, 1, 5, 9]
            ],  # axis=3
            broadcast_dim_sizes=[
                3,  # axis=0
                [2, 0, 3],  # axis=1 (broadcast)
                [3, 3, 2, 2, 2],  # axis=2
                [3, 1, 4, 3, 1, 4, 5, 9, 5, 9, 5, 9]
            ]),  # axis=3

        # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM]
        dict(
            axis=2,
            row_length=[4, 1, 2],
            original_dim_sizes=[
                3,  # axis=0
                [2, 0, 1],  # axis=1
                1,  # axis=2 (broadcast)
                [3, 2, 1],  # axis=3
                [1, 0, 1, 0, 2, 3],  # axis=4
                5
            ],  # axis=5
            broadcast_dim_sizes=[
                3,  # axis=0
                [2, 0, 1],  # axis=2
                [4, 1, 2],  # axis=2 (broadcast)
                [3, 3, 3, 3, 2, 1, 1],  # axis=3
                [
                    1,
                    0,
                    1,
                    1,
                    0,
                    1,
                    1,
                    0,
                    1,
                    1,
                    0,
                    1,
                    0,  # axis=4
                    2,
                    3,
                    3
                ],
                5
            ]),  # axis=5
        dict(axis=0,
             row_length=2,
             original_dim_sizes=[1, 1, 2, (2, 1)],
             broadcast_dim_sizes=[2, 1, 2, (2, 1, 2, 1)]),
        dict(axis=1,
             row_length=(2, 1),
             original_dim_sizes=[2, 1, 2, (2, 1, 2, 1)],
             broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
        dict(axis=2,
             row_length=2,
             original_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)],
             broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
        dict(axis=3,
             row_length=(2, 1, 2, 1, 2, 1),
             original_dim_sizes=[2, (2, 1), 2, 1],
             broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
    ])  # pyformat: disable
    def testBroadcastDimension(self, axis, row_length, original_dim_sizes,
                               broadcast_dim_sizes):
        """Tests for the broadcast_dimension method.

    Verifies that:

    * `original.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, 1) == broadcast`

    Args:
      axis: The axis to broadcast
      row_length: The slice lengths to broadcast to.
      original_dim_sizes: The dimension sizes before broadcasting.
        original_dim_sizes[axis] should be equal to `1` or `row_length`.
      broadcast_dim_sizes: THe dimension sizes after broadcasting.
    """
        original_shape = RaggedTensorDynamicShape.from_dim_sizes(
            original_dim_sizes)
        bcast_shape = RaggedTensorDynamicShape.from_dim_sizes(
            broadcast_dim_sizes)
        self.assertEqual(original_shape.rank, bcast_shape.rank)
        # shape[axis].value == 1 and row_length > 1:
        bcast1 = original_shape.broadcast_dimension(axis, row_length)
        # shape[axis].value > 1 and row_length == shape[axis].value:
        bcast2 = bcast_shape.broadcast_dimension(axis, row_length)
        # shape[axis].value > 1 and row_length == 1:
        bcast3 = bcast_shape.broadcast_dimension(axis, 1)

        self.assertShapeEq(bcast1, bcast_shape)
        self.assertShapeEq(bcast2, bcast_shape)
        self.assertShapeEq(bcast3, bcast_shape)

    @parameterized.parameters([
        # Broadcast scalar
        dict(x_dims=[], y_dims=[], expected_dims=[]),
        dict(x_dims=[], y_dims=[2], expected_dims=[2]),
        dict(x_dims=[], y_dims=[2, 3], expected_dims=[2, 3]),
        dict(x_dims=[],
             y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
             expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
        # Broadcast vector
        dict(x_dims=[3], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
        dict(x_dims=[1], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
        dict(x_dims=[3], y_dims=[4, 2, 1], expected_dims=[4, 2, 3]),
        dict(x_dims=[3],
             y_dims=[3, (2, 3, 1), 1],
             expected_dims=[3, (2, 3, 1), 3]),
        dict(x_dims=[1], y_dims=[3, (2, 1, 3)], expected_dims=[3, (2, 1, 3)]),
        dict(x_dims=[1],
             y_dims=[3, (2, 1, 3), 8],
             expected_dims=[3, (2, 1, 3), 8]),
        dict(x_dims=[1],
             y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
             expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
        # Mixed broadcasting
        dict(
            x_dims=[
                1,  # axis=0
                3,  # axis=1
                (3, 0, 2),  # axis=2
                1,  # axis=3
                2,  # axis=4
            ],
            y_dims=[
                2,  # axis=0
                1,  # axis=1
                1,  # axis=2
                (7, 2),  # axis=3
                1,  # axis=4
            ],
            expected_dims=[
                2,  # axis=0
                3,  # axis=1
                (3, 0, 2, 3, 0, 2),  # axis=2
                (7, 7, 7, 7, 7, 2, 2, 2, 2, 2),  # axis=3
                2,  # axis=4
            ]),
        dict(x_dims=[2, (2, 1), 2, 1],
             y_dims=[1, 1, 2, (2, 1)],
             expected_dims=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
    ])
    def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
        x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims)
        y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims)
        expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
        result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape)
        result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape)
        self.assertShapeEq(expected, result1)
        self.assertShapeEq(expected, result2)

    def testRepr(self):
        shape = RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1])
        self.assertRegex(
            repr(shape), r'RaggedTensorDynamicShape\('
            r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
            r'inner_dim_sizes=<[^>]+>\)')

    @parameterized.parameters([
        dict(
            x=[[10], [20], [30]],  # shape=[3, 1]
            dim_sizes=[3, 2],
            expected=[[10, 10], [20, 20], [30, 30]]),
        dict(
            x=[[10], [20], [30]],  # shape=[3, 1]
            dim_sizes=[3, [3, 0, 2]],
            expected=ragged_factory_ops.constant_value(
                [[10, 10, 10], [], [30, 30]], dtype=np.int32)),
        dict(
            x=[[[1, 2, 3]], [[4, 5, 6]]],  # shape = [2, 1, 3]
            dim_sizes=[2, [2, 3], 3],
            expected=ragged_factory_ops.constant_value(
                [[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
                dtype=np.int32,
                ragged_rank=1)),
        dict(
            x=[[[1]], [[2]]],  # shape = [2, 1, 1]
            dim_sizes=[2, [2, 3], [0, 2, 1, 2, 0]],
            expected=ragged_factory_ops.constant_value(
                [[[], [1, 1]], [[2], [2, 2], []]],
                dtype=np.int32,
                ragged_rank=2)),
        dict(x=10,
             dim_sizes=[3, [3, 0, 2]],
             expected=ragged_factory_ops.constant_value([[10, 10, 10], [],
                                                         [10, 10]])),
    ])
    def testRaggedBroadcastTo(self, x, dim_sizes, expected):
        shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
        result = ragged_tensor_shape.broadcast_to(x, shape)
        self.assertEqual(getattr(result, 'ragged_rank', 0),
                         getattr(expected, 'ragged_rank', 0))
        self.assertAllEqual(result, expected)

    @parameterized.parameters([
        dict(doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
             x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
                                                 dtype=np.int32),
             y=[[10], [20], [30]],
             expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
                                                         [34, 35]])),
        dict(doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]',
             x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
                                                 dtype=np.int32),
             y=10,
             expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
                                                         [14, 15]])),
        dict(doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
             x=ragged_factory_ops.constant_value([[1, 2, 3]], dtype=np.int32),
             y=[[10], [20], [30]],
             expected=ragged_factory_ops.constant_value(
                 [[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)),
        dict(doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; '
                  'bcast.shape=[2, (D1), (D2)]'),
             x=ragged_factory_ops.constant_value([[[1], [2], [3]], [[4]]],
                                                 ragged_rank=1),
             y=ragged_factory_ops.constant_value([[10, 20, 30]]),
             expected=ragged_factory_ops.constant_value([[[11, 21, 31],
                                                          [12, 22, 32],
                                                          [13, 23, 33]],
                                                         [[14, 24, 34]]])),
        dict(doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; '
                  'bcast.shape=[2, (D1), 4]'),
             x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
                                                 ragged_rank=1),
             y=[[[1, 2, 3, 4]]],
             expected=ragged_factory_ops.constant_value(
                 [[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]],
                 ragged_rank=1)),
        dict(doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; '
                  'bcast.shape=[2, (D1), (2), (D2)'),
             x=ragged_factory_ops.constant_value(
                 [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
             y=ragged_factory_ops.constant_value([[10, 20], [30]]),
             expected=ragged_factory_ops.constant_value([[[[11, 21], [32]],
                                                          [[13, 23], [34]]],
                                                         [[[15, 25], [36]]]])),
    ])
    def testRaggedAddWithBroadcasting(self, x, y, expected, doc):
        expected_rrank = getattr(expected, 'ragged_rank', 0)
        x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            x, dtype=dtypes.int32)
        y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            y, dtype=dtypes.int32)
        result = x + y
        result_rrank = getattr(result, 'ragged_rank', 0)
        self.assertEqual(expected_rrank, result_rrank)
        if hasattr(expected, 'tolist'):
            expected = expected.tolist()
        self.assertAllEqual(result, expected)
示例#35
0
class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
                               parameterized.TestCase):
    def assertSameShape(self, x, y):
        """Checks that x and y have the same shape (including ragged shapes)."""
        if isinstance(x, ragged_tensor.RaggedTensor):
            self.assertIsInstance(y, ragged_tensor.RaggedTensor)
            self.assertEqual(x.ragged_rank, y.ragged_rank)
            for (x_splits, y_splits) in zip(x.nested_row_splits,
                                            y.nested_row_splits):
                self.assertAllEqual(x_splits, y_splits)
            self.assertAllEqual(array_ops.shape(x.flat_values),
                                array_ops.shape(y.flat_values))
        else:
            self.assertIsInstance(y, ops.Tensor)
            self.assertAllEqual(array_ops.shape(x), array_ops.shape(y))

    @parameterized.parameters(
        #=========================================================================
        # Test different input shapes.
        #=========================================================================
        [
            # 0-dimensional input
            {
                'x': 12
            },
            # 1-dimensional input
            {
                'x': [1, -2, 3]
            },
            # 2-dimensional input
            {
                'x': [[-2, 3], [-3, 4]]
            },
            {
                'x':
                ragged_factory_ops.constant_value([[-2, 3], [-3]],
                                                  ragged_rank=1)
            },
            # 3-dimensional inputs
            {
                'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]
            },
            {
                'x':
                ragged_factory_ops.constant_value(
                    [[[-2, 3], [3, 4]], [[7, 6]]], ragged_rank=1)
            },
            {
                'x':
                ragged_factory_ops.constant_value(
                    [[[-2, 3, 4], []], [[7, 6]], []], ragged_rank=2)
            },
        ] +
        #=========================================================================
        # Test each unary op.
        #=========================================================================
        [{
            'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
            'op': op
        } for op in UNARY_FLOAT_OPS] + [{
            'x':
            ragged_factory_ops.constant_value([[True, False], [True]]),
            'op':
            op
        } for op in UNARY_BOOL_OPS] + [{
            'x':
            ragged_factory_ops.constant_value([[18, 512], [12412]], np.int32),
            'op':
            op
        } for op in UNARY_INT_OPS] + [{
            'x':
            ragged_factory_ops.constant_value([['abcd', 'efgh'], ['aabbccdd']
                                               ]),
            'op':
            op
        } for op in UNARY_STRING_OPS] + [
            {
                'op': clip_ops.clip_by_value,
                'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
                'clip_value_min': 0.1,
                'clip_value_max': 4.0
            },
            {
                'op': math_ops.cast,
                'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
                'dtype': dtypes.int32
            },
            {
                'op': math_ops.saturate_cast,
                'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
                'dtype': dtypes.int32
            },
            {
                'op':
                string_ops.string_to_hash_bucket,
                'x':
                ragged_factory_ops.constant_value([['abcd', 'efgh'],
                                                   ['aabbccdd']]),
                'num_buckets':
                1000
            },
            {
                'op':
                string_ops.string_to_hash_bucket_fast,
                'x':
                ragged_factory_ops.constant_value([['abcd', 'efgh'],
                                                   ['aabbccdd']]),
                'num_buckets':
                1000
            },
            {
                'op':
                string_ops.string_to_hash_bucket_strong,
                'x':
                ragged_factory_ops.constant_value([['abcd', 'efgh'],
                                                   ['aabbccdd']]),
                'num_buckets':
                1000,
                'key': [1231, 12512]
            },
            {
                'op': string_ops.string_to_number,
                'x': ragged_factory_ops.constant_value([['-2.0', '3.0'],
                                                        ['-3.0']])
            },
            {
                'op': string_ops.regex_full_match,
                'x': ragged_factory_ops.constant_value([['hello', '123'],
                                                        ['1+1']]),
                'pattern': r'\w+'
            },
            {
                'op': string_ops.regex_replace,
                'x': ragged_factory_ops.constant_value([['hello', '123'],
                                                        ['1+1']]),
                'pattern': r'\d',
                'rewrite': '#'
            },
            {
                'op': string_ops.substr,
                'x': ragged_factory_ops.constant_value([['hello', '123'],
                                                        ['1+1']]),
                'pos': 2,
                'len': 3
            },
            {
                'op': array_ops.check_numerics,
                'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
                'message': 'check-numerics'
            },
        ])  # pyformat: disable
    def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
        x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
        result = op(x, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_x = x.flat_values if isinstance(
            x, ragged_tensor.RaggedTensor) else x
        expected_flat_values = array_ops.reshape(op(dense_x, **extra_args),
                                                 [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(x, result)

        # Check that the result has the expected (flattened) values.
        if isinstance(result, ragged_tensor.RaggedTensor):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)

    @parameterized.parameters([
        #=====================================================================
        # Without broadcasting -- i.e., shapes match exactly.
        #=====================================================================
        # Shapes: x:(), y:()
        {
            'x': 12,
            'y': 8
        },
        # Shapes: x:(3,), y:(3,)
        {
            'x': [7, 8, 9],
            'y': [1, -2, 3]
        },
        # Shapes: x:(2, 2), y:(2, 2)
        {
            'x': [[-2, 3], [-3, -4]],
            'y': [[1, 2], [3, 4]]
        },
        # Shapes: x:(2, None), y:(2, None)
        {
            'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
            'y': ragged_factory_ops.constant_value([[5, 6], [7]])
        },
        # Shapes: x:(2, 2, 2), y:(2, 2, 2)
        {
            'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
            'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]
        },
        # Shapes: x:(2, None, None), y: (2, None, None)
        {
            'x':
            ragged_factory_ops.constant_value([[[1, 2], [3], [4]],
                                               [[], [5, 7, 8]]]),
            'y':
            ragged_factory_ops.constant_value([[[3, 8], [2], [5]],
                                               [[], [1, 9, 8]]])
        },
        # Shapes: x:(2, None, 2), y: (2, None, 2)
        {
            'x':
            ragged_factory_ops.constant_value(
                [[[1, 2]], [[3, 4], [5, 6], [7, 8]]], ragged_rank=1),
            'y':
            ragged_factory_ops.constant_value(
                [[[9, 3]], [[5, 2], [3, 4], [7, 6]]], ragged_rank=1)
        },

        #=====================================================================
        # With broadcasting
        #=====================================================================
        # Shapes: x:(), y:(3,)
        {
            'x': 12,  # Broadcast () -> (3,)
            'y': [1, -2, 3]
        },
        # Shapes: x:(1,), y:(3,)
        {
            'x': [12],  # Broadcast (1,) -> (3,)
            'y': [1, -2, 3]
        },
        # Shapes: x:(), y:(2, 2)
        {
            'x': 12,  # Broadcast () -> (2, 2)
            'y': [[1, 2], [3, 4]]
        },
        # Shapes: x:(1,), y:(2, 2)
        {
            'x': 12,  # Broadcast (1,) -> (2, 2)
            'y': [[1, 2], [3, 4]]
        },
        # Shapes: x:(2, 1), y:(2, 2)
        {
            'x': [[10], [20]],  # Broadcast (2, 1) -> (2, 2)
            'y': [[1, 2], [3, 4]]
        },
        # Shapes: x:(), y:(2, None)
        {
            'x': 10,  # Broadcast () -> (2, None)
            'y': ragged_factory_ops.constant_value([[1, 2], [3]],
                                                   dtype=np.int32)
        },
        # TODO(edloper): Add tests for more advanced broadcasting, once we add
        # support for it.

        #=====================================================================
        # Keyword Args
        #=====================================================================
        {
            'x':
            ragged_factory_ops.constant_value([[[1, 2], [3], [4]],
                                               [[], [5, 7, 8]]]),
            'y':
            ragged_factory_ops.constant_value([[[3, 8], [2], [5]],
                                               [[], [1, 9, 8]]]),
            'use_kwargs': ('x', 'y')
        },
        {
            'x':
            ragged_factory_ops.constant_value(
                [[[1, 2]], [[3, 4], [5, 6], [7, 8]]], ragged_rank=1),
            'y':
            ragged_factory_ops.constant_value(
                [[[9, 3]], [[5, 2], [3, 4], [7, 6]]], ragged_rank=1),
            'use_kwargs': ('x', 'y')
        },
        {
            'x':
            ragged_factory_ops.constant_value(
                [[[1, 2]], [[3, 4], [5, 6], [7, 8]]], ragged_rank=1),
            'y':
            ragged_factory_ops.constant_value(
                [[[9, 3]], [[5, 2], [3, 4], [7, 6]]], ragged_rank=1),
            'use_kwargs': ('x', )
        },
    ] +
                              #=========================================================================
                              # Test each unary op.
                              #=========================================================================
                              [{
                                  'x':
                                  ragged_factory_ops.constant_value(
                                      [[-2.0, 3.0], [-3.0]]),
                                  'y':
                                  ragged_factory_ops.constant_value(
                                      [[5.0, 1.0], [12.0]]),
                                  'op':
                                  op
                              } for op in BINARY_FLOAT_OPS] + [{
                                  'x':
                                  ragged_factory_ops.constant_value([[-2, 3],
                                                                     [-3]]),
                                  'y':
                                  ragged_factory_ops.constant_value([[5, 1],
                                                                     [12]]),
                                  'op':
                                  op
                              } for op in BINARY_INT_OPS] + [{
                                  'x':
                                  ragged_factory_ops.constant_value(
                                      [[True, True], [False]]),
                                  'y':
                                  ragged_factory_ops.constant_value(
                                      [[False, True], [False]]),
                                  'op':
                                  op
                              } for op in BINARY_BOOL_OPS]
                              )  # pyformat: disable
    def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
        use_kwargs = extra_args.pop('use_kwargs', ())
        x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
        y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
        if 'x' in use_kwargs and 'y' in use_kwargs:
            result = op(x=x, y=y, **extra_args)
        elif 'y' in use_kwargs:
            result = op(x, y=y, **extra_args)
        else:
            result = op(x, y, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_x = x.flat_values if isinstance(
            x, ragged_tensor.RaggedTensor) else x
        dense_y = y.flat_values if isinstance(
            y, ragged_tensor.RaggedTensor) else y
        expected_flat_values = array_ops.reshape(
            op(dense_x, dense_y, **extra_args), [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(y, result)

        # Check that the result has the expected (flattened) values.
        if isinstance(result, ragged_tensor.RaggedTensor):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)

    @parameterized.parameters([
        {
            'inputs': (12, 8, 3)
        },
        {
            'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])
        },
        {
            'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])
        },
        {
            'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
                       ragged_factory_ops.constant_value([[4, 7], [88]]),
                       ragged_factory_ops.constant_value([[2, 9], [12]]))
        },
        {
            'inputs':
            (ragged_factory_ops.constant_value([[[1, 3], [-3]], [[1]]]),
             ragged_factory_ops.constant_value([[[4, 7], [88]], [[2]]]),
             ragged_factory_ops.constant_value([[[2, 9], [12]], [[8]]]))
        },
        {
            'inputs':
            (ragged_factory_ops.constant_value([[[1, 3], [3, 4]], [[1, 5]]],
                                               ragged_rank=1),
             ragged_factory_ops.constant_value([[[4, 7], [1, 2]], [[2, 2]]],
                                               ragged_rank=1),
             ragged_factory_ops.constant_value([[[2, 9], [5, 2]], [[8, 0]]],
                                               ragged_rank=1))
        },
        {
            'inputs':
            (ragged_factory_ops.constant_value([[[1, 3], [-3]], [[1]]]),
             ragged_factory_ops.constant_value([[[4, 7], [88]], [[2]]]),
             ragged_factory_ops.constant_value([[[2, 9], [12]], [[8]]])),
            'use_kwargs':
            True
        },
    ] + [
        {
            'op':
            math_ops.add_n,
            'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
                       ragged_factory_ops.constant_value([[4, 7], [88]]),
                       ragged_factory_ops.constant_value([[2, 9], [12]]))
        },
        {
            'op':
            string_ops.string_join,
            'inputs':
            (ragged_factory_ops.constant_value([['a', 'b'], ['c']]),
             ragged_factory_ops.constant_value([['foo', 'bar'], ['baz']]),
             ragged_factory_ops.constant_value([['2', '9'], ['12']]))
        },
    ])  # pyformat: disable
    def testListValuedElementwiseOp(self,
                                    inputs,
                                    op=math_ops.add_n,
                                    **extra_args):
        use_kwargs = extra_args.pop('use_kwargs', False)
        inputs = [
            ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
        ]
        if use_kwargs:
            result = op(inputs=inputs, **extra_args)
        else:
            result = op(inputs, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_inputs = [
            x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
            for x in inputs
        ]
        expected_flat_values = array_ops.reshape(
            op(dense_inputs, **extra_args), [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(inputs[0], result)

        # Check that the result has the expected (flattened) values.
        if isinstance(result, ragged_tensor.RaggedTensor):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)

    def testElementwiseOpUnknownRankError(self):
        if context.executing_eagerly():
            return
        x = ragged_factory_ops.constant([[1, 2], [3]])
        y = ragged_tensor.RaggedTensor.from_row_splits(
            array_ops.placeholder_with_default([1, 2, 3], shape=None),
            x.row_splits)
        with self.assertRaisesRegexp(ValueError,
                                     r'Unable to broadcast: unknown rank'):
            math_ops.add(x, y)

    @parameterized.parameters([
        dict(x=ragged_factory_ops.constant_value([[1, 2], [3]]),
             y=[[10]],
             expected=[[11, 12], [13]]),
        dict(x=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5]]],
                                                 ragged_rank=2),
             y=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
                                                 ragged_rank=1),
             expected=[[[11, 12], [23, 24]], [[35]]]),
        dict(x=ragged_factory_ops.constant_value([[[1]]]),
             y=ragged_factory_ops.constant_value([[1]]),
             expected=[[[2]]]),
    ])
    def testElementwiseOpBroadcast(self, x, y, expected):
        x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            x, dtype=dtypes.int32)
        y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            y, dtype=dtypes.int32)
        result = x + y
        self.assertRaggedEqual(result, expected)

    def testElementwiseOpShapeMismatch(self):
        x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
        y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(math_ops.add(x, y))

    def testBinaryOpSparseAndRagged(self):
        x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
        y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3],
                                       [3, 2])
        with self.assertRaises((TypeError, ValueError)):
            self.evaluate(math_ops.add(x, y))

        with self.assertRaises((TypeError, ValueError)):
            self.evaluate(math_ops.add_n([x, y]))

    @parameterized.parameters([
        dict(op=array_ops.batch_gather,
             args=(ragged_factory_ops.constant_value([[5, 6, 7], [8, 9]]),
                   ragged_factory_ops.constant_value([[2, 1, 0], [1]])),
             expected=ragged_factory_ops.constant_value([[7, 6, 5], [9]])),
        dict(op=array_ops.concat,
             args=([
                 ragged_factory_ops.constant_value([[1, 2, 3], [4]],
                                                   dtype=np.int32),
                 np.array([[5, 6]], dtype=np.int32)
             ], ),
             kwargs={'axis': 0},
             expected=ragged_factory_ops.constant_value([[1, 2, 3], [4],
                                                         [5, 6]])),
        dict(op=array_ops.expand_dims,
             kwargs={
                 'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'axis': 0
             },
             expected=ragged_factory_ops.constant_value([[[1, 2], [3]]])),
        dict(
            op=array_ops.expand_dims_v2,
            kwargs={
                'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
                'axis': -1
            },
            expected=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
                                                       ragged_rank=1),
        ),
        dict(op=array_ops.gather,
             kwargs={
                 'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'indices': [1, 0, 1]
             },
             expected=ragged_factory_ops.constant_value([[3], [1, 2], [3]])),
        dict(op=array_ops.gather_v2,
             kwargs={
                 'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'indices': ragged_factory_ops.constant_value([[1, 0], [1]])
             },
             expected=ragged_factory_ops.constant_value([[[3], [1, 2]],
                                                         [[3]]])),
        dict(op=array_ops.gather_nd,
             kwargs={
                 'params': ragged_factory_ops.constant_value([[7, 8], [9]]),
                 'indices': [[0, 1], [1, 0], [0, 0]]
             },
             expected=ragged_factory_ops.constant_value([8, 9, 7])),
        dict(op=array_ops.stack,
             args=([
                 ragged_factory_ops.constant_value([[1, 2, 3], [4]],
                                                   dtype=np.int32),
                 np.array([[5, 6]], dtype=np.int32)
             ], ),
             expected=ragged_factory_ops.constant_value([[[1, 2, 3], [4]],
                                                         [[5, 6]]])),
        dict(op=array_ops.tile,
             args=([
                 ragged_factory_ops.constant_value([[1, 2], [3]],
                                                   dtype=np.int32), [2, 3]
             ]),
             expected=ragged_factory_ops.constant_value([[1, 2, 1, 2, 1, 2],
                                                         [3, 3, 3],
                                                         [1, 2, 1, 2, 1, 2],
                                                         [3, 3, 3]])),
        dict(op=array_ops.where,
             args=(ragged_factory_ops.constant_value([[True, False], [True]]),
                   ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]),
                   ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])),
             expected=ragged_factory_ops.constant_value([[b'A', b'b'],
                                                         [b'C']])),
        dict(op=array_ops.where,
             args=(ragged_factory_ops.constant_value([[True, False],
                                                      [True]]), ),
             expected=[[0, 0], [1, 0]]),
        dict(op=math_ops.unsorted_segment_sum,
             kwargs={
                 'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'segment_ids': ragged_factory_ops.constant_value([[0, 2],
                                                                   [0]]),
                 'num_segments': 3
             },
             expected=[4, 0, 2]),
        dict(op=math_ops.unsorted_segment_prod,
             kwargs={
                 'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'segment_ids': ragged_factory_ops.constant_value([[0, 2],
                                                                   [0]]),
                 'num_segments': 3
             },
             expected=[3, 1, 2]),
        dict(op=math_ops.unsorted_segment_min,
             kwargs={
                 'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'segment_ids': ragged_factory_ops.constant_value([[0, 1],
                                                                   [0]]),
                 'num_segments': 2
             },
             expected=[1, 2]),
        dict(op=math_ops.unsorted_segment_max,
             kwargs={
                 'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'segment_ids': ragged_factory_ops.constant_value([[0, 1],
                                                                   [0]]),
                 'num_segments': 2
             },
             expected=[3, 2]),
        dict(op=math_ops.unsorted_segment_mean,
             kwargs={
                 'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
                 'segment_ids': ragged_factory_ops.constant_value([[0, 1],
                                                                   [0]]),
                 'num_segments': 2
             },
             expected=[2, 2]),
        dict(op=math_ops.unsorted_segment_sqrt_n,
             kwargs={
                 'data':
                 ragged_factory_ops.constant_value([[1.0, 2.0],
                                                    [3.0, 4.0, 6.0]]),
                 'segment_ids':
                 ragged_factory_ops.constant_value([[0, 1], [0, 0, 0]]),
                 'num_segments':
                 2
             },
             expected=[7.0, 2.0]),
        dict(op=math_ops.reduce_sum,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
                 'axis':
                 1
             },
             expected=[3, 12]),
        dict(op=math_ops.reduce_prod,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
                 'axis':
                 1
             },
             expected=[2, 60]),
        dict(op=math_ops.reduce_min,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
                 'axis':
                 1
             },
             expected=[1, 3]),
        dict(op=math_ops.reduce_max,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
                 'axis':
                 1
             },
             expected=[2, 5]),
        dict(op=math_ops.reduce_mean,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[1, 3], [3, 4, 5]]),
                 'axis':
                 1
             },
             expected=[2, 4]),
        dict(op=math_ops.reduce_any,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[True, False],
                                                    [True, True, True]]),
                 'axis':
                 1
             },
             expected=[True, True]),
        dict(op=math_ops.reduce_all,
             kwargs={
                 'input_tensor':
                 ragged_factory_ops.constant_value([[True, False],
                                                    [True, True, True]]),
                 'axis':
                 1
             },
             expected=[False, True]),
    ])
    def testRaggedDispatch(self, op, expected, args=(), kwargs=None):
        if kwargs is None: kwargs = {}
        result = op(*args, **kwargs)
        self.assertRaggedEqual(result, expected)
示例#36
0
class RaggedGatherNdOpTest(ragged_test_util.RaggedTensorTestCase,
                           parameterized.TestCase):

    DOCSTRING_PARAMS = [[['000', '001'], ['010']],
                        [['100'], ['110', '111', '112'], ['120']],
                        [[], ['210']]]  # pyformat: disable

    @parameterized.parameters([
        #=========================================================================
        # Docstring Examples
        #=========================================================================
        dict(descr='Docstring example 1',
             params=ragged_factory_ops.constant_value(DOCSTRING_PARAMS),
             indices=[[2], [0]],
             expected=ragged_factory_ops.constant_value([[[], [b'210']],
                                                         [[b'000', b'001'],
                                                          [b'010']]])),
        dict(descr='Docstring example 2',
             params=ragged_factory_ops.constant_value(DOCSTRING_PARAMS),
             indices=[[2, 1], [0, 0]],
             expected=ragged_factory_ops.constant_value([[b'210'],
                                                         [b'000', b'001']])),
        dict(descr='Docstring example 3',
             params=ragged_factory_ops.constant_value(DOCSTRING_PARAMS),
             indices=[[0, 0, 1], [1, 1, 2]],
             expected=[b'001', b'112']),
        #=========================================================================
        # Indices with 0 values (selects the entire params)
        #=========================================================================
        dict(descr='params: [B1, (B2)], indices: [0], result: [B1, (B2)]',
             params=ragged_factory_ops.constant_value([['a', 'b', 'c'],
                                                       ['d']]),
             indices=np.zeros([0], dtype=np.int32),
             expected=ragged_factory_ops.constant_value([[b'a', b'b', b'c'],
                                                         [b'd']])),
        dict(descr=
             'params: [B1, (B2)], indices: [A1, 0], result: [A1, B1, (B2)]',
             params=ragged_factory_ops.constant_value([['a', 'b', 'c'],
                                                       ['d']]),
             indices=np.zeros([3, 0], dtype=np.int32),
             expected=ragged_factory_ops.constant_value([[[b'a', b'b', b'c'],
                                                          [b'd']],
                                                         [[b'a', b'b', b'c'],
                                                          [b'd']],
                                                         [[b'a', b'b', b'c'],
                                                          [b'd']]])),
        dict(descr=('params: [B1, (B2)], indices: [A1, A2, 0], '
                    'result: [A1, A2, B1, (B2)]'),
             params=ragged_factory_ops.constant_value([['a', 'b', 'c'],
                                                       ['d']]),
             indices=np.zeros([1, 3, 0], dtype=np.int32),
             expected=ragged_factory_ops.constant_value([[[[b'a', b'b', b'c'],
                                                           [b'd']],
                                                          [[b'a', b'b', b'c'],
                                                           [b'd']],
                                                          [[b'a', b'b', b'c'],
                                                           [b'd']]]])),
        dict(descr=
             'params: [B1], indices: [A1, (A2), 0], result: [A1, (A2), B1]',
             params=['a'],
             indices=ragged_factory_ops.constant_value([[[], []], [[]]],
                                                       ragged_rank=1,
                                                       dtype=np.int32),
             expected=ragged_factory_ops.constant_value(
                 [[[b'a'], [b'a']], [[b'a']]], ragged_rank=1)),
        #=========================================================================
        # Indices with 1 value (selects row from params)
        #=========================================================================
        dict(descr='params: [B1, (B2)], indices: [A1, 1], result: [A1, (B2)]',
             params=ragged_factory_ops.constant_value([['a', 'b', 'c'],
                                                       ['d']]),
             indices=[[1], [0]],
             expected=ragged_factory_ops.constant_value([[b'd'],
                                                         [b'a', b'b', b'c']])),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, 1], '
                    'result: [A1, (B2), (B3)]'),
             params=ragged_factory_ops.constant_value([[['a', 'b', 'c'],
                                                        ['d']], [['e', 'f']]]),
             indices=[[1], [1]],
             expected=ragged_factory_ops.constant_value([[[b'e', b'f']],
                                                         [[b'e', b'f']]])),
        dict(descr=('params: [B1, B2, B3], indices: [A1, (A2), 1], '
                    'result: [A1, (A2), B2, B3]'),
             params=[[['a']], [['b']]],
             indices=ragged_factory_ops.constant_value([[[0]]], ragged_rank=1),
             expected=ragged_factory_ops.constant_value([[[[b'a']]]],
                                                        ragged_rank=1)),
        #=========================================================================
        # Indices with 2 values (selects row & col from params)
        #=========================================================================
        dict(descr='params: [B1, (B2)], indices: [A1, 2], result: [A1]',
             params=ragged_factory_ops.constant_value([['a', 'b', 'c'],
                                                       ['d']]),
             indices=[[1, 0], [0, 0], [0, 2]],
             expected=ragged_factory_ops.constant_value([b'd', b'a', b'c'])),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, 2], '
                    'result: [A1, (B3)]'),
             params=ragged_factory_ops.constant_value([[['a', 'b', 'c'],
                                                        ['d']], [['e', 'f']]]),
             indices=[[1, 0], [0, 1], [0, 0]],
             expected=ragged_factory_ops.constant_value([[b'e', b'f'], [b'd'],
                                                         [b'a', b'b', b'c']])),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, A2, 2], '
                    'result: [A1, (A2), (B3)]'),
             params=ragged_factory_ops.constant_value([[['a', 'b', 'c'],
                                                        ['d']], [['e', 'f']]]),
             indices=[[[1, 0], [0, 1], [0, 0]]],
             expected=ragged_factory_ops.constant_value([[[b'e', b'f'], [b'd'],
                                                          [b'a', b'b',
                                                           b'c']]])),
        dict(descr=('params: [B1, (B2), B3], indices: [A1, A2, 2], '
                    'result: [A1, A2, B3]'),
             params=ragged_factory_ops.constant_value(
                 [[['a', 'b'], ['c', 'd']], [['e', 'f']]], ragged_rank=1),
             indices=[[[1, 0], [0, 1], [0, 0]]],
             expected=[[[b'e', b'f'], [b'c', b'd'], [b'a', b'b']]]),
        dict(descr=('params: [B1, (B2), B3], indices: [A1, A2, A3, 2], '
                    'result: [A1, A2, A3, B3]'),
             params=ragged_factory_ops.constant_value(
                 [[['a', 'b'], ['c', 'd']], [['e', 'f']]], ragged_rank=1),
             indices=[[[[1, 0], [0, 1], [0, 0]]]],
             expected=[[[[b'e', b'f'], [b'c', b'd'], [b'a', b'b']]]]),
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, (A2), 2], '
                    'result: [A1, (A2), (B3)]'),
             params=ragged_factory_ops.constant_value([[['a', 'b', 'c'],
                                                        ['d']], [['e', 'f']]]),
             indices=ragged_factory_ops.constant_value(
                 [[[1, 0], [0, 1]], [[0, 0]]], ragged_rank=1),
             expected=ragged_factory_ops.constant_value([[[b'e', b'f'],
                                                          [b'd']],
                                                         [[b'a', b'b',
                                                           b'c']]])),
        #=========================================================================
        # Indices with 3 values
        #=========================================================================
        dict(descr=('params: [B1, (B2), (B3)], indices: [A1, 3], '
                    'result: [A1]'),
             params=ragged_factory_ops.constant_value([[['a', 'b', 'c'],
                                                        ['d']], [['e', 'f']]]),
             indices=[[1, 0, 1], [0, 0, 0], [0, 1, 0]],
             expected=[b'f', b'a', b'd']),
        dict(descr=('params: [B1, (B2), B3], indices: [A1, 3], '
                    'result: [A1]'),
             params=ragged_factory_ops.constant_value(
                 [[['a', 'b'], ['c', 'd']], [['e', 'f']]], ragged_rank=1),
             indices=[[1, 0, 1], [0, 0, 0], [0, 1, 1]],
             expected=[b'f', b'a', b'd']),
        dict(descr=('params: [B1, (B2), (B3), B4], indices: [A1, 3], '
                    'result: [A1, B4]'),
             params=ragged_factory_ops.constant_value(
                 [[[['a', 'b'], ['c', 'd']], [['e', 'f']]]], ragged_rank=2),
             indices=[[0, 0, 1], [0, 0, 0], [0, 1, 0]],
             expected=[[b'c', b'd'], [b'a', b'b'], [b'e', b'f']]),
    ])  # pyformat: disable
    def testRaggedGatherNd(self, descr, params, indices, expected):
        result = ragged_array_ops.gather_nd(params, indices)
        self.assertRaggedEqual(result, expected)

    def testRaggedGatherNdUnknownRankError(self):
        if context.executing_eagerly():
            return
        params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd']])
        indices1 = array_ops.placeholder(dtypes.int32, shape=None)
        indices2 = array_ops.placeholder(dtypes.int32, shape=[None])

        with self.assertRaisesRegexp(ValueError,
                                     'indices.rank be statically known.'):
            ragged_array_ops.gather_nd(params, indices1)
        with self.assertRaisesRegexp(
                ValueError, r'indices.shape\[-1\] must be statically known.'):
            ragged_array_ops.gather_nd(params, indices2)

    @parameterized.parameters([
        dict(params=['a'],
             indices=0,
             error=(ValueError, errors.InvalidArgumentError)),
        dict(params=ragged_factory_ops.constant_value([['a']]),
             indices=0,
             message='indices.rank must be at least 1.'),
        dict(params=['a', 'b', 'c'],
             indices=ragged_factory_ops.constant_value([[0]]),
             message='The innermost dimension of indices may not be ragged'),
    ])
    def testRaggedGatherNdStaticError(self,
                                      params,
                                      indices,
                                      message=None,
                                      error=ValueError):
        with self.assertRaisesRegexp(error, message):
            ragged_array_ops.gather_nd(params, indices)
class RaggedBooleanMaskOpTest(test_util.TensorFlowTestCase,
                              parameterized.TestCase):
  # Define short constants for true & false, so the data & mask can be lined
  # up in the examples below.  This makes it easier to read the examples, to
  # see which values should be kept vs. masked.
  T = True
  F = False

  @parameterized.parameters([
      #=========================================================================
      # Docstring examples
      #=========================================================================
      dict(
          descr='Docstring example 1',
          data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
          mask=[[T, F, T], [F, F, F], [T, F, F]],
          expected=ragged_factory_ops.constant_value([[1, 3], [], [7]])),
      dict(
          descr='Docstring example 2',
          data=ragged_factory_ops.constant_value([[1, 2, 3], [4], [5, 6]]),
          mask=ragged_factory_ops.constant_value([[F, F, T], [F], [T, T]]),
          expected=ragged_factory_ops.constant_value([[3], [], [5, 6]])),
      dict(
          descr='Docstring example 3',
          data=ragged_factory_ops.constant_value([[1, 2, 3], [4], [5, 6]]),
          mask=[True, False, True],
          expected=ragged_factory_ops.constant_value([[1, 2, 3], [5, 6]])),
      #=========================================================================
      # Uniform data and uniform mask.
      #=========================================================================
      dict(
          descr='data.shape=[7]; mask.shape=[7]',
          data=[1, 2, 3, 4, 5, 6, 7],
          mask=[T, F, T, T, F, F, F],
          expected=[1, 3, 4]),
      dict(
          descr='data.shape=[5, 3]; mask.shape=[5]',
          data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],
          mask=[True, False, True, True, False],
          expected=[[1, 2, 3], [7, 8, 9], [10, 11, 12]]),
      dict(
          descr='data.shape=[5, 3]; mask.shape=[5, 3]',
          data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 1, 2], [3, 4, 5]],
          mask=[[F, F, F], [T, F, T], [T, T, T], [F, F, F], [T, T, F]],
          expected=ragged_factory_ops.constant_value(
              [[], [4, 6], [7, 8, 9], [], [3, 4]])),
      dict(
          descr='data.shape=[3, 2, 2]; mask.shape=[3]',
          data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
          mask=[F, F, T],
          expected=[[[2, 4], [6, 8]]]),
      dict(
          descr='data.shape=[3, 2, 2]; mask.shape=[3]',
          data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
          mask=[F, F, T],
          expected=[[[2, 4], [6, 8]]]),
      dict(
          descr='data.shape=[3, 2, 2]; mask.shape=[3, 2]',
          data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
          mask=[[T, F], [T, T], [F, F]],
          expected=ragged_factory_ops.constant_value(
              [[[1, 2]], [[5, 6], [7, 8]], []],
              ragged_rank=1)),
      dict(
          descr='data.shape=[3, 2, 2]; mask.shape=[3, 2, 2]',
          data=[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]],
          mask=[[[T, T], [F, T]], [[F, F], [F, F]], [[T, F], [T, T]]],
          expected=ragged_factory_ops.constant_value(
              [[[1, 2], [4]], [[], []], [[2], [6, 8]]])),
      dict(
          descr='data.shape=mask.shape=[2, 2, 2, 2]',
          data=[[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
                [[[2, 4], [6, 8]], [[1, 3], [5, 7]]]],
          mask=[[[[T, T], [F, F]], [[T, F], [F, F]]],
                [[[F, F], [F, F]], [[T, T], [T, F]]]],
          expected=ragged_factory_ops.constant_value(
              [[[[1, 2], []], [[5], []]], [[[], []], [[1, 3], [5]]]])),

      #=========================================================================
      # Ragged data and ragged mask.
      #=========================================================================
      dict(
          descr='data.shape=[5, (D2)]; mask.shape=[5, (D2)]',
          data=ragged_factory_ops.constant_value(
              [[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]]),
          mask=ragged_factory_ops.constant_value(
              [[F, F], [F, T, F, T], [F, F, F], [], [T, F, T]]),
          expected=ragged_factory_ops.constant_value(
              [[], [4, 6], [], [], [1, 3]])),
      dict(
          descr='data.shape=[3, (D2), (D3)]; mask.shape=[3, (D2)]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4], [6, 8]]]),
          mask=ragged_factory_ops.constant_value([[T, F], [T, T], [F, F]]),
          expected=ragged_factory_ops.constant_value(
              [[[1, 2]], [[5, 6], [7, 8]], []])),
      dict(
          descr='data.shape=[3, (D2), D3]; mask.shape=[3, (D2)]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [[5, 6], [7, 8], [2, 4]], [[6, 8]]],
              ragged_rank=1),
          mask=ragged_factory_ops.constant_value([[T, F], [T, T, F], [F]]),
          expected=ragged_factory_ops.constant_value(
              [[[1, 2]], [[5, 6], [7, 8]], []],
              ragged_rank=1)),
      dict(
          descr='data.shape=[3, (D2), (D3)]; mask.shape=[3, (D2), (D3)]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[2, 4]]]),
          mask=ragged_factory_ops.constant_value(
              [[[T, T], [F, T]], [[F, F], [F, F]], [[T, F]]]),
          expected=ragged_factory_ops.constant_value(
              [[[1, 2], [4]], [[], []], [[2]]])),
      dict(
          descr=('data.shape=[3, (D2), (D3), (D4)]; '
                 'mask.shape=[3, (D2), (D3), (D4)]'),
          data=ragged_factory_ops.constant_value(
              [[[[1, 2], [3, 4]], [[5, 6]]], [[[2, 4], [6, 8]]]]),
          mask=ragged_factory_ops.constant_value(
              [[[[T, T], [F, F]], [[T, F]]], [[[F, F], [T, T]]]]),
          expected=ragged_factory_ops.constant_value(
              [[[[1, 2], []], [[5]]], [[[], [6, 8]]]])),

      #=========================================================================
      # Ragged mask and uniform data
      #=========================================================================
      dict(
          descr='data.shape=[2, 3]; mask.shape=[2, (3)]',
          data=[[1, 2, 3], [4, 5, 6]],
          mask=ragged_factory_ops.constant_value([[T, F, F], [F, T, T]]),
          expected=ragged_factory_ops.constant_value([[1], [5, 6]])),
      dict(
          descr='data.shape=[2, 3, 2]; mask.shape=[2, (3)]',
          data=[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 0], [2, 4]]],
          mask=ragged_factory_ops.constant_value([[T, F, F], [F, T, T]]),
          expected=ragged_factory_ops.constant_value(
              [[[1, 2]], [[9, 0], [2, 4]]],
              ragged_rank=1)),
      dict(
          descr='data.shape=[2, 3, 2]; mask.shape=[2, (3), 2]',
          data=[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 0], [2, 4]]],
          mask=ragged_factory_ops.constant_value(
              [[[T, F], [F, F], [T, T]], [[T, F], [F, T], [F, F]]],
              ragged_rank=1),
          expected=ragged_factory_ops.constant_value(
              [[[1], [], [5, 6]], [[7], [0], []]])),

      #=========================================================================
      # Ragged data and uniform mask.
      #=========================================================================
      dict(
          descr='data.shape=[4, (D2)]; mask.shape=[4]',
          data=ragged_factory_ops.constant_value([[1, 2, 3], [4], [], [5, 6]]),
          mask=[T, F, T, F],
          expected=ragged_factory_ops.constant_value([[1, 2, 3], []])),
      dict(
          descr='data.shape=[4, (D2), (D3)]; mask.shape=[4]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2, 3]], [[4], []], [[5, 6]], []]),
          mask=[T, F, T, T],
          expected=ragged_factory_ops.constant_value(
              [[[1, 2, 3]], [[5, 6]], []])),
      dict(
          descr='data.shape=[4, (D2), 2]; mask.shape=[4]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [], [[5, 6]], [[7, 8], [9, 0], [1, 2]]],
              ragged_rank=1),
          mask=[T, F, F, T],
          expected=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [[7, 8], [9, 0], [1, 2]]],
              ragged_rank=1)),
      dict(
          descr='data.shape=[4, (D2), 2]; mask.shape=[4]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [], [[5, 6]], [[7, 8], [9, 0], [1, 2]]],
              ragged_rank=1),
          mask=[T, F, F, T],
          expected=ragged_factory_ops.constant_value(
              [[[1, 2], [3, 4]], [[7, 8], [9, 0], [1, 2]]],
              ragged_rank=1)),
      dict(
          descr='data.shape=[1, (2)]; mask.shape=[1, 2]',
          data=ragged_factory_ops.constant_value([[1, 2]]),
          mask=[[T, F]],
          expected=ragged_factory_ops.constant_value([[1]])),
      dict(
          descr='data.shape=[2, (2), (D3)]; mask.shape=[2, 2]',
          data=ragged_factory_ops.constant_value(
              [[[1], [2, 3]], [[], [4, 5, 6]]]),
          mask=[[T, F], [T, T]],
          expected=ragged_factory_ops.constant_value([[[1]], [[], [4, 5, 6]]])),
      dict(
          descr='data.shape=[2, (2), 3]; mask.shape=[2, 2]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [2, 4, 6]]],
              ragged_rank=1),
          mask=[[T, F], [T, T]],
          expected=ragged_factory_ops.constant_value(
              [[[1, 2, 3]], [[7, 8, 9], [2, 4, 6]]],
              ragged_rank=1)),
      dict(
          descr='data.shape=[2, (2), 3]; mask.shape=[2, 2, 3]',
          data=ragged_factory_ops.constant_value(
              [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [2, 4, 6]]],
              ragged_rank=1),
          mask=[[[T, F, F], [T, F, T]], [[T, F, T], [F, F, F]]],
          expected=ragged_factory_ops.constant_value(
              [[[1], [4, 6]], [[7, 9], []]])),
  ])  # pyformat: disable
  def testBooleanMask(self, descr, data, mask, expected):
    actual = ragged_array_ops.boolean_mask(data, mask)
    self.assertAllEqual(actual, expected)

  def testErrors(self):
    if not context.executing_eagerly():
      self.assertRaisesRegex(ValueError,
                             r'mask\.shape\.ndims must be known statically',
                             ragged_array_ops.boolean_mask, [[1, 2]],
                             array_ops.placeholder(dtypes.bool))

    self.assertRaises(TypeError, ragged_array_ops.boolean_mask, [[1, 2]],
                      [[0, 1]])
    self.assertRaisesRegex(
        ValueError, 'Tensor conversion requested dtype bool for '
        'RaggedTensor with dtype int32', ragged_array_ops.boolean_mask,
        ragged_factory_ops.constant([[1, 2]]),
        ragged_factory_ops.constant([[0, 0]]))

    self.assertRaisesRegex(ValueError,
                           r'Shapes \(1, 2\) and \(1, 3\) are incompatible',
                           ragged_array_ops.boolean_mask, [[1, 2]],
                           [[True, False, True]])

    self.assertRaisesRegex(errors.InvalidArgumentError,
                           r'Inputs must have identical ragged splits',
                           ragged_array_ops.boolean_mask,
                           ragged_factory_ops.constant([[1, 2]]),
                           ragged_factory_ops.constant([[True, False, True]]))

    self.assertRaisesRegex(ValueError, 'mask cannot be scalar',
                           ragged_array_ops.boolean_mask, [[1, 2]], True)

    self.assertRaisesRegex(ValueError, 'mask cannot be scalar',
                           ragged_array_ops.boolean_mask,
                           ragged_factory_ops.constant([[1, 2]]), True)
 def testSplitWithPaddedOutput(self, texts, expected, ragged_rank=None):
   input_tensor = ragged_factory_ops.constant_value(
       _nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
   result = ragged_string_ops.unicode_split(
       input_tensor, "UTF-8").to_tensor(default_value="")
   self.assertAllEqual(np.array(expected, dtype=bytes), result)
示例#39
0
class WordpieceOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
    _FORWARD_COMPATIBILITY_HORIZONS = [
        (2019, 7, 1),
        (2019, 10, 10),
        (2525, 1, 1),  # future behavior
    ]

    @parameterized.parameters([
        # Basic case
        dict(
            tokens=[[_Utf8(u"купиха")]],
            expected_subwords=[[[
                _Utf8(u"к"),
                _Utf8(u"##уп"),
                _Utf8(u"##иха"),
            ]]],
            vocab=_RUSSIAN_VOCAB,
        ),
        dict(
            tokens=[[b"don't", b"treadness"]],
            expected_subwords=[[[b"don", b"##'", b"##t"],
                                [b"tread", b"##ness"]]],
            vocab=_ENGLISH_VOCAB,
        ),
        dict(
            tokens=[[b"hello", b"there", b"my", b"name", b"is", b"terry"],
                    [b"whatchamacallit?", b"you", b"said"]],
            expected_subwords=[[[b"hel", b"##lo"], [b"there"], [b"my"],
                                [b"na", b"##me"], [b"is"], [b"ter", b"##ry"]],
                               [[
                                   b"what", b"##cha", b"##ma", b"##call",
                                   b"##it?"
                               ], [b"you"], [b"said"]]],
            vocab=_ENGLISH_VOCAB,
        ),
        # Basic case w/ unknown token
        dict(
            tokens=[[b"don't", b"tread", b"cantfindme", b"treadcantfindme"]],
            expected_subwords=[[[b"don", b"##'", b"##t"], [b"tread"],
                                [b"[UNK]"], [b"[UNK]"]]],
            vocab=_ENGLISH_VOCAB,
        ),
        # Basic case w/o unknown token
        dict(
            tokens=[[b"don't", b"tread", b"cantfindme", b"treadcantfindme"]],
            expected_subwords=[[[b"don", b"##'", b"##t"], [b"tread"],
                                [b"cantfindme"], [b"treadcantfindme"]]],
            unknown_token=None,
            vocab=_ENGLISH_VOCAB,
        ),
        # Basic case w/ int id lookup
        dict(
            tokens=[[b"don't", b"tread", b"cantfindme", b"treadcantfindme"]],
            token_out_type=dtypes.int64,
            expected_subwords=[[[0, 1, 2], [3], [21], [21]]],
            vocab=_ENGLISH_VOCAB,
        ),
        # Chinese test case
        dict(
            tokens=[[
                _Utf8(u"貿"),
                _Utf8(u"易"),
                _Utf8(u"戰"),
                _Utf8(u"最"),
                _Utf8(u"大"),
                _Utf8(u"受"),
                _Utf8(u"益"),
                _Utf8(u"者")
            ],
                    [
                        _Utf8(u"越"),
                        _Utf8(u"南"),
                        _Utf8(u"總"),
                        _Utf8(u"理"),
                        _Utf8(u"阮"),
                        _Utf8(u"春"),
                        _Utf8(u"福")
                    ]],
            expected_subwords=[[[_Utf8(u"貿")], [_Utf8(u"易")], [_Utf8(u"戰")],
                                [_Utf8(u"最")], [_Utf8(u"大")], [_Utf8(u"受")],
                                [_Utf8(u"益")], [_Utf8(u"者")]],
                               [[_Utf8(u"越")], [_Utf8(u"南")], [_Utf8(u"總")],
                                [_Utf8(u"理")], [_Utf8(u"阮")], [_Utf8(u"春")],
                                [_Utf8(u"福")]]],
            vocab=_CHINESE_VOCAB,
        ),
        # Mixed lang test cases
        dict(
            tokens=[
                [
                    _Utf8(u"貿"),
                    _Utf8(u"易"),
                    _Utf8(u"戰"),
                    _Utf8(u"最"),
                    _Utf8(u"大"),
                    _Utf8(u"受"),
                    _Utf8(u"益"),
                    _Utf8(u"者")
                ],
                [
                    _Utf8(u"越"),
                    _Utf8(u"南"),
                    _Utf8(u"總"),
                    _Utf8(u"理"),
                    _Utf8(u"阮"),
                    _Utf8(u"春"),
                    _Utf8(u"福")
                ],
                [b"don't", b"treadness"],
            ],
            expected_subwords=[
                [[_Utf8(u"貿")], [_Utf8(u"易")], [_Utf8(u"戰")], [_Utf8(u"最")],
                 [_Utf8(u"大")], [_Utf8(u"受")], [_Utf8(u"益")], [_Utf8(u"者")]],
                [[_Utf8(u"越")], [_Utf8(u"南")], [_Utf8(u"總")], [_Utf8(u"理")],
                 [_Utf8(u"阮")], [_Utf8(u"春")], [_Utf8(u"福")]],
                [[b"don", b"##'", b"##t"], [b"tread", b"##ness"]],
            ],
            vocab=_MIXED_LANG_VOCAB,
        ),
        # Test token whose size is > max_bytes_per_word
        dict(
            tokens=[[b"don't", b"treadness"]],
            expected_subwords=[[[b"don", b"##'", b"##t"], [b"[UNK]"]]],
            vocab=_ENGLISH_VOCAB,
            max_bytes_per_word=5,
            # Explicitly specify the offsets here because the current way of
            # testing offsets would require '[UNK]' to be part of tokens.
            expected_start=[[[0, 3, 4], [0]]],
            expected_limit=[[[3, 4, 5], [5]]],
        ),
        # Test the token of death usecase.
        dict(
            tokens=[[_Utf8(u"करें*👇👇")]],
            token_out_type=dtypes.string,
            expected_subwords=[[[
                _Utf8(u"क"),
                _Utf8(u"##र"),
                _Utf8(u"##े"),
                _Utf8(u"##ं"), b"##*",
                _Utf8(u"##👇"),
                _Utf8(u"##👇")
            ]]],
            vocab=_DEATH_VOCAB,
            max_bytes_per_word=40,
        ),
        # Test not splitting out unknown characters.
        # (p and ! are unknown)
        dict(
            tokens=[[b"nap", b"hello!me"]],
            expected_subwords=[[[b"[UNK]"], [b"[UNK]"]]],
            unknown_token="[UNK]",
            vocab=_ENGLISH_VOCAB,
        ),
        # Test splitting out unknown characters.
        dict(
            tokens=[[b"nap", b"hello!me"]],
            expected_subwords=[[[b"na", b"##[UNK]"],
                                [b"hel", b"##lo", b"##[UNK]", b"##me"]]],
            unknown_token="[UNK]",
            vocab=_ENGLISH_VOCAB,
            split_unknown_characters=True,
        ),
        # Test splitting out unknown characters, with unknown_token set to None.
        dict(
            tokens=[[b"nap", b"hello!me"]],
            expected_subwords=[[[b"na", b"##p"],
                                [b"hel", b"##lo", b"##!", b"##me"]]],
            unknown_token=None,
            vocab=_ENGLISH_VOCAB,
            split_unknown_characters=True,
        ),
    ])
    def testWordPieceOpAndVerifyOffsets(self,
                                        tokens,
                                        expected_subwords,
                                        vocab,
                                        expected_start=None,
                                        expected_limit=None,
                                        use_unknown_token=True,
                                        unknown_token="[UNK]",
                                        token_out_type=dtypes.string,
                                        max_bytes_per_word=100,
                                        split_unknown_characters=False):
        for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
            with compat.forward_compatibility_horizon(*horizon):
                tokens_t = ragged_factory_ops.constant(tokens)
                vocab_table = _CreateTable(vocab)
                self.evaluate(vocab_table.initializer)
                tokenizer = WordpieceTokenizer(
                    vocab_table,
                    unknown_token=unknown_token,
                    token_out_type=token_out_type,
                    max_bytes_per_word=max_bytes_per_word,
                    split_unknown_characters=split_unknown_characters,
                )
                subwords_t, begin_t, end_t = tokenizer.tokenize_with_offsets(
                    tokens_t)
                self.assertAllEqual(subwords_t, expected_subwords)

                # Verify the indices by performing the following:
                # - Extract subwords and join them together to form the original tokens.
                # - Then compare the extracted tokens and original tokens.
                begin, end = (self.evaluate((begin_t, end_t)))

                # If expected start/limit offsets were provided, check them explicitly.
                # Otherwise test the offsets by extracting subwords using token offsets
                # from the original 'tokens' input.
                if expected_start is None or expected_limit is None:
                    extracted_tokens = _GetTokensFromWordpieceOffsets(
                        tokens, begin, end)
                    self.assertAllEqual(extracted_tokens, tokens)
                else:
                    self.assertAllEqual(begin, expected_start)
                    self.assertAllEqual(end, expected_limit)

    @parameterized.parameters([
        dict(
            tokens=[[[b"don't"], [b"treadness"],
                     [b"whatchamacallit?", b"you", b"hello"]],
                    [[b"treadness"]]],
            expected_subwords=[[[[b"don", b"##'", b"##t"]],
                                [[b"tread", b"##ness"]],
                                [[
                                    b"what", b"##cha", b"##ma", b"##call",
                                    b"##it?"
                                ], [b"you"], [b"hel", b"##lo"]]],
                               [[[b"tread", b"##ness"]]]],
            vocab=_ENGLISH_VOCAB,
        ),
    ])
    def testWordPieceOpWithMultipleRaggedRank(self,
                                              tokens,
                                              expected_subwords,
                                              vocab,
                                              expected_start=None,
                                              expected_limit=None,
                                              use_unknown_token=True,
                                              token_out_type=dtypes.string):
        for row_splits_dtype in (dtypes.int32, dtypes.int64):
            ragged_tokens = ragged_factory_ops.constant(
                tokens, row_splits_dtype=row_splits_dtype)
            vocab_table = _CreateTable(vocab)
            self.evaluate(vocab_table.initializer)
            tokenizer = WordpieceTokenizer(vocab_table,
                                           token_out_type=token_out_type)
            subwords = tokenizer.tokenize(ragged_tokens)
            self.assertAllEqual(subwords, expected_subwords)

    def testWordPieceOpWithIdReturned(self):
        """Let the table determine how to do a lookup on unknown tokens."""
        tokens = ragged_factory_ops.constant(
            [[b"don't", b"tread", b"cantfindme", b"treadcantfindme"]])
        vocab_table = _CreateTable(
            _ENGLISH_VOCAB,
            100  # OOV values
        )
        self.evaluate(vocab_table.initializer)
        tokenizer = WordpieceTokenizer(vocab_table,
                                       unknown_token=None,
                                       token_out_type=dtypes.int64)
        subwords, _, _ = tokenizer.tokenize_with_offsets(tokens)

        self.assertAllEqual(subwords, [[[0, 1, 2], [3], [96], [46]]])
        self.assertEqual(subwords.dtype, dtypes.int64)

    def testWordPieceOpWithInt32IdReturned(self):
        """Let the table determine how to do a lookup on unknown tokens."""
        tokens = ragged_factory_ops.constant(
            [[b"don't", b"tread", b"cantfindme", b"treadcantfindme"]])
        vocab_table = _CreateTable(
            _ENGLISH_VOCAB,
            100  # OOV values
        )
        self.evaluate(vocab_table.initializer)
        tokenizer = WordpieceTokenizer(vocab_table,
                                       unknown_token=None,
                                       token_out_type=dtypes.int32)
        subwords, _, _ = tokenizer.tokenize_with_offsets(tokens)

        self.assertAllEqual(subwords, [[[0, 1, 2], [3], [96], [46]]])
        self.assertEqual(subwords.dtype, dtypes.int32)

    @parameterized.parameters([
        dict(
            tokens=[[b"don't", b"treadness", b"whatchamacallit?"]],
            expected_subwords=[[[b"don", b"##'", b"##t"], [
                b"tread", b"##ness"
            ], [b"what", b"##cha", b"##ma", b"##call", b"##it?"]]],
            vocab=_ENGLISH_VOCAB,
        ),
        dict(
            tokens=[[[b"don't"], [b"treadness"], [b"whatchamacallit?"]]],
            expected_subwords=[[[[b"don", b"##'", b"##t"]],
                                [[b"tread", b"##ness"]],
                                [[
                                    b"what", b"##cha", b"##ma", b"##call",
                                    b"##it?"
                                ]]]],
            vocab=_ENGLISH_VOCAB,
        ),
        dict(
            tokens=[[[b"don't", _Utf8(u"貿")], [b"treadness",
                                               _Utf8(u"大")],
                     [b"whatchamacallit?", _Utf8(u"福")]]],
            expected_subwords=[[[[b"don", b"##'", b"##t"], [_Utf8(u"貿")]],
                                [[b"tread", b"##ness"], [_Utf8(u"大")]],
                                [[
                                    b"what", b"##cha", b"##ma", b"##call",
                                    b"##it?"
                                ], [_Utf8(u"福")]]]],
            vocab=_MIXED_LANG_VOCAB,
        ),
        # Vector input
        dict(
            tokens=[_Utf8(u"купиха")],
            expected_subwords=[[
                _Utf8(u"к"),
                _Utf8(u"##уп"),
                _Utf8(u"##иха"),
            ]],
            vocab=_RUSSIAN_VOCAB,
        ),
        # Scalar input
        dict(
            tokens=_Utf8(u"купиха"),
            expected_subwords=[
                _Utf8(u"к"),
                _Utf8(u"##уп"),
                _Utf8(u"##иха"),
            ],
            vocab=_RUSSIAN_VOCAB,
        ),
        # 3D input with 1 ragged dimension.
        dict(
            tokens=[[b"don't", b"treadness", b"whatchamacallit?"]],
            expected_subwords=[[[b"don", b"##'", b"##t"], [
                b"tread", b"##ness"
            ], [b"what", b"##cha", b"##ma", b"##call", b"##it?"]]],
            vocab=_ENGLISH_VOCAB,
        ),
        dict(
            tokens=ragged_factory_ops.constant_value(
                [[[b"don't"], [b"treadness"], [b"whatchamacallit?"]]],
                ragged_rank=1),
            expected_subwords=[[[[b"don", b"##'", b"##t"]],
                                [[b"tread", b"##ness"]],
                                [[
                                    b"what", b"##cha", b"##ma", b"##call",
                                    b"##it?"
                                ]]]],
            vocab=_ENGLISH_VOCAB,
        ),
        # Specifying max_chars_per_token.
        dict(
            tokens=[[b"don't", b"treadness"]],
            max_chars_per_token=5,
            expected_subwords=[[[b"don", b"##'", b"##t"],
                                [b"tread", b"##ness"]]],
            vocab=_ENGLISH_VOCAB + [b"trea", b"##d"],
        ),
        # Specifying max_chars_per_token to 4, so that "tread" is not found, and
        # is split into "trea", "##d".
        dict(
            tokens=[[b"don't", b"treadness"]],
            max_chars_per_token=4,
            expected_subwords=[[[b"don", b"##'", b"##t"],
                                [b"trea", b"##d", b"##ness"]]],
            vocab=_ENGLISH_VOCAB + [b"trea", b"##d"],
        ),
        # Specifying max_chars_per_token where characters are multiple bytes.
        dict(
            tokens=[[_Utf8(u"大"), _Utf8(u"易")]],
            max_chars_per_token=1,
            expected_subwords=[[[_Utf8(u"大")], [_Utf8(u"易")]]],
            vocab=_CHINESE_VOCAB,
        ),
    ])
    def testTensors(self,
                    tokens,
                    expected_subwords,
                    vocab,
                    max_chars_per_token=None,
                    expected_start=None,
                    expected_limit=None,
                    use_unknown_token=True,
                    token_out_type=dtypes.string):
        vocab_table = _CreateTable(vocab)
        self.evaluate(vocab_table.initializer)
        tokenizer = WordpieceTokenizer(
            vocab_table,
            token_out_type=token_out_type,
            max_chars_per_token=max_chars_per_token,
        )
        subwords = tokenizer.tokenize(tokens)
        self.assertAllEqual(subwords, expected_subwords)
class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
                              parameterized.TestCase):

  @parameterized.parameters([
      #=========================================================================
      # Docstring Example
      #=========================================================================
      dict(
          descr='Docstring example',
          params=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d'], [],
                                                    ['e']]),
          indices=ragged_factory_ops.constant_value([[1, 2, 0], [], [], [0,
                                                                         0]]),
          expected=ragged_factory_ops.constant_value([[b'b', b'c', b'a'], [],
                                                      [], [b'e', b'e']])),
      #=========================================================================
      # 0 Batch Dimensions
      #=========================================================================
      dict(
          descr='params: [P1], indices: [I], result: [I]',
          params=['a', 'b', 'c', 'd'],
          indices=[3, 2],
          expected=[b'd', b'c']),
      dict(
          descr='params: [P1, (P2)], indices: [I], result: [I, (P2)]',
          params=ragged_factory_ops.constant_value([['a', 'b'], [], ['c'],
                                                    ['d', 'e']]),
          indices=[3, 2],
          expected=ragged_factory_ops.constant_value([[b'd', b'e'], [b'c']])),
      #=========================================================================
      # 1 Batch Dimension
      #=========================================================================
      dict(
          descr='params: [B1, P1], indices: [B1, I], result: [B1, I]',
          params=[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']],
          indices=[[2, 0], [0, 1], [1, 0]],
          expected=[[b'c', b'a'], [b'd', b'e'], [b'h', b'g']]),
      dict(
          descr='params: [B1, (P1)], indices: [B1, I], result: [B1, I]',
          params=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d', 'e'],
                                                    ['g']]),
          indices=[[2, 0], [0, 1], [0, 0]],
          expected=[[b'c', b'a'], [b'd', b'e'], [b'g', b'g']]),
      dict(
          descr='params: [B1, P1], indices: [B1, (I)], result: [B1, (I)]',
          params=[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']],
          indices=ragged_factory_ops.constant_value([[2, 0, 2], [0], [1]]),
          expected=ragged_factory_ops.constant_value([[b'c', b'a', b'c'],
                                                      [b'd'], [b'h']])),
      dict(
          descr=('params: [B1, (P1), (P2), P3], indices: [B1, I], '
                 'result: [B1, I, (P2), P3]'),
          params=ragged_factory_ops.constant_value(
              [[[['a']], [['b'], ['c']]], [[['d'], ['e']], [['f']]], [[['g']]]],
              ragged_rank=2),
          indices=[[1, 0], [0, 1], [0, 0]],
          expected=ragged_factory_ops.constant_value(
              [[[[b'b'], [b'c']], [[b'a']]], [[[b'd'], [b'e']], [[b'f']]],
               [[[b'g']], [[b'g']]]],
              ragged_rank=2)),
      #=========================================================================
      # 2 Batch Dimensions
      #=========================================================================
      dict(
          descr=('params: [B1, B2, P1], indices: [B1, B2, I], '
                 'result: [B1, B2, I]'),
          params=[[['a', 'b', 'c']], [['d', 'e', 'f']], [['g', 'h', 'i']]],
          indices=[[[2, 0]], [[0, 1]], [[1, 0]]],
          expected=[[[b'c', b'a']], [[b'd', b'e']], [[b'h', b'g']]]),
      dict(
          descr=('params: [B1, (B2), P1], indices: [B1, (B2), I], '
                 'result: [B1, (B2), I]'),
          params=ragged_factory_ops.constant_value(
              [[['a', 'b', 'c'], ['d', 'e', 'f']], [['g', 'h', 'i']]],
              ragged_rank=1),
          indices=ragged_factory_ops.constant_value(
              [[[2, 0], [0, 1]], [[1, 0]]], ragged_rank=1),
          expected=ragged_factory_ops.constant_value(
              [[[b'c', b'a'], [b'd', b'e']], [[b'h', b'g']]], ragged_rank=1)),
      dict(
          descr=('params: [B1, (B2), (P1)], indices: [B1, (B2), I], '
                 'result: [B1, (B2), I]'),
          params=ragged_factory_ops.constant_value(
              [[['a', 'b', 'c'], ['d']], [['e', 'f']]], ragged_rank=2),
          indices=ragged_factory_ops.constant_value(
              [[[2, 0], [0, 0]], [[1, 0]]], ragged_rank=1),
          expected=ragged_factory_ops.constant_value(
              [[[b'c', b'a'], [b'd', b'd']], [[b'f', b'e']]], ragged_rank=1)),
      dict(
          descr=('params: [B1, (B2), P1], indices: [B1, (B2), (I)], '
                 'result: [B1, (B2), (I)]'),
          params=ragged_factory_ops.constant_value(
              [[['a', 'b', 'c'], ['d', 'e', 'f']], [['g', 'h', 'i']]],
              ragged_rank=1),
          indices=ragged_factory_ops.constant_value(
              [[[2, 1, 0], [0]], [[1, 1]]], ragged_rank=2),
          expected=ragged_factory_ops.constant_value(
              [[[b'c', b'b', b'a'], [b'd']], [[b'h', b'h']]], ragged_rank=2)),
      #=========================================================================
      # 3 Batch Dimensions
      #=========================================================================
      dict(
          descr=(
              'params: [B1, (B2), (B3), (P1)], indices: [B1, (B2), (B3), I], '
              'result: [B1, (B2), (B3), I]'),
          params=ragged_factory_ops.constant_value(
              [[[['a', 'b', 'c'], ['d']], [['e', 'f']]]], ragged_rank=3),
          indices=ragged_factory_ops.constant_value(
              [[[[2, 0], [0, 0]], [[1, 0]]]], ragged_rank=2),
          expected=ragged_factory_ops.constant_value(
              [[[[b'c', b'a'], [b'd', b'd']], [[b'f', b'e']]]], ragged_rank=2)),
  ])
  def testRaggedBatchGather(self, descr, params, indices, expected):
    result = ragged_batch_gather_ops.batch_gather(params, indices)
    self.assertAllEqual(result, expected)

  @parameterized.parameters([
      # Docstring example:
      dict(
          descr='Docstring example',
          params=[['a', 'b', 'c'], ['d'], [], ['e']],
          indices=[[1, 2, -1], [], [], [0, 10]],
          expected=[['b', 'c', 'FOO'], [], [], ['e', 'FOO']],
          default_value='FOO',
      ),
      # Dimensions:
      # indices: [4]
      # params: [2, (d1), (d2)]
      dict(
          descr='params: [2, (d1), (d2), indices: [4]',
          indices=[1, 100, 0, -1],
          params=[[['The', 'deal', 'came', 'about', '18', 'months', 'after',
                    'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
                    'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.'],
                   ['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall']],
                  [["It's", 'always', 'darkest', 'before', 'the', 'dawn']]],
          expected=[[["It's", 'always', 'darkest', 'before', 'the', 'dawn']],
                    [['$NONE^']],
                    [['The', 'deal', 'came', 'about', '18', 'months', 'after',
                      'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion',
                      '-', 'dollar', 'takeover', 'offer', 'from', 'Microsoft',
                      '.'],
                     ['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall']],
                    [['$NONE^']]],
      ),
      # Dimensions:
      # params: [1, (d1)]
      # indices: [3]
      dict(
          descr='params: rank 2, indices: rank 1',
          params=[
              ['Bruce', 'Wayne'],
          ],
          indices=[-1, 0, 1000],
          expected=[['$NONE^'], ['Bruce', 'Wayne'], ['$NONE^']]
      ),
      # Dimensions:
      # params: [1, (d1)]
      # indices: [1, (d2)]
      dict(
          descr='Test underbound indices of shape [1, (d2)]',
          params=[
              ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
               '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
               'takeover', 'offer', 'from', 'Microsoft', '.'],
          ],
          indices=[[8, -1]],
          expected=[['!', '$NONE^']],
      ),
      dict(
          descr='Test underbound indices of shape [2, (d2)]',
          params=[
              ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
               '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
               'takeover', 'offer', 'from', 'Microsoft', '.'],
              ['Who', 'let', 'the', 'dogs', 'out', '?'],
          ],
          indices=[[8, -1], [1, 100]],
          expected=[['!', '$NONE^'], ['let', '$NONE^']],
      ),
      # Dimensions:
      # params: [2, (d1)]
      # indices: [2, (d2)]
      dict(
          descr='Test underbound indices of rank 2',
          params=[
              ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
               '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
               'takeover', 'offer', 'from', 'Microsoft', '.'],
              ['He', 'left', 'us', '.', 'Little', 'boys', 'crowded', 'together',
               'on', 'long', 'wooden', 'benches', ',', 'and', 'in', 'the',
               'center', 'of', 'the', 'room', 'sat', 'the', 'teacher', '.',
               'His', 'black', 'beard', 'dripped', 'down', 'over', 'the',
               'front', 'of', 'his', 'coat', '.', 'One', 'white', 'hand',
               'poised', 'a', 'stick', 'above', 'his', 'desk', '.', 'He',
               'turned', 'his', 'surly', ',', 'half', '-', 'closed', 'eyes',
               'toward', 'us', ',', 'stared', 'for', 'a', 'second', ',', 'then',
               'shouted', 'in', 'Yiddish', ',', '``', 'One', ',', 'two', ',',
               'three', "''", '!', '!', 'Rapping', 'the', 'stick', 'against',
               'the', 'desk', '.', 'The', 'little', 'boys', 'shrilled', 'out',
               'a', 'Yiddish', 'translation', 'or', 'interpretation', 'of',
               'the', 'Five', 'Books', 'of', 'Moses', ',', 'which', 'they',
               'had', 'previously', 'chanted', 'in', 'Hebrew', '.']],
          indices=[[8, -1], [3, 23, 35, 45, 75, 83, -121]],
          expected=[['!', '$NONE^'], ['.', '.', '.', '.', '!', '.', '$NONE^']],
      ),
      dict(
          descr='Test overbound indices of rank 2',
          params=[
              ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
               '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
               'takeover', 'offer', 'from', 'Microsoft', '.'],
              ['He', 'left', 'us', '.', 'Little', 'boys', 'crowded', 'together',
               'on', 'long', 'wooden', 'benches', ',', 'and', 'in', 'the',
               'center', 'of', 'the', 'room', 'sat', 'the', 'teacher', '.',
               'His', 'black', 'beard', 'dripped', 'down', 'over', 'the',
               'front', 'of', 'his', 'coat', '.', 'One', 'white', 'hand',
               'poised', 'a', 'stick', 'above', 'his', 'desk', '.', 'He',
               'turned', 'his', 'surly', ',', 'half', '-', 'closed', 'eyes',
               'toward', 'us', ',', 'stared', 'for', 'a', 'second', ',', 'then',
               'shouted', 'in', 'Yiddish', ',', '``', 'One', ',', 'two', ',',
               'three', "''", '!', '!', 'Rapping', 'the', 'stick', 'against',
               'the', 'desk', '.', 'The', 'little', 'boys', 'shrilled', 'out',
               'a', 'Yiddish', 'translation', 'or', 'interpretation', 'of',
               'the', 'Five', 'Books', 'of', 'Moses', ',', 'which', 'they',
               'had', 'previously', 'chanted', 'in', 'Hebrew', '.']],
          indices=[[8, 8823], [3, 23, 35, 45, 75, 83, 1234]],
          expected=[['!', '$NONE^'], ['.', '.', '.', '.', '!', '.', '$NONE^']],
      ),
      # Dimensions:
      # params: [2, (d1), 2]
      # indices: [2, (d2)]
      dict(
          descr='params: rank 3, indices: rank 2',
          params=[
              [['The', 'deal'], ['takeover', 'offer'], ['from', 'Microsoft']],
              [['Who', 'let'], ['the', 'dogs'], ['out', '?']],
          ],
          ragged_rank=1,
          indices=[[1, -1, 2, 30], [1, 100]],
          indices_ragged_rank=1,
          expected=[[['takeover', 'offer'],
                     ['$NONE^', '$NONE^'],
                     ['from', 'Microsoft'],
                     ['$NONE^', '$NONE^']],
                    [['the', 'dogs'],
                     ['$NONE^', '$NONE^']]],
          expected_ragged_rank=1,
          default_value=['$NONE^', '$NONE^'],
      ),
      # Dimensions:
      # params: [2, (d1), (d2)]
      # indices: [2, (d3)]
      dict(
          descr='params: [2, (d1), (d2)], indices: [2, (d3)]',
          params=[
              [['The', 'deal', 'came', 'about', '18', 'months', 'after',
                'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
                'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.'],
               ['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall'],
              ],
              [['It\'s', 'always', 'darkest', 'before', 'the', 'dawn']]
          ],
          indices=[[1, 100], [0, -1]],
          expected=[[['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall'],
                     ['$NONE^']],
                    [["It's", 'always', 'darkest', 'before', 'the', 'dawn'],
                     ['$NONE^']]]
      ),
      # Dimensions:
      # params: [2, (d1), (d2)]
      # indices: [2, (d1), (d3)]
      dict(
          descr='Test overbound indices of rank 3',
          params=[
              [['The', 'deal', 'came', 'about', '18', 'months', 'after',
                'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
                'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.'],
               ['Foo', 'bar', 'mar']],
              [['He', 'left', 'us', '.', 'Little', 'boys', 'crowded',
                'together', 'on', 'long', 'wooden', 'benches', ',', 'and', 'in',
                'the', 'center', 'of', 'the', 'room', 'sat', 'the', 'teacher',
                '.', 'His', 'black', 'beard', 'dripped', 'down', 'over', 'the',
                'front', 'of', 'his', 'coat', '.', 'One', 'white', 'hand',
                'poised', 'a', 'stick', 'above', 'his', 'desk', '.', 'He',
                'turned', 'his', 'surly', ',', 'half', '-', 'closed', 'eyes',
                'toward', 'us', ',', 'stared', 'for', 'a', 'second', ',',
                'then', 'shouted', 'in', 'Yiddish', ',', '``', 'One', ',',
                'two', ',',
                'three', "''", '!', '!', 'Rapping', 'the', 'stick', 'against',
                'the', 'desk', '.', 'The', 'little', 'boys', 'shrilled', 'out',
                'a', 'Yiddish', 'translation', 'or', 'interpretation', 'of',
                'the', 'Five', 'Books', 'of', 'Moses', ',', 'which', 'they',
                'had', 'previously', 'chanted', 'in', 'Hebrew', '.'],
               ['I', 'too', 'was', 'hustled', 'scammed', 'bamboozled', 'hood',
                'winked', 'lead', 'astray']]
          ],
          indices=[[[8, 8823], [0, 100]], [[3, 23, 35, 45, 75, 83, 1234], [5]]],
          expected=[[['!', '$NONE^'], ['Foo', '$NONE^']],
                    [['.', '.', '.', '.', '!', '.', '$NONE^'],
                     ['bamboozled']]],
      ),
      # params.shape = [2, (d1), 8]
      # indices.shape = [2, (d1), 3]
      dict(
          descr='params = [2, (2, 1), 8], indices = [2, (2, 1), 3]',
          params=[[['h'] * 8, ['w'] * 8], [['b'] * 8]],
          ragged_rank=1,
          indices=[[[0, 100, 1], [0, 1, 0]], [[1, 0, 0]]],
          indices_ragged_rank=1,
          expected=[[['h', '$NONE^', 'h'], ['w', 'w', 'w']], [['b', 'b', 'b']]],
          expected_ragged_rank=1,
      ),
  ])
  def testRaggedBatchGatherWithDefault(
      self, descr, params, indices, expected, indices_ragged_rank=None,
      expected_ragged_rank=None, ragged_rank=None, default_value='$NONE^'):
    params = ragged_factory_ops.constant(params, ragged_rank=ragged_rank)
    indices = ragged_factory_ops.constant(
        indices, ragged_rank=indices_ragged_rank or ragged_rank)
    expected = ragged_factory_ops.constant(
        expected, ragged_rank=expected_ragged_rank or ragged_rank)
    result = ragged_batch_gather_with_default_op.batch_gather_with_default(
        params, indices, default_value)
    self.assertAllEqual(result, expected)

  @parameterized.parameters([
      # Dimensions:
      #  params: dims [2, 5], indices: [2, 2]
      dict(
          descr='params: dims [2, 5], indices: [2, 2]',
          params=[
              ['The', 'deal', 'came', 'about', '18'],
              ['He', 'left', 'us', '.', 'Little']],
          indices=[[0, -1], [3, 121]],
          expected=[['The', '$NONE^'], ['.', '$NONE^']],
          default_value='$NONE^',
      ),
      # Dimensions:
      #  params: dims [2, 2, 5], indices: [2, 2]
      dict(
          descr='params: dims [2, 2, 5], indices: [2, 2]',
          params=[
              [['The', 'deal', 'came', 'about', '18'],
               ['The', 'deal', 'came', 'about', '19'],
              ],
              [['He', 'left', 'us', '.', 'Little'],
               ['The', 'deal', 'came', 'about', '20'],
              ]
          ],
          indices=[[0, -1], [0, 121]],
          expected=[[['The', 'deal', 'came', 'about', '18'],
                     ['$NONE^', '$NONE^', '$NONE^', '$NONE^', '$NONE^']],
                    [['He', 'left', 'us', '.', 'Little'],
                     ['$NONE^', '$NONE^', '$NONE^', '$NONE^', '$NONE^']]],
          default_value='$NONE^',
      ),
      # Test default_value with shape [5]
      dict(
          descr='params: dims [2, 2, 5], indices: [2, 2]',
          params=[
              [['The', 'deal', 'came', 'about', '18'],
               ['The', 'deal', 'came', 'about', '19'],
              ],
              [['He', 'left', 'us', '.', 'Little'],
               ['The', 'deal', 'came', 'about', '20'],
              ]
          ],
          indices=[[0, -1], [0, 121]],
          expected=[[['The', 'deal', 'came', 'about', '18'],
                     [':FOO:', ':FOO:', ':FOO:', ':FOO:', ':FOO:']],
                    [['He', 'left', 'us', '.', 'Little'],
                     [':FOO:', ':FOO:', ':FOO:', ':FOO:', ':FOO:']]],
          default_value=[':FOO:', ':FOO:', ':FOO:', ':FOO:', ':FOO:'],
      ),
  ])
  def testRaggedBatchGatherWithDefaultOnTensors(
      self, descr, params, indices, expected, default_value):
    params = constant_op.constant(params)
    indices = constant_op.constant(indices)
    expected = constant_op.constant(expected)
    result = ragged_batch_gather_with_default_op.batch_gather_with_default(
        params, indices, default_value)
    self.assertAllEqual(expected, result)

  @parameterized.parameters([
      dict(
          params=[['The', 'deal', 'came', 'about', '18', 'months', 'after',
                   'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
                   'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.']],
          indices=[[[8, -1]]],
          # Exception here because different errors are thrown in eager vs
          # graph mode.
          error=Exception,
          default_value='$NONE^',
      ),
  ])
  def testRankMismatch(
      self, params, indices, default_value, error):
    params = ragged_factory_ops.constant(params)
    indices = ragged_factory_ops.constant(indices)
    with self.assertRaises(error):
      _ = ragged_batch_gather_with_default_op.batch_gather_with_default(
          params, indices, default_value)

  @parameterized.parameters([
      # Dimensions:
      # params: [2, (d1), 2]
      # indices: [2, (d2)]
      # default_value: []
      dict(
          descr='params: rank 3, indices: rank 2, default: rank = [], but'
          ' should be [2]',
          params=[
              [['The', 'deal'], ['takeover', 'offer'], ['from', 'Microsoft']],
              [['Who', 'let'], ['the', 'dogs'], ['out', '?']],
          ],
          ragged_rank=1,
          indices=[[1, -1, 2, 30], [1, 100]],
          indices_ragged_rank=1,
          default_value='$NONE^',
          error=Exception,
      )
  ])
  def testInvalidDefaultValueRank(
      self, descr, params, indices, default_value, error, ragged_rank=None,
      indices_ragged_rank=None):
    params = ragged_factory_ops.constant(params, ragged_rank=ragged_rank)
    indices = ragged_factory_ops.constant(
        indices, ragged_rank=indices_ragged_rank)
    with self.assertRaises(error):
      _ = ragged_batch_gather_with_default_op.batch_gather_with_default(
          params, indices, default_value)

  def testRaggedBatchGatherUnknownRankError(self):
    if context.executing_eagerly():
      return
    params = [['a', 'b'], ['c', 'd']]
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    ragged_indices = ragged_tensor.RaggedTensor.from_row_splits(
        indices, [0, 2, 4])

    with self.assertRaisesRegex(
        ValueError, r'batch_dims=-1 may only be negative '
        r'if rank\(indices\) is statically known.'):
      ragged_batch_gather_ops.batch_gather(params, indices)

    with self.assertRaisesRegex(
        ValueError, r'batch_dims=-1 may only be negative '
        r'if rank\(indices\) is statically known.'):
      ragged_batch_gather_ops.batch_gather(params, ragged_indices)

  @parameterized.parameters(
      [
          dict(
              params=ragged_factory_ops.constant_value([['a'], ['b'], ['c']]),
              indices=ragged_factory_ops.constant_value([[0], [0]]),
              message=(r'batch shape from indices .* does not match params')),
          dict(
              params=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
              indices=ragged_factory_ops.constant_value([[[0, 0], [0, 0, 0]],
                                                         [[0]]]),
              message='batch shape from indices does not match params shape'),
          dict(  # rank mismatch
              params=ragged_factory_ops.constant_value([[[0, 0], [0, 0, 0]],
                                                        [[0]]]),
              indices=ragged_factory_ops.constant_value([[[0, 0]], [[0, 0, 0]],
                                                         [[0]]]),
              error=(ValueError, errors.InvalidArgumentError)),
          dict(
              params=ragged_factory_ops.constant_value([[[0, 0], [0, 0, 0]],
                                                        [[0]], [[0]]]),
              indices=ragged_factory_ops.constant_value([[[0, 0]], [[0, 0, 0]],
                                                         [[0]]]),
              error=(ValueError, errors.InvalidArgumentError),
              message=(r'batch shape from indices .* does not match '
                       r'params shape|dimension size mismatch')),
          dict(
              params=ragged_factory_ops.constant_value(['a', 'b', 'c']),
              indices=ragged_factory_ops.constant_value([[0], [0]]),
              message=r'batch_dims must be less than rank\(params\)'),
          dict(
              params=ragged_factory_ops.constant_value([['a']]),
              indices=0,
              message='batch_dims=-1 out of bounds: expected 0<=batch_dims<0'),
          dict(
              params=ragged_factory_ops.constant_value([['a']]),
              indices=[[[0]]],
              message=r'batch_dims must be less than rank\(params\)'),
      ])
  def testRaggedBatchGatherStaticError(self,
                                       params,
                                       indices,
                                       message=None,
                                       error=ValueError):
    with self.assertRaisesRegex(error, message):
      ragged_batch_gather_ops.batch_gather(params, indices)
class StructuredTensorTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorTest, self).assertAllEqual(a, b, msg)
        if not (isinstance(a, structured_tensor.StructuredTensor)
                and isinstance(b, structured_tensor.StructuredTensor)):
            # TODO(edloper) Add support for this once structured_factory_ops is added.
            raise ValueError("Not supported yet")

        self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            self.assertAllEqual(a.field_value(field), b.field_value(field))

    @parameterized.parameters([
        {
            "shape": [],
            "fields": {},
        },
        {
            "shape": [None],
            "fields": {},
        },
        {
            "shape": [1, 5, 3],
            "fields": {},
        },
        {
            "shape": [],
            "fields": {
                "Foo": 5,
                "Bar": [1, 2, 3]
            },
        },
        {
            "shape": [2],
            "fields": {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]]
            },
        },
        {
            "shape": [None],
            "fields": {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]]
            },
            "expected_shape": [2],  # inferred from field values.
        },
        {
            "shape": [],
            "fields": {
                "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
            },
        },
        {
            "shape": [2],
            "fields": {
                "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
            },
        },
        {
            "shape": [2, None],
            "fields": {
                "r":
                ragged_factory_ops.constant_value([[[1, 2], [3]],
                                                   [[4, 5, 6], [7], [8, 9]]]),
            },
        },
        {
            # Note: fields must have identical row_splits.
            "shape": [2, None],
            "fields": {
                "a": ragged_factory_ops.constant_value([[1, 2], [3]]),
                "b": ragged_factory_ops.constant_value([[4, 5], [6]]),
            },
        },
        {
            # Note: fields must have identical outer row_splits.
            "shape": [2, None],
            "fields": {
                "a":
                ragged_factory_ops.constant_value([[[1, 2], [3]],
                                                   [[4, 5, 6], [7], [8, 9]]]),
                "b":
                ragged_factory_ops.constant_value([[[1], []],
                                                   [[2, 3], [4, 5, 6], [7,
                                                                        8]]]),
            },
        },
    ])  # pyformat: disable
    def testFromFields(self, shape, fields, expected_shape=None):
        struct = structured_tensor.StructuredTensor.from_fields(shape, fields)
        if expected_shape is None:
            expected_shape = shape
        self.assertEqual(struct.shape.as_list(), expected_shape)
        self.assertLen(expected_shape, struct.rank)
        self.assertEqual(struct.field_names(), tuple(fields.keys()))
        for field, value in fields.items():
            self.assertIsInstance(
                struct.field_value(field),
                (ops.Tensor, structured_tensor.StructuredTensor,
                 ragged_tensor.RaggedTensor))
            self.assertAllEqual(struct.field_value(field), value)

    def testNestedStructConstruction(self):
        rt = ragged_factory_ops.constant([[1, 2], [3]])
        struct1 = structured_tensor.StructuredTensor.from_fields([],
                                                                 {"x": [1, 2]})
        struct2 = structured_tensor.StructuredTensor.from_fields([2],
                                                                 {"x": [1, 2]})
        struct3 = structured_tensor.StructuredTensor.from_fields([], {
            "r": rt,
            "s": struct1
        })
        struct4 = structured_tensor.StructuredTensor.from_fields([2], {
            "r": rt,
            "s": struct2
        })

        self.assertEqual(struct3.shape.as_list(), [])
        self.assertEqual(struct3.rank, 0)
        self.assertEqual(set(struct3.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct3.field_value("r"), rt)
        self.assertAllEqual(struct3.field_value("s"), struct1)

        self.assertEqual(struct4.shape.as_list(), [2])
        self.assertEqual(struct4.rank, 1)
        self.assertEqual(set(struct4.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct4.field_value("r"), rt)
        self.assertAllEqual(struct4.field_value("s"), struct2)

    @parameterized.parameters([
        (object(), {}, TypeError),
        ([], object(), TypeError, "fields must be a dictionary"),
        ([], {
            1: 2
        }, TypeError, "Unexpected type for key"),
        ([], {
            "x": object()
        }, TypeError, "Unexpected type for value"),
        (None, {}, ValueError,
         "StructuredTensor's shape must have known rank"),
        ([5], {
            "f": 5
        }, ValueError, r"Shapes \(5,\) and \(\) are not compatible"),
        ([None], {
            "x": [1],
            "y": []
        }, ValueError, r"Shapes \([01],\) and \([01],\) are not compatible"),
        ([], {
            "": 5
        }, ValueError, "Field name '' is not currently allowed."),
        ([], {
            "_": 5
        }, ValueError, "Field name '_' is not currently allowed."),
        {
            # Note: fields must have identical outer row_splits.
            "shape": [2, None],
            "fields": {
                "r1": ragged_factory_ops.constant_value([[1, 2], [3]]),
                "r2": ragged_factory_ops.constant_value([[1, 2, 3], [4]]),
            },
            "err": errors.InvalidArgumentError,
            "msg": r"`fields` are not consistent in the outer 2 dimension"
        },
    ])  # pyformat: disable
    def testFromFieldsErrors(self, shape, fields, err, msg=None):
        with self.assertRaisesRegexp(err, msg):
            struct = structured_tensor.StructuredTensor.from_fields(
                shape, fields)
            self.evaluate(struct.field_value(struct.field_names()[0]))

    @parameterized.parameters([
        {
            "shape": [3],
            "fields": {
                "x": [1, 2, 3],
                "y": [[1, 2], [3, 4], [5, 6]]
            },
            "row_splits": [0, 2, 3],
        },
    ])  # pyformat: disable
    def testFromRowSplits(self,
                          shape,
                          fields,
                          row_splits,
                          expected_shape=None):
        values = structured_tensor.StructuredTensor.from_fields(shape, fields)
        struct = structured_tensor.StructuredTensor.from_row_splits(
            values, row_splits)
        if expected_shape is None:
            expected_shape = tensor_shape.TensorShape(
                [None, None]).concatenate(shape[1:])
            struct.shape.assert_is_compatible_with(expected_shape)
        else:
            self.assertEqual(struct.shape.as_list(), expected_shape)
        self.assertEqual(struct.shape.rank, struct.rank)
        self.assertEqual(struct.field_names(), tuple(fields.keys()))
        for field, value in fields.items():
            self.assertIsInstance(
                struct.field_value(field),
                (ops.Tensor, structured_tensor.StructuredTensor,
                 ragged_tensor.RaggedTensor))
            self.assertAllEqual(
                struct.field_value(field),
                ragged_tensor.RaggedTensor.from_row_splits(value, row_splits))

    @parameterized.parameters([
        ([], {}, ["x"], ValueError, r"Shape \(\) must have rank at least 1"),
        ([0], {}, ["x"], ValueError,
         r"Row-partitioning tensors must have dtype int32 or int64"),
        ([0], {}, [[0]], ValueError, r"Shape \(1, 1\) must have rank 1"),
        ([0], {}, np.array([], np.int32), ValueError,
         r"row_splits may not be empty"),
    ])  # pyformat: disable
    def testFromRowSplitsErrors(self,
                                shape,
                                fields,
                                row_splits,
                                err,
                                msg=None):
        with self.assertRaisesRegexp(err, msg):
            values = structured_tensor.StructuredTensor.from_fields(
                shape, fields)
            structured_tensor.StructuredTensor.from_row_splits(
                values, row_splits)

    def testFromRowSplitsBadValueType(self):
        with self.assertRaisesRegexp(TypeError,
                                     "values must be a StructuredTensor"):
            structured_tensor.StructuredTensor.from_row_splits([1, 2], [0, 2])
示例#42
0
class RaggedConvertToTensorOrRaggedTensorTest(
        ragged_test_util.RaggedTensorTestCase, parameterized.TestCase):

    #=============================================================================
    # Tests where the 'value' param is a RaggedTensor
    #=============================================================================
    @parameterized.parameters([
        dict(pylist=[[1, 2], [3]]),
        dict(pylist=[[1, 2], [3]], preferred_dtype=dtypes.float32),
        dict(pylist=[[1, 2], [3]], preferred_dtype=dtypes.string),
        # Note: Conversion of a single np.array is tested below. These tests
        # check nestings consisting of multiple or irregularily-shaped np.arrays.
        dict(pylist=[np.array([1, 2]), np.array([3])],
             preferred_dtype=dtypes.string),
        dict(pylist=np.array([[1, 2], [3]]), preferred_dtype=dtypes.float32),
        dict(pylist=np.array([[1, 2], [3]]), preferred_dtype=dtypes.string),
        dict(pylist=[np.array([[1], np.array([2])]), [np.array([3])]],
             preferred_dtype=dtypes.float32),
        dict(pylist=[np.array(1)], preferred_dtype=dtypes.string),
    ])
    def testConvertRaggedTensor(self,
                                pylist,
                                dtype=None,
                                preferred_dtype=None):
        rt = ragged_factory_ops.constant(pylist)
        converted = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            rt, dtype, preferred_dtype)
        self.assertIs(converted, rt)

    @parameterized.parameters([
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.float32,
             message=('Tensor conversion requested dtype float32 for '
                      'RaggedTensor with dtype int32')),
        dict(pylist=np.array([[1, 2], [3, 4]]),
             dtype=dtypes.float32,
             message=('Tensor conversion requested dtype float32 for '
                      'RaggedTensor with dtype int32')),
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.string,
             message=('Tensor conversion requested dtype string for '
                      'RaggedTensor with dtype .*')),
    ])
    def testConvertRaggedTensorError(self,
                                     pylist,
                                     message,
                                     dtype=None,
                                     preferred_dtype=None):
        rt = ragged_factory_ops.constant(pylist)

        with self.assertRaisesRegexp(ValueError, message):
            ragged_tensor.convert_to_tensor_or_ragged_tensor(
                rt, dtype, preferred_dtype)

    #=============================================================================
    # Tests where the 'value' param is a RaggedTensorValue
    #=============================================================================
    @parameterized.parameters([
        dict(value=ragged_factory_ops.constant_value([[1, 2], [3]],
                                                     dtype=np.int32),
             expected_dtype=dtypes.int32),
        dict(value=ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']]),
             expected_dtype=dtypes.string),
        dict(value=ragged_factory_ops.constant_value([[1, 2], [3]],
                                                     dtype=np.int32),
             dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=ragged_factory_ops.constant_value([[1, 2], [3]],
                                                     dtype=np.int32),
             preferred_dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=ragged_factory_ops.constant_value([[1, 2], [3]],
                                                     dtype=np.int32),
             preferred_dtype=dtypes.string,
             expected_dtype=dtypes.int32),
    ])
    def testConvertRaggedTensorValue(self,
                                     value,
                                     dtype=None,
                                     preferred_dtype=None,
                                     expected_dtype=None):
        if expected_dtype is None:
            expected_dtype = value.dtype if dtype is None else dtype
        converted = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            value, dtype, preferred_dtype)
        self.assertEqual(value.ragged_rank, converted.ragged_rank)
        self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
        self.assertEqual(value.to_list(), self.eval_to_list(converted))

    @parameterized.parameters([
        dict(value=ragged_factory_ops.constant_value([['a', 'b'], ['c']],
                                                     dtype=str),
             dtype=dtypes.int32,
             message=r"invalid literal for int\(\) with base 10: 'a'"),
    ])
    def testConvertRaggedTensorValueError(self,
                                          value,
                                          message,
                                          dtype=None,
                                          preferred_dtype=None):
        with self.assertRaisesRegexp(ValueError, message):
            ragged_tensor.convert_to_tensor_or_ragged_tensor(
                value, dtype, preferred_dtype)

    #=============================================================================
    # Tests where the 'value' param is a Tensor
    #=============================================================================
    @parameterized.parameters([
        dict(pylist=[[1, 2], [3, 4]]),
        dict(pylist=[[1, 2], [3, 4]], preferred_dtype=dtypes.float32),
        dict(pylist=[[1, 2], [3, 4]], preferred_dtype=dtypes.string),
    ])
    def testConvertTensor(self, pylist, dtype=None, preferred_dtype=None):
        tensor = constant_op.constant(pylist)
        converted = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            tensor, dtype, preferred_dtype)
        self.assertIs(tensor, converted)

    @parameterized.parameters([
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.float32,
             message=('Tensor conversion requested dtype float32 for '
                      'Tensor with dtype int32')),
        dict(pylist=[[1, 2], [3, 4]],
             dtype=dtypes.string,
             message=('Tensor conversion requested dtype string for '
                      'Tensor with dtype int32')),
    ])
    def testConvertTensorError(self,
                               pylist,
                               message,
                               dtype=None,
                               preferred_dtype=None):
        tensor = constant_op.constant(pylist)
        with self.assertRaisesRegexp(ValueError, message):
            ragged_tensor.convert_to_tensor_or_ragged_tensor(
                tensor, dtype, preferred_dtype)

    #=============================================================================
    # Tests where the 'value' param is a np.array
    #=============================================================================
    @parameterized.parameters([
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             expected_dtype=dtypes.int32),
        dict(value=np.array([[b'a', b'b'], [b'c', b'd']]),
             expected_dtype=dtypes.string),
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             preferred_dtype=dtypes.float32,
             expected_dtype=dtypes.float32),
        dict(value=np.array([[1, 2], [3, 4]], dtype=np.int32),
             preferred_dtype=dtypes.string,
             expected_dtype=dtypes.int32),
    ])
    def testConvertNumpyArray(self,
                              value,
                              dtype=None,
                              preferred_dtype=None,
                              expected_dtype=None):
        if expected_dtype is None:
            expected_dtype = value.dtype if dtype is None else dtype
        converted = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            value, dtype, preferred_dtype)
        self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
        self.assertAllEqual(value, converted)

    @parameterized.parameters([
        dict(value=np.array([['a', 'b'], ['c', 'd']], dtype=str),
             dtype=dtypes.int32,
             message=r"invalid literal for int\(\) with base 10: 'a'"),
    ])
    def testConvertNumpyArrayError(self,
                                   value,
                                   message,
                                   dtype=None,
                                   preferred_dtype=None):
        with self.assertRaisesRegexp(ValueError, message):
            ragged_tensor.convert_to_tensor_or_ragged_tensor(
                value, dtype, preferred_dtype)