def testConvertFieldsMismatch(self, field_values, error): fields = [ extension_type_field.ExtensionTypeField('x', int), extension_type_field.ExtensionTypeField('y', float) ] with self.assertRaisesRegex(ValueError, error): extension_type_field.convert_fields(fields, field_values)
def testForwardReferences(self): A, B = ForwardRefA, ForwardRefB self.assertEqual(A._tf_extension_type_fields(), (extension_type_field.ExtensionTypeField( 'x', typing.Tuple[typing.Union[A, B], ...]), extension_type_field.ExtensionTypeField('y', B))) self.assertEqual( B._tf_extension_type_fields(), (extension_type_field.ExtensionTypeField('z', B), extension_type_field.ExtensionTypeField('n', ops.Tensor))) # Check the signature. expected_parameters = [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter( 'x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...]), tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation='ForwardRefB'), ] expected_sig = tf_inspect.Signature(expected_parameters, return_annotation=A) self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
def testConvertFields(self): fields = [ extension_type_field.ExtensionTypeField('x', int), extension_type_field.ExtensionTypeField( 'y', typing.Tuple[typing.Union[int, bool], ...]), extension_type_field.ExtensionTypeField('z', ops.Tensor) ] field_values = {'x': 1, 'y': [1, True, 3], 'z': [[1, 2], [3, 4], [5, 6]]} extension_type_field.convert_fields(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) self.assertIsInstance(field_values['z'], ops.Tensor) self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]])
def testRepr(self, expected, name, value_type, default=extension_type_field.ExtensionTypeField.NO_DEFAULT): field = extension_type_field.ExtensionTypeField(name, value_type, default) self.assertEqual(repr(field), expected)
def testConvertFieldsForSpec(self): fields = [ extension_type_field.ExtensionTypeField('x', int), extension_type_field.ExtensionTypeField( 'y', typing.Tuple[typing.Union[int, bool], ...]), extension_type_field.ExtensionTypeField('z', ops.Tensor) ] field_values = { 'x': 1, 'y': [1, True, 3], 'z': tensor_spec.TensorSpec([5, 3]) } extension_type_field.convert_fields_for_spec(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) self.assertEqual(field_values['z'], tensor_spec.TensorSpec([5, 3]))
def testConstruction( self, name, value_type, default=extension_type_field.ExtensionTypeField.NO_DEFAULT, converted_default=None): if callable(default): default = default() # deferred construction (contains tensor) field = extension_type_field.ExtensionTypeField(name, value_type, default) if converted_default is not None: default = converted_default self.assertEqual(field.name, name) self.assertEqual(field.value_type, value_type) if isinstance(field.default, (ops.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(field.default, default) else: self.assertEqual(field.default, default)
def _tf_extension_type_fields(cls): # pylint: disable=no-self-argument """An ordered list describing the fields of this ExtensionType. Returns: A list of `ExtensionTypeField` objects. Forward references are resolved if possible, or left unresolved otherwise. """ if '_tf_extension_type_cached_fields' in cls.__dict__: # do not inherit. return cls._tf_extension_type_cached_fields try: # Using include_extras=False will replace all Annotated[T, ...] with T. # The typing_extensions module is used since this is only supported in # Python 3.9. type_hints = typing_extensions.get_type_hints(cls, include_extras=False) ok_to_cache = True # all forward references have been resolved. except (NameError, AttributeError): # Unresolved forward reference -- gather type hints manually. # * NameError comes from an annotation like `Foo` where class # `Foo` hasn't been defined yet. # * AttributeError comes from an annotation like `foo.Bar`, where # the module `foo` exists but `Bar` hasn't been defined yet. # Note: If a user attempts to instantiate a `ExtensionType` type that # still has unresolved forward references (e.g., because of a typo or a # missing import), then the constructor will raise an exception. type_hints = {} for base in reversed(cls.__mro__): type_hints.update(base.__dict__.get('__annotations__', {})) ok_to_cache = False fields = [] for (name, value_type) in type_hints.items(): default = getattr( cls, name, extension_type_field.ExtensionTypeField.NO_DEFAULT) fields.append( extension_type_field.ExtensionTypeField( name, value_type, default)) fields = tuple(fields) if ok_to_cache: cls._tf_extension_type_cached_fields = fields return fields
def _tf_extension_type_fields(cls): return [ extension_type_field.ExtensionTypeField(name, None) for name in cls.__dict__ if not extension_type_field.ExtensionTypeField.is_reserved_name(name) ]
def testConstructionError(self, name, value_type, default, error): if callable(default): default = default() # deferred construction (contains tensor) with self.assertRaisesRegex(TypeError, error): extension_type_field.ExtensionTypeField(name, value_type, default)