def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
        """Fetches the Jacobian-vector product computed for `primals`.

    Note that this method performs no computation, and simply looks up a JVP
    that was already computed (unlike backprop using a `tf.GradientTape`, where
    the computation happens on the call to `tape.gradient`).

    Args:
      primals: A watched Tensor or structure of Tensors to fetch the JVPs for.
      unconnected_gradients: A value which can either hold 'none' or 'zero' and
        alters the value which will be returned if no JVP was computed for
        `primals`. The possible values and effects are detailed in
        'tf.UnconnectedGradients' and it defaults to 'none'.

    Returns:
      Tensors with the same shapes and dtypes as `primals`, or None if no JVP
      is available.
    """
        unconnected_gradients = UnconnectedGradients(unconnected_gradients)
        if self._accumulator is None:
            raise ValueError("Called jvp() without first tracing anything.")

        def _fetch_jvp(tensor):
            if hasattr(tensor, "handle"):
                tensor = ops.convert_to_tensor(tensor.handle)
            result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(
                self._accumulator, tensor)
            if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
                return array_ops.zeros_like(tensor)
            return result

        return nest.map_structure(_fetch_jvp, primals)
Beispiel #2
0
def imperative_grad(tape,
                    target,
                    sources,
                    output_gradients=None,
                    sources_raw=None,
                    unconnected_gradients=UnconnectedGradients.NONE):
  """Computes gradients from the imperatively defined tape on top of the stack.

  Works by filtering the tape, computing how many downstream usages are of each
  tensor and entry, and repeatedly applying backward functions until we have
  gradients for all sources.

  Args:
   tape: the gradient tape which stores the trace.
   target: either a Tensor or list of Tensors to be differentiated.
   sources: list of Tensors for which we want gradients
   output_gradients: if not None, a list of gradient provided for each Target,
    or None if we are to use the target's computed downstream gradient.
   sources_raw: if not None, a list of the source python objects from which the
    sources were generated. Should have the same length as sources. Only needs
    to be populated if unconnected_gradients is 'zero'.
   unconnected_gradients: determines the value returned if the target and
    sources are unconnected. When 'none' the value returned is None wheras when
    'zero' a zero tensor in the same shape as the sources is returned.

  Returns:
   the gradient wrt each of the sources.

  Raises:
    ValueError: if the arguments are invalid.
    RuntimeError: if something goes wrong.
  """
  try:
    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
  except ValueError:
    raise ValueError(
        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)

  return pywrap_tfe.TFE_Py_TapeGradient(
      tape._tape,  # pylint: disable=protected-access
      target,
      sources,
      output_gradients,
      sources_raw,
      compat.as_str(unconnected_gradients.value))
Beispiel #3
0
def _GradientsHelper(ys,
                     xs,
                     grad_ys=None,
                     name="gradients",
                     colocate_gradients_with_ops=False,
                     gate_gradients=False,
                     aggregation_method=None,
                     stop_gradients=None,
                     unconnected_gradients=UnconnectedGradients.NONE,
                     src_graph=None):
    """Implementation of gradients()."""
    if context.executing_eagerly():
        raise RuntimeError(
            "tf.gradients is not supported when eager execution "
            "is enabled. Use tf.GradientTape instead.")
    if src_graph is None:
        src_graph = ops.get_default_graph()
    try:
        unconnected_gradients = UnconnectedGradients(unconnected_gradients)
    except ValueError:
        raise ValueError("Unknown value for unconnected_gradients: %r" %
                         unconnected_gradients)

    # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
    # ancestor graphs. This is necessary for correctly handling captured values.
    func_graphs = []
    curr_graph = src_graph
    while _IsFunction(curr_graph):
        func_graphs.append(curr_graph)
        if isinstance(curr_graph, FuncGraph):
            curr_graph = curr_graph.outer_graph
        else:
            assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
            curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access

    ys = _AsList(ys)
    xs = _AsList(xs)
    stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)

    with ops.name_scope(
            name, "gradients",
            list(ys) + list(xs) + list(stop_gradients) +
            list(grad_ys)) as grad_scope:
        # Get a uid for this call to gradients that can be used to help
        # cluster ops for compilation.
        gradient_uid = ops.get_default_graph().unique_name("uid")
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in xs
        ]
        xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs,
                                                                name="x",
                                                                as_ref=True)
        xs_set = object_identity.ObjectIdentitySet(xs)
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
                                 gradient_uid)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        stop_gradient_ops = [t.op for t in stop_gradients]
        reachable_to_ops, pending_count, loop_state = _PendingCount(
            to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            ready = (pending_count[op] == 0)
            if ready and op not in to_ops_set and op in reachable_to_ops:
                to_ops_set.add(op)
                queue.append(op)

        if loop_state:
            loop_exits = loop_state.ProcessUnusedLoopExits(
                pending_count, to_ops_set)
            for y in loop_exits:
                if backprop_util.IsTrainable(y):
                    _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, gradient_uid,
                                      colocate_gradients_with_ops):
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=True)
                out_grads = _AggregatedGrads(grads, op, gradient_uid,
                                             loop_state, aggregation_method)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=True)

                grad_fn = None
                func_call = None
                is_partitioned_call = _IsPartitionedCall(op)
                # pylint: disable=protected-access
                is_func_call = (src_graph._is_function(op.type)
                                or is_partitioned_call)
                # pylint: enable=protected-access
                has_out_grads = any(
                    isinstance(g, ops.Tensor) or g for g in out_grads)
                if has_out_grads and (op not in stop_ops):
                    try:
                        grad_fn = ops.get_gradient_function(op)
                    except LookupError:
                        if is_func_call:
                            if is_partitioned_call:
                                func_call = src_graph._get_function(  # pylint: disable=protected-access
                                    compat.as_bytes(op.get_attr("f").name))
                            else:
                                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
                            # Note that __defun is not set if the graph is
                            # imported. If it's set, we prefer to access the original
                            # defun.
                            func_call = getattr(op, "__defun", func_call)
                            grad_fn = func_call.python_grad_func
                        else:
                            raise LookupError(
                                "No gradient defined for operation '%s' (op type: %s)"
                                % (op.name, op.type))
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=False)

                # NOTE(skyewm): We don't support computing gradients wrt a loop variable
                # unless it's within the context of a single iteration (i.e. the
                # gradient is wrt to the loop parameter in the body function, not wrt or
                # through the initial value). This means if we're in a while loop
                # context, we should never see a switch node from this context.
                # pylint: disable=protected-access
                if (control_flow_util.IsSwitch(op)
                        and op._control_flow_context is not None
                        and op._control_flow_context.IsWhileContext()
                        and op._control_flow_context ==
                        ops.get_default_graph()._get_control_flow_context()):
                    _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
                # pylint: enable=protected-access

                if (grad_fn or is_func_call) and has_out_grads:
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if (not isinstance(out_grad, ops.Tensor)
                                and not out_grad) and (
                                    (not grad_fn and is_func_call) or
                                    backprop_util.IsTrainable(op.outputs[i])):
                            # Only trainable outputs or outputs for a function call that
                            # will use SymbolicGradient get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            # TODO(apassos) gradients of resource handles might be an
                            # issue here because of zeros.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLike(op, i)
                            elif default_gradient.supports_default_grad(
                                    op.outputs[i]):
                                # TODO(b/143286622): The supports_default_grad check is needed
                                # because While op emits non-differentiable resource tensors
                                # as outputs. Remove this check when that is not the case.
                                out_grads[
                                    i] = control_flow_state.ZerosLikeOutsideLoop(
                                        op, i)
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with src_graph._original_op(op):
                            # pylint: enable=protected-access
                            if grad_fn:
                                # If grad_fn was found, do not use SymbolicGradient even for
                                # functions.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: grad_fn(op, *out_grads))
                            else:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: _SymGrad(op, out_grads))
                            in_grads = _AsList(in_grads)
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                [x for x in in_grads if x is not None]) > 1:
                                with ops.device(None):
                                    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
                                            None,
                                            gradient_uid,
                                            ignore_existing=True):
                                        in_grads = control_flow_ops.tuple(
                                            in_grads)
                    _LogOpGradients(op, out_grads, in_grads)
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagate a list of None backwards.
                    in_grads = [None] * len(_Inputs(op, xs_set))
                # Note: we don't filter out eager inputs here because the inputs need to
                # line up with in_grads.
                for i, (t_in, in_grad) in enumerate(
                        zip(_Inputs(op, xs_set), in_grads)):
                    if in_grad is not None:
                        if (isinstance(in_grad, ops.Tensor)
                                and t_in.dtype != dtypes.resource):
                            try:
                                in_grad.set_shape(t_in.get_shape())
                            except ValueError:
                                raise ValueError(
                                    "Incompatible shapes between op input and calculated "
                                    "input gradient.  Forward operation: %s.  Input index: %d. "
                                    "Original input shape: %s.  "
                                    "Calculated input gradient shape: %s" %
                                    (op.name, i, t_in.shape, in_grad.shape))
                        if not isinstance(t_in, ops.EagerTensor):
                            _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=False)

            # Update pending count for the inputs of op and enqueue ready ops.
            _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count,
                                          loop_state, xs_set)

    if loop_state:
        loop_state.PostProcessing()
    return [_GetGrad(grads, x, unconnected_gradients) for x in xs]