Example #1
0
    def __init__(self, layer):
        self.layer = layer

        self.layer_call_method = _get_layer_call_method(layer)
        self._expects_training_arg = utils.layer_uses_training_bool(layer)
        self._training_arg_index = utils.get_training_arg_index(
            self.layer_call_method)

        # If the layer call function has kwargs, then the traced function cannot
        # have an input signature.
        arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
        self._has_kwargs = bool(self._expects_training_arg or arg_spec.defaults
                                or arg_spec.kwonlyargs or arg_spec.varkw)

        self._input_signature = self._generate_input_signature(layer)
        self._functions = weakref.WeakValueDictionary()
        # Bool indicating whether this object is currently tracing the layer call
        # functions.
        self.tracing = False

        # Get the input argument name from the args.
        args = arg_spec.args
        if tf_inspect.ismethod(self.layer_call_method):
            args = args[1:]
        self._input_arg_name = args[0] if args else 'inputs'
Example #2
0
def fn_args(fn):
    """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
    if isinstance(fn, functools.partial):
        args = fn_args(fn.func)
        args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
    else:
        if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
            fn = fn.__call__
        args = tf_inspect.getfullargspec(fn).args
        if is_bound_method(fn) and args:
            # If it's a bound method, it may or may not have a self/cls first
            # argument; for example, self could be captured in *args.
            # If it does have a positional argument, it is self/cls.
            args.pop(0)
    return tuple(args)
Example #3
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    argspec = tf_inspect.getfullargspec(call_fn)
    if argspec.varargs:
        # When there are variable args, training must be a keyword arg.
        if 'training' in argspec.kwonlyargs or argspec.varkw:
            return -1
        return None
    else:
        # Try to find 'training' in the list of args or kwargs.
        arg_list = argspec.args
        if tf_inspect.ismethod(call_fn):
            arg_list = arg_list[1:]

        if 'training' in arg_list:
            return arg_list.index('training')
        elif 'training' in argspec.kwonlyargs or argspec.varkw:
            return -1
        return None
Example #4
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    arg_list = tf_inspect.getfullargspec(call_fn).args
    if tf_inspect.ismethod(call_fn):
        arg_list = arg_list[1:]
    if 'training' in arg_list:
        return arg_list.index('training')
    else:
        return -1
Example #5
0
    def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
        """Wraps call function with added training argument if necessary."""
        if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
            # Add training arg to wrapper function.
            arg_spec = tf_inspect.getfullargspec(call_fn)
            args = arg_spec.args + ['training']
            defaults = list(arg_spec.defaults or [])
            defaults.append(False)
            new_arg_spec = tf_inspect.FullArgSpec(
                args=args,
                varargs=arg_spec.varargs,
                varkw=arg_spec.varkw,
                defaults=defaults,
                kwonlyargs=arg_spec.kwonlyargs,
                kwonlydefaults=arg_spec.kwonlydefaults,
                annotations=arg_spec.annotations)

            # Set new training arg index
            self._training_arg_index = len(args) - 1
            if tf_inspect.ismethod(call_fn):
                self._training_arg_index -= 1

            def wrap_with_training_arg(*args, **kwargs):
                if match_layer_training_arg:
                    # Remove the training value, since the original call_fn does not
                    # expect a training arg. Instead, the training value will be
                    # propagated using the call context created in LayerCall.
                    args = list(args)
                    kwargs = kwargs.copy()
                    utils.remove_training_arg(self._training_arg_index, args,
                                              kwargs)
                return call_fn(*args, **kwargs)

            return tf_decorator.make_decorator(
                target=call_fn,
                decorator_func=wrap_with_training_arg,
                decorator_argspec=new_arg_spec)

        return call_fn
Example #6
0
def is_bound_method(fn):
    _, fn = tf_decorator.unwrap(fn)
    return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
Example #7
0
def _is_callable_object(obj):
    return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__)