예제 #1
0
    def testConstructorSignature(self):
        class MyStruct(tensor_struct.Struct):
            x: ops.Tensor
            y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
            z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]

        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('x',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation=ops.Tensor),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation=tensor_spec.TensorSpec(
                                     shape=None, dtype=dtypes.bool)),
            tf_inspect.Parameter('z',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation=typing.Tuple[typing.Union[int,
                                                                      str],
                                                         ...],
                                 default=(1, 'two', 3)),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=MyStruct)
        self.assertEqual(expected_sig, tf_inspect.signature(MyStruct.__init__))
예제 #2
0
    def testForwardReferences(self):
        A, B = ForwardRefA, ForwardRefB

        self.assertEqual(A._tf_struct_fields(), (struct_field.StructField(
            'x', typing.Tuple[typing.Union[A, B],
                              ...]), struct_field.StructField('y', B)))
        self.assertEqual(B._tf_struct_fields(), (struct_field.StructField(
            'z', B), struct_field.StructField('n', ops.Tensor)))

        # Check the signature.
        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter(
                'x',
                tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                annotation=typing.Tuple[typing.Union['ForwardRefA',
                                                     'ForwardRefB'], ...]),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation='ForwardRefB'),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=A)
        self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
예제 #3
0
def _build_spec_constructor(cls):
  """Builds a constructor for ExtensionTypeSpec subclass `cls`."""
  params = []
  kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
  for field in cls._tf_extension_type_fields():  # pylint: disable=protected-access
    params.append(tf_inspect.Parameter(field.name, kind))

  signature = tf_inspect.Signature(params, return_annotation=cls.__name__)

  def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
    bound_args = signature.bind(*args, **kwargs)
    bound_args.apply_defaults()
    self.__dict__.update(bound_args.arguments)
    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
    self.__validate__()

  # __signature__ is supported by some inspection/documentation tools.
  __init__.__signature__ = tf_inspect.Signature(
      [
          tf_inspect.Parameter('self',
                               tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
      ] + params,
      return_annotation=cls)

  cls.__init__ = __init__
예제 #4
0
def _build_struct_constructor(cls):
  """Builds a constructor for tf.Struct subclass `cls`."""

  params = []
  kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
  for field in cls._tf_struct_fields():  # pylint: disable=protected-access
    if field.default is struct_field.StructField.NO_DEFAULT:
      default = tf_inspect.Parameter.empty
    else:
      default = field.default
    params.append(
        tf_inspect.Parameter(
            field.name, kind, default=default, annotation=field.value_type))

  signature = tf_inspect.Signature(params, return_annotation=cls.__name__)

  def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
    bound_args = signature.bind(*args, **kwargs)
    bound_args.apply_defaults()
    self.__dict__.update(bound_args.arguments)
    self._tf_struct_convert_fields()  # pylint: disable=protected-access
    self.__validate__()

  # __signature__ is supported by some inspection/documentation tools
  # (but note: typing.get_type_hints does not respect __signature__).
  __init__.__signature__ = tf_inspect.Signature(
      [
          tf_inspect.Parameter('self',
                               tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
      ] + params,
      return_annotation=cls)

  cls.__init__ = __init__
예제 #5
0
    def testSignatureOnDecoratorsThatDontProvideFullArgSpec(self):
        signature = tf_inspect.signature(test_decorated_function_with_defaults)

        self.assertEqual([
            tf_inspect.Parameter('a',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter(
                'b', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, default=2),
            tf_inspect.Parameter('c',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 default='Hello')
        ], list(signature.parameters.values()))
예제 #6
0
def _build_extension_type_constructor(cls):
    """Builds a constructor for tf.ExtensionType subclass `cls`."""
    fields = cls._tf_extension_type_fields()  # pylint: disable=protected-access

    # Mark any no-default fields that follow default fields as keyword_only.
    got_default = False
    keyword_only_start = len(fields)
    for i in range(len(fields)):
        if got_default:
            if fields[i].default is _NO_DEFAULT:
                keyword_only_start = i
                break
        elif fields[i].default is not _NO_DEFAULT:
            got_default = True

    params = []
    for i, field in enumerate(fields):
        if i < keyword_only_start:
            kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
        else:
            kind = tf_inspect.Parameter.KEYWORD_ONLY
        if field.default is _NO_DEFAULT:
            default = tf_inspect.Parameter.empty
        else:
            default = field.default
        params.append(
            tf_inspect.Parameter(field.name,
                                 kind,
                                 default=default,
                                 annotation=field.value_type))

    signature = tf_inspect.Signature(params, return_annotation=cls.__name__)

    def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
        bound_args = signature.bind(*args, **kwargs)
        bound_args.apply_defaults()
        self.__dict__.update(bound_args.arguments)
        self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
        self.__validate__()

    # __signature__ is supported by some inspection/documentation tools
    # (but note: typing.get_type_hints does not respect __signature__).
    __init__.__signature__ = tf_inspect.Signature([
        tf_inspect.Parameter('self',
                             tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ] + params,
                                                  return_annotation=cls)

    cls.__init__ = __init__
예제 #7
0
def _build_extension_type_constructor(cls):
    """Builds a constructor for tf.ExtensionType subclass `cls`."""
    fields = cls._tf_extension_type_fields()  # pylint: disable=protected-access

    # Check that no-default fields don't follow default fields.  (Otherwise, we
    # can't build a well-formed constructor.)
    default_fields = []
    for field in fields:
        if field.default is not extension_type_field.ExtensionTypeField.NO_DEFAULT:
            default_fields.append(field.name)
        elif default_fields:
            raise ValueError(
                f'In definition for {cls.__name__}: Field without default '
                f'{field.name!r} follows field with default {default_fields[-1]!r}.  '
                f'Either add a default value for {field.name!r}, or move it before '
                f'{default_fields[0]!r} in the field annotations.')

    params = []
    kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
    for field in fields:
        if field.default is extension_type_field.ExtensionTypeField.NO_DEFAULT:
            default = tf_inspect.Parameter.empty
        else:
            default = field.default
        params.append(
            tf_inspect.Parameter(field.name,
                                 kind,
                                 default=default,
                                 annotation=field.value_type))

    signature = tf_inspect.Signature(params, return_annotation=cls.__name__)

    def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
        bound_args = signature.bind(*args, **kwargs)
        bound_args.apply_defaults()
        self.__dict__.update(bound_args.arguments)
        self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
        self.__validate__()

    # __signature__ is supported by some inspection/documentation tools
    # (but note: typing.get_type_hints does not respect __signature__).
    __init__.__signature__ = tf_inspect.Signature([
        tf_inspect.Parameter('self',
                             tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ] + params,
                                                  return_annotation=cls)

    cls.__init__ = __init__
예제 #8
0
    def testSignatureFollowsNestedDecorators(self):
        signature = tf_inspect.signature(test_decorated_function)

        self.assertEqual([
            tf_inspect.Parameter('x',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
        ], list(signature.parameters.values()))
    def testSpecConstructorSignature(self):
        class MyType(extension_type.ExtensionType):
            x: ops.Tensor
            y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
            z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]

        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('x',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('z',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=MyType.Spec)
        self.assertEqual(expected_sig,
                         tf_inspect.signature(MyType.Spec.__init__))