def testConstructorSignature(self): class MyStruct(tensor_struct.Struct): x: ops.Tensor y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool) z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] expected_parameters = [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter('x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=tensor_spec.TensorSpec( shape=None, dtype=dtypes.bool)), tf_inspect.Parameter('z', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typing.Tuple[typing.Union[int, str], ...], default=(1, 'two', 3)), ] expected_sig = tf_inspect.Signature(expected_parameters, return_annotation=MyStruct) self.assertEqual(expected_sig, tf_inspect.signature(MyStruct.__init__))
def testForwardReferences(self): A, B = ForwardRefA, ForwardRefB self.assertEqual(A._tf_struct_fields(), (struct_field.StructField( 'x', typing.Tuple[typing.Union[A, B], ...]), struct_field.StructField('y', B))) self.assertEqual(B._tf_struct_fields(), (struct_field.StructField( 'z', B), struct_field.StructField('n', ops.Tensor))) # Check the signature. expected_parameters = [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter( 'x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...]), tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation='ForwardRefB'), ] expected_sig = tf_inspect.Signature(expected_parameters, return_annotation=A) self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
def _build_spec_constructor(cls): """Builds a constructor for ExtensionTypeSpec subclass `cls`.""" params = [] kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD for field in cls._tf_extension_type_fields(): # pylint: disable=protected-access params.append(tf_inspect.Parameter(field.name, kind)) signature = tf_inspect.Signature(params, return_annotation=cls.__name__) def __init__(self, *args, **kwargs): # pylint: disable=invalid-name bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() self.__dict__.update(bound_args.arguments) self._tf_extension_type_convert_fields() # pylint: disable=protected-access self.__validate__() # __signature__ is supported by some inspection/documentation tools. __init__.__signature__ = tf_inspect.Signature( [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) ] + params, return_annotation=cls) cls.__init__ = __init__
def _build_struct_constructor(cls): """Builds a constructor for tf.Struct subclass `cls`.""" params = [] kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD for field in cls._tf_struct_fields(): # pylint: disable=protected-access if field.default is struct_field.StructField.NO_DEFAULT: default = tf_inspect.Parameter.empty else: default = field.default params.append( tf_inspect.Parameter( field.name, kind, default=default, annotation=field.value_type)) signature = tf_inspect.Signature(params, return_annotation=cls.__name__) def __init__(self, *args, **kwargs): # pylint: disable=invalid-name bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() self.__dict__.update(bound_args.arguments) self._tf_struct_convert_fields() # pylint: disable=protected-access self.__validate__() # __signature__ is supported by some inspection/documentation tools # (but note: typing.get_type_hints does not respect __signature__). __init__.__signature__ = tf_inspect.Signature( [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) ] + params, return_annotation=cls) cls.__init__ = __init__
def testSignatureOnDecoratorsThatDontProvideFullArgSpec(self): signature = tf_inspect.signature(test_decorated_function_with_defaults) self.assertEqual([ tf_inspect.Parameter('a', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter( 'b', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, default=2), tf_inspect.Parameter('c', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, default='Hello') ], list(signature.parameters.values()))
def _build_extension_type_constructor(cls): """Builds a constructor for tf.ExtensionType subclass `cls`.""" fields = cls._tf_extension_type_fields() # pylint: disable=protected-access # Mark any no-default fields that follow default fields as keyword_only. got_default = False keyword_only_start = len(fields) for i in range(len(fields)): if got_default: if fields[i].default is _NO_DEFAULT: keyword_only_start = i break elif fields[i].default is not _NO_DEFAULT: got_default = True params = [] for i, field in enumerate(fields): if i < keyword_only_start: kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD else: kind = tf_inspect.Parameter.KEYWORD_ONLY if field.default is _NO_DEFAULT: default = tf_inspect.Parameter.empty else: default = field.default params.append( tf_inspect.Parameter(field.name, kind, default=default, annotation=field.value_type)) signature = tf_inspect.Signature(params, return_annotation=cls.__name__) def __init__(self, *args, **kwargs): # pylint: disable=invalid-name bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() self.__dict__.update(bound_args.arguments) self._tf_extension_type_convert_fields() # pylint: disable=protected-access self.__validate__() # __signature__ is supported by some inspection/documentation tools # (but note: typing.get_type_hints does not respect __signature__). __init__.__signature__ = tf_inspect.Signature([ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) ] + params, return_annotation=cls) cls.__init__ = __init__
def _build_extension_type_constructor(cls): """Builds a constructor for tf.ExtensionType subclass `cls`.""" fields = cls._tf_extension_type_fields() # pylint: disable=protected-access # Check that no-default fields don't follow default fields. (Otherwise, we # can't build a well-formed constructor.) default_fields = [] for field in fields: if field.default is not extension_type_field.ExtensionTypeField.NO_DEFAULT: default_fields.append(field.name) elif default_fields: raise ValueError( f'In definition for {cls.__name__}: Field without default ' f'{field.name!r} follows field with default {default_fields[-1]!r}. ' f'Either add a default value for {field.name!r}, or move it before ' f'{default_fields[0]!r} in the field annotations.') params = [] kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD for field in fields: if field.default is extension_type_field.ExtensionTypeField.NO_DEFAULT: default = tf_inspect.Parameter.empty else: default = field.default params.append( tf_inspect.Parameter(field.name, kind, default=default, annotation=field.value_type)) signature = tf_inspect.Signature(params, return_annotation=cls.__name__) def __init__(self, *args, **kwargs): # pylint: disable=invalid-name bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() self.__dict__.update(bound_args.arguments) self._tf_extension_type_convert_fields() # pylint: disable=protected-access self.__validate__() # __signature__ is supported by some inspection/documentation tools # (but note: typing.get_type_hints does not respect __signature__). __init__.__signature__ = tf_inspect.Signature([ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) ] + params, return_annotation=cls) cls.__init__ = __init__
def testSignatureFollowsNestedDecorators(self): signature = tf_inspect.signature(test_decorated_function) self.assertEqual([ tf_inspect.Parameter('x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) ], list(signature.parameters.values()))
def testSpecConstructorSignature(self): class MyType(extension_type.ExtensionType): x: ops.Tensor y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool) z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] expected_parameters = [ tf_inspect.Parameter('self', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter('x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), tf_inspect.Parameter('z', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), ] expected_sig = tf_inspect.Signature(expected_parameters, return_annotation=MyType.Spec) self.assertEqual(expected_sig, tf_inspect.signature(MyType.Spec.__init__))