Exemple #1
0
def fn_args(fn):
    """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
    if isinstance(fn, functools.partial):
        args = fn_args(fn.func)
        args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
    else:
        if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
            fn = fn.__call__
        args = tf_inspect.getfullargspec(fn).args
        if is_bound_method(fn) and args:
            # If it's a bound method, it may or may not have a self/cls first
            # argument; for example, self could be captured in *args.
            # If it does have a positional argument, it is self/cls.
            args.pop(0)
    return tuple(args)
Exemple #2
0
def call_is_method(call_fn):
    """Check if call_fn is a method regardless of pre/post-binding decoration.

  E.g. decoration after instance creation/method binding
  self.call = @tf.function(self.call).

  Or decoration in the class definition before binding:
  class Foo(Layer):
    @decorator
    def call(self, ...):
      pass

  Args:
    call_fn: The fn to check

  Returns:
    True if the fn is a bound method, or a bound method that was decorated
    after binding. Else False.
  """
    # tf_inspect checks if a call_fn is either an undecorated bound method,
    # or a bound method that was decorated after the method was bound
    # to an instance. E.g. in the case of
    # self.call = @tf.function(self.call).
    #
    # _inspect checks if the method is bound, and returns true even if the method
    # was decorated before binding occurred.
    # e.g. this would happen in the case of
    # class Foo(Layer):
    #   @decorator
    #   def call(self, ...):
    #     pass
    return tf_inspect.ismethod(call_fn) or _inspect.ismethod(call_fn)
Exemple #3
0
  def __init__(self, layer):
    self.layer = layer

    self.layer_call_method = _get_layer_call_method(layer)
    self._expects_training_arg = utils.layer_uses_training_bool(layer)
    self._training_arg_index = utils.get_training_arg_index(
        self.layer_call_method)

    # If the layer call function has kwargs, then the traced function cannot
    # have an input signature.
    arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
    self._has_kwargs = bool(self._expects_training_arg or
                            arg_spec.defaults or
                            arg_spec.kwonlyargs or
                            arg_spec.varkw)

    self._input_signature = self._generate_input_signature(layer)
    self._functions = weakref.WeakValueDictionary()
    # Bool indicating whether this object is currently tracing the layer call
    # functions.
    self.tracing = False

    # Get the input argument name from the args.
    args = arg_spec.args
    if tf_inspect.ismethod(self.layer_call_method):
      args = args[1:]
    self._input_arg_name = args[0] if args else 'inputs'
Exemple #4
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    argspec = tf_inspect.getfullargspec(call_fn)
    if argspec.varargs:
        # When there are variable args, training must be a keyword arg.
        if 'training' in argspec.kwonlyargs or argspec.varkw:
            return -1
        return None
    else:
        # Try to find 'training' in the list of args or kwargs.
        arg_list = argspec.args
        if tf_inspect.ismethod(call_fn):
            arg_list = arg_list[1:]

        if 'training' in arg_list:
            return arg_list.index('training')
        elif 'training' in argspec.kwonlyargs or argspec.varkw:
            return -1
        return None
Exemple #5
0
    def __init__(self, layer):
        self.layer = layer

        self.layer_call_method = _get_layer_call_method(layer)
        self._expects_training_arg = utils.layer_uses_training_bool(layer)
        self._training_arg_index = utils.get_training_arg_index(
            self.layer_call_method)

        self._layer_inputs = self._get_layer_inputs(layer)
        self._functions = weakref.WeakValueDictionary()

        # Get the input argument name from the args.
        arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
        args = arg_spec.args
        if tf_inspect.ismethod(self.layer_call_method):
            args = args[1:]
        self._input_arg_name = args[0] if args else 'inputs'
Exemple #6
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    arg_list = tf_inspect.getfullargspec(call_fn).args
    if tf_inspect.ismethod(call_fn):
        arg_list = arg_list[1:]
    if 'training' in arg_list:
        return arg_list.index('training')
    else:
        return -1
Exemple #7
0
    def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
        """Wraps call function with added training argument if necessary."""
        if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
            # Add training arg to wrapper function.
            arg_spec = tf_inspect.getfullargspec(call_fn)
            args = arg_spec.args + ['training']
            defaults = list(arg_spec.defaults or [])
            defaults.append(False)
            new_arg_spec = tf_inspect.FullArgSpec(
                args=args,
                varargs=arg_spec.varargs,
                varkw=arg_spec.varkw,
                defaults=defaults,
                kwonlyargs=arg_spec.kwonlyargs,
                kwonlydefaults=arg_spec.kwonlydefaults,
                annotations=arg_spec.annotations)

            # Set new training arg index
            self._training_arg_index = len(args) - 1
            if tf_inspect.ismethod(call_fn):
                self._training_arg_index -= 1

            def wrap_with_training_arg(*args, **kwargs):
                if match_layer_training_arg:
                    # Remove the training value, since the original call_fn does not
                    # expect a training arg. Instead, the training value will be
                    # propagated using the call context created in LayerCall.
                    args = list(args)
                    kwargs = kwargs.copy()
                    utils.remove_training_arg(self._training_arg_index, args,
                                              kwargs)
                return call_fn(*args, **kwargs)

            return tf.__internal__.decorator.make_decorator(
                target=call_fn,
                decorator_func=wrap_with_training_arg,
                decorator_argspec=new_arg_spec)

        return call_fn
Exemple #8
0
def is_bound_method(fn):
    _, fn = tf.__internal__.decorator.unwrap(fn)
    return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
def _is_callable_object(obj):
    return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__)