def testSplitV1(self, input, expected, input_is_ragged=False, **kwargs): # pylint: disable=redefined-builtin # Prepare the input tensor. if input_is_ragged: input = ragged_factory_ops.constant(input, dtype=dtypes.string) else: input = constant_op.constant(input, dtype=dtypes.string) expected_ragged = ragged_factory_ops.constant(expected) actual_ragged_v1 = ragged_string_ops.strings_split_v1( input, result_type="RaggedTensor", **kwargs) actual_ragged_v1_input_kwarg = ragged_string_ops.strings_split_v1( input=input, result_type="RaggedTensor", **kwargs) actual_ragged_v1_source_kwarg = ragged_string_ops.strings_split_v1( source=input, result_type="RaggedTensor", **kwargs) self.assertAllEqual(expected_ragged, actual_ragged_v1) self.assertAllEqual(expected_ragged, actual_ragged_v1_input_kwarg) self.assertAllEqual(expected_ragged, actual_ragged_v1_source_kwarg) expected_sparse = self.evaluate(expected_ragged.to_sparse()) actual_sparse_v1 = ragged_string_ops.strings_split_v1( input, result_type="SparseTensor", **kwargs) self.assertEqual(expected_sparse.indices.tolist(), self.evaluate(actual_sparse_v1.indices).tolist()) self.assertEqual(expected_sparse.values.tolist(), self.evaluate(actual_sparse_v1.values).tolist()) self.assertEqual(expected_sparse.dense_shape.tolist(), self.evaluate(actual_sparse_v1.dense_shape).tolist())
def testSplitV2(self, input, expected, input_is_ragged=False, **kwargs): # pylint: disable=redefined-builtin # Check that we are matching the behavior of Python's str.split: self.assertEqual(expected, self._py_split(input, **kwargs)) # Prepare the input tensor. if input_is_ragged: input = ragged_factory_ops.constant(input, dtype=dtypes.string) else: input = constant_op.constant(input, dtype=dtypes.string) # Check that the public version (which returns a RaggedTensor) works # correctly. expected_ragged = ragged_factory_ops.constant( expected, ragged_rank=input.shape.ndims) actual_ragged_v1 = ragged_string_ops.strings_split_v1( input, result_type="RaggedTensor", **kwargs) actual_ragged_v2 = ragged_string_ops.string_split_v2(input, **kwargs) self.assertRaggedEqual(expected_ragged, actual_ragged_v1) self.assertRaggedEqual(expected_ragged, actual_ragged_v2) # Check that the internal version (which returns a SparseTensor) works # correctly. Note: the internal version oly supports vector inputs. if input.shape.ndims == 1: expected_sparse = self.evaluate(expected_ragged.to_sparse()) actual_sparse_v1 = ragged_string_ops.strings_split_v1( input, result_type="SparseTensor", **kwargs) actual_sparse_v2 = string_ops.string_split_v2(input, **kwargs) for actual_sparse in [actual_sparse_v1, actual_sparse_v2]: self.assertEqual(expected_sparse.indices.tolist(), self.evaluate(actual_sparse.indices).tolist()) self.assertEqual(expected_sparse.values.tolist(), self.evaluate(actual_sparse.values).tolist()) self.assertEqual( expected_sparse.dense_shape.tolist(), self.evaluate(actual_sparse.dense_shape).tolist())
def testSplitV2(self, input, expected, input_is_ragged=False, **kwargs): # pylint: disable=redefined-builtin # Check that we are matching the behavior of Python's str.split: self.assertEqual(expected, self._py_split(input, **kwargs)) # Prepare the input tensor. if input_is_ragged: input = ragged_factory_ops.constant(input, dtype=dtypes.string) else: input = constant_op.constant(input, dtype=dtypes.string) # Check that the public version (which returns a RaggedTensor) works # correctly. expected_ragged = ragged_factory_ops.constant( expected, ragged_rank=input.shape.ndims) actual_ragged_v1 = ragged_string_ops.strings_split_v1( input, result_type="RaggedTensor", **kwargs) actual_ragged_v1_input_kwarg = ragged_string_ops.strings_split_v1( input=input, result_type="RaggedTensor", **kwargs) actual_ragged_v1_source_kwarg = ragged_string_ops.strings_split_v1( source=input, result_type="RaggedTensor", **kwargs) actual_ragged_v2 = ragged_string_ops.string_split_v2(input, **kwargs) actual_ragged_v2_input_kwarg = ragged_string_ops.string_split_v2( input=input, **kwargs) self.assertRaggedEqual(expected_ragged, actual_ragged_v1) self.assertRaggedEqual(expected_ragged, actual_ragged_v1_input_kwarg) self.assertRaggedEqual(expected_ragged, actual_ragged_v1_source_kwarg) self.assertRaggedEqual(expected_ragged, actual_ragged_v2) self.assertRaggedEqual(expected_ragged, actual_ragged_v2_input_kwarg) # Check that the internal version (which returns a SparseTensor) works # correctly. Note: the internal version oly supports vector inputs. if input.shape.ndims == 1: expected_sparse = self.evaluate(expected_ragged.to_sparse()) actual_sparse_v1 = ragged_string_ops.strings_split_v1( input, result_type="SparseTensor", **kwargs) actual_sparse_v2 = string_ops.string_split_v2(input, **kwargs) for actual_sparse in [actual_sparse_v1, actual_sparse_v2]: self.assertEqual(expected_sparse.indices.tolist(), self.evaluate(actual_sparse.indices).tolist()) self.assertEqual(expected_sparse.values.tolist(), self.evaluate(actual_sparse.values).tolist()) self.assertEqual(expected_sparse.dense_shape.tolist(), self.evaluate(actual_sparse.dense_shape).tolist())
def testSplitV1BadResultType(self): with self.assertRaisesRegex(ValueError, "result_type must be .*"): ragged_string_ops.strings_split_v1("foo", result_type="BouncyTensor")