def _grad_fn(func_graph, args):
    """Computes the gradient of `func_graph` in the current graph.

  This function builds the gradient graph of the corresponding forward-pass
  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.

  Args:
    func_graph: function.FuncGraph. The corresponding forward-pass function.
    args: The input arguments. args[0] - Loop counter args[1] - Total number of
      iterations.
      args[2:] - Incoming gradients for `func_graph.outputs`.

  Returns:
    The output gradient Tensors.
  """
    xs = func_graph.inputs
    ys = func_graph.outputs
    grad_ys = args[2:]

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    # TODO(srbs): Mark GradientsHelper as public?
    grad_outs = gradients_impl._GradientsHelper(ys,
                                                xs,
                                                grad_ys=grad_ys,
                                                src_graph=func_graph)

    assert all([g is not None for g in grad_outs])
    counter = args[0]
    total_iters = args[1]
    return [counter + 1, total_iters] + grad_outs
Exemple #2
0
def _grad_fn(ys, xs, args, func_graph):
  """Computes the gradient of `func_graph` in the current graph.

  This function builds the gradient graph of the corresponding forward-pass
  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    args: The input arguments.
      args[0] - Loop counter
      args[1] - Total number of iterations.
      args[2:] - Incoming gradients for `ys`.
    func_graph: function.FuncGraph. The corresponding forward-pass function.

  Returns:
    The output gradient Tensors.
  """
  grad_ys = args[2:]

  # Build the gradient graph. Note that this builds the gradient computation of
  # func_graph in the current graph, which requires capturing tensors from
  # func_graph. The captured func_graph tensors are resolved to external tensors
  # in _resolve_grad_inputs.
  # TODO(srbs): Mark GradientsHelper as public?
  grad_outs = gradients_impl._GradientsHelper(
      ys, xs, grad_ys=grad_ys, src_graph=func_graph)

  # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
  # is a tf.StopGradient in the loop body.
  assert all(g is not None for g in grad_outs)
  counter = args[0]
  total_iters = args[1]
  return [counter + 1, total_iters] + grad_outs
Exemple #3
0
def _grad_fn(func_graph, args):
  """Computes the gradient of `func_graph` in the current graph.

  This function builds the gradient graph of the corresponding forward-pass
  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.

  Args:
    func_graph: function.FuncGraph. The corresponding forward-pass function.
    args: The input arguments. args[0] - Loop counter args[1] - Total number of
      iterations.
      args[2:] - Incoming gradients for `func_graph.outputs`.

  Returns:
    The output gradient Tensors.
  """
  xs = func_graph.inputs
  ys = func_graph.outputs
  grad_ys = args[2:]

  # Build the gradient graph. Note that this builds the gradient computation of
  # func_graph in the current graph, which requires capturing tensors from
  # func_graph. The captured func_graph tensors are resolved to external tensors
  # in _resolve_grad_inputs.
  # TODO(srbs): Mark GradientsHelper as public?
  grad_outs = gradients_impl._GradientsHelper(
      ys, xs, grad_ys=grad_ys, src_graph=func_graph)

  assert all([g is not None for g in grad_outs])
  counter = args[0]
  total_iters = args[1]
  return [counter + 1, total_iters] + grad_outs
Exemple #4
0
def _grad_fn(ys, xs, args, func_graph):
  """Computes the gradient of `func_graph` in the current graph.

  This function builds the gradient graph of the corresponding forward-pass
  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    args: The input arguments.
      args[0] - Loop counter
      args[1] - Total number of iterations.
      args[2:] - Incoming gradients for `ys`.
    func_graph: function.FuncGraph. The corresponding forward-pass function.

  Returns:
    The output gradient Tensors.
  """
  grad_ys = args[2:]

  # Build the gradient graph. Note that this builds the gradient computation of
  # func_graph in the current graph, which requires capturing tensors from
  # func_graph. The captured func_graph tensors are resolved to external tensors
  # in _resolve_grad_inputs.
  # TODO(srbs): Mark GradientsHelper as public?
  grad_outs = gradients_impl._GradientsHelper(
      ys, xs, grad_ys=grad_ys, src_graph=func_graph)

  # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
  # is a tf.StopGradient in the loop body.
  assert all(g is not None for g in grad_outs)
  counter = args[0]
  total_iters = args[1]
  return [counter + 1, total_iters] + grad_outs
Exemple #5
0
def _grad_fn(func_graph, grads):
    """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
    # Filter out untrainable function outputs.
    # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
    # cause _GradientsHelper to raise an exception (e.g. the implementation
    # doesn't expect 'ys' to contain boolean tensors).
    assert len(func_graph.outputs) == len(grads)
    ys = []
    grad_ys = []
    for y, grad_y in zip(func_graph.outputs, grads):
        if not gradients_impl.IsTrainable(y):
            continue
        ys.append(y)
        grad_ys.append(grad_y)

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    result = gradients_impl._GradientsHelper(ys,
                                             func_graph.inputs,
                                             grad_ys=grad_ys,
                                             src_graph=func_graph)

    # Functions can't return None; replace Nones with zero tensors.
    # TODO(b/80444525): don't return anything here and make _IfGrad return None if
    # both branches have zero gradient.
    for i in range(len(result)):
        if result[i] is None:
            if func_graph.inputs[i].dtype == dtypes.resource:
                result[i] = array_ops.zeros(
                    gen_resource_variable_ops.variable_shape(
                        func_graph.inputs[i]))
            else:
                result[i] = array_ops.zeros_like(func_graph.inputs[i])

    return result
Exemple #6
0
def _grad_fn(func_graph, grads):
  """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
  # Filter out untrainable function outputs.
  # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
  # cause _GradientsHelper to raise an exception (e.g. the implementation
  # doesn't expect 'ys' to contain boolean tensors).
  assert len(func_graph.outputs) == len(grads)
  ys = []
  grad_ys = []
  for y, grad_y in zip(func_graph.outputs, grads):
    if not gradients_impl.IsTrainable(y):
      continue
    ys.append(y)
    grad_ys.append(grad_y)

  # Build the gradient graph. Note that this builds the gradient computation of
  # func_graph in the current graph, which requires capturing tensors from
  # func_graph. The captured func_graph tensors are resolved to external tensors
  # in _resolve_grad_inputs.
  result = gradients_impl._GradientsHelper(
      ys, func_graph.inputs, grad_ys=grad_ys,
      src_graph=func_graph)

  # Functions can't return None; replace Nones with zero tensors.
  # TODO(b/80444525): don't return anything here and make _IfGrad return None if
  # both branches have zero gradient.
  for i in range(len(result)):
    if result[i] is None:
      if func_graph.inputs[i].dtype == dtypes.resource:
        result[i] = array_ops.zeros(
            gen_resource_variable_ops.variable_shape(func_graph.inputs[i]))
      else:
        result[i] = array_ops.zeros_like(func_graph.inputs[i])

  return result