def testConstruction(self): flat_values_spec = WrappedTensorSpec( tensor_spec.TensorSpec(shape=(None, 5), dtype=dtypes.float32)) spec1 = RaggedTensorSpec(shape=None, dtype=dtypes.float32, ragged_rank=1, row_splits_dtype=dtypes.int64, flat_values_spec=flat_values_spec) self.assertIsNone(spec1._shape.rank) self.assertEqual(spec1._dtype, dtypes.float32) self.assertEqual(spec1._row_splits_dtype, dtypes.int64) self.assertEqual(spec1._ragged_rank, 1) self.assertEqual(spec1._flat_values_spec, flat_values_spec) self.assertIsNone(spec1.shape.rank) self.assertEqual(spec1.dtype, dtypes.float32) self.assertEqual(spec1.row_splits_dtype, dtypes.int64) self.assertEqual(spec1.ragged_rank, 1) self.assertEqual(spec1.flat_values_spec, flat_values_spec) with self.assertRaisesRegex( ValueError, 'dtype must be the same as flat_values_spec.dtype'): spec1 = RaggedTensorSpec(shape=None, dtype=dtypes.float64, ragged_rank=1, row_splits_dtype=dtypes.int64, flat_values_spec=flat_values_spec)
def testFromValue(self): tensor_values = constant_op.constant([[1.0, 2], [4, 5], [7, 8]]) values = WrappedTensor(tensor_values) row_splits = constant_op.constant([0, 2, 3, 3, 3], dtypes.int32) rt = RaggedTensor.from_row_splits(values, row_splits) rt_spec = type_spec.type_spec_from_value(rt) self.assertEqual( rt_spec, RaggedTensorSpec(shape=[4, None, 2], dtype=dtypes.float32, ragged_rank=1, row_splits_dtype=dtypes.int32, flat_values_spec=WrappedTensor.Spec( tensor_spec.TensorSpec([None, 2], dtypes.float32)))) # Ensure the shape of flat_values_spec being consistent with the shape # of the RaggedTensor. self.assertEqual(rt_spec.shape[rt_spec.ragged_rank:], rt_spec.flat_values_spec.shape)
def testIsCompatibleWith(self): spec1 = RaggedTensorSpec([32, None, None], dtypes.float32, 2, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec([None, None], dtypes.float32))) spec2 = RaggedTensorSpec(None, dtypes.float32, 2, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtypes.float32))) spec3 = RaggedTensorSpec(None, dtypes.int32, 1, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtypes.int32))) spec4 = RaggedTensorSpec([None], dtypes.int32, 0, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtypes.int32))) spec5 = RaggedTensorSpec([None], dtypes.int32, 0) self.assertTrue(spec1.is_compatible_with(spec2)) self.assertFalse(spec1.is_compatible_with(spec3)) self.assertFalse(spec1.is_compatible_with(spec4)) self.assertFalse(spec2.is_compatible_with(spec3)) self.assertFalse(spec2.is_compatible_with(spec4)) self.assertFalse(spec3.is_compatible_with(spec4)) self.assertFalse(spec4.is_compatible_with(spec5)) value = constant_op.constant([1, 2, 3]) self.assertFalse(spec4.is_compatible_with(value)) self.assertTrue(spec4.is_compatible_with(WrappedTensor(value)))
class RaggedTensorSpecSupportedValuesTest(test_util.TensorFlowTestCase, parameterized.TestCase): def assertAllTensorsEqual(self, list1, list2): self.assertLen(list1, len(list2)) for (t1, t2) in zip(list1, list2): self.assertAllEqual(t1, t2) def testConstruction(self): flat_values_spec = WrappedTensorSpec( tensor_spec.TensorSpec(shape=(None, 5), dtype=dtypes.float32)) spec1 = RaggedTensorSpec(shape=None, dtype=dtypes.float32, ragged_rank=1, row_splits_dtype=dtypes.int64, flat_values_spec=flat_values_spec) self.assertIsNone(spec1._shape.rank) self.assertEqual(spec1._dtype, dtypes.float32) self.assertEqual(spec1._row_splits_dtype, dtypes.int64) self.assertEqual(spec1._ragged_rank, 1) self.assertEqual(spec1._flat_values_spec, flat_values_spec) self.assertIsNone(spec1.shape.rank) self.assertEqual(spec1.dtype, dtypes.float32) self.assertEqual(spec1.row_splits_dtype, dtypes.int64) self.assertEqual(spec1.ragged_rank, 1) self.assertEqual(spec1.flat_values_spec, flat_values_spec) with self.assertRaisesRegex( ValueError, 'dtype must be the same as flat_values_spec.dtype'): spec1 = RaggedTensorSpec(shape=None, dtype=dtypes.float64, ragged_rank=1, row_splits_dtype=dtypes.int64, flat_values_spec=flat_values_spec) @parameterized.parameters([ (RaggedTensorSpec(ragged_rank=1, flat_values_spec=tensor_spec.TensorSpec( None, dtypes.float32)), (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64, tensor_spec.TensorSpec(None, dtypes.float32))), (RaggedTensorSpec(shape=(5, None, 5), ragged_rank=1, dtype=dtypes.float64, flat_values_spec=tensor_spec.TensorSpec( (5, ), dtypes.float64)), (tensor_shape.TensorShape( (5, None, 5)), dtypes.float64, 1, dtypes.int64, tensor_spec.TensorSpec((5, ), dtypes.float64))), ]) def testSerialize(self, rt_spec, expected): serialization = rt_spec._serialize() # TensorShape has an unconventional definition of equality, so we can't use # assertEqual directly here. But repr() is deterministic and lossless for # the expected values, so we can use that instead. self.assertEqual(repr(serialization), repr(expected)) @parameterized.parameters([ (RaggedTensorSpec(ragged_rank=0, shape=[5, 3], flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec([5, 3], dtypes.float32))), [WrappedTensorSpec(tensor_spec.TensorSpec([5, 3], dtypes.float32))]), (RaggedTensorSpec(ragged_rank=1, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec([None, 3], dtypes.float32))), [ WrappedTensorSpec( tensor_spec.TensorSpec([None, 3], dtypes.float32)), tensor_spec.TensorSpec([None], dtypes.int64), ]), (RaggedTensorSpec(ragged_rank=2, dtype=dtypes.float64, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec([None, 3], dtypes.float64))), [ WrappedTensorSpec( tensor_spec.TensorSpec([None, 3], dtypes.float64)), tensor_spec.TensorSpec([None], dtypes.int64), tensor_spec.TensorSpec([None], dtypes.int64), ]), (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.string, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec([None, 3], dtypes.string))), [ WrappedTensorSpec(tensor_spec.TensorSpec([None, 3], dtypes.string)), tensor_spec.TensorSpec([6], dtypes.int64), tensor_spec.TensorSpec([None], dtypes.int64), ]), ]) def testComponentSpecs(self, rt_spec, expected): self.assertEqual(rt_spec._component_specs, expected) @parameterized.parameters([ { 'rt_spec': RaggedTensorSpec(shape=[3, None, None], ragged_rank=1, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtype=dtypes.float32))), 'flat_values': [[1.0, 2.0], [3.0, 4.0]], 'nested_row_splits': [[0, 1, 1, 2]], }, { 'rt_spec': RaggedTensorSpec(shape=[2, None, None], flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtype=dtypes.float32))), 'flat_values': [1.0, 2.0, 3.0, 4.0], 'nested_row_splits': [[0, 2, 4], [0, 2, 3, 3, 4]], }, ]) def testToFromComponents(self, rt_spec, flat_values, nested_row_splits): wrapped_tensor = WrappedTensor(constant_op.constant(flat_values)) rt = RaggedTensor.from_nested_row_splits(wrapped_tensor, nested_row_splits) components = rt_spec._to_components(rt) self.assertIsInstance(components[0], WrappedTensor) self.assertAllEqual(components[0].value, wrapped_tensor.value) self.assertAllTensorsEqual(components[1:], nested_row_splits) rt_reconstructed = rt_spec._from_components(components) self.assertIsInstance(rt_reconstructed.flat_values, WrappedTensor) self.assertAllEqual(rt_reconstructed.flat_values.value, wrapped_tensor.value) self.assertAllTensorsEqual(rt_reconstructed.nested_row_splits, rt.nested_row_splits) self.assertEqual(rt_reconstructed.dtype, rt.dtype) def testIsCompatibleWith(self): spec1 = RaggedTensorSpec([32, None, None], dtypes.float32, 2, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec([None, None], dtypes.float32))) spec2 = RaggedTensorSpec(None, dtypes.float32, 2, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtypes.float32))) spec3 = RaggedTensorSpec(None, dtypes.int32, 1, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtypes.int32))) spec4 = RaggedTensorSpec([None], dtypes.int32, 0, flat_values_spec=WrappedTensorSpec( tensor_spec.TensorSpec( None, dtypes.int32))) spec5 = RaggedTensorSpec([None], dtypes.int32, 0) self.assertTrue(spec1.is_compatible_with(spec2)) self.assertFalse(spec1.is_compatible_with(spec3)) self.assertFalse(spec1.is_compatible_with(spec4)) self.assertFalse(spec2.is_compatible_with(spec3)) self.assertFalse(spec2.is_compatible_with(spec4)) self.assertFalse(spec3.is_compatible_with(spec4)) self.assertFalse(spec4.is_compatible_with(spec5)) value = constant_op.constant([1, 2, 3]) self.assertFalse(spec4.is_compatible_with(value)) self.assertTrue(spec4.is_compatible_with(WrappedTensor(value))) def testToList(self): with context.eager_mode(): tensor_values = constant_op.constant( ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']) row_splits = constant_op.constant([0, 2, 2, 5, 6, 8], dtypes.int64) values = WrappedTensor(tensor_values) rt = RaggedTensor.from_row_splits(values, row_splits) expected = ragged_factory_ops.constant([['a', 'b'], [], ['c', 'd', 'e'], ['f'], ['g', 'h']]).to_list() with self.subTest('Raise on unsupported'): with self.assertRaisesRegex( ValueError, 'values must be convertible to a list', ): _ = rt.to_list() with self.subTest('Value with numpy method'): class WrappedTensorWithNumpy(WrappedTensor): def numpy(self): return self.value.numpy() values = WrappedTensorWithNumpy(tensor_values) rt = RaggedTensor.from_row_splits(values, row_splits) self.assertEqual(rt.to_list(), expected) with self.subTest('Value with to_list method'): class WrappedTensorWithToList(WrappedTensor): def to_list(self): return self.value.numpy().tolist() values = WrappedTensorWithToList(tensor_values) rt = RaggedTensor.from_row_splits(values, row_splits) self.assertEqual(rt.to_list(), expected)