Ejemplo n.º 1
0
def ZerosLikeOutsideLoop(op, index):
    """Create zeros_like for the specified output of an op."""
    val = op.outputs[index]
    if not util.IsSwitch(op):
        if val.dtype == dtypes.resource:
            return array_ops.zeros(
                gen_resource_variable_ops.variable_shape(val),
                dtype=default_gradient.get_zeros_dtype(val))
        return array_ops.zeros_like(val, optimize=False)
    else:
        op_ctxt = op._get_control_flow_context()
        if op_ctxt:
            # We are in a cond context. Use a switch to create zeros only when needed.
            pred = op_ctxt.pred
            branch = op_ctxt.branch
            switch_val = control_flow_ops.switch(op.inputs[0],
                                                 pred)[1 - branch]
            # A op is created along the branch taken as control dependencies are on
            # the whole op and not on the tensor output.
            pivot = array_ops.identity(switch_val)
            if val.dtype == dtypes.resource:
                with ops.control_dependencies([pivot]):
                    return array_ops.zeros(
                        gen_resource_variable_ops.variable_shape(switch_val),
                        dtype=default_gradient.get_zeros_dtype(val))
            zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
            # Ensure ops created within array_ops.zeros are dominated by switch in
            # cond context.
            with ops.control_dependencies([pivot]):
                return array_ops.zeros(zeros_shape, dtype=val.dtype)
        else:
            return array_ops.zeros_like(val, optimize=False)
Ejemplo n.º 2
0
def _ZerosLike(t):
    t_dtype = default_gradient.get_zeros_dtype(t)
    if t.dtype == dtypes.resource:
        return array_ops.zeros(resource_variable_ops.variable_shape(t),
                               dtype=t_dtype)
    else:
        return array_ops.zeros_like(t, dtype=t_dtype)
Ejemplo n.º 3
0
def _zeros_like(op_output):
  """Like array_ops.zeros_like() but also accepts resource var handles."""
  if op_output.dtype == dtypes.resource:
    return array_ops.zeros(
        gen_resource_variable_ops.variable_shape(op_output),
        dtype=default_gradient.get_zeros_dtype(op_output))
  return array_ops.zeros_like(op_output)
Ejemplo n.º 4
0
def _SymGrad(op, out_grads):
    """Backprop through a function call node op given its outputs' gradients."""
    f_in = [x for x in op.inputs] + out_grads
    f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs]
    f = attr_value_pb2.NameAttrList()
    if _IsPartitionedCall(op):
        f.name = op.get_attr("f").name
    else:
        f.name = op.type
    for k in op.node_def.attr:
        f.attr[k].CopyFrom(op.node_def.attr[k])
    in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
    return in_grads
Ejemplo n.º 5
0
def _grad_fn(func_graph, grads):
    """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
    # Filter out untrainable function outputs.
    # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
    # cause _GradientsHelper to raise an exception (e.g. the implementation
    # doesn't expect 'ys' to contain boolean tensors).
    assert len(func_graph.outputs) == len(grads)
    ys = []
    grad_ys = []
    for y, grad_y in zip(func_graph.outputs, grads):
        if not backprop_util.IsTrainable(y):
            continue
        ys.append(y)
        grad_ys.append(grad_y)

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    result = gradients_util._GradientsHelper(ys,
                                             func_graph.inputs,
                                             grad_ys=grad_ys,
                                             src_graph=func_graph)

    # Functions can't return None; replace Nones with zero tensors.
    # TODO(b/80444525): don't return anything here and make _IfGrad return None if
    # both branches have zero gradient.
    for i in range(len(result)):
        if result[i] is None:
            if func_graph.inputs[i].dtype == dtypes.resource:
                result[i] = array_ops.zeros(
                    gen_resource_variable_ops.variable_shape(
                        func_graph.inputs[i]),
                    dtype=default_gradient.get_zeros_dtype(
                        func_graph.inputs[i]))
            else:
                result[i] = array_ops.zeros_like(func_graph.inputs[i])

    return result
Ejemplo n.º 6
0
def _GetGrad(grads, t, unconnected_gradients):
    """Gets gradient for tensor "t"."""
    op = t.op
    op_grads = grads.get(op)
    if not op_grads:
        if unconnected_gradients == UnconnectedGradients.ZERO:
            t_dtype = default_gradient.get_zeros_dtype(t)
            return array_ops.zeros_like(t, dtype=t_dtype)
        elif unconnected_gradients == UnconnectedGradients.NONE:
            return None
        else:
            raise ValueError("Unknown value for unconnected_gradients: %r" %
                             unconnected_gradients)

    t_grad = op_grads[t.value_index]
    assert not isinstance(
        t_grad, list), ("gradients list should have been aggregated by now.")
    return t_grad
def _ZerosLikeV2(op, index):
  """Branch of ZerosLike for TF2."""
  val = op.outputs[index]
  if val.dtype == dtypes.resource:
    return array_ops.zeros(
        gen_resource_variable_ops.variable_shape(val),
        dtype=default_gradient.get_zeros_dtype(val))
  if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph)
      and val.dtype != dtypes.variant):
    # In while_v2 we do not want to add a `ZerosLike` op because that will
    # trigger accumulation of `val`. Normally `ZerosLike` is preferred because
    # it helps avoid creating extra nodes(possibly Consts) for the shape.
    # For variants, we must use ZerosLike.
    if val.shape.is_fully_defined():
      return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype)
    else:
      # Note: Even though we add `Shape` in the default graph, while_v2 is smart
      # enough to place it in the forward graph i.e. `val.graph`.
      zeros_shape = array_ops.shape_internal(val, optimize=False)
      return array_ops.zeros(zeros_shape, val.dtype)
  else:
    return array_ops.zeros_like(val, optimize=False)
  def ZerosLikeV1WhileLoop(self, op, index):
    """Create zeros_like for the specified output of an op.

    If op is in a while loop that is part of gradients(), this method
    must be called in its grad loop context.

    Args:
      op: A tensorflow operation.
      index: the index for a specific output of the op.

    Returns:
      A zero tensor of the same shape of op.outputs[index].
    """
    if util.IsLoopSwitch(op):
      return None
    if op.graph.building_function:
      # The optimization here is tricky to apply to functions
      return array_ops.zeros_like(op.outputs[index])
    dead_branch = util.IsSwitch(op)
    forward_ctxt = util.GetWhileContext(op)
    grad_state = self._map.get(forward_ctxt)
    if grad_state is None:
      # op is not in a while loop that is part of gradients().
      return ZerosLike(op, index)
    op_ctxt = op._get_control_flow_context()
    val = ops.convert_to_tensor(op.outputs[index], name="tensor")
    shape = val.get_shape()
    if shape.is_fully_defined():
      # If the shape is known statically, just create a zero tensor with
      # the right shape in the grad loop context.
      if val.dtype == dtypes.resource:
        result = array_ops.zeros(
            resource_variable_ops.variable_shape(val),
            dtype=default_gradient.get_zeros_dtype(val))
      else:
        result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
      if dead_branch:
        # op is a cond switch. Guard the zero tensor with a switch.
        pred = grad_state.history_map.get(op_ctxt.pred.name)
        branch = op_ctxt.branch
        result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch]
    else:
      # Unknown shape so keep a history of the shape at runtime.
      if dead_branch:
        # Need to add a special switch to guard the value.
        pred = op_ctxt.pred
        branch = op_ctxt.branch
        op_ctxt.outer_context.Enter()
        val = control_flow_ops._SwitchRefOrTensor(op.inputs[0],
                                                  pred)[1 - branch]
        zeros_shape = array_ops.shape_internal(val, optimize=False)
        op_ctxt.outer_context.Exit()
        val.op._set_control_flow_context(op_ctxt)
        zeros_shape.op._set_control_flow_context(op_ctxt)
      else:
        op_ctxt.Enter()
        zeros_shape = array_ops.shape_internal(val, optimize=False)
        op_ctxt.Exit()

      # Add forward accumulator for shape.
      grad_state.grad_context.Exit()
      history_zeros_shape = grad_state.AddForwardAccumulator(
          zeros_shape, dead_branch=dead_branch)
      grad_state.grad_context.Enter()

      # Create a zero tensor with the right shape.
      shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
                                                     zeros_shape, dead_branch)
      result = array_ops.zeros(shape, val.dtype)
    return result