Exemple #1
0
    def testSplitV1(self, input, expected, input_is_ragged=False, **kwargs):  # pylint: disable=redefined-builtin
        # Prepare the input tensor.
        if input_is_ragged:
            input = ragged_factory_ops.constant(input, dtype=dtypes.string)
        else:
            input = constant_op.constant(input, dtype=dtypes.string)

        expected_ragged = ragged_factory_ops.constant(expected)
        actual_ragged_v1 = ragged_string_ops.strings_split_v1(
            input, result_type="RaggedTensor", **kwargs)
        actual_ragged_v1_input_kwarg = ragged_string_ops.strings_split_v1(
            input=input, result_type="RaggedTensor", **kwargs)
        actual_ragged_v1_source_kwarg = ragged_string_ops.strings_split_v1(
            source=input, result_type="RaggedTensor", **kwargs)
        self.assertAllEqual(expected_ragged, actual_ragged_v1)
        self.assertAllEqual(expected_ragged, actual_ragged_v1_input_kwarg)
        self.assertAllEqual(expected_ragged, actual_ragged_v1_source_kwarg)
        expected_sparse = self.evaluate(expected_ragged.to_sparse())
        actual_sparse_v1 = ragged_string_ops.strings_split_v1(
            input, result_type="SparseTensor", **kwargs)
        self.assertEqual(expected_sparse.indices.tolist(),
                         self.evaluate(actual_sparse_v1.indices).tolist())
        self.assertEqual(expected_sparse.values.tolist(),
                         self.evaluate(actual_sparse_v1.values).tolist())
        self.assertEqual(expected_sparse.dense_shape.tolist(),
                         self.evaluate(actual_sparse_v1.dense_shape).tolist())
Exemple #2
0
    def testSplitV2(self, input, expected, input_is_ragged=False, **kwargs):  # pylint: disable=redefined-builtin
        # Check that we are matching the behavior of Python's str.split:
        self.assertEqual(expected, self._py_split(input, **kwargs))

        # Prepare the input tensor.
        if input_is_ragged:
            input = ragged_factory_ops.constant(input, dtype=dtypes.string)
        else:
            input = constant_op.constant(input, dtype=dtypes.string)

        # Check that the public version (which returns a RaggedTensor) works
        # correctly.
        expected_ragged = ragged_factory_ops.constant(
            expected, ragged_rank=input.shape.ndims)
        actual_ragged_v1 = ragged_string_ops.strings_split_v1(
            input, result_type="RaggedTensor", **kwargs)
        actual_ragged_v2 = ragged_string_ops.string_split_v2(input, **kwargs)
        self.assertRaggedEqual(expected_ragged, actual_ragged_v1)
        self.assertRaggedEqual(expected_ragged, actual_ragged_v2)

        # Check that the internal version (which returns a SparseTensor) works
        # correctly.  Note: the internal version oly supports vector inputs.
        if input.shape.ndims == 1:
            expected_sparse = self.evaluate(expected_ragged.to_sparse())
            actual_sparse_v1 = ragged_string_ops.strings_split_v1(
                input, result_type="SparseTensor", **kwargs)
            actual_sparse_v2 = string_ops.string_split_v2(input, **kwargs)
            for actual_sparse in [actual_sparse_v1, actual_sparse_v2]:
                self.assertEqual(expected_sparse.indices.tolist(),
                                 self.evaluate(actual_sparse.indices).tolist())
                self.assertEqual(expected_sparse.values.tolist(),
                                 self.evaluate(actual_sparse.values).tolist())
                self.assertEqual(
                    expected_sparse.dense_shape.tolist(),
                    self.evaluate(actual_sparse.dense_shape).tolist())
  def testSplitV2(self,
                  input,
                  expected,
                  input_is_ragged=False,
                  **kwargs):  # pylint: disable=redefined-builtin
    # Check that we are matching the behavior of Python's str.split:
    self.assertEqual(expected, self._py_split(input, **kwargs))

    # Prepare the input tensor.
    if input_is_ragged:
      input = ragged_factory_ops.constant(input, dtype=dtypes.string)
    else:
      input = constant_op.constant(input, dtype=dtypes.string)

    # Check that the public version (which returns a RaggedTensor) works
    # correctly.
    expected_ragged = ragged_factory_ops.constant(
        expected, ragged_rank=input.shape.ndims)
    actual_ragged_v1 = ragged_string_ops.strings_split_v1(
        input, result_type="RaggedTensor", **kwargs)
    actual_ragged_v1_input_kwarg = ragged_string_ops.strings_split_v1(
        input=input, result_type="RaggedTensor", **kwargs)
    actual_ragged_v1_source_kwarg = ragged_string_ops.strings_split_v1(
        source=input, result_type="RaggedTensor", **kwargs)
    actual_ragged_v2 = ragged_string_ops.string_split_v2(input, **kwargs)
    actual_ragged_v2_input_kwarg = ragged_string_ops.string_split_v2(
        input=input, **kwargs)
    self.assertRaggedEqual(expected_ragged, actual_ragged_v1)
    self.assertRaggedEqual(expected_ragged, actual_ragged_v1_input_kwarg)
    self.assertRaggedEqual(expected_ragged, actual_ragged_v1_source_kwarg)
    self.assertRaggedEqual(expected_ragged, actual_ragged_v2)
    self.assertRaggedEqual(expected_ragged, actual_ragged_v2_input_kwarg)

    # Check that the internal version (which returns a SparseTensor) works
    # correctly.  Note: the internal version oly supports vector inputs.
    if input.shape.ndims == 1:
      expected_sparse = self.evaluate(expected_ragged.to_sparse())
      actual_sparse_v1 = ragged_string_ops.strings_split_v1(
          input, result_type="SparseTensor", **kwargs)
      actual_sparse_v2 = string_ops.string_split_v2(input, **kwargs)
      for actual_sparse in [actual_sparse_v1, actual_sparse_v2]:
        self.assertEqual(expected_sparse.indices.tolist(),
                         self.evaluate(actual_sparse.indices).tolist())
        self.assertEqual(expected_sparse.values.tolist(),
                         self.evaluate(actual_sparse.values).tolist())
        self.assertEqual(expected_sparse.dense_shape.tolist(),
                         self.evaluate(actual_sparse.dense_shape).tolist())
Exemple #4
0
 def testSplitV1BadResultType(self):
     with self.assertRaisesRegex(ValueError, "result_type must be .*"):
         ragged_string_ops.strings_split_v1("foo",
                                            result_type="BouncyTensor")