Esempio n. 1
0
def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
    """Computes a Jacobian-vector product for an op.

  Note that this function would be wasteful if executed eagerly. It runs the
  backward gradient function and throws away the result just to record its
  operations on a GradientTape. These unused ops are pruned away when this
  function is traced.

  Args:
    op_name: A string, the type of operation being executed.
    attr_tuple: Attributes of the operation.
    inputs: A flat list of input Tensors to the operation.
    outputs: A flat list of output Tensors from the operation.
    tangents: A flat list of Tensors, same shape as `inputs`.

  Returns:
    A flat list of tangents corresponding to `outputs`.
  """
    if not outputs:
        # tape.gradients([], inputs) doesn't make much sense
        return []
    trainable_inputs = []
    trainable_indices = []
    nontrivial_tangents = []
    for input_index, tensor in enumerate(inputs):
        if backprop_util.IsTrainable(tensor):
            trainable_inputs.append(tensor)
            trainable_indices.append(input_index)
            nontrivial_tangents.append(tangents[input_index])

    with backprop.GradientTape() as transpose_tape:
        with backprop.GradientTape() as backfunc_tape:
            backfunc_tape.watch(trainable_inputs)
            execute.record_gradient(op_name, inputs, attr_tuple, outputs,
                                    "forward_op_replay")

        forwardprop_aids = []
        trainable_outputs = []
        nontrivial_output_indices = []
        for output_index, output in enumerate(outputs):
            if backprop_util.IsTrainable(output):
                forwardprop_aids.append(
                    array_ops.ones_like(output, name="unused_forwardprop_aid"))
                trainable_outputs.append(output)
                nontrivial_output_indices.append(output_index)

        transpose_tape.watch(forwardprop_aids)
        grads = backfunc_tape.gradient(
            trainable_outputs,
            trainable_inputs,
            forwardprop_aids,
            unconnected_gradients=UnconnectedGradients.ZERO)
    nontrivial_output_tangents = transpose_tape.gradient(
        grads, forwardprop_aids, output_gradients=nontrivial_tangents)
    output_tangents = [None] * len(outputs)
    for index, tangent in zip(nontrivial_output_indices,
                              nontrivial_output_tangents):
        output_tangents[index] = tangent
    return output_tangents
Esempio n. 2
0
def _is_trainable(tensor):
  """Returns whether the given tensor is trainable."""
  if not backprop_util.IsTrainable(tensor):
    return False

  # Special case: untrainable accumulator output. The gradients algorithm
  # doesn't know about tensor lists of untrainable elements. In theory the
  # tensor list gradient functions should return None as appropriate, but
  # because we can't return None from the gradient function we filter out
  # untrainable accumulator output here to avoid computing the gradient at all.
  if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
    assert tensor.dtype == dtypes.variant
    element_type = tensor.op.get_attr("element_dtype")
    return backprop_util.IsTrainable(element_type)

  return True
Esempio n. 3
0
    def watch(self, tensor):
        """Ensures that `tensor` is being traced by this tape.

    Args:
      tensor: a Tensor or list of Tensors.

    Raises:
      ValueError: if it encounters something that is not a tensor.
    """
        for t in nest.flatten(tensor):
            if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
                raise ValueError(
                    "Passed in object of type {}, not tf.Tensor".format(
                        type(t)))
            if not backprop_util.IsTrainable(t):
                logging.log_first_n(
                    logging.WARN, "The dtype of the watched tensor must be "
                    "floating (e.g. tf.float32), got %r", 5, t.dtype)
            if hasattr(t, "handle"):
                # There are many variable-like objects, all of them currently have
                # `handle` attribute that points to a tensor. If this changes, internals
                # of watch_variable need to change as well.
                tape.watch_variable(self._tape, t)
            else:
                tape.watch(self._tape, t)
def get_flat_tensors_for_gradients(xs):
    """Returns a flat list of Tensors that should be differentiated for `xs`.

  Args:
    xs: A list of `Tensor`s or `CompositeTensor`s.

  Returns:
    A flat list of `Tensor`s constructed from `xs`, where `Tensor` values are
    left as-is, and `CompositeTensor`s are replaced with
    `_get_tensors_for_gradient(x)`.
  """
    # Note: we could just return
    # nest.flatten([_get_tensors_for_gradient(x) for x in xs]), but we
    # manually walk over the results to give better warning messages.
    result = []
    for x in xs:
        if not isinstance(x, composite_tensor.CompositeTensor):
            result.append(x)
        else:
            x_tensors = nest.flatten(_get_tensors_for_gradient(x))
            for t in x_tensors:
                if not backprop_util.IsTrainable(t):
                    logging.log_first_n(
                        logging.WARN,
                        "The dtype of differentiable component %s in %s "
                        "must be floating (e.g., tf.float32), got %r.", 5, t,
                        x, t.dtype)
            result.extend(x_tensors)
    return result
Esempio n. 5
0
def _grad_fn(func_graph, grads):
    """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
    # Filter out untrainable function outputs.
    # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
    # cause _GradientsHelper to raise an exception (e.g. the implementation
    # doesn't expect 'ys' to contain boolean tensors).
    assert len(func_graph.outputs) == len(grads)
    ys = []
    grad_ys = []
    for y, grad_y in zip(func_graph.outputs, grads):
        if not backprop_util.IsTrainable(y):
            continue
        ys.append(y)
        grad_ys.append(grad_y)

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    result = gradients_util._GradientsHelper(ys,
                                             func_graph.inputs,
                                             grad_ys=grad_ys,
                                             src_graph=func_graph)

    # Functions can't return None; replace Nones with zero tensors.
    # TODO(b/80444525): don't return anything here and make _IfGrad return None if
    # both branches have zero gradient.
    for i in range(len(result)):
        if result[i] is None:
            if func_graph.inputs[i].dtype == dtypes.resource:
                result[i] = array_ops.zeros(
                    gen_resource_variable_ops.variable_shape(
                        func_graph.inputs[i]),
                    dtype=default_gradient.get_zeros_dtype(
                        func_graph.inputs[i]))
            else:
                result[i] = array_ops.zeros_like(func_graph.inputs[i])

    return result
Esempio n. 6
0
    def __call__(self, device, token, args):
        """Calls `self._func` in eager mode, recording the tape if needed."""
        use_tape_cache = (self._support_graph_mode_gradient
                          or tape_lib.could_possibly_record())

        if use_tape_cache:
            with backprop.GradientTape() as tape:
                for tensor in args:
                    for t in nest.flatten(tensor):
                        if backprop_util.IsTrainable(t):
                            tape.watch(t)
                outputs = self._call(device, args)
            tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
        else:
            outputs = self._call(device, args)

        return outputs
Esempio n. 7
0
def _MarkReachedOps(from_ops, reached_ops, func_graphs):
  """Mark all ops reached from "from_ops".

  Args:
    from_ops: list of Operations.
    reached_ops: set of Operations.
    func_graphs: list of FuncGraphs. This method will traverse through
      these functions if they capture from_ops or any reachable ops.
  """
  queue = collections.deque()
  queue.extend(from_ops)
  while queue:
    op = queue.popleft()
    if op not in reached_ops:
      reached_ops.add(op)
      for output in op.outputs:
        if backprop_util.IsTrainable(output):
          queue.extend(_Consumers(output, func_graphs))
Esempio n. 8
0
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
                                  xs_set):
    """Update pending count for the inputs of op and enqueue ready ops."""
    for x in _NonEagerInputs(op, xs_set):
        pending_count[x.op] -= 1
        ready = (pending_count[x.op] == 0)
        if loop_state and not ready:
            ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(
                x.op)
        if ready:
            if control_flow_util.IsLoopExit(x.op):
                # if x is an exit without real gradient, defer processing them.
                grad_state = loop_state.GetGradState(x.op, before=False)
                grad_state.deferred_exits.append(x)
                grad_state.pending_exits_count -= 1
                if grad_state.pending_exits_count == 0:
                    # We now have all the exits so process them.
                    has_not_none_grad = False
                    for y in grad_state.deferred_exits:
                        if _HasAnyNotNoneGrads(grads, y.op):
                            has_not_none_grad = True
                            queue.append(y.op)
                        else:
                            grad_state.unused_exits.append(y)
                    if has_not_none_grad:
                        # For an unused exit, if it has trainable outputs, backprop
                        # a zero gradient. Otherwise, just ignore it.
                        for y in grad_state.unused_exits:
                            if backprop_util.IsTrainable(y):
                                _SetGrad(grads, y,
                                         loop_state.ZerosLikeForExit(y))
                            queue.append(y.op)
                    else:
                        # All exits are "unused" so use None as gradient.
                        for y in grad_state.unused_exits:
                            queue.append(y.op)
            else:
                queue.append(x.op)
Esempio n. 9
0
def _grad_fn(func_graph, grads):
    """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
    # Filter out untrainable function outputs.
    # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
    # cause _GradientsHelper to raise an exception (e.g. the implementation
    # doesn't expect 'ys' to contain boolean tensors).
    assert len(func_graph.outputs) == len(grads)
    ys = []
    grad_ys = []
    for y, grad_y in zip(func_graph.outputs, grads):
        if not backprop_util.IsTrainable(y):
            continue
        ys.append(y)
        grad_ys.append(grad_y)

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    result = gradients_util._GradientsHelper(ys,
                                             func_graph.inputs,
                                             grad_ys=grad_ys,
                                             src_graph=func_graph)

    return result
Esempio n. 10
0
    def gradient(self,
                 target,
                 sources,
                 output_gradients=None,
                 unconnected_gradients=UnconnectedGradients.NONE):
        """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: a list or nested structure of Tensors or Variables to be
        differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
        if self._tape is None:
            raise RuntimeError(
                "GradientTape.gradient can only be called once on "
                "non-persistent tapes.")
        if self._recording:
            if not self._persistent:
                self._pop_tape()
            else:
                logging.log_first_n(
                    logging.WARN,
                    "Calling GradientTape.gradient on a persistent "
                    "tape inside its context is significantly less "
                    "efficient than calling it outside the context (it "
                    "causes the gradient ops to be recorded on the "
                    "tape, leading to increased CPU and memory usage). "
                    "Only call GradientTape.gradient inside the "
                    "context if you actually want to trace the "
                    "gradient in order to compute higher order "
                    "derivatives.", 1)

        flat_targets = []
        for t in nest.flatten(target):
            if not backprop_util.IsTrainable(t):
                logging.vlog(
                    logging.WARN, "The dtype of the target tensor must be "
                    "floating (e.g. tf.float32) when calling GradientTape.gradient, "
                    "got %r", t.dtype)
            if resource_variable_ops.is_resource_variable(t):
                with self:
                    t = ops.convert_to_tensor(t)
            flat_targets.append(t)

        flat_sources = nest.flatten(sources)
        flat_sources_raw = flat_sources
        flat_sources = [_handle_or_self(x) for x in flat_sources]
        for t in flat_sources_raw:
            if not backprop_util.IsTrainable(t):
                logging.vlog(
                    logging.WARN, "The dtype of the source tensor must be "
                    "floating (e.g. tf.float32) when calling GradientTape.gradient, "
                    "got %r", t.dtype)

        if output_gradients is not None:
            output_gradients = [
                None if x is None else ops.convert_to_tensor(x)
                for x in nest.flatten(output_gradients)
            ]

        flat_grad = imperative_grad.imperative_grad(
            self._tape,
            flat_targets,
            flat_sources,
            output_gradients=output_gradients,
            sources_raw=flat_sources_raw,
            unconnected_gradients=unconnected_gradients)

        if not self._persistent:
            self._tape = None

        grad = nest.pack_sequence_as(sources, flat_grad)
        return grad
Esempio n. 11
0
def _GradientsHelper(ys,
                     xs,
                     grad_ys=None,
                     name="gradients",
                     colocate_gradients_with_ops=False,
                     gate_gradients=False,
                     aggregation_method=None,
                     stop_gradients=None,
                     unconnected_gradients=UnconnectedGradients.NONE,
                     src_graph=None):
    """Implementation of gradients()."""
    if context.executing_eagerly():
        raise RuntimeError(
            "tf.gradients is not supported when eager execution "
            "is enabled. Use tf.GradientTape instead.")
    if src_graph is None:
        src_graph = ops.get_default_graph()
    try:
        unconnected_gradients = UnconnectedGradients(unconnected_gradients)
    except ValueError:
        raise ValueError("Unknown value for unconnected_gradients: %r" %
                         unconnected_gradients)

    # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
    # ancestor graphs. This is necessary for correctly handling captured values.
    func_graphs = []
    curr_graph = src_graph
    while _IsFunction(curr_graph):
        func_graphs.append(curr_graph)
        if isinstance(curr_graph, FuncGraph):
            curr_graph = curr_graph.outer_graph
        else:
            assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
            curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access

    ys = _AsList(ys)
    xs = _AsList(xs)
    stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)

    with ops.name_scope(
            name, "gradients",
            list(ys) + list(xs) + list(stop_gradients) +
            list(grad_ys)) as grad_scope:
        # Get a uid for this call to gradients that can be used to help
        # cluster ops for compilation.
        gradient_uid = ops.get_default_graph().unique_name("uid")
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in xs
        ]
        xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs,
                                                                name="x",
                                                                as_ref=True)
        xs_set = object_identity.ObjectIdentitySet(xs)
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
                                 gradient_uid)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        stop_gradient_ops = [t.op for t in stop_gradients]
        reachable_to_ops, pending_count, loop_state = _PendingCount(
            to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            ready = (pending_count[op] == 0)
            if ready and op not in to_ops_set and op in reachable_to_ops:
                to_ops_set.add(op)
                queue.append(op)

        if loop_state:
            loop_exits = loop_state.ProcessUnusedLoopExits(
                pending_count, to_ops_set)
            for y in loop_exits:
                if backprop_util.IsTrainable(y):
                    _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, gradient_uid,
                                      colocate_gradients_with_ops):
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=True)
                out_grads = _AggregatedGrads(grads, op, gradient_uid,
                                             loop_state, aggregation_method)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=True)

                grad_fn = None
                func_call = None
                is_partitioned_call = _IsPartitionedCall(op)
                # pylint: disable=protected-access
                is_func_call = (src_graph._is_function(op.type)
                                or is_partitioned_call)
                # pylint: enable=protected-access
                has_out_grads = any(
                    isinstance(g, ops.Tensor) or g for g in out_grads)
                if has_out_grads and (op not in stop_ops):
                    try:
                        grad_fn = ops.get_gradient_function(op)
                    except LookupError:
                        if is_func_call:
                            if is_partitioned_call:
                                func_call = src_graph._get_function(  # pylint: disable=protected-access
                                    compat.as_bytes(op.get_attr("f").name))
                            else:
                                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
                            # Note that __defun is not set if the graph is
                            # imported. If it's set, we prefer to access the original
                            # defun.
                            func_call = getattr(op, "__defun", func_call)
                            grad_fn = func_call.python_grad_func
                        else:
                            raise LookupError(
                                "No gradient defined for operation '%s' (op type: %s)"
                                % (op.name, op.type))
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=False)

                # NOTE(skyewm): We don't support computing gradients wrt a loop variable
                # unless it's within the context of a single iteration (i.e. the
                # gradient is wrt to the loop parameter in the body function, not wrt or
                # through the initial value). This means if we're in a while loop
                # context, we should never see a switch node from this context.
                # pylint: disable=protected-access
                if (control_flow_util.IsSwitch(op)
                        and op._control_flow_context is not None
                        and op._control_flow_context.IsWhileContext()
                        and op._control_flow_context ==
                        ops.get_default_graph()._get_control_flow_context()):
                    _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
                # pylint: enable=protected-access

                if (grad_fn or is_func_call) and has_out_grads:
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if (not isinstance(out_grad, ops.Tensor)
                                and not out_grad) and (
                                    (not grad_fn and is_func_call) or
                                    backprop_util.IsTrainable(op.outputs[i])):
                            # Only trainable outputs or outputs for a function call that
                            # will use SymbolicGradient get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            # TODO(apassos) gradients of resource handles might be an
                            # issue here because of zeros.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLike(op, i)
                            elif default_gradient.supports_default_grad(
                                    op.outputs[i]):
                                # TODO(b/143286622): The supports_default_grad check is needed
                                # because While op emits non-differentiable resource tensors
                                # as outputs. Remove this check when that is not the case.
                                out_grads[
                                    i] = control_flow_state.ZerosLikeOutsideLoop(
                                        op, i)
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with src_graph._original_op(op):
                            # pylint: enable=protected-access
                            if grad_fn:
                                # If grad_fn was found, do not use SymbolicGradient even for
                                # functions.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: grad_fn(op, *out_grads))
                            else:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: _SymGrad(op, out_grads))
                            in_grads = _AsList(in_grads)
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                [x for x in in_grads if x is not None]) > 1:
                                with ops.device(None):
                                    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
                                            None,
                                            gradient_uid,
                                            ignore_existing=True):
                                        in_grads = control_flow_ops.tuple(
                                            in_grads)
                    _LogOpGradients(op, out_grads, in_grads)
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagate a list of None backwards.
                    in_grads = [None] * len(_Inputs(op, xs_set))
                # Note: we don't filter out eager inputs here because the inputs need to
                # line up with in_grads.
                for i, (t_in, in_grad) in enumerate(
                        zip(_Inputs(op, xs_set), in_grads)):
                    if in_grad is not None:
                        if (isinstance(in_grad, ops.Tensor)
                                and t_in.dtype != dtypes.resource):
                            try:
                                in_grad.set_shape(t_in.get_shape())
                            except ValueError:
                                raise ValueError(
                                    "Incompatible shapes between op input and calculated "
                                    "input gradient.  Forward operation: %s.  Input index: %d. "
                                    "Original input shape: %s.  "
                                    "Calculated input gradient shape: %s" %
                                    (op.name, i, t_in.shape, in_grad.shape))
                        if not isinstance(t_in, ops.EagerTensor):
                            _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=False)

            # Update pending count for the inputs of op and enqueue ready ops.
            _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count,
                                          loop_state, xs_set)

    if loop_state:
        loop_state.PostProcessing()
    return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
Esempio n. 12
0
def _IsBackpropagatable(tensor):
    if backprop_util.IsTrainable(tensor):
        return True
    dtype = dtypes.as_dtype(tensor.dtype)
    return dtype.base_dtype == dtypes.bfloat16
Esempio n. 13
0
  def gradient(self,
               target,
               sources,
               output_gradients=None,
               unconnected_gradients=UnconnectedGradients.NONE):
    """Computes the gradient using operations recorded in context of this tape.

    Note: Unless you set `persistent=True` a GradientTape can only be used to
    compute one set of gradients (or jacobians).

    Args:
      target: a list or nested structure of Tensors or Variables to be
        differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: If called on a used, non-persistent tape.
      RuntimeError: If called inside the context of the tape.
      ValueError: If the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
    if self._tape is None:
      raise RuntimeError("A non-persistent GradientTape can only be used to"
                         "compute one set of gradients (or jacobians)")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(
            logging.WARN, "Calling GradientTape.gradient on a persistent "
            "tape inside its context is significantly less "
            "efficient than calling it outside the context (it "
            "causes the gradient ops to be recorded on the "
            "tape, leading to increased CPU and memory usage). "
            "Only call GradientTape.gradient inside the "
            "context if you actually want to trace the "
            "gradient in order to compute higher order "
            "derivatives.", 1)

    num_ndarrays = 0
    flat_targets = []
    for t in nest.flatten(target):
      if not backprop_util.IsTrainable(t):
        logging.vlog(
            logging.WARN, "The dtype of the target tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if resource_variable_ops.is_resource_variable(t):
        with self:
          t = ops.convert_to_tensor(t)
      elif isinstance(t, np_arrays.ndarray):
        t = t.data
        num_ndarrays += 1
      flat_targets.append(t)
    # Only rewrap if all targets are ndarray. If not, prefer tensors.
    rewrap_as_ndarray = num_ndarrays == len(flat_targets)

    flat_sources = nest.flatten(sources)
    flat_sources_raw = flat_sources
    flat_sources = [_handle_or_self(x) for x in flat_sources]
    for t in flat_sources_raw:
      if not backprop_util.IsTrainable(t):
        logging.vlog(
            logging.WARN, "The dtype of the source tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if getattr(t, "is_packed", False):
        raise ValueError(
            "GradientTape.gradient is not supported on packed EagerTensors yet."
        )

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        flat_targets,
        flat_sources,
        output_gradients=output_gradients,
        sources_raw=flat_sources_raw,
        unconnected_gradients=unconnected_gradients)

    if not self._persistent:
      # Keep track of watched variables before setting tape to None
      self._watched_variables = self._tape.watched_variables()
      self._tape = None

    if rewrap_as_ndarray:
      def _tensor_to_ndarray(x):
        if x is not None:
          return np_arrays.tensor_to_ndarray(x)
        return None
      flat_grad = nest.map_structure(_tensor_to_ndarray, flat_grad)

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
Esempio n. 14
0
def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
    """Computes a Jacobian-vector product for an op.

  Note that this function would be wasteful if executed eagerly. It runs the
  backward gradient function and throws away the result just to record its
  operations on a GradientTape. These unused ops are pruned away when this
  function is traced.

  Args:
    op_name: A string, the type of operation being executed.
    attr_tuple: Attributes of the operation.
    inputs: A flat list of input Tensors to the operation.
    outputs: A flat list of output Tensors from the operation.
    tangents: A flat list of Tensors, same shape as `inputs`.

  Returns:
    A flat list of tangents corresponding to `outputs`.
  """
    with _TRACE_COUNT_CONSISTENCY_LOCK:
        # Just make sure writes don't clobber each other's increments; reads in
        # _jvp_dispatch do not lock.
        _TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1

    special_case = _SPECIAL_CASES.get(op_name, None)
    if special_case is not None:
        return special_case(attr_tuple, inputs, outputs, tangents)
    if not outputs:
        # tape.gradients([], inputs) doesn't make much sense
        return []
    # Generally inner GradientTapes won't function while outer accumulators are
    # recording. We temporarily reset forwardprop state to allow GradientTapes to
    # function here.
    with forwardprop_util.push_forwardprop_state():
        trainable_inputs = []
        trainable_indices = []
        nontrivial_tangents = []
        for input_index, tensor in enumerate(inputs):
            if backprop_util.IsTrainable(tensor):
                trainable_inputs.append(tensor)
                trainable_indices.append(input_index)
                nontrivial_tangents.append(tangents[input_index])

        with backprop.GradientTape() as transpose_tape:
            with backprop.GradientTape() as backfunc_tape:
                backfunc_tape.watch(trainable_inputs)
                execute.record_gradient(op_name, inputs, attr_tuple, outputs)

            forwardprop_aids = []
            trainable_outputs = []
            nontrivial_output_indices = []
            for output_index, output in enumerate(outputs):
                if backprop_util.IsTrainable(output):
                    forwardprop_aids.append(
                        array_ops.ones_like(output,
                                            name="unused_forwardprop_aid"))
                    trainable_outputs.append(output)
                    nontrivial_output_indices.append(output_index)

            transpose_tape.watch(forwardprop_aids)
            grads = backfunc_tape.gradient(
                trainable_outputs,
                trainable_inputs,
                forwardprop_aids,
                unconnected_gradients=UnconnectedGradients.ZERO)
        nontrivial_output_tangents = transpose_tape.gradient(
            grads, forwardprop_aids, output_gradients=nontrivial_tangents)
        output_tangents = [None] * len(outputs)
        for index, tangent in zip(nontrivial_output_indices,
                                  nontrivial_output_tangents):
            output_tangents[index] = tangent
        return output_tangents