예제 #1
0
def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
  """Captures a Tensor while building a graph mode function.

  Arguments:
    value: A Tensor object.
    dtype: The datatype of the value produced by the node in the graph.
    name:  Name of the node in the graph.
    as_ref: Ignored (required by register_tensor_conversion_function).

  Returns:
    Returns a constant (the current value of the tensor) if capturing
    is not enabled. A placeholder which will have the value of the
    tensor at runtime otherwise.
  """
  if context.in_eager_mode():
    return value
  _ = as_ref
  tensor_map = _scoped_captures.tensors
  if tensor_map is None:
    # Capturing is not enabled.
    return constant_op.constant(value.numpy())
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value], [],
                        lambda x: x)
  return captured_value
예제 #2
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      handle_data = value._handle_data  # pylint: disable=protected-access
      captured_value._handle_data = handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        with errors.raise_exception_on_not_ok_status() as status:
          pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
              captured_value._op._graph._c_graph,  # pylint: disable=protected-access
              captured_value._as_tf_output(),  # pylint: disable=protected-access
              shapes,
              ranks,
              types,
              status)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
예제 #3
0
 def _capture_helper(self, tensor, name):
   captured_tensor = self.captures.get(tensor, None)
   if captured_tensor is None:
     captured_tensor = _create_substitute_placeholder(tensor, name=name,
                                                      dtype=tensor.dtype)
     self.captures[tensor] = captured_tensor
     self.inputs.append(captured_tensor)
   tape.record_operation("captured_value", [captured_tensor], [tensor],
                         lambda x: [x])
   return captured_tensor
예제 #4
0
def _record_gradient(op_name, inputs, attrs, results, ctx, name):
  """Records gradients for a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    inputs: A flat list of Tensor object inputs to the operation.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    results: The results of the operation (as a flat list).
    ctx: The value of context.context().
    name: Customized name for the operation.

  Returns:
    A list of maybe-wrapped results. Either Tensors or TensorNodes.

  Raises:
    An exception on error.
  """
  if not tape.could_possibly_record():
    return

  if op_name in _ops_which_dont_need_outputs:
    op_outputs = None
  else:
    # TODO(apassos) this line creates a weak circular reference where the
    # backprop function keeps an output alive which in turn keeps the tape entry
    # alive which keeps the backprop function alive. Figure out how to break
    # this up without breaking second derivatives of ops like Exp whose
    # gradients depend only on the outputs.
    op_outputs = results

  if op_name in _ops_which_dont_need_inputs:
    op_inputs = None
  else:
    op_inputs = inputs

  num_inputs = len(inputs)

  def grad_fn(*orig_outputs):
    """Generated gradient function."""
    result = _magic_gradient_function(op_name, attrs, num_inputs,
                                      op_inputs, op_outputs, orig_outputs)
    if _tracing:
      print("Gradient for", (name if name else op_name), "inputs", op_inputs,
            "output_grads", orig_outputs, "gradients", result)
    return result

  inputs = [ops.internal_convert_to_tensor(x, ctx=ctx) for x in inputs]
  tape.record_operation(op_name, results, inputs, [], grad_fn)
  if _tracing:
    print("Computed op", (name if name else op_name), "inputs", inputs,
          "outputs", results)
예제 #5
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
예제 #6
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
예제 #7
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [tf_ops.convert_to_tensor(x) for x in args]

    result, grad_fn = f(*args, **kwargs)
    flat_result = nest.flatten(result)
    # TODO(apassos) consider removing the identity below.
    flat_result = [gen_array_ops.identity(x) for x in flat_result]

    def actual_grad_fn(*outputs):
      return nest.flatten(grad_fn(*outputs))

    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(result, flat_result)
예제 #8
0
  def _backprop_call(self, args):
    """Calls the wrapped function and records the result on a tape."""
    all_args = args + self._extra_inputs
    signature = self._forward_fdef.definition.signature
    ctx = context.context()
    if ctx.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._forward_fdef)  # pylint: disable=protected-access
      def make_tensor(x):
        if isinstance(x, ops.Tensor):
          return x
        return ops.internal_convert_to_tensor(x, ctx=ctx)
      op = g.create_op(
          signature.name, [make_tensor(x) for x in all_args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      outputs = op.outputs
      outputs = [outputs] if isinstance(
          outputs, (ops.Tensor, type(None))) else list(outputs)
      for i, s in enumerate(self._output_shapes):
        outputs[i].set_shape(s)
    else:
      outputs = execute.execute(
          str(signature.name),
          num_outputs=len(signature.output_arg),
          inputs=all_args,
          attrs=None,
          ctx=ctx)
    real_outputs = outputs[:len(self._returns)]
    side_outputs = outputs[len(self._returns):]

    def backward_function(*args):
      return self._backward_function(*(list(args) + side_outputs))

    tape.record_operation(
        signature.name,
        real_outputs,
        (args + self._extra_inputs),
        backward_function)

    return self._build_call_outputs(real_outputs)
예제 #9
0
def _eager_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for eager mode."""
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args, **kwargs)
  all_inputs = list(args) + list(kwargs.values())
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  if (variables and ("variables" not in grad_argspec.args) and
      not grad_argspec.varkw):
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  flat_result = nest.flatten(result)
  # TODO(apassos) consider removing the identity below.
  flat_result = [gen_array_ops.identity(x) for x in flat_result]

  input_tensors = [ops.convert_to_tensor(x) for x
                   in list(args) + list(variables)]
  arg_count = len(args)
  def actual_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []
    flat_grads = nest.flatten(input_grads)
    if len(flat_grads) != arg_count:
      raise ValueError(
          "custom_gradient function expected to return", arg_count,
          "gradients but returned", len(flat_grads), "instead.")
    return nest.flatten(input_grads) + variable_grads

  tape_lib.record_operation(f.__name__, flat_result, input_tensors,
                            actual_grad_fn)
  flat_result = list(flat_result)
  return nest.pack_sequence_as(result, flat_result)
예제 #10
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      if ops._USE_C_SHAPES:  # pylint: disable=protected-access
        if isinstance(value, ops.EagerTensor):
          handle_data = value._handle_data  # pylint: disable=protected-access
        else:
          handle_data = resource_variable_ops.get_resource_handle_data(value)
      else:
        handle_data = value._handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
          pywrap_tensorflow.SetResourceHandleShapeAndType(
              captured_value.graph._c_graph, captured_value._as_tf_output(),
              handle_data.SerializeToString())
        else:
          captured_value._handle_data = handle_data
        # pylint: enable=protected-access
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
            captured_value._op._graph._c_graph,  # pylint: disable=protected-access
            captured_value._as_tf_output(),  # pylint: disable=protected-access
            shapes, ranks, types)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
예제 #11
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
예제 #12
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
예제 #13
0
def _record_gradient(op_name, inputs, attrs, results, name):
  """Records gradients for a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    inputs: A flat list of Tensor object inputs to the operation.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    results: The results of the operation (as a flat list).
    name: Customized name for the operation.

  Returns:
    A list of maybe-wrapped results. Either Tensors or TensorNodes.

  Raises:
    An exception on error.
  """
  if not any(ag_core.isnode(x) for x in inputs):
    return results
  num_outputs = len(results)
  if num_outputs == 0:
    return results
  if attrs is not None:
    attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs)

  # It is imperative we make a copy of results here as otherwise we create a
  # dependency cycle in the captured function and this can delay garbage
  # collecting of the tensors arbitrarily.
  results_size = len(results) if isinstance(results, (list, tuple)) else 1

  def grad_fn(*orig_outputs):
    """Generated gradient function."""
    tensors = inputs + list(orig_outputs)
    tensors = container_types.make_sequence(tape.EagerList, *tensors)
    result = _magic_gradient_function(op_name, attrs, len(inputs),
                                      num_outputs, *(tensors))
    if _tracing:
      print("Gradient for", (name if name else op_name), "inputs", inputs,
            "output_grads", orig_outputs[results_size:], "gradients", result)
    return result

  results = tape.record_operation(results, inputs, [], grad_fn)
  if _tracing:
    print("Computed op", (name if name else op_name), "inputs", inputs,
          "outputs", results)
  return results
예제 #14
0
  def _backprop_call(self, args):
    """Calls the wrapped function and records the result on a tape."""
    all_args = args + self._extra_inputs
    signature = self._forward_fdef.definition.signature
    if context.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._forward_fdef)  # pylint: disable=protected-access
      unwrapped_args = [ag_core.getval(x) for x in all_args]
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      outputs = op.outputs
      outputs = [outputs] if isinstance(
          outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
      for i, s in enumerate(self._output_shapes):
        outputs[i].set_shape(s)
    else:
      outputs = execute.execute(
          signature.name,
          num_outputs=len(signature.output_arg),
          inputs=all_args)
    real_outputs = outputs[:len(self._returns)]
    side_outputs = outputs[len(self._returns):]
    watched_extra_inputs = []
    for t in self._extra_inputs:
      tid = ops.tensor_id(t)
      for t in tape._tape_stack.stack:  # pylint: disable=protected-access
        w = t.value.tensors.get(tid, None)
        if w is not None:
          watched_extra_inputs.append(w)
          break
      else:  # Note: for-else here done on purpose
        watched_extra_inputs.append(t)

    def backward_function_wrapper(*outputs):
      outputs = outputs[len(real_outputs):]
      return self._backward_function(*outputs)
    real_outputs = tape.record_operation(
        real_outputs,
        (args + watched_extra_inputs),
        side_outputs,
        backward_function_wrapper)

    return self._build_call_outputs(self._returns, real_outputs)
예제 #15
0
 def testSpecialForwardFunctionUsed(self):
     c = constant_op.constant(1.)
     d = constant_op.constant(2.)
     e = constant_op.constant(3.)
     with forwardprop.ForwardAccumulator(c, 10.) as acc:
         tape_lib.record_operation("ForwardIsSpecial", [d], [c], None,
                                   lambda jvp: [-2. * jvp])
         self.assertAllClose(-20., acc.jvp(d))
         tape_lib.record_operation("ForwardIsSpecial2", [], [], None,
                                   lambda: [])
         tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None,
                                   lambda x: [x])
         self.assertAllClose(-20., acc.jvp(e))
예제 #16
0
    def _backprop_call(self, args):
        """Calls the wrapped function and records the result on a tape."""
        all_args = args + self._extra_inputs
        signature = self._forward_fdef.definition.signature
        if context.in_graph_mode():
            g = ops.get_default_graph()
            g._add_function(self._forward_fdef)  # pylint: disable=protected-access
            unwrapped_args = [ag_core.getval(x) for x in all_args]
            op = g.create_op(
                signature.name,
                [ops.convert_to_tensor(x) for x in unwrapped_args],
                [dtypes.DType(x.type) for x in signature.output_arg],
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            outputs = op.outputs
            outputs = [outputs] if isinstance(outputs,
                                              (tensor.Tensor, ops.Tensor,
                                               type(None))) else list(outputs)
            for i, s in enumerate(self._output_shapes):
                outputs[i].set_shape(s)
        else:
            outputs = execute.execute(signature.name,
                                      num_outputs=len(signature.output_arg),
                                      inputs=all_args)
        real_outputs = outputs[:len(self._returns)]
        side_outputs = outputs[len(self._returns):]
        watched_extra_inputs = []
        for t in self._extra_inputs:
            tid = ops.tensor_id(t)
            for t in tape._tape_stack.stack:  # pylint: disable=protected-access
                w = t.value.tensors.get(tid, None)
                if w is not None:
                    watched_extra_inputs.append(w)
                    break
            else:  # Note: for-else here done on purpose
                watched_extra_inputs.append(t)
        real_outputs = tape.record_operation(real_outputs,
                                             (args + watched_extra_inputs),
                                             side_outputs,
                                             self._backward_function)

        return self._build_call_outputs(self._returns, real_outputs)
예제 #17
0
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        input_tensors = [
            _watch_value_from_tape(x) for x in args
            if isinstance(x, (_tensor.Tensor,
                              tf_ops.Tensor)) or ag_core.isnode(x)
        ]
        result, grad_fn = f(*args, **kwargs)
        result_size = len(result) if isinstance(result, (list, tuple)) else 1

        # TODO(apassos): naive uses of custom_gradient will not get the correct
        # second derivative this way if they capture any output tensors. Change the
        # signature of custom_gradient.
        def actual_grad_fn(*outputs):
            outputs = outputs[result_size:]
            return grad_fn(*outputs)

        flat_result = nest.flatten(result)
        flat_result = [ag_core.getval(x) for x in flat_result]
        flat_result = tape.record_operation(flat_result, input_tensors, [],
                                            actual_grad_fn)
        flat_result = list(flat_result)
        return nest.pack_sequence_as(structure=result,
                                     flat_sequence=flat_result)
예제 #18
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)
    result_size = len(result) if isinstance(result, (list, tuple)) else 1

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      outputs = outputs[result_size:]
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
예제 #19
0
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()

    default_graph = ops.get_default_graph()

    def convert_arg(x):
        x = ops.convert_to_tensor(x)
        # If graph building, be sure to capture all inputs
        if default_graph.building_function and x.graph != default_graph:
            x = default_graph.capture(x)
        return x

    args = nest.map_structure(convert_arg, args)

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args)

    args = nest.flatten(args)
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    after_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")

    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables_in_tape = frozenset(
        [v.ref() for v in variable_watcher.watched_variables()])
    variables_in_subgraph = frozenset([
        v.ref() for v in _get_dependent_variables(input_ops=args,
                                                  output_ops=flat_result)
    ])
    variables = list(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or "variables" in grad_argspec.kwonlyargs
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError(
            "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
            "since function uses variables: {}".format(variables))
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        logging.warn(
            "@custom_gradient grad_fn has 'variables' in signature, but "
            "no ResourceVariables were used on the forward pass.")

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])
예제 #20
0
def _graph_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = [ops.convert_to_tensor(x) for x in args]

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set(current_var_scope.global_variables() +
                    current_var_scope.local_variables())
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args)
  after_vars = set(current_var_scope.global_variables() +
                   current_var_scope.local_variables())
  new_vars = after_vars - before_vars
  for v in new_vars:
    if not isinstance(v, resource_variable_ops.ResourceVariable):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = list(set(tape.watched_variables()) - set(args))
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    if not variable_scope.get_variable_scope().use_resource:
      raise TypeError("If using @custom_gradient with a function that "
                      "uses variables, the enclosing variable scope must "
                      "have use_resource=True.")
    else:
      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                   "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  all_tensors = flat_result + args + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:len(flat_result)]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * len(flat_result)) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)
  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:len(flat_result)])
    def inner(*args, _watch_vars=None, num_checkpoints=0, **kwargs):
        r"""Performs a forward pass while storing only the checkpoint activations """
        if _watch_vars is None:
            _watch_vars = []
        tensor_watches = [tf.convert_to_tensor(x) for x in _watch_vars]
        model, x = args
        # Dictionary to cache the desired activations during forward pass
        saved_tensors = {}
        # index -1 represents the inputs x
        idx_ckpt = np.array([-1])
        num_layers = len(model.layers)
        # Perform checkpointing. Naive scheme - just distribute checkpoints uniformly across the layers.
        if num_checkpoints:
            if num_checkpoints >= num_layers:
                raise ValueError(
                    "The number of checkpoints is {} and should be less than number of"
                    "layers in the model, which is {} .".format(
                        num_checkpoints, num_layers))
            idx_start, idx_end = 0, num_layers - 1
            # Use offset to avoid checkpointing the start and end layers of the model
            offset = idx_end // num_checkpoints
            start, end = (idx_start + offset) // 2, (idx_end - offset +
                                                     idx_end) // 2
            idx_tmp = np.linspace(start, end, num_checkpoints, dtype=np.uint32)
            idx_ckpt = np.append(idx_ckpt, idx_tmp).tolist()

        x = tf.convert_to_tensor(x)
        with tape_lib.stop_recording():
            # perform forward pass while caching checkpoint layer outputs
            result = x
            saved_tensors[-1] = result
            for idx_layer in range(num_layers):
                result = model.layers[idx_layer](result)
                if idx_layer in idx_ckpt:
                    saved_tensors[idx_layer] = result
            flat_result = nest.flatten(result)
            flat_result = [tf.identity(x) for x in flat_result]
            output = nest.pack_sequence_as(result, flat_result)

        def grad(*grads_output):
            r"""Performs the backward pass while recomputing the forward pass activations for each layer. """
            grads = []
            for idx_forward in range(len(model.layers)):
                idx_back = len(model.layers) - idx_forward - 1
                back_layer = model.layers[idx_back]
                idx_last_ckpt = idx_ckpt[-1]
                if idx_back <= idx_last_ckpt:
                    idx_ckpt.pop()
                    idx_last_ckpt = idx_ckpt[-1]
                prev_output = saved_tensors[idx_last_ckpt]
                for idx_layer in range(idx_last_ckpt + 1, idx_back):
                    prev_output = model.layers[idx_layer](prev_output)
                with tf.GradientTape(watch_accessed_variables=False) as tape:
                    tape.watch(back_layer.trainable_variables)
                    tape.watch(prev_output)
                    recomputed_output = back_layer(prev_output)
                    # identity necessary for grad propagation across 'dead' layers
                    recomputed_output = [
                        tf.identity(x) for x in recomputed_output
                    ]
                    recomputed_output = tf.convert_to_tensor(recomputed_output)
                    prev_output = nest.flatten(prev_output)
                    sources = prev_output + back_layer.trainable_variables
                grads_intermediate = tape.gradient(
                    recomputed_output, sources, output_gradients=grads_output)
                grads_output = grads_intermediate[:len(prev_output)]
                grads_vars = grads_intermediate[len(prev_output):]
                grads.extend(grads_vars[::-1])
                del tape
            return grads[::-1]

        tape_lib.record_operation(str(f), flat_result, tensor_watches, grad)

        return output
예제 #22
0
    def inner(*args,
              _checkpoint=False,
              _watch_vars=None,
              _force_seed=False,
              **kwargs):
        if _force_seed:
            if isinstance(_force_seed, Iterator):
                seed = next(_force_seed)
            else:
                seed = random.randint(1, 1 << 31)

        if _checkpoint:
            if _watch_vars is None:
                _watch_vars = []

            watch_args = []

            flat_inputs = nest.flatten(args) + nest.flatten(
                list(kwargs.values()))
            flat_inputs = [x for x in flat_inputs if tf.is_tensor(x)]
            flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32]
            unique_inputs = [
                x.deref()
                for x in set(x.experimental_ref() for x in flat_inputs)
            ]

            unique_vars = [
                v.deref()
                for v in set(v.experimental_ref() for v in _watch_vars)
                if not any(v is inp for inp in flat_inputs)
            ]

            watches = unique_inputs + unique_vars
            tensor_watches = [tf.convert_to_tensor(x) for x in watches]

            with tape.stop_recording():
                if _force_seed:
                    tf.random.set_seed(seed)

                result = f(*args, **kwargs)

                flat_result = nest.flatten(result)
                # No idea what the point of this is but they do it in tf.custom_gradient so I'm doing it too
                flat_result = [tf.identity(x) for x in flat_result]
                output = nest.pack_sequence_as(result, flat_result)

            def grad(*output_grads):
                with tf.GradientTape() as g:
                    g.watch(watches)
                    if _force_seed:
                        tf.random.set_seed(seed)
                    recomputed_output = f(*args, **kwargs)
                    recomputed_output = [
                        tf.identity(x) for x in nest.flatten(recomputed_output)
                    ]

                grads = g.gradient(recomputed_output,
                                   watches,
                                   output_gradients=output_grads)
                del g
                return grads

            tape.record_operation(str(f), flat_result, tensor_watches, grad)

            return output
        else:
            if _force_seed:
                tf.random.set_seed(seed)
            return f(*args, **kwargs)
예제 #23
0
def _graph_mode_decorator(f, args, kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = nest.map_structure(ops.convert_to_tensor, args)

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set([
      v.ref() for v in current_var_scope.global_variables() +
      current_var_scope.local_variables()
  ])
  with tape_lib.VariableWatcher() as variable_watcher:
    result, grad_fn = f(*args)
  args = nest.flatten(args)
  after_vars = set([
      v.ref() for v in current_var_scope.global_variables() +
      current_var_scope.local_variables()
  ])
  new_vars = after_vars - before_vars
  new_vars_list = [v.deref() for v in new_vars]
  for v in new_vars_list:
    if not resource_variable_ops.is_resource_variable(v):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")

  # It is possible for the caller to pass in an input that is from a different
  # graph. Even though this is not valid we filter these out if they are not
  # from the output graph to make it easier for some code to migrate to custom
  # gradients.
  inputs = nest.flatten(args)
  outputs = nest.flatten(result)
  graphs = {getattr(o, "graph", None) for o in outputs}
  # Not all results may be tensors. However, we want to ensure that all outputs
  # are from the same graph and use that to filter the inputs.
  graphs.discard(None)  # Discard non-graph outputs
  if graphs:
    if len(graphs) > 1:
      raise ValueError("All graph outputs should be from the same graph")
    output_graph = graphs.pop()
    filtered_inputs = []
    for i in inputs:
      if i.graph != output_graph:
        logging.warn("%s does not belong to output graph %s", i, output_graph)
      else:
        filtered_inputs.append(i)

    inputs = filtered_inputs

  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables_in_tape = frozenset([
      v.ref() for v in variable_watcher.watched_variables()
  ]) - frozenset(v.ref() for v in inputs)
  variables_in_subgraph = frozenset([
      v.ref()
      for v in get_dependent_variables(input_ops=inputs, output_ops=outputs)
  ])
  variables = list(
      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                 "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  flat_result_len = len(flat_result)

  all_tensors = flat_result + inputs + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:flat_result_len]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * flat_result_len) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)

  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:flat_result_len])
예제 #24
0
 def capture_distributed_variable(self, variable, placeholder):
   """Add given distributed variable to captures with given placeholder."""
   self._captures[ops.tensor_id(variable)] = (variable, placeholder)
   tape.record_operation("captured_value", [placeholder], [variable],
                         lambda x: [x])
예제 #25
0
def _graph_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = [ops.convert_to_tensor(x) for x in args]

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set(current_var_scope.global_variables() +
                    current_var_scope.local_variables())
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args)
  after_vars = set(current_var_scope.global_variables() +
                   current_var_scope.local_variables())
  new_vars = after_vars - before_vars
  for v in new_vars:
    if not resource_variable_ops.is_resource_variable(v):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = list(set(tape.watched_variables()) - set(args))
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    if not variable_scope.get_variable_scope().use_resource:
      raise TypeError("If using @custom_gradient with a function that "
                      "uses variables, the enclosing variable scope must "
                      "have use_resource=True.")
    else:
      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                   "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  all_tensors = flat_result + args + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:len(flat_result)]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * len(flat_result)) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)
  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:len(flat_result)])