コード例 #1
0
    def PostProcessing(self):
        """Perform postprocessing at the end of gradients().

    We have created the gradient graph at this point. So this function
    can be used to perform any postprocessing on the gradient graph.
    We currently perform the following postprocessing:
      1. Patch the gradient graph if the output of a loop variable
         doesn't depend on its input.
    """
        for _, grad_state in self._map.items():
            for _, b_merge in grad_state.switch_map.items():
                if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
                    # The value of this loop variable at iteration i+1 doesn't
                    # depend on its value at iteration i. So use zeros as the
                    # gradients for all iterations > 0.
                    dtype = b_merge.op.inputs[0].dtype
                    shape = b_merge.op.inputs[0].get_shape()
                    # pylint: disable=protected-access
                    if shape.is_fully_defined():
                        grad_state.grad_context.Enter()
                        # Create a zeros and use it for iterations > 0.
                        grad_val = constant_op.constant(0,
                                                        dtype=dtype,
                                                        shape=shape)
                        next_grad_val = control_flow_ops._NextIteration(
                            grad_val)
                        grad_state.grad_context.Exit()
                    else:
                        # Create a zeros in the outer grad context.
                        outer_grad_ctxt = grad_state.grad_context.outer_context
                        if outer_grad_ctxt:
                            outer_grad_ctxt.Enter()
                        enter_grad_op = b_merge.op.inputs[0].op
                        enter_grad = enter_grad_op.inputs[0]
                        grad_shape = array_ops.shape_internal(enter_grad,
                                                              optimize=False)
                        grad_val = array_ops.zeros(grad_shape)
                        if outer_grad_ctxt:
                            outer_grad_ctxt.Exit()
                        # Use the zeros for iterations > 0.
                        grad_state.grad_context.Enter()
                        next_grad_val = control_flow_ops._NextIteration(
                            grad_val)
                        grad_state.grad_context.Exit()
                    b_merge.op._update_input(1, next_grad_val)
コード例 #2
0
def _SwitchGrad(op, *grad):
  """Gradients for a Switch op is calculated using a Merge op.

  If the switch is a loop switch, it will be visited twice. We create
  the merge on the first visit, and update the other input of the merge
  on the second visit. A next_iteration is also added on second visit.
  """
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = op._get_control_flow_context()
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    merge_op = grad_ctxt.grad_state.switch_map.get(op)
    if merge_op:
      # This is the second time this Switch is visited. It comes from
      # the non-exit branch of the Switch, so update the second input
      # to the Merge.
      # TODO: Perform shape inference with this new input.
      # pylint: disable=protected-access
      merge_op._update_input(1, control_flow_ops._NextIteration(grad[1]))
      # pylint: enable=protected-access
      return None, None
    else:
      # This is the first time this Switch is visited. It always comes
      # from the Exit branch, which is grad[0]. grad[1] is empty at this point.
      # Use grad[0] for both inputs to merge for now, but update the second
      # input of merge when we see this Switch the second time.
      merge_fn = control_flow_ops._Merge  # pylint: disable=protected-access
      merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0]
      grad_ctxt.grad_state.switch_map[op] = merge_op.op
      return merge_op, None
  elif isinstance(op_ctxt, CondContext):
    good_grad = grad[op_ctxt.branch]
    zero_grad = grad[1 - op_ctxt.branch]
    # If we are in a grad context, this switch is part of a cond within a
    # loop. In this case, we have called ControlFlowState.ZeroLike() so grad
    # is ready for merge. Otherwise, we need a switch to control zero_grad.
    if not (grad_ctxt and grad_ctxt.grad_state):
      dtype = good_grad.dtype
      branch = op_ctxt.branch
      zero_grad = switch(zero_grad, op_ctxt.pred, dtype=dtype)[1 - branch]
    return merge([good_grad, zero_grad], name="cond_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None