def testConvertFieldsMismatch(self, field_values, error):
   fields = [
       extension_type_field.ExtensionTypeField('x', int),
       extension_type_field.ExtensionTypeField('y', float)
   ]
   with self.assertRaisesRegex(ValueError, error):
     extension_type_field.convert_fields(fields, field_values)
    def testForwardReferences(self):
        A, B = ForwardRefA, ForwardRefB

        self.assertEqual(A._tf_extension_type_fields(),
                         (extension_type_field.ExtensionTypeField(
                             'x', typing.Tuple[typing.Union[A, B], ...]),
                          extension_type_field.ExtensionTypeField('y', B)))
        self.assertEqual(
            B._tf_extension_type_fields(),
            (extension_type_field.ExtensionTypeField('z', B),
             extension_type_field.ExtensionTypeField('n', ops.Tensor)))

        # Check the signature.
        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter(
                'x',
                tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                annotation=typing.Tuple[typing.Union['ForwardRefA',
                                                     'ForwardRefB'], ...]),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation='ForwardRefB'),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=A)
        self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
 def testConvertFields(self):
   fields = [
       extension_type_field.ExtensionTypeField('x', int),
       extension_type_field.ExtensionTypeField(
           'y', typing.Tuple[typing.Union[int, bool], ...]),
       extension_type_field.ExtensionTypeField('z', ops.Tensor)
   ]
   field_values = {'x': 1, 'y': [1, True, 3], 'z': [[1, 2], [3, 4], [5, 6]]}
   extension_type_field.convert_fields(fields, field_values)
   self.assertEqual(set(field_values), set(['x', 'y', 'z']))
   self.assertEqual(field_values['x'], 1)
   self.assertEqual(field_values['y'], (1, True, 3))
   self.assertIsInstance(field_values['z'], ops.Tensor)
   self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]])
 def testRepr(self,
              expected,
              name,
              value_type,
              default=extension_type_field.ExtensionTypeField.NO_DEFAULT):
   field = extension_type_field.ExtensionTypeField(name, value_type, default)
   self.assertEqual(repr(field), expected)
 def testConvertFieldsForSpec(self):
   fields = [
       extension_type_field.ExtensionTypeField('x', int),
       extension_type_field.ExtensionTypeField(
           'y', typing.Tuple[typing.Union[int, bool], ...]),
       extension_type_field.ExtensionTypeField('z', ops.Tensor)
   ]
   field_values = {
       'x': 1,
       'y': [1, True, 3],
       'z': tensor_spec.TensorSpec([5, 3])
   }
   extension_type_field.convert_fields_for_spec(fields, field_values)
   self.assertEqual(set(field_values), set(['x', 'y', 'z']))
   self.assertEqual(field_values['x'], 1)
   self.assertEqual(field_values['y'], (1, True, 3))
   self.assertEqual(field_values['z'], tensor_spec.TensorSpec([5, 3]))
 def testConstruction(
     self,
     name,
     value_type,
     default=extension_type_field.ExtensionTypeField.NO_DEFAULT,
     converted_default=None):
   if callable(default):
     default = default()  # deferred construction (contains tensor)
   field = extension_type_field.ExtensionTypeField(name, value_type, default)
   if converted_default is not None:
     default = converted_default
   self.assertEqual(field.name, name)
   self.assertEqual(field.value_type, value_type)
   if isinstance(field.default, (ops.Tensor, ragged_tensor.RaggedTensor)):
     self.assertAllEqual(field.default, default)
   else:
     self.assertEqual(field.default, default)
    def _tf_extension_type_fields(cls):  # pylint: disable=no-self-argument
        """An ordered list describing the fields of this ExtensionType.

    Returns:
      A list of `ExtensionTypeField` objects.  Forward references are resolved
      if possible, or left unresolved otherwise.
    """
        if '_tf_extension_type_cached_fields' in cls.__dict__:  # do not inherit.
            return cls._tf_extension_type_cached_fields

        try:
            # Using include_extras=False will replace all Annotated[T, ...] with T.
            # The typing_extensions module is used since this is only supported in
            # Python 3.9.
            type_hints = typing_extensions.get_type_hints(cls,
                                                          include_extras=False)
            ok_to_cache = True  # all forward references have been resolved.
        except (NameError, AttributeError):
            # Unresolved forward reference -- gather type hints manually.
            # * NameError comes from an annotation like `Foo` where class
            #   `Foo` hasn't been defined yet.
            # * AttributeError comes from an annotation like `foo.Bar`, where
            #   the module `foo` exists but `Bar` hasn't been defined yet.
            # Note: If a user attempts to instantiate a `ExtensionType` type that
            # still has unresolved forward references (e.g., because of a typo or a
            # missing import), then the constructor will raise an exception.
            type_hints = {}
            for base in reversed(cls.__mro__):
                type_hints.update(base.__dict__.get('__annotations__', {}))
            ok_to_cache = False

        fields = []
        for (name, value_type) in type_hints.items():
            default = getattr(
                cls, name, extension_type_field.ExtensionTypeField.NO_DEFAULT)
            fields.append(
                extension_type_field.ExtensionTypeField(
                    name, value_type, default))
        fields = tuple(fields)

        if ok_to_cache:
            cls._tf_extension_type_cached_fields = fields

        return fields
 def _tf_extension_type_fields(cls):
     return [
         extension_type_field.ExtensionTypeField(name, None)
         for name in cls.__dict__ if
         not extension_type_field.ExtensionTypeField.is_reserved_name(name)
     ]
 def testConstructionError(self, name, value_type, default, error):
   if callable(default):
     default = default()  # deferred construction (contains tensor)
   with self.assertRaisesRegex(TypeError, error):
     extension_type_field.ExtensionTypeField(name, value_type, default)