def testAttributeAccessors(self, fields):
     if callable(fields):
         fields = fields()
     s = extension_type.AnonymousExtensionType(**fields)
     for (name, value) in fields.items():
         actual = getattr(s, name)
         if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)):
             self.assertAllEqual(actual, value)
         else:
             self.assertEqual(actual, value)
 def testAttributeAccessorsAreImmutable(self):
     s = extension_type.AnonymousExtensionType(x=12, y={'x': 55})
     with self.assertRaisesRegex(AttributeError,
                                 "cannot assign to field 'x'"):
         s.x = 22
     with self.assertRaisesRegex(AttributeError, "cannot delete field 'y'"):
         del s.y
     with self.assertRaisesRegex(TypeError,
                                 'does not support item assignment'):
         s.y['x'] = 66
 def testConstructionErrors(self, fields, error):
     with self.assertRaisesRegex(ValueError, error):
         extension_type.AnonymousExtensionType(**fields)
 def testConstruction(self, fields):
     if callable(fields):
         fields = fields()
     extension_type.AnonymousExtensionType(**fields)
class AnonymousExtensionTypeTest(test_util.TensorFlowTestCase,
                                 parameterized.TestCase):
    @parameterized.parameters([
        [dict(i=5, f=3.2, b=True, n=None)],
        [dict(x=(1, 2), y={
            3: 4,
            5: 6
        })],
        [lambda: dict(t=constant_op.constant(123))],
        [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))],
    ])
    def testConstruction(self, fields):
        if callable(fields):
            fields = fields()
        extension_type.AnonymousExtensionType(**fields)

    @parameterized.parameters([
        [dict(x=[1, 2, 3]), 'Unsupported field value'],
        [dict(x=set([1, 2])), 'Unsupported field value'],
        [dict(x=(1, dict([(2, [])]))), 'Unsupported field value'],
        [
            dict(_tf_extension_type_xyz=5),
            "The field name '_tf_extension_type_xyz' is reserved"
        ],
    ])
    def testConstructionErrors(self, fields, error):
        with self.assertRaisesRegex(ValueError, error):
            extension_type.AnonymousExtensionType(**fields)

    @parameterized.parameters([
        [dict(i=5, f=3.2, b=True, n=None)],
        [dict(x=(1, 2), y={
            3: 4,
            5: 6
        })],
        [lambda: dict(t=constant_op.constant(123))],
        [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))],
    ])
    def testAttributeAccessors(self, fields):
        if callable(fields):
            fields = fields()
        s = extension_type.AnonymousExtensionType(**fields)
        for (name, value) in fields.items():
            actual = getattr(s, name)
            if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)):
                self.assertAllEqual(actual, value)
            else:
                self.assertEqual(actual, value)

    def testAttributeAccessorsAreImmutable(self):
        s = extension_type.AnonymousExtensionType(x=12, y={'x': 55})
        with self.assertRaisesRegex(AttributeError,
                                    "cannot assign to field 'x'"):
            s.x = 22
        with self.assertRaisesRegex(AttributeError, "cannot delete field 'y'"):
            del s.y
        with self.assertRaisesRegex(TypeError,
                                    'does not support item assignment'):
            s.y['x'] = 66

    def testReinterpret(self):
        x = MaskedTensorV2([4, 5], [True, False])
        anon_x = extension_type.reinterpret(
            x, extension_type.AnonymousExtensionType)
        self.assertAllEqual(anon_x.values, [4, 5])
        self.assertAllEqual(anon_x.mask, [True, False])

        round_trip_x = extension_type.reinterpret(anon_x, MaskedTensorV2)
        self.assertAllEqual(round_trip_x.values, [4, 5])
        self.assertAllEqual(round_trip_x.mask, [True, False])

        converted_x = extension_type.reinterpret(anon_x, MaskedTensorV1)
        self.assertAllEqual(converted_x.values, [4, 5])
        self.assertAllEqual(converted_x.mask, [True, False])

    # pylint: disable=g-long-lambda
    @parameterized.parameters([
        [
            lambda: extension_type.AnonymousExtensionType(
                values=constant_op.constant([1, 2, 3])), MaskedTensorV2,
            "Missing required fields: {'mask'}"
        ],
        [
            lambda: extension_type.AnonymousExtensionType(values=(1, 2, 3),
                                                          mask=None),
            MaskedTensorV2, 'mask: expected a tf.bool Tensor, got None'
        ],
        [
            lambda: extension_type.AnonymousExtensionType(
                values=constant_op.constant([[1, 2], [3, 4]]),
                mask=ragged_factory_ops.constant([[1, 2], [3]])),
            MaskedTensorV2, 'mask: expected a tf.bool Tensor'
        ],
        [
            lambda: extension_type.AnonymousExtensionType(
                values=constant_op.constant([1, 2, 3]),
                mask=constant_op.constant([True, False])), MaskedTensorV2,
            'Shapes .* are incompatible'
        ],
        [
            lambda: extension_type.AnonymousExtensionType(
                values=constant_op.constant([1, 2, 3])), ops.Tensor,
            'Expected `new_type` to be a subclass of tf.ExtensionType'
        ],
        [
            lambda: constant_op.constant([1, 2, 3]),
            extension_type.AnonymousExtensionType,
            'Expected `value` to be a tf.ExtensionType'
        ],
    ])
    def testReinterpretErrors(self, value, new_type, error):
        if callable(value):
            value = value()
        with self.assertRaisesRegex((TypeError, ValueError), error):
            extension_type.reinterpret(value, new_type)

    def testLoadSavedModelWithUnregisteredExtensionType(self):
        def f(x, y):
            x_values = x.values if isinstance(x, MaskedTensorV1) else x
            y_values = y.values if isinstance(y, MaskedTensorV1) else y
            x_mask = x.mask if isinstance(x, MaskedTensorV1) else True
            y_mask = y.mask if isinstance(y, MaskedTensorV1) else True
            return MaskedTensorV1(x_values + y_values, x_mask & y_mask)

        t_spec = tensor_spec.TensorSpec(None, dtypes.int32)
        b_spec = tensor_spec.TensorSpec(None, dtypes.bool)
        mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec)
        model = module.Module()
        model.f = def_function.function(f)
        model.f.get_concrete_function(t_spec, t_spec)
        model.f.get_concrete_function(t_spec, mt_spec)
        model.f.get_concrete_function(mt_spec, t_spec)
        model.f.get_concrete_function(mt_spec, mt_spec)

        path = tempfile.mkdtemp(prefix=test.get_temp_dir())
        with temporarily_register_type_spec('tf.test.MaskedTensorV1.Spec',
                                            MaskedTensorV1.Spec):
            save.save(model, path)
        loaded_model = load.load(path)

        with self.assertRaises(ValueError):
            type_spec.lookup('tf.test.MaskedTensorV1')

        t = constant_op.constant([10, 20, 30])
        v1 = loaded_model.f(t, t)
        self.assertIsInstance(v1, extension_type.AnonymousExtensionType)
        self.assertAllEqual(v1.values, [20, 40, 60])
        self.assertAllEqual(v1.mask, True)

        v2 = loaded_model.f(v1, v1)
        self.assertIsInstance(v2, extension_type.AnonymousExtensionType)
        self.assertAllEqual(v2.values, [40, 80, 120])
        self.assertAllEqual(v2.mask, True)

        mt = MaskedTensorV1([1, 2, 3], [True, True, False])
        v3 = loaded_model.f(
            t,
            extension_type.reinterpret(mt,
                                       extension_type.AnonymousExtensionType))
        self.assertIsInstance(v3, extension_type.AnonymousExtensionType)
        self.assertAllEqual(v3.values, [11, 22, 33])
        self.assertAllEqual(v3.mask, [True, True, False])

        v4 = extension_type.reinterpret(v3, MaskedTensorV1)
        self.assertIsInstance(v4, MaskedTensorV1)
        self.assertAllEqual(v4.values, [11, 22, 33])
        self.assertAllEqual(v4.mask, [True, True, False])