Exemplo n.º 1
0
def _SetGrad(grads, t, grad):
  """Sets gradient "grad" in "grads" for tensor "t"."""
  op = t.op
  op_grads = grads.get(op)
  if not op_grads:
    op_grads = [[] for _ in xrange(len(op.outputs))]
    grads[op] = op_grads
  t_grads = op_grads[t.value_index]
  if isinstance(t_grads, list):
    t_grads.append(grad)
  else:
    assert control_flow_ops.IsLoopSwitch(op)
    op_grads[t.value_index] = grad
Exemplo n.º 2
0
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
    """Update pending count for the inputs of op and enqueue ready ops."""
    for x in op.inputs:
        # pylint: disable=protected-access
        pending_count[x.op._id] -= 1
        ready = (pending_count[x.op._id] == 0)
        if loop_state and not ready:
            ready = (pending_count[x.op._id] > 0
                     and control_flow_ops.IsLoopSwitch(x.op))
        # pylint: enable=protected-access
        if ready:
            if control_flow_ops.IsLoopExit(x.op):
                # if x is an exit without real gradient, defer processing them.
                grad_state = loop_state.GetGradState(x.op, before=False)
                grad_state.deferred_exits.append(x)
                grad_state.pending_exits_count -= 1
                if grad_state.pending_exits_count == 0:
                    # We now have all the exits so process them.
                    has_real_grad = False
                    for y in grad_state.deferred_exits:
                        if _HasAnyNotNoneGrads(grads, y.op):
                            has_real_grad = True
                            queue.append(y.op)
                        else:
                            grad_state.unused_exits.append(y)
                    if has_real_grad:
                        # For an unused exit, if it has floating-point outputs, backprop
                        # a zero gradient. Otherwise, just ignore it.
                        for y in grad_state.unused_exits:
                            if _IsTrainable(y):
                                _SetGrad(grads, y,
                                         loop_state.ZerosLikeForExit(y))
                            queue.append(y.op)
                    else:
                        # All exits are "unused" so use None as gradient.
                        for y in grad_state.unused_exits:
                            queue.append(y.op)
            else:
                queue.append(x.op)
Exemplo n.º 3
0
def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
  """Get the aggregated gradients for op.

  Args:
    grads: The map of memoized gradients.
    op: The op to get gradients for.
    loop_state: An object for maintaining the state of the while loops in the
                graph. It is of type ControlFlowState. None if the graph
                contains no while loops.
    aggregation_method: Specifies the method used to combine gradient terms.
      Accepted values are constants defined in the class `AggregationMethod`.

  Returns:
    A list of gradients, one per each output of `op`. If the gradients
      for a particular output is a list, this function aggregates it
      before returning.

  Raises:
    TypeError: if the incoming grads are not Tensors or IndexedSlices.
    ValueError: if the arguments are invalid.

  """
  if aggregation_method is None:
    aggregation_method = AggregationMethod.DEFAULT
  if aggregation_method not in [
      AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
      AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
  ]:
    raise ValueError("Invalid aggregation_method specified %s." %
                     aggregation_method)
  out_grads = _GetGrads(grads, op)
  for i, out_grad in enumerate(out_grads):
    if loop_state:
      if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
        assert control_flow_ops.IsLoopSwitch(op)
        continue
    # Grads have to be Tensors or IndexedSlices
    if (isinstance(out_grad, collections.Sequence) and not all([
        isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in out_grad
        if g is not None
    ])):
      raise TypeError("gradients have to be either all Tensors "
                      "or all IndexedSlices")
    # Aggregate multiple gradients, and convert [] to None.
    if out_grad:
      if len(out_grad) < 2:
        used = "nop"
        out_grads[i] = out_grad[0]
      elif all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]):
        tensor_shape = _AccumulatorShape(out_grad)
        if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
            and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
          # The benefit of using AccumulateN is that its inputs can be combined
          # in any order and this can allow the expression to be evaluated with
          # a smaller memory footprint.  When used with gpu_allocator_retry,
          # it is possible to compute a sum of terms which are much larger than
          # total GPU memory.
          # AccumulateN can currently only be used if we know the shape for
          # an accumulator variable.  If this is not known, or if we only have
          # 2 grads then we fall through to the "tree" case below.
          used = "accumulate_n"
          out_grads[i] = math_ops.accumulate_n(out_grad)
        elif aggregation_method in [
            AggregationMethod.EXPERIMENTAL_TREE,
            AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
        ]:
          # Aggregate all gradients by doing pairwise sums: this may
          # reduce performance, but it can improve memory because the
          # gradients can be released earlier.
          #
          # TODO(vrv): Consider replacing this with a version of
          # tf.AddN() that eagerly frees its inputs as soon as they are
          # ready, so the order of this tree does not become a problem.
          used = "tree"
          with ops.name_scope(op.name + "_gradient_sum"):
            running_sum = out_grad[0]
            for grad in out_grad[1:]:
              running_sum = math_ops.add_n([running_sum, grad])
            out_grads[i] = running_sum
        else:
          used = "add_n"
          out_grads[i] = _MultiDeviceAddN(out_grad)
        logging.vlog(2, "  _AggregatedGrads %d x %s using %s",
                     len(out_grad), tensor_shape, used)
      else:
        out_grad = math_ops._as_indexed_slices_list(
            [g for g in out_grad if g is not None])
        out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
        # Form IndexedSlices out of the concatenated values and
        # indices.
        out_grads[i] = ops.IndexedSlices(
            array_ops.concat_v2([x.values for x in out_grad], 0),
            array_ops.concat_v2([x.indices for x in out_grad], 0),
            out_grad[0].dense_shape)
    else:
      out_grads[i] = []
  return out_grads
Exemplo n.º 4
0
def gradients(ys,
              xs,
              grad_ys=None,
              name="gradients",
              colocate_gradients_with_ops=False,
              gate_gradients=False,
              aggregation_method=None):
    """Constructs symbolic partial derivatives of `ys` w.r.t. x in `xs`.

  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
  is a list of `Tensor`, holding the gradients received by the
  `ys`. The list must be the same length as `ys`.

  `gradients()` adds ops to the graph to output the partial
  derivatives of `ys` with respect to `xs`.  It returns a list of
  `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
  for y in `ys`.

  `grad_ys` is a list of tensors of the same length as `ys` that holds
  the initial gradients for each y in `ys`.  When `grad_ys` is None,
  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
  user can provide their own initial `grad_ys` to compute the
  derivatives using a different initial gradient for each y (e.g., if
  one wanted to weight the gradient differently for each value in
  each y).

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grad_ys: Optional. A `Tensor` or list of tensors the same size as
      `ys` and holding the gradients computed for each y in `ys`.
    name: Optional name to use for grouping all the gradient ops together.
      defaults to 'gradients'.
    colocate_gradients_with_ops: If True, try colocating gradients with
      the corresponding op.
    gate_gradients: If True, add a tuple around the gradients returned
      for an operations.  This avoids some race conditions.
    aggregation_method: Specifies the method used to combine gradient terms.
      Accepted values are constants defined in the class `AggregationMethod`.

  Returns:
    A list of `sum(dy/dx)` for each x in `xs`.

  Raises:
    LookupError: if one of the operations between `x` and `y` does not
      have a registered gradient function.
    ValueError: if the arguments are invalid.

  """
    ys = _AsList(ys)
    xs = _AsList(xs)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)
    with ops.op_scope(ys + xs + grad_ys, name, "gradients"):
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = ops.convert_n_to_tensor_or_indexed_slices(xs, name="x")
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        pending_count, loop_state = _PendingCount(ops.get_default_graph(),
                                                  to_ops, from_ops)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            # pylint: disable=protected-access
            ready = (pending_count[op._id] == 0)
            if ready and op._id not in to_ops_set:
                to_ops_set.add(op._id)
                queue.append(op)

        if loop_state:
            # The "unused" exits of the loops are added to ys. As an example,
            # people often write:
            #         v1, _ = While(p, b, [x1, x2])
            #         result = gradients(v1, x1)
            # The exit node of x2 is not included by the betweenness analysis.
            # But we need it if x2 is involved in computing v1. So we add it
            # back in backprop with a zeros_like gradient.
            loop_exits = loop_state.GetAllLoopExits()
            for y in loop_exits:
                if pending_count[y.op._id] == 0 and y.op._id not in to_ops_set:
                    if _IsFloat(y):
                        # Floating-point outputs get a zero gradient.
                        _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        # The set of 'from_ops'.
        stop_ops = _StopOps(from_ops, pending_count)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
                if loop_state:
                    loop_state.EnterGradWhileContext(op)
                out_grads = _AggregatedGrads(grads, op, loop_state,
                                             aggregation_method)
                grad_fn = None

                # pylint: disable=protected-access
                is_func_call = ops.get_default_graph()._is_function(op.type)
                # pylint: enable=protected-access

                if not is_func_call and any(
                        out_grads) and op._id not in stop_ops:
                    # pylint: enable=protected-access
                    # A grad_fn must be defined, either as a function or as None
                    # for ops that do not have gradients.
                    try:
                        grad_fn = ops.get_gradient_function(op)
                    except LookupError:
                        raise LookupError(
                            "No gradient defined for operation '%s' (op type: %s)"
                            % (op.name, op.type))
                if (grad_fn or is_func_call) and any(out_grads):
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if not out_grad and _IsFloat(op.outputs[i]):
                            # Only floating-point outputs get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLike(op, i)
                            else:
                                out_grads[i] = array_ops.zeros_like(
                                    op.outputs[i])
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with ops.get_default_graph()._original_op(op):
                            # pylint: enable=protected-access
                            wrapped_op = op
                            if loop_state:
                                wrapped_op = loop_state.MakeWrapper(op)
                            if is_func_call:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                f_in = [x for x in op.inputs] + out_grads
                                f_types = [x.dtype for x in op.inputs]
                                # pylint: disable=protected-access
                                in_grads = _AsList(
                                    functional_ops._symbolic_gradient(
                                        f_in, f_types, op.type))
                                # pylint: enable=protected-access
                            else:
                                in_grads = _AsList(
                                    grad_fn(wrapped_op, *out_grads))
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                    tuple(filter(None, in_grads))) > 1:
                                in_grads = control_flow_ops.tuple(in_grads)
                    logging.vlog(1, "Gradient for '" + op.name + "'")
                    logging.vlog(1, "  in  --> %s",
                                 ", ".join([x.name for x in out_grads if x]))
                    logging.vlog(1, "  out --> %s",
                                 ", ".join([x.name for x in in_grads if x]))
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagates a list of None backwards.
                    in_grads = [None] * len(op.inputs)
                for t_in, in_grad in zip(op.inputs, in_grads):
                    if in_grad:
                        _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op)

            # update pending count for the inputs of op.
            # pylint: disable=protected-access
            for x in op.inputs:
                pending_count[x.op._id] -= 1
                ready = (pending_count[x.op._id] == 0)
                if loop_state and not ready:
                    ready = (pending_count[x.op._id] > 0
                             and control_flow_ops.IsLoopSwitch(x.op))
                if ready:
                    queue.append(x.op)
            for x in op.control_inputs:
                pending_count[x._id] -= 1
                if pending_count[x._id] is 0:
                    queue.append(x)
            # pylint: enable=protected-access
    return [_GetGrad(grads, x) for x in xs]