Example #1
0
  def _defun_with_scope(self, scope):
    """Creates a defun wrapped inside a variable creator scope."""

    weak_wrapped_fn = None
    def wrapped_fn(*args, **kwds):
      """Wraps `self._python_function` in a variable creator scope."""
      # We register a variable creator with reduced priority. If an outer
      # variable creator is just modifying keyword arguments to the variable
      # constructor, this will work harmoniously. Since the `scope` registered
      # here actually creates the variable, it taking priority would otherwise
      # ignore the outer creator.
      #
      # If an outer variable creator calls the variable constructor manually,
      # for example creating a MirroredVariable, then they won't call our
      # creator. This means we won't be able to trace the initialization graph,
      # and so variable initializers can't depend on function arguments. This is
      # better than the alternative, tracing the initialization graph but giving
      # the user a variable type they didn't want.
      with ops.get_default_graph()._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
        # __wrapped__ allows AutoGraph to swap in a converted function. We give
        # the function a weak reference to itself to avoid a reference cycle.
        return weak_wrapped_fn().__wrapped__(*args, **kwds)
    weak_wrapped_fn = weakref.ref(wrapped_fn)

    # TODO(mdan): Pipe self._experimental_autograph_options through.
    return function_lib.defun(
        tf_decorator.make_decorator(self._python_function, wrapped_fn),
        input_signature=self._input_signature,
        autograph=self._autograph,
        experimental_autograph_options=self._experimental_autograph_options)
Example #2
0
 def decorated(function):
   try:
     name = function.__name__
   except AttributeError:
     name = "function"
   return tf_decorator.make_decorator(
       function, named_defun(function, name, compiled=compiled))
Example #3
0
def update_state_wrapper(update_state_fn):
  """Decorator to wrap metric `update_state()` with `defun()`, `add_update()`.

  Args:
    update_state_fn: function that accumulates metric statistics.

  Returns:
    If eager execution is enabled, returns None.
    If graph execution is enabled, returns an update op. This op should be
      executed to update the metric state with the given inputs.
  """

  def decorated(metric_obj, *args, **kwargs):
    """Decorated function with `defun()` and `add_update()`."""

    # Converting update_state_fn() into a graph function, so that
    # we can return a single op that performs all of the variable updates.
    # Assigning to a different method name to avoid reference cycle.
    defuned_update_state_fn = function.defun(update_state_fn)
    update_op = defuned_update_state_fn(*args, **kwargs)
    if update_op is not None:  # update_op will be None in eager execution.
      metric_obj.add_update(update_op, inputs=True)
      check_is_tensor_or_operation(
          update_op, 'Metric {0}\'s update'.format(metric_obj.name))
    return update_op

  return tf_decorator.make_decorator(update_state_fn, decorated)
Example #4
0
def kwarg_only(f):
  """A wrapper that throws away all non-kwarg arguments."""
  def wrapper(**kwargs):
    return f(**kwargs)

  return tf_decorator.make_decorator(
      f, wrapper, decorator_argspec=tf_inspect.getargspec(f))
Example #5
0
def must_use_result_or_fatal(fn):
  """Function wrapper that ensures the function's output is used.

  If the output is not used, a `tf.compat.v1.logging.fatal` error is raised.

  An output is marked as used if any of its attributes are read, modified, or
  updated.  Examples when the output is a `Tensor` include:

  - Using it in any capacity (e.g. `y = t + 0`, `sess.run(t)`)
  - Accessing a property (e.g. getting `t.name` or `t.op`).

  Note, certain behaviors cannot be tracked - for these the object may not
  be marked as used.  Examples include:

  - `t != 0`.  In this case, comparison is done on types / ids.
  - `isinstance(t, tf.Tensor)`.  Similar to above.

  Args:
    fn: The function to wrap.

  Returns:
    The wrapped function.
  """
  def wrapped(*args, **kwargs):
    return _add_should_use_warning(fn(*args, **kwargs), fatal_error=True)
  return tf_decorator.make_decorator(
      fn, wrapped, 'must_use_result_or_fatal',
      ((fn.__doc__ or '') +
       ('\n\n  '
        '**NOTE** The output of this function must be used.  If it is not, '
        'a fatal error will be raised.  To mark the output as used, '
        'call its .mark_used() method.')))
  def testCompatibleWithNamelessCallables(self):

    class Callable(object):

      def __call__(self):
        pass

    callable_object = Callable()
    # Smoke test: This should not raise an exception, even though
    # `callable_object` does not have a `__name__` attribute.
    _ = tf_decorator.make_decorator(callable_object, test_wrapper)

    partial = functools.partial(test_function, x=1)
    # Smoke test: This should not raise an exception, even though `partial` does
    # not have `__name__`, `__module__`, and `__doc__` attributes.
    _ = tf_decorator.make_decorator(partial, test_wrapper)
Example #7
0
def custom_gradient(f):
  """Decorator to define a function with a custom gradient.

  The input function is expected to return the tuple
    (results, gradient_function).

  The output function will return results while possibly recording the
  gradient_function and inputs in the tape.

  Args:
    f: function to be decorated.

  Returns:
    decorated function.
  """

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [tf_ops.convert_to_tensor(x) for x in args]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    def actual_grad_fn(*outputs):
      return nest.flatten(grad_fn(*outputs))

    flat_result = nest.flatten(result)
    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        actual_grad_fn)
    flat_result = list(flat_result)
    return result

  return tf_decorator.make_decorator(f, decorated)
Example #8
0
def no_automatic_dependency_tracking(method):
  """Disables automatic dependency tracking on attribute assignment.

  Use to decorate any method of a Checkpointable object. Attribute assignment in
  that method will not add dependencies (also respected in Model). Harmless if
  used in a class which does not do automatic dependency tracking (which means
  it's safe to use in base classes which may have subclasses which also inherit
  from Checkpointable).

  Args:
    method: The method to decorate.
  Returns:
    A decorated method which sets and un-sets automatic dependency tracking for
    the object the method is called on (not thread safe).
  """

  def _method_wrapper(self, *args, **kwargs):
    previous_value = getattr(self, "_setattr_tracking", True)
    self._setattr_tracking = False  # pylint: disable=protected-access
    try:
      method(self, *args, **kwargs)
    finally:
      self._setattr_tracking = previous_value  # pylint: disable=protected-access

  return tf_decorator.make_decorator(
      target=method, decorator_func=_method_wrapper)
Example #9
0
  def with_name_scope(cls, method):
    """Decorator to automatically enter the module name scope.

    ```
    class MyModule(tf.Module):
      @tf.Module.with_name_scope
      def __call__(self, x):
        if not hasattr(self, 'w'):
          self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
        return tf.matmul(x, self.w)
    ```

    Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
    names included the module name:

    ```
    mod = MyModule()
    mod(tf.ones([8, 32]))
    # ==> <tf.Tensor: ...>
    mod.w
    # ==> <tf.Variable ...'my_module/w:0'>
    ```

    Args:
      method: The method to wrap.

    Returns:
      The original method wrapped such that it enters the module's name scope.
    """
    def method_with_name_scope(self, *args, **kwargs):
      with self.name_scope:
        return method(self, *args, **kwargs)

    return tf_decorator.make_decorator(method, method_with_name_scope)
Example #10
0
def with_name_scope(unbound_method):
  """Patches the given method so it enters the modules name scope."""
  def enter_name_scope(self, *args, **kwargs):
    """Decorator that calls the given function in the module name scope.

    Args:
      self: Module instance.
      *args: Positional arguments to `unbound_method`.
      **kwargs: Keyword arguments to `unbound_method`.

    Returns:
      `with self.name_scope: return unbound_method(self, *args, **kwargs)`
    """
    try:
      module_name_scope = self.name_scope
    except AttributeError as exc_value_from:
      exc_value = AttributeError(
          "The super constructor must be called before any other methods in "
          "your constructor. If this is not possible then annotate all the "
          "methods called with `@no_module_name_scope`.")
      six.raise_from(exc_value, exc_value_from)

    with module_name_scope:
      # tf.Module enters the module name scope for all methods. To disable this
      # for a particular method annotate it with `@no_module_name_scope`.
      return unbound_method(self, *args, **kwargs)

  return tf_decorator.make_decorator(unbound_method, enter_name_scope)
 def testUpdatesDict_doesNotOverridePresentEntries(self):
   test_function.foobar = True
   test_wrapper.foobar = False
   decorated = tf_decorator.make_decorator(test_function, test_wrapper)
   self.assertFalse(decorated.foobar)
   del test_function.foobar
   del test_wrapper.foobar
Example #12
0
def _defun_with_scope(scope, fn, input_signature):

  def wrapped_fn(*args, **kwds):
    with variable_scope.variable_creator_scope(scope):
      return fn(*args, **kwds)

  return function_lib.defun(tf_decorator.make_decorator(fn, wrapped_fn),
                            input_signature=input_signature)
Example #13
0
def defun(func):
  """Decorator to compile func into graph_mode.

  `defun` converts a function that constructs a TensorFlow graph into a function
  that executes the graph. TensorFlow graphs typically execute faster and with a
  lower memory-footprint than executing each of the operations that make up the
  function individually as the TensorFlow runtime can optimize the graph and
  execute sub-operations in parallel.

  func must be a Python function that constructs a TensorFlow graph,
  typically using functions in the tensorflow module.

  Arguments to func can be either Tensor objects or Python
  objects. Non-Tensor python objects are treated as constants, and new function
  definitions are created internally based on their values.

  func must return a tf.Tensor (NOT a Tensor) or a list of tf.Tensor (NOT a
  Tensor).

  Control flow constructs (e.g., `if`, `while`) are not yet compatible with
  `defun`.

  Example:
  ```python
  def f(x, y):
    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)

  @tfe.defun
  def g(x, y):
    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)

  x = tf.constant([[2.0, 3.0]])
  y = tf.constant([[3.0, -2.0]])
  # The plain function and defun-compiled function should return the same value.
  assert f(x, y).numpy() == g(x, y).numpy()

  # After the first invocation, the defun-compiled (graph) function runs faster
  # than the plain function because the defun-compiled function does not involve
  # Python interpreter overhead during the execution.
  %time print(f(x, y))
  %time print(g(x, y))
  ```

  Args:
    func: function to be compiled.

  Returns:
     A callable that will execute the compiled function (and return zero
     or more Tensor objects).
  """
  # TODO(apassos): deal with captured global state. Deal with control flow.
  try:
    name = func.__name__
  except AttributeError:
    name = "function"
  return tf_decorator.make_decorator(func, named_defun(func, name))
 def testSetsTFDecoratorArgSpec(self):
   argspec = tf_inspect.ArgSpec(
       args=['a', 'b', 'c'],
       varargs=None,
       keywords=None,
       defaults=(1, 'hello'))
   decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '',
                                           argspec)
   decorator = getattr(decorated, '_tf_decorator')
   self.assertEqual(argspec, decorator.decorator_argspec)
Example #15
0
 def decorated(inner_function):
   try:
     name = inner_function.__name__
   except AttributeError:
     name = "function"
   return tf_decorator.make_decorator(
       inner_function,
       PolymorphicFunction(
           inner_function,
           name,
           input_signature=input_signature))
Example #16
0
def contextmanager(target):
  """A tf_decorator-aware wrapper for `contextlib.contextmanager`.

  Usage is identical to `contextlib.contextmanager`.

  Args:
    target: A callable to be wrapped in a contextmanager.
  Returns:
    A callable that can be used inside of a `with` statement.
  """
  context_manager = _contextlib.contextmanager(target)
  return tf_decorator.make_decorator(target, context_manager, 'contextmanager')
def wrap_keras_model_for_export(model, batch_input_shape,
                                set_hparams, default_hparams):
  """Wraps `model` for saving and loading as SavedModel."""
  if default_hparams is None: default_hparams = {}
  hparam_keys = list(default_hparams.keys())
  hparam_defaults = tuple(default_hparams.values())
  # The goal is to save a function with this argspec...
  argspec = tf_inspect.FullArgSpec(
      args=(['inputs', 'training'] + hparam_keys),
      defaults=((False,) + hparam_defaults),
      varargs=None, varkw=None,
      kwonlyargs=[], kwonlydefaults=None,
      annotations={})
  # ...and this behavior:
  def call_fn(inputs, training, *args):
    if FLAGS.export_print_hparams:
      args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s='
                                            % (training, hparam_keys[i]))
              for i in range(len(args))]
    kwargs = dict(zip(hparam_keys, args))
    if kwargs: set_hparams(model, **kwargs)
    return model(inputs, training=training)
  # We cannot spell out `args` in def statement for call_fn, but since
  # tf.function uses tf_inspect, we can use tf_decorator to wrap it with
  # the desired argspec.
  def wrapped(*args, **kwargs):  # TODO(arnoegw): Can we use call_fn itself?
    return call_fn(*args, **kwargs)
  traced_call_fn = tf.function(autograph=False)(
      tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
  # Now we need to trigger traces for
  # - training set to Python values True or False (hence two traces),
  # - tensor inputs of the expected nesting, shape and dtype,
  # - tensor-valued kwargs for hparams, with caller-side defaults.
  # Tracing with partially determined shapes requires an input signature,
  # so we initiate tracing from a helper function with only tensor inputs.
  @tf.function(autograph=False)
  def trigger_traces(inputs, **kwargs):
    return tuple(traced_call_fn(inputs, training=training, **kwargs)
                 for training in (True, False))
  inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32)
  hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value))
                  for name, value in default_hparams.items()}
  _ = trigger_traces.get_concrete_function(inputs_spec, **hparams_spec)

  # Assemble the output object.
  obj = tf.train.Checkpoint()
  obj.__call__ = traced_call_fn
  obj.trainable_variables = model.trainable_variables
  obj.variables = model.trainable_variables + model.non_trainable_variables
  obj.regularization_losses = [_get_traced_loss(model, i)
                               for i in range(len(model.losses))]
  return obj
Example #18
0
  def decorator(f):
    """Decorator implementation."""

    @wraps(f)
    def wrapper(*args, **kwargs):
      return converted_call(f, recursive, verbose, arg_types, *args, **kwargs)

    wrapper = tf_decorator.make_decorator(f, wrapper)

    # Sometimes the decorator is just desugared, making it impossible to detect.
    # This attribute makes detection easier.
    setattr(wrapper, '__pyct_is_compile_decorator', True)
    return wrapper
Example #19
0
 def decorated(inner_function):
   try:
     name = inner_function.__name__
   except AttributeError:
     name = "function"
   return tf_decorator.make_decorator(
       inner_function,
       Function(
           inner_function,
           name,
           input_signature=input_signature,
           autograph=autograph,
           experimental_autograph_options=experimental_autograph_options))
Example #20
0
  def _defun_with_scope(self, scope):
    """Creates a defun wrapped inside a variable creator scope."""

    def wrapped_fn(*args, **kwds):
      with variable_scope.variable_creator_scope(scope):
        # __wrapped__ allows AutoGraph to swap in a converted function.
        return wrapped_fn.__wrapped__(*args, **kwds)

    # TODO(mdan): Pipe self._experimental_autograph_options through.
    return function_lib.defun(
        tf_decorator.make_decorator(self._python_function, wrapped_fn),
        input_signature=self._input_signature,
        autograph=self._autograph)
Example #21
0
def kwarg_only(f):
  """A wrapper that throws away all non-kwarg arguments."""
  f_argspec = tf_inspect.getargspec(f)

  def wrapper(*args, **kwargs):
    if args:
      raise TypeError(
          '{f} only takes keyword args (possible keys: {kwargs}). '
          'Please pass these args as kwargs instead.'
          .format(f=f.__name__, kwargs=f_argspec.args))
    return f(**kwargs)

  return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec)
Example #22
0
def result_wrapper(result_fn):
  """Decorator to wrap metric `result()` function in `merge_call()`.

  Result computation is an idempotent operation that simply calculates the
  metric value using the state variables.

  If metric state variables are distributed across replicas/devices and
  `result()` is requested from the context of one device - This function wraps
  `result()` in a distribution strategy `merge_call()`. With this,
  the metric state variables will be aggregated across devices.

  Args:
    result_fn: function that computes the metric result.

  Returns:
    Decorated function that wraps `result_fn()` in distribution strategy
    `merge_call()`.
  """

  def decorated(_, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = array_ops.identity(result_fn(*args))
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerReplica` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        merged_result_fn = (
            distribution.experimental_local_results(merge_fn)[0](*args))

        # Wrapping result in identity so that control dependency between
        # update_op from `update_state` and result works in case result returns
        # a tensor.
        return array_ops.identity(merged_result_fn)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)
    return result_t

  return tf_decorator.make_decorator(result_fn, decorated)
def test_decorator_increment_first_int_arg(target):
  """This test decorator skips past `self` as args[0] in the bound case."""

  def wrapper(*args, **kwargs):
    new_args = []
    found = False
    for arg in args:
      if not found and isinstance(arg, int):
        new_args.append(arg + 1)
        found = True
      else:
        new_args.append(arg)
    return target(*new_args, **kwargs)

  return tf_decorator.make_decorator(target, wrapper)
Example #24
0
  def _defun_with_scope(self, scope):
    """Creates a defun wrapped inside a variable creator scope."""

    weak_wrapped_fn = None
    def wrapped_fn(*args, **kwds):
      with variable_scope.variable_creator_scope(scope):
        # __wrapped__ allows AutoGraph to swap in a converted function. We give
        # the function a weak reference to itself to avoid a reference cycle.
        return weak_wrapped_fn().__wrapped__(*args, **kwds)
    weak_wrapped_fn = weakref.ref(wrapped_fn)

    # TODO(mdan): Pipe self._experimental_autograph_options through.
    return function_lib.defun(
        tf_decorator.make_decorator(self._python_function, wrapped_fn),
        input_signature=self._input_signature,
        autograph=self._autograph)
Example #25
0
 def deprecated_wrapper(func):
   """Deprecation wrapper."""
   decorator_utils.validate_callable(func, 'deprecated')
   @functools.wraps(func)
   def new_func(*args, **kwargs):
     logging.warning(
         'From %s: %s (from %s) is deprecated and will be removed %s.\n'
         'Instructions for updating:\n%s',
         _call_location(), decorator_utils.get_qualified_name(func),
         func.__module__,
         'in a future version' if date is None else ('after %s' % date),
         instructions)
     return func(*args, **kwargs)
   return tf_decorator.make_decorator(
       func, new_func, 'deprecated',
       _add_deprecated_function_notice_to_docstring(func.__doc__, date,
                                                    instructions))
Example #26
0
def result_wrapper(result_fn):
  """Decorator to wrap metric `result()` function in `merge_call()`.

  Result computation is an idempotent operation that simply calculates the
  metric value using the state variables.

  If metric state variables are distributed across replicas/devices and
  `result()` is requested from the context of one device - This function wraps
  `result()` in a distribution strategy `merge_call()`. With this,
  the metric state variables will be aggregated across devices.

  Args:
    result_fn: function that computes the metric result.

  Returns:
    Decorated function that wraps `result_fn()` in distribution strategy
    `merge_call()`.
  """

  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)
    check_is_tensor_or_operation(result_t,
                                 'Metric {0}\'s result'.format(metric_obj.name))
    return result_t

  return tf_decorator.make_decorator(result_fn, decorated)
Example #27
0
def _wrap_define_function(original_function):
  """Wraps absl.flags's define functions so tf.flags accepts old names."""

  def wrapper(*args, **kwargs):
    """Wrapper function that turns old keyword names to new ones."""
    has_old_names = False
    for old_name, new_name in _six.iteritems(_RENAMED_ARGUMENTS):
      if old_name in kwargs:
        has_old_names = True
        value = kwargs.pop(old_name)
        kwargs[new_name] = value
    if has_old_names:
      _logging.warning(
          'Use of the keyword argument names (flag_name, default_value, '
          'docstring) is deprecated, please use (name, default, help) instead.')
    return original_function(*args, **kwargs)

  return tf_decorator.make_decorator(original_function, wrapper)
Example #28
0
def update_state_wrapper(update_state_fn):
  """Decorator to wrap metric `update_state()` with `add_update()`.

  Args:
    update_state_fn: function that accumulates metric statistics.

  Returns:
    Decorated function that wraps `update_state_fn()` with `add_update()`.
  """

  def decorated(metric_obj, *args, **kwargs):
    """Decorated function with `add_update()`."""

    update_op = update_state_fn(*args, **kwargs)
    if update_op is not None:  # update_op will be None in eager execution.
      metric_obj.add_update(update_op, inputs=True)
    return update_op

  return tf_decorator.make_decorator(update_state_fn, decorated)
Example #29
0
  def decorator(f):
    """Decorator implementation."""

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
      return converted_call(
          f, None,
          converter.ConversionOptions(
              recursive=recursive,
              verbose=verbose,
              force_conversion=True,
          ), *args, **kwargs)

    wrapper = tf_decorator.make_decorator(f, wrapper)

    # Sometimes the decorator is just desugared, making it impossible to detect.
    # This attribute makes detection easier.
    setattr(wrapper, '__pyct_is_compile_decorator', True)
    return wrapper
Example #30
0
 def deprecated_wrapper(func):
   """Deprecation decorator."""
   decorator_utils.validate_callable(func, 'deprecated_arg_values')
   @functools.wraps(func)
   def new_func(*args, **kwargs):
     """Deprecation wrapper."""
     named_args = tf_inspect.getcallargs(func, *args, **kwargs)
     for arg_name, arg_value in deprecated_kwargs.items():
       if arg_name in named_args and named_args[arg_name] == arg_value:
         logging.warning(
             'From %s: calling %s (from %s) with %s=%s is deprecated and will '
             'be removed %s.\nInstructions for updating:\n%s',
             _call_location(), decorator_utils.get_qualified_name(func),
             func.__module__, arg_name, arg_value,
             'in a future version' if date is None else ('after %s' % date),
             instructions)
     return func(*args, **kwargs)
   return tf_decorator.make_decorator(func, new_func, 'deprecated',
                                      _add_deprecated_arg_notice_to_docstring(
                                          func.__doc__, date, instructions))
Example #31
0
    def deprecated_wrapper(func):
        """Deprecation decorator."""
        decorator_utils.validate_callable(func, 'deprecated_args')
        deprecated_arg_names = _get_arg_names_to_ok_vals()

        arg_spec = tf_inspect.getargspec(func)
        deprecated_positions = _get_deprecated_positional_arguments(
            deprecated_arg_names, arg_spec)

        is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
        is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names

        if (len(deprecated_positions) + is_varargs_deprecated +
                is_kwargs_deprecated != len(deprecated_arg_names_or_tuples)):
            known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords]
            missing_args = [
                arg_name for arg_name in deprecated_arg_names
                if arg_name not in known_args
            ]
            raise ValueError(
                'The following deprecated arguments are not present '
                'in the function signature: %s. '
                'Found next arguments: %s.' % (missing_args, known_args))

        def _same_value(a, b):
            """A comparison operation that works for multiple object types.

      Returns True for two empty lists, two numeric values with the
      same value, etc.

      Returns False for (pd.DataFrame, None), and other pairs which
      should not be considered equivalent.

      Args:
        a: value one of the comparison.
        b: value two of the comparison.

      Returns:
        A boolean indicating whether the two inputs are the same value
        for the purposes of deprecation.
      """
            if a is b:
                return True
            try:
                equality = a == b
                if isinstance(equality, bool):
                    return equality
            except TypeError:
                return False
            return False

        @functools.wraps(func)
        def new_func(*args, **kwargs):
            """Deprecation wrapper."""
            if _PRINT_DEPRECATION_WARNINGS:
                invalid_args = []
                named_args = tf_inspect.getcallargs(func, *args, **kwargs)
                for arg_name, spec in iter(deprecated_positions.items()):
                    if (spec.position < len(args)
                            and not (spec.has_ok_value and _same_value(
                                named_args[arg_name], spec.ok_value))):
                        invalid_args.append(arg_name)
                if is_varargs_deprecated and len(args) > len(arg_spec.args):
                    invalid_args.append(arg_spec.varargs)
                if is_kwargs_deprecated and kwargs:
                    invalid_args.append(arg_spec.keywords)
                for arg_name in deprecated_arg_names:
                    if (arg_name in kwargs and not (
                            deprecated_positions[arg_name].has_ok_value
                            and _same_value(
                                named_args[arg_name],
                                deprecated_positions[arg_name].ok_value))):
                        invalid_args.append(arg_name)
                for arg_name in invalid_args:
                    logging.warning(
                        'From %s: calling %s (from %s) with %s is deprecated and will '
                        'be removed %s.\nInstructions for updating:\n%s',
                        _call_location(),
                        decorator_utils.get_qualified_name(func),
                        func.__module__, arg_name,
                        'in a future version' if date is None else
                        ('after %s' % date), instructions)
            return func(*args, **kwargs)

        return tf_decorator.make_decorator(
            func, new_func, 'deprecated',
            _add_deprecated_arg_notice_to_docstring(func.__doc__, date,
                                                    instructions))
def test_injectable_decorator_square(target):
    def wrapper(x):
        return wrapper.__wrapped__(x)**2

    return tf_decorator.make_decorator(target, wrapper)
Example #33
0
def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
    """Deprecate a symbol in favor of a new name with identical semantics.

  This function is meant to be used when defining a backwards-compatibility
  alias for a symbol which has been moved. For example:

  module1.py:
  ```python
  class NewNameForClass: pass
  ```

  module2.py:
  ```python
  import module1

  DeprecatedNameForClass = deprecated_alias(
    deprecated_name='module2.DeprecatedNameForClass',
    name='module1.NewNameForClass',
    module1.NewNameForClass)
  ```

  This function works for classes and functions.

  For classes, it creates a new class which is functionally identical (it
  inherits from the original, and overrides its constructor), but which prints
  a deprecation warning when an instance is created. It also adds a deprecation
  notice to the class' docstring.

  For functions, it returns a function wrapped by `tf_decorator.make_decorator`.
  That function prints a warning when used, and has a deprecation notice in its
  docstring. This is more or less equivalent (the deprecation warning has
  slightly different text) to writing:

  ```python
  @deprecated
  def deprecated_alias(original_args):
    real_function(original_args)
  ```

  Args:
    deprecated_name: The name of the symbol that is being deprecated, to be used
      in the warning message. This should be its fully qualified name to avoid
      confusion.
    name: The name of the symbol that is to be used instead of the deprecated
      name. This should be a fully qualified name to avoid confusion.
    func_or_class: The (non-deprecated) class or function for which a deprecated
      alias should be created.
    warn_once: If True (the default), only print a deprecation warning the first
      time this function is used, or the class is instantiated.

  Returns:
    A wrapped version of `func_or_class` which prints a deprecation warning on
    use and has a modified docstring.
  """
    if tf_inspect.isclass(func_or_class):

        # Make a new class with __init__ wrapped in a warning.
        class _NewClass(func_or_class):  # pylint: disable=missing-docstring
            __doc__ = decorator_utils.add_notice_to_docstring(
                func_or_class.__doc__, 'Please use %s instead.' % name,
                'DEPRECATED CLASS', '(deprecated)', [
                    'THIS CLASS IS DEPRECATED. '
                    'It will be removed in a future version. '
                ])
            __name__ = func_or_class.__name__
            __module__ = _call_location(outer=True)

            @_wrap_decorator(func_or_class.__init__)
            def __init__(self, *args, **kwargs):
                if hasattr(_NewClass.__init__, '__func__'):
                    # Python 2
                    _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
                else:
                    # Python 3
                    _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__

                if _PRINT_DEPRECATION_WARNINGS:
                    # We're making the alias as we speak. The original may have other
                    # aliases, so we cannot use it to check for whether it's already been
                    # warned about.
                    if _NewClass.__init__ not in _PRINTED_WARNING:
                        if warn_once:
                            _PRINTED_WARNING[_NewClass.__init__] = True
                        logging.warning(
                            'From %s: The name %s is deprecated. Please use %s instead.\n',
                            _call_location(), deprecated_name, name)
                super(_NewClass, self).__init__(*args, **kwargs)

        return _NewClass
    else:
        decorator_utils.validate_callable(func_or_class, 'deprecated')

        # Make a wrapper for the original
        @functools.wraps(func_or_class)
        def new_func(*args, **kwargs):  # pylint: disable=missing-docstring
            if _PRINT_DEPRECATION_WARNINGS:
                # We're making the alias as we speak. The original may have other
                # aliases, so we cannot use it to check for whether it's already been
                # warned about.
                if new_func not in _PRINTED_WARNING:
                    if warn_once:
                        _PRINTED_WARNING[new_func] = True
                    logging.warning(
                        'From %s: The name %s is deprecated. Please use %s instead.\n',
                        _call_location(), deprecated_name, name)
            return func_or_class(*args, **kwargs)

        return tf_decorator.make_decorator(
            func_or_class, new_func, 'deprecated',
            _add_deprecated_function_notice_to_docstring(
                func_or_class.__doc__, None, 'Please use %s instead.' % name))
 def testSetsTFDecoratorNameToDecoratorNameArg(self):
     decorated = tf_decorator.make_decorator(test_function, test_wrapper,
                                             'test decorator name')
     decorator = getattr(decorated, '_tf_decorator')
     self.assertEqual('test decorator name', decorator.decorator_name)
Example #35
0
def recreate_function(saved_function, concrete_functions):
  """Creates a `Function` from a `SavedFunction`.

  Args:
    saved_function: `SavedFunction` proto.
    concrete_functions: map from function name to `ConcreteFunction`.

  Returns:
    A `Function`.
  """
  # TODO(andresp): Construct a `Function` with the cache populated
  # instead of creating a new `Function` backed by a Python layer to
  # glue things together. Current approach is nesting functions deeper for each
  # serialization cycle.

  coder = nested_structure_coder.StructureCoder()

  # Note: handling method functions is tricky since make_decorator does not
  # allows control of "ismethod". Additionally since restored functions do
  # not behave as methods i.e. they always use the same captured tensors
  # independent of the object they are bound to, there is little value on
  # propagating that correctly.
  #
  # Ideally this conversion should happen at serialization time. But since
  # there are SavedModels which have "ismethod" populated and have an extra
  # argument that they expect to be ignored, we do it at deserialization.
  function_spec = _deserialize_function_spec_as_nonmethod(
      saved_function.function_spec,
      coder)

  def restored_function_body(*args, **kwargs):
    """Calls a restored function."""
    # This is the format of function.graph.structured_input_signature. At this
    # point, the args and kwargs have already been canonicalized.
    inputs = (args, kwargs)

    # First try to find a concrete function that can be called without input
    # conversions. This allows one to pick a more specific trace in case there
    # was also a more expensive one that supported tensors.
    for allow_conversion in [False, True]:
      for function_name in saved_function.concrete_functions:
        function = concrete_functions[function_name]
        if _concrete_function_callable_with(function, inputs, allow_conversion):
          return _call_concrete_function(function, inputs)

    signature_descriptions = []

    def _pretty_format_positional(positional):
      return "Positional arguments ({} total):\n    * {}".format(
          len(positional),
          "\n    * ".join([str(a) for a in positional]))

    for index, function_name in enumerate(saved_function.concrete_functions):
      concrete_function = concrete_functions[function_name]
      positional, keyword = concrete_function.structured_input_signature
      signature_descriptions.append(
          "Option {}:\n  {}\n  Keyword arguments: {}"
          .format(index + 1, _pretty_format_positional(positional), keyword))
    raise ValueError(
        "Could not find matching function to call loaded from the SavedModel. "
        "Got:\n  {}\n  Keyword arguments: {}\n\nExpected "
        "these arguments to match one of the following {} option(s):\n\n{}"
        .format(_pretty_format_positional(args), kwargs,
                len(saved_function.concrete_functions),
                "\n\n".join(signature_descriptions)))

  concrete_function_objects = []
  for concrete_function_name in saved_function.concrete_functions:
    concrete_function_objects.append(concrete_functions[concrete_function_name])

  restored_function = RestoredFunction(
      restored_function_body,
      restored_function_body.__name__,
      function_spec,
      concrete_function_objects)

  return tf_decorator.make_decorator(
      restored_function_body,
      restored_function,
      decorator_argspec=function_spec.fullargspec)
Example #36
0
 def wrapper(wrapper_func):
     return tf_decorator.make_decorator(wrapped_function, wrapper_func)
Example #37
0
def custom_gradient(f=None):
    """Decorator to define a function with a custom gradient.

  This decorator allows fine grained control over the gradients of a sequence
  for operations.  This may be useful for multiple reasons, including providing
  a more efficient or numerically stable gradient for a sequence of operations.

  For example, consider the following function that commonly occurs in the
  computation of cross entropy and log likelihoods:

  ```python
  def log1pexp(x):
    return tf.math.log(1 + tf.exp(x))
  ```

  Due to numerical instability, the gradient of this function evaluated at x=100
  is NaN.  For example:

  ```python
  x = tf.constant(100.)
  y = log1pexp(x)
  dy_dx = tf.gradients(y, x) # Will be NaN when evaluated.
  ```

  The gradient expression can be analytically simplified to provide numerical
  stability:

  ```python
  @tf.custom_gradient
  def log1pexp(x):
    e = tf.exp(x)
    def grad(upstream):
      return upstream * (1 - 1 / (1 + e))
    return tf.math.log(1 + e), grad
  ```

  With this definition, the gradient `dy_dx` at `x = 100` will be correctly
  evaluated as 1.0.

  The variable `upstream` is defined as the upstream gradient. i.e. the gradient
  from all the layers or functions originating from this layer. The above
  example has no upstream functions, therefore `upstream = dy/dy = 1.0`.

  Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`,
  `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By
  chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... *
  dx_i/dx_i-1 * ... * dx_1/dx_0`.

  In this case the gradient of our current function defined as
  `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be
  `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient
  multiplied by the current gradient is then passed downstream.

  In case the function takes multiple variables as input, the `grad`
  function must also return  the same number of variables.
  We take the function `z = x * y` as an example.

  >>> @tf.custom_gradient
  ... def bar(x, y):
  ...   def grad(upstream):
  ...     dz_dx = y
  ...     dz_dy = x
  ...     return upstream * dz_dx, upstream * dz_dy
  ...   z = x * y
  ...   return z, grad
  >>> x = tf.constant(2.0, dtype=tf.float32)
  >>> y = tf.constant(3.0, dtype=tf.float32)
  >>> with tf.GradientTape(persistent=True) as tape:
  ...   tape.watch(x)
  ...   tape.watch(y)
  ...   z = bar(x, y)
  >>> z
  <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
  >>> tape.gradient(z, x)
  <tf.Tensor: shape=(), dtype=float32, numpy=3.0>
  >>> tape.gradient(z, y)
  <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

  Nesting custom gradients can lead to unintuitive results. The default
  behavior does not correspond to n-th order derivatives. For example

  ```python
  @tf.custom_gradient
  def op(x):
    y = op1(x)
    @tf.custom_gradient
    def grad_fn(dy):
      gdy = op2(x, y, dy)
      def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
        return op3(x, y, dy, ddy)
      return gdy, grad_grad_fn
    return y, grad_fn
  ```

  The function `grad_grad_fn` will be calculating the first order gradient
  of `grad_fn` with respect to `dy`, which is used to generate forward-mode
  gradient graphs from backward-mode gradient graphs, but is not the same as
  the second order gradient of `op` with respect to `x`.

  Instead, wrap nested `@tf.custom_gradients` in another function:

  ```python
  @tf.custom_gradient
  def op_with_fused_backprop(x):
    y, x_grad = fused_op(x)
    def first_order_gradient(dy):
      @tf.custom_gradient
      def first_order_custom(unused_x):
        def second_order_and_transpose(ddy):
          return second_order_for_x(...), gradient_wrt_dy(...)
        return x_grad, second_order_and_transpose
      return dy * first_order_custom(x)
    return y, first_order_gradient
  ```

  Additional arguments to the inner `@tf.custom_gradient`-decorated function
  control the expected return values of the innermost function.

  The examples above illustrate how to specify custom gradients for functions
  which do not read from variables. The following example uses variables, which
  require special handling because they are effectively inputs of the forward
  function.

  >>> weights = tf.Variable(tf.ones([2]))  # Trainable variable weights
  >>> @tf.custom_gradient
  ... def linear_poly(x):
  ...   # Creating polynomial
  ...   poly = weights[1] * x + weights[0]
  ...
  ...   def grad_fn(dpoly, variables):
  ...     # dy/dx = weights[1] and we need to left multiply dpoly
  ...     grad_xs = dpoly * weights[1]  # Scalar gradient
  ...
  ...     grad_vars = []  # To store gradients of passed variables
  ...     assert variables is not None
  ...     assert len(variables) == 1
  ...     assert variables[0] is weights
  ...     # Manually computing dy/dweights
  ...     dy_dw = dpoly * tf.stack([x ** 1, x ** 0])
  ...     grad_vars.append(
  ...         tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1)
  ...     )
  ...     return grad_xs, grad_vars
  ...   return poly, grad_fn
  >>> x = tf.constant([1., 2., 3.])
  >>> with tf.GradientTape(persistent=True) as tape:
  ...   tape.watch(x)
  ...   poly = linear_poly(x)
  >>> poly # poly = x + 1
  <tf.Tensor: shape=(3,),
    dtype=float32,
    numpy=array([2., 3., 4.], dtype=float32)>
  >>> tape.gradient(poly, x)  # conventional scalar gradient dy/dx
  <tf.Tensor: shape=(3,),
    dtype=float32,
    numpy=array([1., 1., 1.], dtype=float32)>
  >>> tape.gradient(poly, weights)
  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)>

  Above example illustrates usage of trainable variable `weights`.
  In the example, the inner `grad_fn` accepts an extra `variables` input
  parameter and also returns an extra `grad_vars` output. That extra argument
  is passed if the forward function reads any variables. You need to
  compute the gradient w.r.t. each of those `variables` and output it as a list
  of `grad_vars`. Note here that default value of `variables` is set to `None`
  when no variables are used in the forward function.

  It should be noted `tf.GradientTape` is still watching the forward pass of a
  `tf.custom_gradient`, and will use the ops it watches. As a consequence,
  calling `tf.function` while the tape is still watching leads
  to a gradient graph being built. If an op is used in `tf.function` without
  registered gradient, a `LookupError` will be raised.

  Users can insert `tf.stop_gradient` to customize this behavior. This
  is demonstrated in the example below. `tf.random.shuffle` does not have a
  registered gradient. As a result `tf.stop_gradient` is used to avoid the
  `LookupError`.

  ```python
  x = tf.constant([0.3, 0.5], dtype=tf.float32)

  @tf.custom_gradient
  def test_func_with_stop_grad(x):
    @tf.function
    def _inner_func():
      # Avoid exception during the forward pass
      return tf.stop_gradient(tf.random.shuffle(x))
      # return tf.random.shuffle(x)  # This will raise

    res = _inner_func()
    def grad(upstream):
      return upstream  # Arbitrarily defined custom gradient
    return res, grad

  with tf.GradientTape() as g:
    g.watch(x)
    res = test_func_with_stop_grad(x)

  g.gradient(res, x)
  ```

  See also `tf.RegisterGradient` which registers a gradient function for a
  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
  for fine grained control over the gradient computation of a sequence of
  operations.

  Note that if the decorated function uses `Variable`s, the enclosing variable
  scope must be using `ResourceVariable`s.

  Args:
    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
       - `x` is a sequence of (nested structures of) `Tensor` inputs to the
         function.
       - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
         operations in `f` to `x`.
       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
         a list of `Tensor`s the same size as (flattened) `x` - the derivatives
         of `Tensor`s in `y` with respect to the `Tensor`s in `x`.  `grad_ys` is
         a sequence of `Tensor`s the same size as (flattened) `y` holding the
         initial value gradients for each `Tensor` in `y`.

         In a pure mathematical sense, a vector-argument vector-valued function
         `f`'s derivatives should be its Jacobian matrix `J`. Here we are
         expressing the Jacobian `J` as a function `grad_fn` which defines how
         `J` will transform a vector `grad_ys` when left-multiplied with it
         (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
         representation of a matrix is convenient to use for chain-rule
         calculation (in e.g. the back-propagation algorithm).

         If `f` uses `Variable`s (that are not part of the
         inputs), i.e. through `get_variable`, then `grad_fn` should have
         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
         with the derivatives of `Tensor`s in `y` with respect to the variables
         (that is, grad_vars has one Tensor per variable in variables).

  Returns:
    A function `h(x)` which returns the same value as `f(x)[0]` and whose
    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
  """

    if f is None:
        return lambda f: custom_gradient(f=f)

    @Bind.decorator
    def decorated(wrapped, args, kwargs):
        """Decorated function with custom gradient."""
        if context.executing_eagerly():
            return _eager_mode_decorator(wrapped, args, kwargs)
        else:
            return _graph_mode_decorator(wrapped, args, kwargs)

    return tf_decorator.make_decorator(f, decorated(f))  # pylint: disable=no-value-for-parameter
Example #38
0
def use_wrapped_call(layer,
                     call_fn,
                     default_training_value=None,
                     return_method=False):
    """Creates fn that adds the losses returned by call_fn & returns the outputs.

  Args:
    layer: A Keras layer object
    call_fn: tf.function that takes layer inputs (and possibly a training arg),
      and returns a tuple of (outputs, list of losses).
    default_training_value: Default value of the training kwarg. If `None`, the
      default is `K.learning_phase()`.
    return_method: Whether to return a method bound to the layer.

  Returns:
    function that calls call_fn and returns the outputs. Losses returned by
    call_fn are added to the layer losses.
  """
    expects_training_arg = layer_uses_training_bool(layer)
    if hasattr(call_fn, 'original_call'):  # call_fn is a LayerCall object
        original_call = call_fn.original_call
        # In Python 3, callable objects are not compatible with inspect.getargspec
        call_fn = call_fn.__call__
    else:
        original_call = call_fn
    fn, arg_spec = maybe_add_training_arg(original_call, call_fn,
                                          expects_training_arg,
                                          default_training_value)

    def return_outputs_and_add_losses(*args, **kwargs):
        """Returns the outputs from the call_fn, and adds the losses."""
        inputs_arg_index = 1 if return_method else 0
        inputs = args[inputs_arg_index]
        args = args[inputs_arg_index + 1:]
        outputs, losses = fn(inputs, *args, **kwargs)
        layer.add_loss(losses, inputs=inputs)

        # TODO(kathywu): This is a temporary hack. When a network of layers is
        # revived from SavedModel, only the top-level layer will have losses. This
        # causes issues in eager mode because the child layers may have graph losses
        # (thus model.losses returns a mix of Eager and graph tensors). To fix this,
        # whenever eager losses are added to one layer, add eager losses to all
        # child layers. This causes `.losses` to only return eager losses.
        # pylint: disable=protected-access
        if context.executing_eagerly():
            for i in layer._flatten_layers():
                if i is not layer:
                    i._eager_losses = [
                        base_layer_utils.REVIVED_LOSS_PLACEHOLDER
                    ]
        # pylint: enable=protected-access
        return outputs

    decorated = tf_decorator.make_decorator(
        target=call_fn,
        decorator_func=return_outputs_and_add_losses,
        decorator_argspec=arg_spec)

    if return_method:
        return types.MethodType(decorated, layer)
    else:
        return decorated
def make_elementwise_op(op, *elementwise_args):
    """Returns a ragged-tensor version of the elementwise operation `op`.

  The returned operation will:

  1. Broadcast the elementwise arguments to have a compatible shape.
     An exception is raised if the tensors not broadcast-compatible.
  2. Call `op`, substituting the dense values of the broadcasted tensor for
     each elementwise argument.
  3. Return a potentially ragged tensor constructed from the output of `op`
     and the broadcasted tensors' nested row splits.

  For example, you can construct a ragged-tensor version of the standard
  operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.

  Args:
    op: The operation to wrap.
    *elementwise_args: The names of arguments to `op` that are treated as
      elementwise.  Arguments that take a list of tensors should have their
      names wrapped in square brackets (e.g. "[inputs]").

  Raises:
    ValueError: If any name specified in `elementwise_args` is not the name
      of an argument to `op`.
  """
    elementwise_arg_infos = _get_arg_infos(op, elementwise_args)

    def ragged_op(*args, **kwargs):
        """Ragged version of `op`."""
        args = list(args)

        # Collect all of the elementwise arguments, and put them in a single
        # dict whose values are the (potentially ragged) tensors that need to
        # be broadcast to a common shape.  The keys of this dict are tuples
        # (argkey, index), where argkey is an int for poitional args or a string
        # for keyword args; and index is None for non-list args and the index of the
        # tensor for list args.
        elementwise_args = {}
        for (name, position, is_list) in elementwise_arg_infos.values():
            if position < len(args):
                if is_list:
                    args[position] = list(args[position])
                    for (index, arg) in enumerate(args[position]):
                        elementwise_args[position, index] = arg
                else:
                    elementwise_args[position, None] = args[position]
            elif name in kwargs:
                if is_list:
                    kwargs[name] = list(kwargs[name])
                    for (i, arg) in enumerate(kwargs[name]):
                        elementwise_args[name, i] = arg
                else:
                    elementwise_args[name, None] = kwargs[name]

        with ops.name_scope(None, op.__name__, elementwise_args.values()):
            # Convert all inputs to tensors or ragged tensors.
            for ((key, index), tensor) in elementwise_args.items():
                argname = elementwise_arg_infos[key].name
                converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
                    tensor, name=argname)
                elementwise_args[key, index] = converted

            # Broadcast tensors to have compatible shapes.
            broadcast_args, result_splits, broadcast_check_ops = \
                _broadcast_elementwise_args(elementwise_args)

            # Replace tensor arguments with their dense values.
            for ((key, index), tensor) in broadcast_args.items():
                if ragged_tensor.is_ragged(tensor):
                    if isinstance(key, int) and index is None:
                        args[key] = tensor.inner_values
                    elif isinstance(key, int) and index is not None:
                        args[key][index] = tensor.inner_values
                    elif isinstance(key, str) and index is None:
                        kwargs[key] = tensor.inner_values
                    else:
                        assert isinstance(key, str) and index is not None
                        kwargs[key][index] = tensor.inner_values

            # Call the elementwise op on the broadcasted dense values.
            with ops.control_dependencies(broadcast_check_ops):
                result_values = op(*args, **kwargs)

            # Restore any ragged dimensions that we stripped off, and return the
            # result.
            return ragged_factory_ops.from_nested_row_splits(
                result_values, result_splits)

    # Construct the docstring.
    op_name = tf_export.get_canonical_name_for_symbol(op)
    assert op_name is not None, op
    argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
    docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name,
                                              argnames=argnames)

    # Update name, docstring, signature, etc., for the wrapper, and return it.
    return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
def custom_gradient(f):
    """Decorator to define a function with a custom gradient.

  The input function is expected to return the tuple
    (results, gradient_function).

  The output function will return results while possibly recording the
  gradient_function and inputs in the tape.

  Args:
    f: function to be decorated.

  Returns:
    decorated function.
  """
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        if context.in_graph_mode():
            if kwargs:
                raise ValueError(
                    "custom_gradient in graph mode doesn't support keyword arguments."
                )
            name = "CustomGradient-%s" % tf_ops.uid()
            args = [tf_ops.convert_to_tensor(x) for x in args]
            result, grad_fn = f(*args)
            flat_result = nest.flatten(result)
            all_tensors = flat_result + args

            @tf_ops.RegisterGradient(name)
            def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
                gradients = nest.flatten(
                    grad_fn(*result_grads[:len(flat_result)]))
                # Need to return one value per input to the IdentityN, so pad the
                # gradients of the inputs of the custom_gradient function with the
                # gradients of the outputs as well.
                return ([None] * len(flat_result)) + gradients

            with tf_ops.get_default_graph().gradient_override_map(
                {"IdentityN": name}):
                all_tensors = array_ops.identity_n(all_tensors)
            return nest.pack_sequence_as(
                structure=result, flat_sequence=all_tensors[:len(flat_result)])

        input_tensors = [x for x in args if isinstance(x, tf_ops.Tensor)]

        with tape.stop_recording():
            result, grad_fn = f(*args, **kwargs)

        # TODO(apassos): naive uses of custom_gradient will not get the correct
        # second derivative this way if they capture any output tensors. Change the
        # signature of custom_gradient.
        def actual_grad_fn(*outputs):
            return nest.flatten(grad_fn(*outputs))

        flat_result = nest.flatten(result)
        tape.record_operation(f.__name__, flat_result, input_tensors, [],
                              actual_grad_fn)
        flat_result = list(flat_result)
        return result

    return tf_decorator.make_decorator(f, decorated)
Example #41
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            autograph=False,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None,
                            collections=None):
    """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.
    collections: a dictionary of collections this FuncGraph should start
      with. If not specified (None), the FuncGraph will read (but not write to)
      the outer graph's collections that are not whitelisted, and both
      read and write to the outer graph's collections that are whitelisted.
      The current whitelisted collections are the global variables, the
      local variables, and the trainable variables.
      Defaults to None.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
    if op_return_value is not None:
        assert isinstance(op_return_value, ops.Tensor), op_return_value
    if func_graph is None:
        func_graph = FuncGraph(name, collections=collections)
    assert isinstance(func_graph, FuncGraph)
    if add_control_dependencies:
        control_manager = AutomaticControlDependencies
    else:
        control_manager = ops.NullContextmanager
    with func_graph.as_default(), control_manager() as a:
        current_scope = variable_scope.get_variable_scope()
        default_use_recource = current_scope.use_resource
        current_scope.set_use_resource(True)

        if signature is not None:
            args = signature
            kwargs = {}

        # Creates and names placeholders for all arguments.
        func_args = _get_defun_inputs_from_args(args, arg_names)
        func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

        # Convert all Tensors into TensorSpecs before saving the structured inputs.
        # If storing pure concrete functions that are not called through polymorphic
        # functions, we don't have access to FunctionSpec, so we need to call the
        # TensorSpecs by their `arg_names` for later binding.
        func_graph.structured_input_signature = (
            convert_structure_to_signature(func_args, arg_names),
            convert_structure_to_signature(func_kwargs))

        # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
        # Variables to help check whether mutation happens in calling the function
        # Copy the recursive list, tuple and map structure, but not base objects
        func_args_before = nest.pack_sequence_as(func_args,
                                                 nest.flatten(func_args))
        func_kwargs_before = nest.pack_sequence_as(func_kwargs,
                                                   nest.flatten(func_kwargs))

        def convert(x):
            """Converts a function output to a Tensor."""
            if x is None:
                return None
            if op_return_value is not None and isinstance(x, ops.Operation):
                # TODO(b/79881896): we currently can't capture external control deps, so
                # this won't work if x needs to be captured (i.e. if python_func returns
                # captured Operations).
                with ops.control_dependencies([x]):
                    x = array_ops.identity(op_return_value)
            elif not isinstance(x, tensor_array_ops.TensorArray):
                try:
                    x = ops.convert_to_tensor_or_composite(x)
                except (ValueError, TypeError):
                    raise TypeError(
                        "To be compatible with tf.contrib.eager.defun, Python functions "
                        "must return zero or more Tensors; in compilation of %s, found "
                        "return value of type %s, which is not a Tensor." %
                        (str(python_func), type(x)))
            if add_control_dependencies:
                x = a.mark_as_return(x)
            return x

        this_tape = tape.push_new_tape()
        try:
            if autograph:
                from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
                _, original_func = tf_decorator.unwrap(python_func)

                def wrapper(*args, **kwargs):
                    # Note: functions annotated with @tf.function should always be
                    # converted even though they would meet autograph's whitelisting
                    # criteria.
                    # If this assumption is ever broken, converted_call will need to
                    # handle the possibility of original_func still being a shim, e.g.
                    # bound to WeakrefSelf.
                    return autograph.converted_call(
                        original_func, None,
                        autograph.ConversionOptions(
                            verbose=autograph.Verbosity.BRIEF,
                            recursive=True,
                            strip_decorators=(def_function.function, ),
                            optional_features=(),
                            force_conversion=True,
                        ), *args, **kwargs)

                # Wrapping around a decorator allows checks like tf_inspect.getargspec
                # to be accurate.
                converted_func = tf_decorator.make_decorator(
                    original_func, wrapper)
                tf_decorator.rewrap(python_func, original_func, converted_func)

            func_outputs = python_func(*func_args, **func_kwargs)

            # invariant: `func_outputs` contains only Tensors, IndexedSlices,
            # SparseTensors, TensorArrays and `None`s.
            func_outputs = nest.map_structure(convert, func_outputs)

            check_mutation(func_args_before, func_args)
            check_mutation(func_kwargs_before, func_kwargs)
        finally:
            tape.pop_tape(this_tape)
            current_scope.set_use_resource(default_use_recource)

        # Variables in `func_args`, `func_kwargs` should be explicit inputs
        # to the function, not captured inputs.
        tape_variables = this_tape.watched_variables()
        arg_variables = set()
        inputs = []
        for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
            if isinstance(arg, resource_variable_ops.ResourceVariable):
                # Even if an argument variable was not used in the function, we've
                # already manually captured the resource Tensor when creating argument
                # placeholders.
                resource_placeholder = func_graph.captures.pop(arg.handle)
                arg_variables.add(arg)
                inputs.append(resource_placeholder)
            elif isinstance(arg, ops.Tensor):
                inputs.append(arg)
        variables = [v for v in tape_variables if v not in arg_variables]
        func_graph.inputs = inputs + list(func_graph.captures.values())

        func_graph.structured_outputs = func_outputs
        # Returning a closed-over tensor does not trigger convert_to_tensor.
        func_graph.outputs.extend(
            func_graph.capture(x)
            for x in flatten(func_graph.structured_outputs) if x is not None)

        func_graph.variables = variables

    # Register any other functions defined in the graph.
    with ops.init_scope():
        if context.executing_eagerly():
            for f in func_graph._functions.values():  # pylint: disable=protected-access
                # TODO(ashankar): What about the gradient registry?
                context.add_function(f._c_func.func)  # pylint: disable=protected-access

    return func_graph
Example #42
0
def result_wrapper(result_fn):
    """Decorator to wrap metric `result()` function in `merge_call()`.

  Result computation is an idempotent operation that simply calculates the
  metric value using the state variables.

  If metric state variables are distributed across replicas/devices and
  `result()` is requested from the context of one device - This function wraps
  `result()` in a distribution strategy `merge_call()`. With this,
  the metric state variables will be aggregated across devices.

  Args:
    result_fn: function that computes the metric result.

  Returns:
    Decorated function that wraps `result_fn()` in distribution strategy
    `merge_call()`.
  """
    def decorated(metric_obj, *args):
        """Decorated function with merge_call."""
        has_strategy = distribution_strategy_context.has_strategy()
        replica_context = distribution_strategy_context.get_replica_context()
        if not has_strategy or replica_context is None:
            raw_result = result_fn(*args)
            # Results need to be wrapped in a `tf.identity` op to ensure
            # correct execution order.
            if isinstance(raw_result,
                          (ops.Tensor, variables_module.Variable, float, int)):
                result_t = array_ops.identity(raw_result)
            elif isinstance(raw_result, dict):
                result_t = {
                    key: array_ops.identity(value)
                    for key, value in raw_result.items()
                }
            else:
                try:
                    result_t = array_ops.identity(raw_result)
                except (ValueError, TypeError):
                    raise RuntimeError(
                        'The output of `metric.result()` can only be a single '
                        'Tensor/Variable, or a dict of Tensors/Variables. '
                        'For metric %s, got result %s.' %
                        (metric_obj.name, raw_result))
        else:
            # TODO(psv): Test distribution of metrics using different distribution
            # strategies.

            # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
            # with distribution object as the first parameter. We create a wrapper
            # here so that the result function need not have that parameter.
            def merge_fn_wrapper(distribution, merge_fn, *args):
                # We will get `PerReplica` merge function. Taking the first one as all
                # are identical copies of the function that we had passed below.
                result = distribution.experimental_local_results(merge_fn)[0](
                    *args)

                # Wrapping result in identity so that control dependency between
                # update_op from `update_state` and result works in case result returns
                # a tensor.
                return array_ops.identity(result)

            # Wrapping result in merge_call. merge_call is used when we want to leave
            # replica mode and compute a value in cross replica mode.
            result_t = replica_context.merge_call(merge_fn_wrapper,
                                                  args=(result_fn, ) + args)

        # We are saving the result op here to be used in train/test execution
        # functions. This basically gives the result op that was generated with a
        # control dep to the updates for these workflows.
        metric_obj._call_result = result_t
        return result_t

    return tf_decorator.make_decorator(result_fn, decorated)
Example #43
0
def _multi_worker_test(test_method):
    """Decorate test_method so that it runs in each worker.

  We use `multi_process_runner` to simulate multiple workers. Since we run the
  this function in the main process and all worker processes, this decoration
  behaves differently in the main process and worker procssses. In the main
  process, it spawns subprocesses and runs the test on each of them; in a worker
  process, it executes test in the same way as a normal test, e.g.
  setUp()/tearDown() are called before/after the test.

  Args:
    test_method: a function which must be a test method.

  Returns:
    Decorated `test_method`. Note that the decorated function has additional
    arguments.
  """
    def decorator(self, has_chief, num_workers, runner, **kwargs):
        if _num_total_workers(has_chief,
                              num_workers) == 1 or _running_in_worker:
            # We're in worker process or the test is for single worker. Either case we
            # execute the test method directly instead of spawning subprocesses.

            # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a
            # session that connects to the local server. This is necessary for multi
            # worker graph mode tests to work. Those tests cannot use their graphs or
            # sessions, including the one returned by self.cached_session(). Since
            # existing tests may already be doing so, we only install the session for
            # multi worker tests.
            with _multi_worker_session(kwargs):
                test_method(self, **kwargs)
            return

        # We're in the main process. We spawn subprocesses and run the *test* on
        # each of them. Note that we're not directly executing test_method passed to
        # _multi_worker_test, because we need setUp()/tearDown() to be called and
        # all the decorations on the test method. The conceptual call stack is:
        #   [main process]test.main()
        #     [main process]test_runner.run(test)
        #       [main process]wrapper by combinations.generate()
        #         [main process]_multi_worker_test.decorator()
        #           # A sub process goes through the same code path as the main
        #           # process.
        #           [sub process]_test_runner()
        #             [sub process]test_runner.run(test)
        #               [sub process]wrapper by combinations.generate()
        #                 [sub process]_multi_worker_test.decorator()
        #                   # _running_in_worker is True
        #                   [sub process]test_method()
        test_id = self.id()
        if runner:
            results = runner.run(_test_runner, args=(test_id, _env))
        else:
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                has_chief=has_chief,
                num_workers=num_workers,
                num_ps=0,
                has_eval=False)
            results = multi_process_runner.run(_test_runner,
                                               cluster_spec,
                                               args=(test_id,
                                                     _env)).return_value

        skip_reason = None
        for result in results:
            if result.status == "failure":
                # We can't tell which worker the return value come from, so we fail on
                # the  first error.
                self.fail(result.message)
                break
            elif result.status == "skipped":
                # Record the skip reason, but do not actually skip the test in case some
                # processes fail instead.
                skip_reason = result.message
        if skip_reason is not None:
            self.skipTest(skip_reason)

    argspec = tf_inspect.getfullargspec(test_method)
    decorator_args = (argspec.args
                      or []) + ["has_chief", "num_workers", "runner"]
    decorator_argspec = argspec._replace(args=decorator_args)
    return tf_decorator.make_decorator(test_method,
                                       decorator,
                                       decorator_argspec=decorator_argspec)
Example #44
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            autograph=False,
                            autograph_options=None,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None,
                            collections=None,
                            capture_by_value=None,
                            override_flat_arg_shapes=None):
    """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    autograph_options: additional knobs to control when `autograph=True`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.
    collections: a dictionary of collections this FuncGraph should start
      with. If not specified (None), the FuncGraph will read (but not write to)
      the outer graph's collections that are not whitelisted, and both
      read and write to the outer graph's collections that are whitelisted.
      The current whitelisted collections are the global variables, the
      local variables, and the trainable variables.
      Defaults to None.
    capture_by_value: An optional boolean. If True, the func graph will capture
      Variables by value instead of reference. By default inherit from outer
      graphs, and failing that will default to False.
    override_flat_arg_shapes: An optional list of instances that are either
      `None` or `TensorShape`.  The length must match that of
      `nest.flatten((args, kwargs), expand_composites=True)`.  The entries
      containing value `None` must match entries in flattened arguments
      containing non-tensors, while entries containing a `TensorShape` must
      match entries in the flattened arguments containing tensors.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
    ValueError: If both `signature` and `override_flat_arg_shapes` are
      passed in.
  """
    if op_return_value is not None:
        assert isinstance(op_return_value, ops.Tensor), op_return_value
    if func_graph is None:
        func_graph = FuncGraph(name,
                               collections=collections,
                               capture_by_value=capture_by_value)
    assert isinstance(func_graph, FuncGraph)
    if add_control_dependencies:
        control_manager = AutomaticControlDependencies()
    else:
        control_manager = ops.NullContextmanager()
    with func_graph.as_default(), control_manager as a:
        current_scope = variable_scope.get_variable_scope()
        default_use_recource = current_scope.use_resource
        current_scope.set_use_resource(True)

        if signature is not None and override_flat_arg_shapes is not None:
            raise ValueError(
                "Passed both signature and override_flat_arg_shapes: %s and %s."
                % (signature, override_flat_arg_shapes))

        if signature is not None:
            args = signature
            kwargs = {}

        # Creates and names placeholders for all arguments.
        if override_flat_arg_shapes is not None:
            flat_args = nest.flatten(args, expand_composites=True)
            arg_shapes = override_flat_arg_shapes[:len(flat_args)]
            kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
        else:
            arg_shapes = None
            kwarg_shapes = None
        func_args = _get_defun_inputs_from_args(args,
                                                arg_names,
                                                flat_shapes=arg_shapes)
        func_kwargs = _get_defun_inputs_from_kwargs(kwargs,
                                                    flat_shapes=kwarg_shapes)

        # Convert all Tensors into TensorSpecs before saving the structured inputs.
        # If storing pure concrete functions that are not called through polymorphic
        # functions, we don't have access to FunctionSpec, so we need to call the
        # TensorSpecs by their `arg_names` for later binding.
        func_graph.structured_input_signature = (
            convert_structure_to_signature(func_args, arg_names),
            convert_structure_to_signature(func_kwargs))

        flat_func_args = nest.flatten(func_args, expand_composites=True)
        flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
        # Temporarily set inputs to allow graph building code to inspect
        # them. Reassigned below.
        func_graph.inputs = [
            arg for arg in flat_func_args + flat_func_kwargs
            if isinstance(arg, ops.Tensor)
        ]

        # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
        # Variables to help check whether mutation happens in calling the function
        # Copy the recursive list, tuple and map structure, but not base objects
        func_args_before = nest.pack_sequence_as(func_args,
                                                 flat_func_args,
                                                 expand_composites=True)
        func_kwargs_before = nest.pack_sequence_as(func_kwargs,
                                                   flat_func_kwargs,
                                                   expand_composites=True)

        def convert(x):
            """Converts a function output to a Tensor."""
            if x is None:
                return None
            if op_return_value is not None and isinstance(x, ops.Operation):
                # TODO(b/79881896): we currently can't capture external control deps, so
                # this won't work if x needs to be captured (i.e. if python_func returns
                # captured Operations).
                with ops.control_dependencies([x]):
                    x = array_ops.identity(op_return_value)
            elif not isinstance(x, tensor_array_ops.TensorArray):
                try:
                    x = ops.convert_to_tensor_or_composite(x)
                except (ValueError, TypeError):
                    raise TypeError(
                        "To be compatible with tf.contrib.eager.defun, Python functions "
                        "must return zero or more Tensors; in compilation of %s, found "
                        "return value of type %s, which is not a Tensor." %
                        (str(python_func), type(x)))
            if add_control_dependencies:
                x = a.mark_as_return(x)
            return x

        try:
            if autograph:
                from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
                _, original_func = tf_decorator.unwrap(python_func)

                def wrapper(*args, **kwargs):
                    """Calls a converted version of original_func."""
                    # TODO(mdan): Push this block higher in tf.function's call stack.
                    try:
                        return autograph.converted_call(
                            original_func, None,
                            autograph.ConversionOptions(
                                recursive=True,
                                optional_features=autograph_options,
                                force_conversion=True,
                            ), args, kwargs)
                    except Exception as e:  # pylint:disable=broad-except
                        if hasattr(e, "ag_error_metadata"):
                            raise e.ag_error_metadata.to_exception(type(e))
                        else:
                            raise

                # Wrapping around a decorator allows checks like tf_inspect.getargspec
                # to be accurate.
                converted_func = tf_decorator.make_decorator(
                    original_func, wrapper)
                python_func = tf_decorator.rewrap(python_func, original_func,
                                                  converted_func)

            func_outputs = python_func(*func_args, **func_kwargs)

            # invariant: `func_outputs` contains only Tensors, CompositeTensors,
            # TensorArrays and `None`s.
            func_outputs = nest.map_structure(convert,
                                              func_outputs,
                                              expand_composites=True)

            check_mutation(func_args_before, func_args)
            check_mutation(func_kwargs_before, func_kwargs)
        finally:
            current_scope.set_use_resource(default_use_recource)

        # Variables in `func_args`, `func_kwargs` should be explicit inputs
        # to the function, not captured inputs.
        graph_variables = list(func_graph._watched_variables)  # pylint: disable=protected-access
        arg_variables = set()
        inputs = []
        for arg in (nest.flatten(func_args, expand_composites=True) +
                    nest.flatten(func_kwargs, expand_composites=True)):
            if isinstance(arg, resource_variable_ops.BaseResourceVariable):
                # Even if an argument variable was not used in the function, we've
                # already manually captured the resource Tensor when creating argument
                # placeholders.
                resource_placeholder = func_graph.captures.pop(
                    arg.handle, None)
                if resource_placeholder is None:
                    continue
                arg_variables.add(arg)
                inputs.append(resource_placeholder)
            elif isinstance(arg, ops.Tensor):
                inputs.append(arg)
        variables = [v for v in graph_variables if v not in arg_variables]
        func_graph.inputs = inputs + list(func_graph.captures.values()) + [
            x[1] for x in func_graph.deferred_captures.values()
        ]

        func_graph.structured_outputs = func_outputs
        # Returning a closed-over tensor does not trigger convert_to_tensor.
        func_graph.outputs.extend(
            func_graph.capture(x)
            for x in flatten(func_graph.structured_outputs) if x is not None)

        func_graph.variables = variables

    if add_control_dependencies:
        func_graph.control_outputs.extend(control_manager.ops_which_must_run)

    return func_graph
 def test_decorator_name(wrapper):
     return tf_decorator.make_decorator(test_function, wrapper)
def custom_gradient(f=None):
  """Decorator to define a function with a custom gradient.

  This decorator allows fine grained control over the gradients of a sequence
  for operations.  This may be useful for multiple reasons, including providing
  a more efficient or numerically stable gradient for a sequence of operations.

  For example, consider the following function that commonly occurs in the
  computation of cross entropy and log likelihoods:

  ```python
  def log1pexp(x):
    return tf.math.log(1 + tf.exp(x))
  ```

  Due to numerical instability, the gradient of this function evaluated at x=100
  is NaN.  For example:

  ```python
  x = tf.constant(100.)
  y = log1pexp(x)
  dy = tf.gradients(y, x) # Will be NaN when evaluated.
  ```

  The gradient expression can be analytically simplified to provide numerical
  stability:

  ```python
  @tf.custom_gradient
  def log1pexp(x):
    e = tf.exp(x)
    def grad(dy):
      return dy * (1 - 1 / (1 + e))
    return tf.math.log(1 + e), grad
  ```

  With this definition, the gradient at x=100 will be correctly evaluated as
  1.0.

  Nesting custom gradients can lead to unintuitive results. The default
  behavior does not correspond to n-th order derivatives. For example

  ```python
  @tf.custom_gradient
  def op(x):
    y = op1(x)
    @tf.custom_gradient
    def grad_fn(dy):
      gdy = op2(x, y, dy)
      def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
        return op3(x, y, dy, ddy)
      return gdy, grad_grad_fn
    return y, grad_fn
  ```

  The function `grad_grad_fn` will be calculating the first order gradient
  of `grad_fn` with respect to `dy`, which is used to generate forward-mode
  gradient graphs from backward-mode gradient graphs, but is not the same as
  the second order gradient of `op` with respect to `x`.

  Instead, wrap nested `@tf.custom_gradients` in another function:

  ```python
  @tf.custom_gradient
  def op_with_fused_backprop(x):
    y, x_grad = fused_op(x)
    def first_order_gradient(dy):
      @tf.custom_gradient
      def first_order_custom(unused_x):
        def second_order_and_transpose(ddy):
          return second_order_for_x(...), gradient_wrt_dy(...)
        return x_grad, second_order_and_transpose
      return dy * first_order_custom(x)
    return y, first_order_gradient
  ```

  Additional arguments to the inner `@tf.custom_gradient`-decorated function
  control the expected return values of the innermost function.

  See also `tf.RegisterGradient` which registers a gradient function for a
  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
  for fine grained control over the gradient computation of a sequence of
  operations.

  Note that if the decorated function uses `Variable`s, the enclosing variable
  scope must be using `ResourceVariable`s.

  Args:
    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
       - `x` is a sequence of (nested structures of) `Tensor` inputs to the
         function.
       - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
         operations in `f` to `x`.
       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
         a list of `Tensor`s the same size as (flattened) `x` - the derivatives
         of `Tensor`s in `y` with respect to the `Tensor`s in `x`.  `grad_ys` is
         a sequence of `Tensor`s the same size as (flattened) `y` holding the
         initial value gradients for each `Tensor` in `y`.

         In a pure mathematical sense, a vector-argument vector-valued function
         `f`'s derivatives should be its Jacobian matrix `J`. Here we are
         expressing the Jacobian `J` as a function `grad_fn` which defines how
         `J` will transform a vector `grad_ys` when left-multiplied with it
         (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
         representation of a matrix is convenient to use for chain-rule
         calculation (in e.g. the back-propagation algorithm).

         If `f` uses `Variable`s (that are not part of the
         inputs), i.e. through `get_variable`, then `grad_fn` should have
         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
         with the derivatives of `Tensor`s in `y` with respect to the variables
         (that is, grad_vars has one Tensor per variable in variables).

  Returns:
    A function `h(x)` which returns the same value as `f(x)[0]` and whose
    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
  """

  if f is None:
    return lambda f: custom_gradient(f=f)

  @Bind.decorator
  def decorated(wrapped, args, kwargs):
    """Decorated function with custom gradient."""
    # raise ValueError("PW: trap")

    if context.executing_eagerly():
      return _eager_mode_decorator(wrapped, args, kwargs)
    else:
      return _graph_mode_decorator(wrapped, args, kwargs)

  return tf_decorator.make_decorator(f, decorated(f))  # pylint: disable=no-value-for-parameter
 def testUpdatesDictWithMissingEntries(self):
     test_function.foobar = True
     decorated = tf_decorator.make_decorator(test_function, test_wrapper)
     self.assertTrue(decorated.foobar)
     del test_function.foobar
Example #48
0
        def get_wrapper(func):
            def wrapper(*unused_args, **unused_kwargs):
                pass

            return tf_decorator.make_decorator(func, wrapper)
Example #49
0
 def decorated(function):
   return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
Example #50
0
def _multi_worker_test(test_method):
    """Decorate test_method so that it runs in each worker.

  We use `multi_process_runner` to simulate multiple workers. Since we run the
  this function in the main process and all worker processes, this decoration
  behaves differently in the main process and worker procssses. In the main
  process, it spawns subprocesses and runs the test on each of them; in a worker
  process, it executes test in the same way as a normal test, e.g.
  setUp()/tearDown() are called before/after the test.

  Args:
    test_method: a function which must be a test method.

  Returns:
    Decorated `test_method`. Note that the decorated function has additional
    arguments.
  """
    def decorator(self, has_chief, num_workers, **kwargs):
        if _num_total_workers(has_chief,
                              num_workers) == 1 or _running_in_worker:
            # We're in worker process or the test is for single worker. Either case we
            # execute the test method directly instead of spawning subprocesses.
            test_method(self, **kwargs)
            return

        # We're in the main process. We spawn subprocesses and run the *test* on
        # each of them. Note that we're not directly executing test_method passed to
        # _multi_worker_test, because we need setUp()/tearDown() to be called and
        # all the decorations on the test method. The conceptual call stack is:
        #   [main process]test.main()
        #     [main process]test_runner.run(test)
        #       [main process]wrapper by combinations.generate()
        #         [main process]_multi_worker_test.decorator()
        #           # A sub process goes through the same code path as the main
        #           # process.
        #           [sub process]_test_runner()
        #             [sub process]test_runner.run(test)
        #               [sub process]wrapper by combinations.generate()
        #                 [sub process]_multi_worker_test.decorator()
        #                   # _running_in_worker is True
        #                   [sub process]test_method()
        test_id = self.id()
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief,
            num_workers=num_workers,
            num_ps=0,
            has_eval=False)
        result = multi_process_runner.run(_test_runner,
                                          cluster_spec,
                                          args=(test_id, ))
        for was_successful in result.return_value:
            if not was_successful:
                raise AssertionError(
                    "some worker failed, see logs for details")

    argspec = tf_inspect.getfullargspec(test_method)
    decorator_args = (argspec.args or []) + ["has_chief", "num_workers"]
    decorator_argspec = argspec._replace(args=decorator_args)
    return tf_decorator.make_decorator(test_method,
                                       decorator,
                                       decorator_argspec=decorator_argspec)
 def testSetsTFDecoratorDocToDecoratorDocArg(self):
     decorated = tf_decorator.make_decorator(
         test_function, test_wrapper, decorator_doc='test decorator doc')
     decorator = getattr(decorated, '_tf_decorator')
     self.assertEqual('test decorator doc', decorator.decorator_doc)
Example #52
0
 def __get__(self, instance, owner):
     if instance is not None:
         f = self._f.__get__(instance, owner)
         return tf_decorator.make_decorator(f, Bind(f, self._d))
     else:
         return self
 def testAttachesWrappedAttr(self):
     decorated = tf_decorator.make_decorator(test_function, test_wrapper)
     wrapped_attr = getattr(decorated, '__wrapped__')
     self.assertIs(test_function, wrapped_attr)
def recreate_function(saved_function, concrete_functions):
    """Creates a `Function` from a `SavedFunction`.

  Args:
    saved_function: `SavedFunction` proto.
    concrete_functions: map from function name to `ConcreteFunction`.

  Returns:
    A `Function`.
  """
    # TODO(andresp): Construct a `Function` with the cache populated
    # instead of creating a new `Function` backed by a Python layer to
    # glue things together. Current approach is nesting functions deeper for each
    # serialization cycle.

    coder = nested_structure_coder.StructureCoder()
    function_spec = _deserialize_function_spec(saved_function.function_spec,
                                               coder)

    def restored_function_body(*args, **kwargs):
        """Calls a restored function."""
        # TODO(allenl): Functions saved with input_signatures should revive with
        # input_signatures.
        try:
            canonicalized_inputs = function_spec.canonicalize_function_inputs(
                *args, **kwargs)
        except ValueError as e:
            raise ValueError(
                "Cannot canonicalize input args %r and kwargs %r. Error: %r." %
                (args, kwargs, e))

        debug_considered_signatures = []
        for concrete_function_name in saved_function.concrete_functions:
            function_obj = concrete_functions[concrete_function_name]
            canonicalized_original_inputs = (
                function_obj.graph.structured_input_signature)
            debug_considered_signatures.append(canonicalized_original_inputs)

            if _inputs_compatible(canonicalized_inputs,
                                  canonicalized_original_inputs):
                flattened_inputs = nest.flatten(canonicalized_inputs)
                filtered_inputs = [
                    t for t in flattened_inputs if _is_tensor(t)
                ]

                result = function_obj._call_flat(filtered_inputs)  # pylint: disable=protected-access
                if isinstance(result, ops.Operation):
                    return None
                return result

        raise AssertionError(
            "Could not find matching function to call for canonicalized inputs %r. "
            "Only existing signatures are %r." %
            (canonicalized_inputs, debug_considered_signatures))

    concrete_function_objects = []
    for concrete_function_name in saved_function.concrete_functions:
        concrete_function_objects.append(
            concrete_functions[concrete_function_name])

    restored_function = RestoredFunction(restored_function_body,
                                         restored_function_body.__name__,
                                         function_spec,
                                         concrete_function_objects)

    return tf_decorator.make_decorator(
        restored_function_body,
        restored_function,
        decorator_argspec=function_spec.fullargspec)
 def testAttachesATFDecoratorAttr(self):
     decorated = tf_decorator.make_decorator(test_function, test_wrapper)
     decorator = getattr(decorated, '_tf_decorator')
     self.assertIsInstance(decorator, tf_decorator.TFDecorator)
Example #56
0
def _right(operator):
    """Right-handed version of an operator: swap args x and y."""
    return tf_decorator.make_decorator(operator, lambda y, x: operator(x, y))
def test_injectable_decorator_increment(target):
    def wrapper(x):
        return wrapper.__wrapped__(x) + 1

    return tf_decorator.make_decorator(target, wrapper)
Example #58
0
def custom_gradient(f):
    """Decorator to define a function with a custom gradient.

  This decorator allows fine grained control over the gradients of a sequence
  for operations.  This may be useful for multiple reasons, including providing
  a more efficient or numerically stable gradient for a sequence of operations.

  For example, consider the following function that commonly occurs in the
  computation of cross entropy and log likelihoods:

  ```python
  def log1pexp(x):
    return tf.math.log(1 + tf.exp(x))
  ```

  Due to numerical instability, the gradient this function evaluated at x=100 is
  NaN.  For example:

  ```python
  x = tf.constant(100.)
  y = log1pexp(x)
  dy = tf.gradients(y, x) # Will be NaN when evaluated.
  ```

  The gradient expression can be analytically simplified to provide numerical
  stability:

  ```python
  @tf.custom_gradient
  def log1pexp(x):
    e = tf.exp(x)
    def grad(dy):
      return dy * (1 - 1 / (1 + e))
    return tf.math.log(1 + e), grad
  ```

  With this definition, the gradient at x=100 will be correctly evaluated as
  1.0.

  See also `tf.RegisterGradient` which registers a gradient function for a
  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
  for fine grained control over the gradient computation of a sequence of
  operations.

  Note that if the decorated function uses `Variable`s, the enclosing variable
  scope must be using `ResourceVariable`s.

  Args:
    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
       - `x` is a sequence of `Tensor` inputs to the function.
       - `y` is a `Tensor` or sequence of `Tensor` outputs of applying
         TensorFlow operations in `f` to `x`.
       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
         a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
         to the `Tensor`s in `x`.  `grad_ys` is a `Tensor` or sequence of
         `Tensor`s the same size as `y` holding the initial value gradients for
         each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
         vector-valued function `f`'s derivatives should be its Jacobian matrix
         `J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
         which defines how `J` will transform a vector `grad_ys` when
         left-multiplied with it (`grad_ys * J`). This functional representation
         of a matrix is convenient to use for chain-rule calculation
         (in e.g. the back-propagation algorithm).

         If `f` uses `Variable`s (that are not part of the
         inputs), i.e. through `get_variable`, then `grad_fn` should have
         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
         with the derivatives of `Tensor`s in `y` with respect to the variables
         (that is, grad_vars has one Tensor per variable in variables).

  Returns:
    A function `h(x)` which returns the same value as `f(x)[0]` and whose
    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
  """
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        if context.executing_eagerly():
            return _eager_mode_decorator(f, *args, **kwargs)
        else:
            return _graph_mode_decorator(f, *args, **kwargs)

    return tf_decorator.make_decorator(f, decorated)
Example #59
0
def make_template_internal(name_,
                           func_,
                           create_scope_now_=False,
                           unique_name_=None,
                           custom_getter_=None,
                           create_graph_function_=False,
                           **kwargs):
    """Make a template, optionally compiling func_ into a graph function.

  See `make_template` for full documentation.

  Args:
    name_: A name for the scope created by this template. If necessary, the name
      will be made unique by appending `_N` to the name.
    func_: The function to wrap.
    create_scope_now_: Boolean controlling whether the scope should be created
      when the template is constructed or when the template is called. Default
      is False, meaning the scope is created when the template is called.
    unique_name_: When used, it overrides name_ and is not made unique. If a
      template of the same scope/unique_name already exists and reuse is false,
      an error is raised. Defaults to None. If executing eagerly, must be None.
    custom_getter_: Optional custom getter for variables used in `func_`. See
      the @{tf.get_variable} `custom_getter` documentation for
      more information.
    create_graph_function_: When True, `func_` will be executed as a graph
      function. This implies that `func_` must satisfy the properties that
      `function.defun` requires of functions: See the documentation of
      `function.defun` for details. When executing eagerly, setting this flag to
      True can improve performance. Regardless of whether eager execution is
      enabled, enabling this flag gives the caller access to graph-function
      semantics, i.e., accesses to variables are totally ordered and
      side-effecting ops are not pruned.
    **kwargs: Keyword arguments to apply to `func_`.

  Returns:
    A function to encapsulate a set of variables which should be created once
    and reused. An enclosing scope will be created either when `make_template`
    is called or when the result is called, depending on the value of
    `create_scope_now_`. Regardless of the value, the first time the template
    is called it will enter the scope with no reuse, and call `func_` to create
    variables, which are guaranteed to be unique. All subsequent calls will
    re-enter the scope and reuse those variables.

  Raises:
    ValueError: if `name_` is None.
    ValueError: if `unique_name_` is not None and eager execution is enabled.
  """

    if kwargs:
        func_ = tf_decorator.make_decorator(func_,
                                            functools.partial(func_, **kwargs))
    if context.executing_eagerly():
        if unique_name_ is not None:
            raise ValueError(
                "unique_name_ cannot be used when eager exeuction is enabled.")
        return EagerTemplate(name_,
                             func_,
                             create_scope_now=create_scope_now_,
                             custom_getter=custom_getter_,
                             create_graph_function=create_graph_function_)
    return Template(name_,
                    func_,
                    create_scope_now=create_scope_now_,
                    unique_name=unique_name_,
                    custom_getter=custom_getter_,
                    create_graph_function=create_graph_function_)
 def decorator(func):
     return tf_decorator.make_decorator(
         func, _graph_callable_internal(func, shape_and_dtypes))