Ejemplo n.º 1
0
    def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
        """Wraps call function with added training argument if necessary."""
        if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
            # Add training arg to wrapper function.
            arg_spec = tf_inspect.getfullargspec(call_fn)
            args = arg_spec.args + ['training']
            defaults = list(arg_spec.defaults or [])
            defaults.append(False)
            new_arg_spec = tf_inspect.FullArgSpec(
                args=args,
                varargs=arg_spec.varargs,
                varkw=arg_spec.varkw,
                defaults=defaults,
                kwonlyargs=arg_spec.kwonlyargs,
                kwonlydefaults=arg_spec.kwonlydefaults,
                annotations=arg_spec.annotations)

            # Set new training arg index
            self._training_arg_index = len(args) - 1
            if tf_inspect.ismethod(call_fn):
                self._training_arg_index -= 1

            def wrap_with_training_arg(*args, **kwargs):
                if match_layer_training_arg:
                    # Remove the training value, since the original call_fn does not
                    # expect a training arg. Instead, the training value will be
                    # propagated using the call context created in LayerCall.
                    args = list(args)
                    kwargs = kwargs.copy()
                    utils.remove_training_arg(self._training_arg_index, args,
                                              kwargs)
                return call_fn(*args, **kwargs)

            return tf_decorator.make_decorator(
                target=call_fn,
                decorator_func=wrap_with_training_arg,
                decorator_argspec=new_arg_spec)

        return call_fn
Ejemplo n.º 2
0
def maybe_add_training_arg(original_call, wrapped_call, expects_training_arg,
                           default_training_value):
    """Decorate call and optionally adds training argument.

  If a layer expects a training argument, this function ensures that 'training'
  is present in the layer args or kwonly args, with the default training value.

  Args:
    original_call: Original call function.
    wrapped_call: Wrapped call function.
    expects_training_arg: Whether to include 'training' argument.
    default_training_value: Default value of the training kwarg to include in
      the arg spec. If `None`, the default is `K.learning_phase()`.

  Returns:
    Tuple of (
      function that calls `wrapped_call` and sets the training arg,
      Argspec of returned function or `None` if the argspec is unchanged)
  """
    if not expects_training_arg:
        return wrapped_call, None

    def wrap_with_training_arg(*args, **kwargs):
        """Wrap the `wrapped_call` function, and set training argument."""
        training_arg_index = get_training_arg_index(original_call)
        training = get_training_arg(training_arg_index, args, kwargs)
        if training is None:
            training = default_training_value or K.learning_phase()

        args = list(args)
        kwargs = kwargs.copy()

        def replace_training_and_call(training):
            set_training_arg(training, training_arg_index, args, kwargs)
            return wrapped_call(*args, **kwargs)

        return control_flow_util.smart_cond(
            training, lambda: replace_training_and_call(True),
            lambda: replace_training_and_call(False))

    # Create arg spec for decorated function. If 'training' is not defined in the
    # args of the original arg spec, then add it to kwonlyargs.
    arg_spec = tf_inspect.getfullargspec(original_call)
    defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []

    kwonlyargs = arg_spec.kwonlyargs
    kwonlydefaults = arg_spec.kwonlydefaults or {}
    # Add training arg if it does not exist, or set the default training value.
    if 'training' not in arg_spec.args:
        kwonlyargs.append('training')
        kwonlydefaults['training'] = default_training_value
    else:
        index = arg_spec.args.index('training')
        training_default_index = len(arg_spec.args) - index
        if (arg_spec.defaults
                and len(arg_spec.defaults) >= training_default_index
                and defaults[-training_default_index] is None):
            defaults[-training_default_index] = default_training_value

    decorator_argspec = tf_inspect.FullArgSpec(
        args=arg_spec.args,
        varargs=arg_spec.varargs,
        varkw=arg_spec.varkw,
        defaults=defaults,
        kwonlyargs=kwonlyargs,
        kwonlydefaults=kwonlydefaults,
        annotations=arg_spec.annotations)
    return wrap_with_training_arg, decorator_argspec