Example #1
0
def _MergeGrad(op, grad, _):
    """Gradients for a Merge op are calculated using a Switch op."""
    real_op = GetRealOp(op)
    input_op = real_op.inputs[0].op
    # pylint: disable=protected-access
    ctxt = input_op._get_control_flow_context()
    # pylint: enable=protected-access
    if isinstance(ctxt, WhileContext):
        grad_ctxt = op.grad_state.grad_context
        # pylint: disable=protected-access
        return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
        # pylint: enable=protected-access
    elif isinstance(ctxt, CondContext):
        pred = ctxt.pred
        if isinstance(op, ControlFlowOpWrapper):
            # This Merge node is part of a cond within a loop.
            # The backprop needs to have the value of this predicate for every
            # iteration. So we must have its values accumulated in the forward, and
            # use the accumulated values as the predicate for this backprop switch.
            grad_state = op.grad_state
            real_pred = grad_state.history_map.get(pred.name)
            if not real_pred:
                # Remember the value of pred for every iteration.
                grad_ctxt = grad_state.grad_context
                grad_ctxt.Exit()
                history_pred = grad_state.AddForwardAccumulator(pred)
                grad_ctxt.Enter()

                # Add the stack pop op. If pred.op is in a (outer) CondContext,
                # the stack pop will be guarded with a switch.
                real_pred = grad_state.AddBackPropAccumulatedValue(
                    history_pred, pred)
                grad_state.history_map[pred.name] = real_pred
            pred = real_pred
        # pylint: disable=protected-access
        return control_flow_ops._SwitchRefOrTensor(grad,
                                                   pred,
                                                   name="cond_grad")
        # pylint: enable=protected-access
    else:
        num_inputs = len(real_op.inputs)
        cond = [
            math_ops.equal(real_op.outputs[1], i) for i in xrange(num_inputs)
        ]
        # pylint: disable=protected-access
        return [
            control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
            for i in xrange(num_inputs)
        ]
  def testRefSwitch(self):
    with self.test_session():
      v = tf.Variable(7)

      p = tf.constant(True)
      v1 = control_flow_ops._SwitchRefOrTensor(v, p)
      v2 = tf.assign(v1[1], 9)
      tf.initialize_all_variables().run()
      self.assertEqual(9, v2.eval())
  def testRefSwitch(self):
    with self.test_session():
      v = tf.Variable(7)

      p = tf.constant(True)
      v1 = control_flow_ops._SwitchRefOrTensor(v, p)
      v2 = tf.assign(v1[1], 9)
      tf.initialize_all_variables().run()
      self.assertEqual(9, v2.eval())
def _MergeGrad(op, grad, _):
  """Gradients for a Merge op are calculated using a Switch op."""
  input_op = op.inputs[0].op
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = control_flow_util.GetOutputContext(input_op)
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    # pylint: disable=protected-access
    return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
    # pylint: enable=protected-access
  elif isinstance(op_ctxt, CondContext):
    pred = op_ctxt.pred
    if grad_ctxt and grad_ctxt.grad_state:
      # This Merge node is part of a cond within a loop.
      # The backprop needs to have the value of this predicate for every
      # iteration. So we must have its values accumulated in the forward, and
      # use the accumulated values as the predicate for this backprop switch.
      grad_state = grad_ctxt.grad_state
      real_pred = grad_state.history_map.get(pred.name)
      if real_pred is None:
        # Remember the value of pred for every iteration.
        grad_ctxt = grad_state.grad_context
        grad_ctxt.Exit()
        history_pred = grad_state.AddForwardAccumulator(pred)
        grad_ctxt.Enter()

        # Add the stack pop op. If pred.op is in a (outer) CondContext,
        # the stack pop will be guarded with a switch.
        real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred)
        grad_state.history_map[pred.name] = real_pred
      pred = real_pred
    # pylint: disable=protected-access
    return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
    # pylint: enable=protected-access
  else:
    num_inputs = len(op.inputs)
    cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
    # pylint: disable=protected-access
    return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
            for i in xrange(num_inputs)]
    def AddBackpropAccumulatedValue(self,
                                    history_value,
                                    value,
                                    dead_branch=False):
        """Add the getter for an accumulated value in the grad context.

    This is added to the backprop loop. Called in the grad context to
    get the value of an accumulated value. The stack pop op must be guarded
    by the pred of the controlling cond.

    Args:
      history_value: The history (a stack) of a value.
      value: The value that is pushed onto the stack.
      dead_branch: True iff the tensor is on a dead branch of a cond.

    Returns:
      The current value (the top of the stack).
    """
        history_ctxt = history_value.op._get_control_flow_context()
        # Find the cond context that controls history_value if any.
        cond_ctxt = None
        value_ctxt = value.op._get_control_flow_context()
        while value_ctxt and value_ctxt != history_ctxt:
            if isinstance(value_ctxt, control_flow_ops.CondContext):
                cond_ctxt = value_ctxt
                break
            value_ctxt = value_ctxt.outer_context
        with ops.control_dependencies(None):
            self.grad_context.Enter()
            if cond_ctxt:
                # Guard stack pop with a switch if it is controlled by a cond.
                grad_state = self
                pred = None
                while pred is None and grad_state:
                    pred = grad_state.history_map.get(cond_ctxt.pred.name)
                    grad_state = grad_state.outer_grad_state
                if pred is None:
                    pred = cond_ctxt.pred
                branch = (
                    1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
                history_value = control_flow_ops._SwitchRefOrTensor(
                    history_value, pred)[branch]
            pop = gen_data_flow_ops.stack_pop_v2(history_value,
                                                 value.dtype.base_dtype)
            pop.set_shape(value.get_shape())
            self.grad_context.Exit()
        parallel_iterations = self.grad_context.parallel_iterations
        if parallel_iterations > 1:
            # All pops are ordered after pivot_for_body and before grad_sync.
            self.grad_sync._add_control_input(pop.op)
        return pop
    def ZerosLike(self, op, index):
        """Create zeros_like for the specified output of an op.

    If op is in a while loop that is part of gradients(), this method
    must be called in its grad loop context.

    Args:
      op: A tensorflow operation.
      index: the index for a specific output of the op.

    Returns:
      A zero tensor of the same shape of op.outputs[index].
    """
        if util.IsLoopSwitch(op):
            return None
        if op.graph._building_function:  # pylint: disable=protected-access
            # The optimization here is tricky to apply to functions
            return array_ops.zeros_like(op.outputs[index])
        dead_branch = util.IsSwitch(op)
        forward_ctxt = util.GetWhileContext(op)
        grad_state = self._map.get(forward_ctxt)
        if grad_state is None:
            # op is not in a while loop that is part of gradients().
            return ZerosLikeOutsideLoop(op, index)
        op_ctxt = op._get_control_flow_context()
        val = ops.convert_to_tensor(op.outputs[index], name="tensor")
        shape = val.get_shape()
        if shape.is_fully_defined():
            # If the shape is known statically, just create a zero tensor with
            # the right shape in the grad loop context.
            result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
            if dead_branch:
                # op is a cond switch. Guard the zero tensor with a switch.
                pred = grad_state.history_map.get(op_ctxt.pred.name)
                branch = op_ctxt.branch
                result = control_flow_ops._SwitchRefOrTensor(result,
                                                             pred)[1 - branch]
        else:
            # Unknown shape so keep a history of the shape at runtime.
            if dead_branch:
                # Need to add a special switch to guard the value.
                pred = op_ctxt.pred
                branch = op_ctxt.branch
                op_ctxt.outer_context.Enter()
                val = control_flow_ops._SwitchRefOrTensor(op.inputs[0],
                                                          pred)[1 - branch]
                zeros_shape = array_ops.shape_internal(val, optimize=False)
                op_ctxt.outer_context.Exit()
                val.op._set_control_flow_context(op_ctxt)
                zeros_shape.op._set_control_flow_context(op_ctxt)
            else:
                op_ctxt.Enter()
                zeros_shape = array_ops.shape_internal(val, optimize=False)
                op_ctxt.Exit()

            # Add forward accumulator for shape.
            grad_state.grad_context.Exit()
            history_zeros_shape = grad_state.AddForwardAccumulator(
                zeros_shape, dead_branch=dead_branch)
            grad_state.grad_context.Enter()

            # Create a zero tensor with the right shape.
            shape = grad_state.AddBackpropAccumulatedValue(
                history_zeros_shape, zeros_shape, dead_branch)
            result = array_ops.zeros(shape, val.dtype)
        return result