def testConvertFieldsMismatch(self, field_values, error): fields = [ struct_field.StructField('x', int), struct_field.StructField('y', float) ] with self.assertRaisesRegex(ValueError, error): struct_field.convert_fields(fields, field_values)
def testForwardReferences(self): A, B = ForwardRefA, ForwardRefB self.assertEqual(A._tf_struct_fields(), (struct_field.StructField( 'x', typing.Tuple[typing.Union[A, B], ...]), struct_field.StructField('y', B))) self.assertEqual(B._tf_struct_fields(), (struct_field.StructField( 'z', B), struct_field.StructField('n', ops.Tensor))) # Check the signature. expected_parameters = [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter( 'x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...]), tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation='ForwardRefB'), ] expected_sig = tf_inspect.Signature(expected_parameters, return_annotation=A) self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
def testStructFieldRepr(self, expected, name, value_type, default=struct_field.StructField.NO_DEFAULT): field = struct_field.StructField(name, value_type, default) self.assertEqual(repr(field), expected)
def testConvertFieldsForSpec(self): fields = [ struct_field.StructField('x', int), struct_field.StructField( 'y', typing.Tuple[typing.Union[int, bool], ...]), struct_field.StructField('z', ops.Tensor) ] field_values = { 'x': 1, 'y': [1, True, 3], 'z': tensor_spec.TensorSpec([5, 3]) } struct_field.convert_fields_for_spec(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) self.assertEqual(field_values['z'], tensor_spec.TensorSpec([5, 3]))
def testConvertFields(self): fields = [ struct_field.StructField('x', int), struct_field.StructField( 'y', typing.Tuple[typing.Union[int, bool], ...]), struct_field.StructField('z', ops.Tensor) ] field_values = { 'x': 1, 'y': [1, True, 3], 'z': [[1, 2], [3, 4], [5, 6]] } struct_field.convert_fields(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) self.assertIsInstance(field_values['z'], ops.Tensor) self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]])
def testStructFieldConstruction( self, name, value_type, default=struct_field.StructField.NO_DEFAULT, converted_default=None): if callable(default): default = default() # deferred construction (contains tensor) field = struct_field.StructField(name, value_type, default) if converted_default is not None: default = converted_default self.assertEqual(field.name, name) self.assertEqual(field.value_type, value_type) if isinstance(field.default, (ops.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(field.default, default) else: self.assertEqual(field.default, default)
def _tf_struct_fields(cls): # pylint: disable=no-self-argument """An ordered list of `StructField`s describing the fields of this struct. Returns: A list of `StructField` objects. Forward references are resolved if possible, or left unresolved otherwise. """ if cls._tf_struct_cached_fields is not None: return cls._tf_struct_cached_fields try: type_hints = typing.get_type_hints(cls) ok_to_cache = True # all forward references have been resolved. except (NameError, AttributeError): # Unresolved forward reference -- gather type hints manually. # * NameError comes from an annotation like `Foo` where class # `Foo` hasn't been defined yet. # * AttributeError comes from an annotation like `foo.Bar`, where # the module `foo` exists but `Bar` hasn't been defined yet. # Note: If a user attempts to instantiate a `Struct` type that still # has unresolved forward references (e.g., because of a typo or a # missing import), then the constructor will raise an exception. type_hints = {} for base in reversed(cls.__mro__): type_hints.update(base.__dict__.get('__annotations__', {})) ok_to_cache = False fields = [] for (name, value_type) in type_hints.items(): default = getattr(cls, name, struct_field.StructField.NO_DEFAULT) fields.append(struct_field.StructField(name, value_type, default)) fields = tuple(fields) if ok_to_cache: cls._tf_struct_cached_fields = fields return fields
def _tf_struct_fields(cls): return [ struct_field.StructField(name, None) for name in cls.__dict__ if not struct_field.StructField.is_reserved_name(name) ]
def testStructFieldConstructionError(self, name, value_type, default, error): if callable(default): default = default() # deferred construction (contains tensor) with self.assertRaisesRegex(TypeError, error): struct_field.StructField(name, value_type, default)