Пример #1
0
    def testConstruction(self):
        flat_values_spec = WrappedTensorSpec(
            tensor_spec.TensorSpec(shape=(None, 5), dtype=dtypes.float32))
        spec1 = RaggedTensorSpec(shape=None,
                                 dtype=dtypes.float32,
                                 ragged_rank=1,
                                 row_splits_dtype=dtypes.int64,
                                 flat_values_spec=flat_values_spec)
        self.assertIsNone(spec1._shape.rank)
        self.assertEqual(spec1._dtype, dtypes.float32)
        self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
        self.assertEqual(spec1._ragged_rank, 1)
        self.assertEqual(spec1._flat_values_spec, flat_values_spec)

        self.assertIsNone(spec1.shape.rank)
        self.assertEqual(spec1.dtype, dtypes.float32)
        self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
        self.assertEqual(spec1.ragged_rank, 1)
        self.assertEqual(spec1.flat_values_spec, flat_values_spec)

        with self.assertRaisesRegex(
                ValueError,
                'dtype must be the same as flat_values_spec.dtype'):
            spec1 = RaggedTensorSpec(shape=None,
                                     dtype=dtypes.float64,
                                     ragged_rank=1,
                                     row_splits_dtype=dtypes.int64,
                                     flat_values_spec=flat_values_spec)
    def testFromValue(self):
        tensor_values = constant_op.constant([[1.0, 2], [4, 5], [7, 8]])
        values = WrappedTensor(tensor_values)

        row_splits = constant_op.constant([0, 2, 3, 3, 3], dtypes.int32)
        rt = RaggedTensor.from_row_splits(values, row_splits)

        rt_spec = type_spec.type_spec_from_value(rt)
        self.assertEqual(
            rt_spec,
            RaggedTensorSpec(shape=[4, None, 2],
                             dtype=dtypes.float32,
                             ragged_rank=1,
                             row_splits_dtype=dtypes.int32,
                             flat_values_spec=WrappedTensor.Spec(
                                 tensor_spec.TensorSpec([None, 2],
                                                        dtypes.float32))))
        # Ensure the shape of flat_values_spec being consistent with the shape
        # of the RaggedTensor.
        self.assertEqual(rt_spec.shape[rt_spec.ragged_rank:],
                         rt_spec.flat_values_spec.shape)
Пример #3
0
    def testIsCompatibleWith(self):
        spec1 = RaggedTensorSpec([32, None, None],
                                 dtypes.float32,
                                 2,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec([None, None],
                                                            dtypes.float32)))
        spec2 = RaggedTensorSpec(None,
                                 dtypes.float32,
                                 2,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec(
                                         None, dtypes.float32)))
        spec3 = RaggedTensorSpec(None,
                                 dtypes.int32,
                                 1,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec(
                                         None, dtypes.int32)))
        spec4 = RaggedTensorSpec([None],
                                 dtypes.int32,
                                 0,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec(
                                         None, dtypes.int32)))
        spec5 = RaggedTensorSpec([None], dtypes.int32, 0)

        self.assertTrue(spec1.is_compatible_with(spec2))
        self.assertFalse(spec1.is_compatible_with(spec3))
        self.assertFalse(spec1.is_compatible_with(spec4))
        self.assertFalse(spec2.is_compatible_with(spec3))
        self.assertFalse(spec2.is_compatible_with(spec4))
        self.assertFalse(spec3.is_compatible_with(spec4))
        self.assertFalse(spec4.is_compatible_with(spec5))
        value = constant_op.constant([1, 2, 3])
        self.assertFalse(spec4.is_compatible_with(value))
        self.assertTrue(spec4.is_compatible_with(WrappedTensor(value)))
Пример #4
0
class RaggedTensorSpecSupportedValuesTest(test_util.TensorFlowTestCase,
                                          parameterized.TestCase):
    def assertAllTensorsEqual(self, list1, list2):
        self.assertLen(list1, len(list2))
        for (t1, t2) in zip(list1, list2):
            self.assertAllEqual(t1, t2)

    def testConstruction(self):
        flat_values_spec = WrappedTensorSpec(
            tensor_spec.TensorSpec(shape=(None, 5), dtype=dtypes.float32))
        spec1 = RaggedTensorSpec(shape=None,
                                 dtype=dtypes.float32,
                                 ragged_rank=1,
                                 row_splits_dtype=dtypes.int64,
                                 flat_values_spec=flat_values_spec)
        self.assertIsNone(spec1._shape.rank)
        self.assertEqual(spec1._dtype, dtypes.float32)
        self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
        self.assertEqual(spec1._ragged_rank, 1)
        self.assertEqual(spec1._flat_values_spec, flat_values_spec)

        self.assertIsNone(spec1.shape.rank)
        self.assertEqual(spec1.dtype, dtypes.float32)
        self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
        self.assertEqual(spec1.ragged_rank, 1)
        self.assertEqual(spec1.flat_values_spec, flat_values_spec)

        with self.assertRaisesRegex(
                ValueError,
                'dtype must be the same as flat_values_spec.dtype'):
            spec1 = RaggedTensorSpec(shape=None,
                                     dtype=dtypes.float64,
                                     ragged_rank=1,
                                     row_splits_dtype=dtypes.int64,
                                     flat_values_spec=flat_values_spec)

    @parameterized.parameters([
        (RaggedTensorSpec(ragged_rank=1,
                          flat_values_spec=tensor_spec.TensorSpec(
                              None, dtypes.float32)),
         (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64,
          tensor_spec.TensorSpec(None, dtypes.float32))),
        (RaggedTensorSpec(shape=(5, None, 5),
                          ragged_rank=1,
                          dtype=dtypes.float64,
                          flat_values_spec=tensor_spec.TensorSpec(
                              (5, ), dtypes.float64)),
         (tensor_shape.TensorShape(
             (5, None, 5)), dtypes.float64, 1, dtypes.int64,
          tensor_spec.TensorSpec((5, ), dtypes.float64))),
    ])
    def testSerialize(self, rt_spec, expected):
        serialization = rt_spec._serialize()
        # TensorShape has an unconventional definition of equality, so we can't use
        # assertEqual directly here.  But repr() is deterministic and lossless for
        # the expected values, so we can use that instead.
        self.assertEqual(repr(serialization), repr(expected))

    @parameterized.parameters([
        (RaggedTensorSpec(ragged_rank=0,
                          shape=[5, 3],
                          flat_values_spec=WrappedTensorSpec(
                              tensor_spec.TensorSpec([5, 3], dtypes.float32))),
         [WrappedTensorSpec(tensor_spec.TensorSpec([5, 3], dtypes.float32))]),
        (RaggedTensorSpec(ragged_rank=1,
                          flat_values_spec=WrappedTensorSpec(
                              tensor_spec.TensorSpec([None, 3],
                                                     dtypes.float32))),
         [
             WrappedTensorSpec(
                 tensor_spec.TensorSpec([None, 3], dtypes.float32)),
             tensor_spec.TensorSpec([None], dtypes.int64),
         ]),
        (RaggedTensorSpec(ragged_rank=2,
                          dtype=dtypes.float64,
                          flat_values_spec=WrappedTensorSpec(
                              tensor_spec.TensorSpec([None, 3],
                                                     dtypes.float64))),
         [
             WrappedTensorSpec(
                 tensor_spec.TensorSpec([None, 3], dtypes.float64)),
             tensor_spec.TensorSpec([None], dtypes.int64),
             tensor_spec.TensorSpec([None], dtypes.int64),
         ]),
        (RaggedTensorSpec(shape=[5, None, None],
                          dtype=dtypes.string,
                          flat_values_spec=WrappedTensorSpec(
                              tensor_spec.TensorSpec([None, 3],
                                                     dtypes.string))),
         [
             WrappedTensorSpec(tensor_spec.TensorSpec([None, 3],
                                                      dtypes.string)),
             tensor_spec.TensorSpec([6], dtypes.int64),
             tensor_spec.TensorSpec([None], dtypes.int64),
         ]),
    ])
    def testComponentSpecs(self, rt_spec, expected):
        self.assertEqual(rt_spec._component_specs, expected)

    @parameterized.parameters([
        {
            'rt_spec':
            RaggedTensorSpec(shape=[3, None, None],
                             ragged_rank=1,
                             flat_values_spec=WrappedTensorSpec(
                                 tensor_spec.TensorSpec(
                                     None, dtype=dtypes.float32))),
            'flat_values': [[1.0, 2.0], [3.0, 4.0]],
            'nested_row_splits': [[0, 1, 1, 2]],
        },
        {
            'rt_spec':
            RaggedTensorSpec(shape=[2, None, None],
                             flat_values_spec=WrappedTensorSpec(
                                 tensor_spec.TensorSpec(
                                     None, dtype=dtypes.float32))),
            'flat_values': [1.0, 2.0, 3.0, 4.0],
            'nested_row_splits': [[0, 2, 4], [0, 2, 3, 3, 4]],
        },
    ])
    def testToFromComponents(self, rt_spec, flat_values, nested_row_splits):
        wrapped_tensor = WrappedTensor(constant_op.constant(flat_values))
        rt = RaggedTensor.from_nested_row_splits(wrapped_tensor,
                                                 nested_row_splits)
        components = rt_spec._to_components(rt)
        self.assertIsInstance(components[0], WrappedTensor)
        self.assertAllEqual(components[0].value, wrapped_tensor.value)
        self.assertAllTensorsEqual(components[1:], nested_row_splits)
        rt_reconstructed = rt_spec._from_components(components)
        self.assertIsInstance(rt_reconstructed.flat_values, WrappedTensor)
        self.assertAllEqual(rt_reconstructed.flat_values.value,
                            wrapped_tensor.value)
        self.assertAllTensorsEqual(rt_reconstructed.nested_row_splits,
                                   rt.nested_row_splits)
        self.assertEqual(rt_reconstructed.dtype, rt.dtype)

    def testIsCompatibleWith(self):
        spec1 = RaggedTensorSpec([32, None, None],
                                 dtypes.float32,
                                 2,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec([None, None],
                                                            dtypes.float32)))
        spec2 = RaggedTensorSpec(None,
                                 dtypes.float32,
                                 2,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec(
                                         None, dtypes.float32)))
        spec3 = RaggedTensorSpec(None,
                                 dtypes.int32,
                                 1,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec(
                                         None, dtypes.int32)))
        spec4 = RaggedTensorSpec([None],
                                 dtypes.int32,
                                 0,
                                 flat_values_spec=WrappedTensorSpec(
                                     tensor_spec.TensorSpec(
                                         None, dtypes.int32)))
        spec5 = RaggedTensorSpec([None], dtypes.int32, 0)

        self.assertTrue(spec1.is_compatible_with(spec2))
        self.assertFalse(spec1.is_compatible_with(spec3))
        self.assertFalse(spec1.is_compatible_with(spec4))
        self.assertFalse(spec2.is_compatible_with(spec3))
        self.assertFalse(spec2.is_compatible_with(spec4))
        self.assertFalse(spec3.is_compatible_with(spec4))
        self.assertFalse(spec4.is_compatible_with(spec5))
        value = constant_op.constant([1, 2, 3])
        self.assertFalse(spec4.is_compatible_with(value))
        self.assertTrue(spec4.is_compatible_with(WrappedTensor(value)))

    def testToList(self):
        with context.eager_mode():
            tensor_values = constant_op.constant(
                ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])
            row_splits = constant_op.constant([0, 2, 2, 5, 6, 8], dtypes.int64)
            values = WrappedTensor(tensor_values)
            rt = RaggedTensor.from_row_splits(values, row_splits)
            expected = ragged_factory_ops.constant([['a', 'b'], [],
                                                    ['c', 'd', 'e'], ['f'],
                                                    ['g', 'h']]).to_list()

            with self.subTest('Raise on unsupported'):
                with self.assertRaisesRegex(
                        ValueError,
                        'values must be convertible to a list',
                ):
                    _ = rt.to_list()

            with self.subTest('Value with numpy method'):

                class WrappedTensorWithNumpy(WrappedTensor):
                    def numpy(self):
                        return self.value.numpy()

                values = WrappedTensorWithNumpy(tensor_values)
                rt = RaggedTensor.from_row_splits(values, row_splits)
                self.assertEqual(rt.to_list(), expected)

            with self.subTest('Value with to_list method'):

                class WrappedTensorWithToList(WrappedTensor):
                    def to_list(self):
                        return self.value.numpy().tolist()

                values = WrappedTensorWithToList(tensor_values)
                rt = RaggedTensor.from_row_splits(values, row_splits)
                self.assertEqual(rt.to_list(), expected)