示例#1
0
def _check_signature(api_signature, func):
    """Checks that a dispatch target's signature is compatible with an API.

  Args:
    api_signature: The signature of the TensorFlow API.
    func: The dispatch target.

  Raises:
    ValueError: if the signatures are incompatible.  Two signatures are
      considered compatible if they have the same number of parameters, and all
      corresponding parameters have the same `name` and `kind`.  (Parameters
      are not required to have the same default value or the same annotation.)
  """
    # Special case: if func_signature is (*args, **kwargs), then assume it's ok.
    func_argspec = tf_inspect.getargspec(func)
    if (func_argspec.varargs is not None and func_argspec.keywords is not None
            and not func_argspec.args):
        return

    func_signature = tf_inspect.signature(func)
    ok = len(api_signature.parameters) == len(func_signature.parameters)
    if ok:
        for param_1, param_2 in zip(api_signature.parameters.values(),
                                    func_signature.parameters.values()):
            if (param_1.name != param_2.name) or (param_1.kind !=
                                                  param_2.kind):
                ok = False
    if not ok:
        raise ValueError(
            f"Dispatch function's signature {func_signature} does "
            f"not match API's signature {api_signature}.")
示例#2
0
def _add_dispatch_for_unary_elementwise_api(api, x_type,
                                            elementwise_api_handler):
    """Registers a unary elementwise handler as a dispatcher for a given API."""
    api_signature = tf_inspect.signature(api)
    x_name = list(api_signature.parameters)[0]
    name_index = _find_name_index(api_signature)

    need_to_bind_api_args = (len(api_signature.parameters) > 2
                             or "name" not in api_signature.parameters)

    @dispatch_for_api(api, {x_name: x_type})
    def dispatch_target(*args, **kwargs):
        args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
        if args:
            x, args = args[0], args[1:]
        else:
            x = kwargs.pop(x_name)

        if need_to_bind_api_args:
            tensor_api = lambda v: api(v, *args, **kwargs)
        else:
            tensor_api = api

        if name is None:
            return elementwise_api_handler(tensor_api, x)
        else:
            with ops.name_scope(name, None, [x]):
                return elementwise_api_handler(tensor_api, x)

    dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
    dispatch_target.__qualname__ = dispatch_target.__name__
    # Keep track of what targets we've registered (so we can unregister them).
    target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, ), [])
    target_list.append((api, dispatch_target))
示例#3
0
    def testForwardReferences(self):
        A, B = ForwardRefA, ForwardRefB

        self.assertEqual(A._tf_struct_fields(), (struct_field.StructField(
            'x', typing.Tuple[typing.Union[A, B],
                              ...]), struct_field.StructField('y', B)))
        self.assertEqual(B._tf_struct_fields(), (struct_field.StructField(
            'z', B), struct_field.StructField('n', ops.Tensor)))

        # Check the signature.
        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter(
                'x',
                tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                annotation=typing.Tuple[typing.Union['ForwardRefA',
                                                     'ForwardRefB'], ...]),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation='ForwardRefB'),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=A)
        self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
    def testSignatureFollowsNestedDecorators(self):
        signature = tf_inspect.signature(test_decorated_function)

        self.assertEqual([
            tf_inspect.Parameter('x',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
        ], list(signature.parameters.values()))
示例#5
0
    def testConstructorSignature(self):
        class MyStruct(tensor_struct.Struct):
            x: ops.Tensor
            y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
            z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]

        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('x',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation=ops.Tensor),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation=tensor_spec.TensorSpec(
                                     shape=None, dtype=dtypes.bool)),
            tf_inspect.Parameter('z',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation=typing.Tuple[typing.Union[int,
                                                                      str],
                                                         ...],
                                 default=(1, 'two', 3)),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=MyStruct)
        self.assertEqual(expected_sig, tf_inspect.signature(MyStruct.__init__))
示例#6
0
def _update_docstring_with_api_list(target, api_list):
    """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
    lines = []
    for func in api_list:
        name = tf_export_lib.get_canonical_name_for_symbol(
            func, add_prefix_to_v1_names=True)
        if name is not None:
            signature = tf_inspect.signature(func)
            lines.append(f"  * `tf.{name}{signature}`")
    lines.sort()
    target.__doc__ = target.__doc__.replace("  <<API_LIST>>", "\n".join(lines))
    def testSignatureOnDecoratorsThatDontProvideFullArgSpec(self):
        signature = tf_inspect.signature(test_decorated_function_with_defaults)

        self.assertEqual([
            tf_inspect.Parameter('a',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter(
                'b', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, default=2),
            tf_inspect.Parameter('c',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 default='Hello')
        ], list(signature.parameters.values()))
示例#8
0
def _signature_from_annotations(func):
    """Builds a dict mapping from parameter names to type annotations."""
    func_signature = tf_inspect.signature(func)

    signature = dict([(name, param.annotation)
                      for (name, param) in func_signature.parameters.items()
                      if param.annotation != tf_inspect.Parameter.empty])
    if not signature:
        raise ValueError("The dispatch_for decorator must be called with at "
                         "least one signature, or applied to a function that "
                         "has type annotations on its parameters.")
    return signature
示例#9
0
def _extract_init_kwargs(obj,
                         omit_kwargs=(),
                         limit_to=None,
                         prefer_static_value=()):
    """Extract constructor kwargs to reconstruct `obj`."""
    sig = tf_inspect.signature(obj.__init__)
    if any(v.kind in (tf_inspect.Parameter.VAR_KEYWORD,
                      tf_inspect.Parameter.VAR_POSITIONAL)
           for v in sig.parameters.values()):
        raise ValueError(
            '*args and **kwargs are not supported. Found `{}`'.format(sig))

    keys = [p for p in sig.parameters if p != 'self' and p not in omit_kwargs]
    if limit_to is not None:
        keys = [k for k in keys if k in limit_to]

    kwargs = {}
    not_found = object()
    for k in keys:

        if k in prefer_static_value:
            srcs = [
                getattr(obj, 'parameters', {}).get(k, not_found),
                getattr(obj, k, not_found),
                getattr(obj, '_' + k, not_found),
            ]
        else:
            srcs = [
                getattr(obj, k, not_found),
                getattr(obj, '_' + k, not_found),
                getattr(obj, 'parameters', {}).get(k, not_found),
            ]
        if any(v is not not_found for v in srcs):
            kwargs[k] = [v for v in srcs if v is not not_found][0]
        else:
            raise ValueError(
                f'Could not determine an appropriate value for field `{k}` in object '
                ' `{obj}`. Looked for \n'
                ' 1. an attr called `{k}`,\n'
                ' 2. an attr called `_{k}`,\n'
                ' 3. an entry in `obj.parameters` with key "{k}".')
        if k in prefer_static_value and kwargs[k] is not None:
            if tf.is_tensor(kwargs[k]):
                static_val = tf.get_static_value(kwargs[k])
                if static_val is not None:
                    kwargs[k] = static_val
            if isinstance(kwargs[k], (np.ndarray, np.generic)):
                # Generally, these are shapes or int.
                kwargs[k] = kwargs[k].tolist()
    return kwargs
示例#10
0
def _add_name_scope_wrapper(func, api_signature):
    """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.

  If `func` already expects a "name" arg, or if `api_signature` does not
  expect a "name" arg, then returns `func` as-is.

  Args:
    func: The function to wrap.  Signature must match `api_signature` (except
      the "name" parameter may be missing.
    api_signature: The signature of the original API (used to find the index for
      the "name" parameter).

  Returns:
    The wrapped function (or the original function if no wrapping is needed).
  """
    if "name" not in api_signature.parameters:
        return func  # no wrapping needed (API has no name parameter).

    func_signature = tf_inspect.signature(func)
    func_argspec = tf_inspect.getargspec(func)
    if "name" in func_signature.parameters or func_argspec.keywords is not None:
        return func  # No wrapping needed (already has name parameter).

    name_index = list(api_signature.parameters).index("name")

    def wrapped_func(*args, **kwargs):
        if name_index < len(args):
            name = args[name_index]
            args = args[:name_index] + args[name_index + 1:]
        else:
            name = kwargs.pop("name", None)
        if name is None:
            return func(*args, **kwargs)
        else:
            with ops.name_scope(name):
                return func(*args, **kwargs)

    wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
    wrapped_func.__signature__ = func_signature.replace(
        parameters=(list(func_signature.parameters.values()) +
                    [api_signature.parameters["name"]]))
    del wrapped_func._tf_decorator
    return wrapped_func
    def testSpecConstructorSignature(self):
        class MyType(extension_type.ExtensionType):
            x: ops.Tensor
            y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
            z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]

        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('x',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter('z',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=MyType.Spec)
        self.assertEqual(expected_sig,
                         tf_inspect.signature(MyType.Spec.__init__))
示例#12
0
def dispatch_for_api(api, *signatures):
    """Decorator that overrides the default implementation for a TensorFlow API.

  The decorated function (known as the "dispatch target") will override the
  default implementation for the API when the API is called with parameters that
  match a specified type signature.  Signatures are specified using dictionaries
  that map parameter names to type annotations.  E.g., in the following example,
  `masked_add` will be called for `tf.add` if both `x` and `y` are
  `MaskedTensor`s:

  >>> class MaskedTensor(extension_type.ExtensionType):
  ...   values: tf.Tensor
  ...   mask: tf.Tensor

  >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
  ... def masked_add(x, y, name=None):
  ...   return MaskedTensor(x.values + y.values, x.mask & y.mask)

  >>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
  >>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
  values=[11 12], mask=[ True False]

  If multiple type signatures are specified, then the dispatch target will be
  called if any of the signatures match.  For example, the following code
  registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is
  a `MaskedTensor`.

  >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
  ... def masked_add(x, y):
  ...   x_values = x.values if isinstance(x, MaskedTensor) else x
  ...   x_mask = x.mask if isinstance(x, MaskedTensor) else True
  ...   y_values = y.values if isinstance(y, MaskedTensor) else y
  ...   y_mask = y.mask if isinstance(y, MaskedTensor) else True
  ...   return MaskedTensor(x_values + y_values, x_mask & y_mask)

  The type annotations in type signatures may be type objects (e.g.,
  `MaskedTensor`), `typing.List` values, or `typing.Union` values.   For
  example, the following will register `masked_concat` to be called if `values`
  is a list of `MaskedTensor` values:

  >>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
  ... def masked_concat(values, axis):
  ...   return MaskedTensor(tf.concat([v.values for v in values], axis),
  ...                       tf.concat([v.mask for v in values], axis))

  Each type signature must contain at least one subclass of `tf.CompositeTensor`
  (which includes subclasses of `tf.ExtensionType`), and dispatch will only be
  triggered if at least one type-annotated parameter contains a
  `CompositeTensor` value.  This rule avoids invoking dispatch in degenerate
  cases, such as the following examples:

  * `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not
    dispatch to the decorated dispatch target when the user calls
    `tf.concat([])`.

  * `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y':
    Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch
    target when the user calls `tf.add(tf.constant(1), tf.constant(2))`.

  The dispatch target's signature must match the signature of the API that is
  being overridden.  In particular, parameters must have the same names, and
  must occur in the same order.  The dispatch target may optionally elide the
  "name" parameter, in which case it will be wrapped with a call to
  `tf.name_scope` when appropraite.

  Args:
    api: The TensorFlow API to override.
    *signatures: Dictionaries mapping parameter names or indices to type
      annotations, specifying when the dispatch target should be called.  In
      particular, the dispatch target will be called if any signature matches;
      and a signature matches if all of the specified parameters have types that
      match with the indicated type annotations.  If no signatures are
      specified, then a signature will be read from the dispatch target
      function's type annotations.

  Returns:
    A decorator that overrides the default implementation for `api`.

  #### Registered APIs

  The TensorFlow APIs that may be overridden by `@dispatch_for_api` are:

  <<API_LIST>>
  """
    dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None)
    if dispatcher is None:
        raise ValueError(f"{api} does not support dispatch.")

    api_signature = tf_inspect.signature(api)
    signature_checkers = [
        _make_signature_checker(api_signature, signature)
        for signature in signatures
    ]

    def decorator(dispatch_target):
        """Decorator that registers the given dispatch target."""
        if not callable(dispatch_target):
            raise TypeError("Expected dispatch_target to be callable; "
                            f"got {dispatch_target!r}")
        dispatch_target = _add_name_scope_wrapper(dispatch_target,
                                                  api_signature)
        _check_signature(api_signature, dispatch_target)

        for signature_checker in signature_checkers:
            dispatcher.Register(signature_checker, dispatch_target)
        _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(
            signatures)

        if not signature_checkers:
            signature = _signature_from_annotations(dispatch_target)
            checker = _make_signature_checker(api_signature, signature)
            dispatcher.Register(checker, dispatch_target)
            _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(
                signature)

        return dispatch_target

    return decorator
def _cached_signature(f):
    if f not in _sig_cache:
        _sig_cache[f] = tf_inspect.signature(f)
    return _sig_cache[f]