Ejemplo n.º 1
0
    def AddWhileContext(self, op, between_op_list, between_ops):
        """Add the grad state for the while loop that op belongs to.

    Note that op is an Exit, and this method must be called in
    the control flow context where gradients() is called.

    Note that this method modifies `between_op_list` and `between_ops`.
    """
        forward_ctxt = util.GetWhileContext(op)
        grad_state = self._map.get(forward_ctxt)
        if grad_state is None:
            # This is a new while loop so create a grad state for it.
            outer_forward_ctxt = forward_ctxt.outer_context
            if outer_forward_ctxt:
                outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
            outer_grad_state = None
            if outer_forward_ctxt:
                outer_grad_state = self._map.get(outer_forward_ctxt)
            grad_state = _GradLoopState(forward_ctxt, outer_grad_state)
            self._map[forward_ctxt] = grad_state

            # We need to include all exits of a loop for backprop.
            for loop_exit in grad_state.forward_loop_exits:
                if loop_exit.op not in between_ops:
                    between_ops.add(loop_exit.op)
                    between_op_list.append(loop_exit.op)
Ejemplo n.º 2
0
 def GetGradState(self, op, before):
     """Return the grad state for this op if it's in a forward loop context."""
     if before and util.IsLoopExit(op):
         forward_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
         forward_ctxt = forward_ctxt.outer_context
         if forward_ctxt:
             forward_ctxt = forward_ctxt.GetWhileContext()
     else:
         forward_ctxt = util.GetWhileContext(op)
     if forward_ctxt:
         return self._map.get(forward_ctxt)
     return None
Ejemplo n.º 3
0
    def ZerosLike(self, op, index):
        """Create zeros_like for the specified output of an op.

    If op is in a while loop that is part of gradients(), this method
    must be called in its grad loop context.

    Args:
      op: A tensorflow operation.
      index: the index for a specific output of the op.

    Returns:
      A zero tensor of the same shape of op.outputs[index].
    """
        if util.IsLoopSwitch(op):
            return None
        if op.graph._building_function:  # pylint: disable=protected-access
            # The optimization here is tricky to apply to functions
            return array_ops.zeros_like(op.outputs[index])
        dead_branch = util.IsSwitch(op)
        forward_ctxt = util.GetWhileContext(op)
        grad_state = self._map.get(forward_ctxt)
        if grad_state is None:
            # op is not in a while loop that is part of gradients().
            return ZerosLikeOutsideLoop(op, index)
        op_ctxt = op._get_control_flow_context()
        val = ops.convert_to_tensor(op.outputs[index], name="tensor")
        shape = val.get_shape()
        if shape.is_fully_defined():
            # If the shape is known statically, just create a zero tensor with
            # the right shape in the grad loop context.
            result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
            if dead_branch:
                # op is a cond switch. Guard the zero tensor with a switch.
                pred = grad_state.history_map.get(op_ctxt.pred.name)
                branch = op_ctxt.branch
                result = control_flow_ops._SwitchRefOrTensor(result,
                                                             pred)[1 - branch]
        else:
            # Unknown shape so keep a history of the shape at runtime.
            if dead_branch:
                # Need to add a special switch to guard the value.
                pred = op_ctxt.pred
                branch = op_ctxt.branch
                op_ctxt.outer_context.Enter()
                val = control_flow_ops._SwitchRefOrTensor(op.inputs[0],
                                                          pred)[1 - branch]
                zeros_shape = array_ops.shape_internal(val, optimize=False)
                op_ctxt.outer_context.Exit()
                val.op._set_control_flow_context(op_ctxt)
                zeros_shape.op._set_control_flow_context(op_ctxt)
            else:
                op_ctxt.Enter()
                zeros_shape = array_ops.shape_internal(val, optimize=False)
                op_ctxt.Exit()

            # Add forward accumulator for shape.
            grad_state.grad_context.Exit()
            history_zeros_shape = grad_state.AddForwardAccumulator(
                zeros_shape, dead_branch=dead_branch)
            grad_state.grad_context.Enter()

            # Create a zero tensor with the right shape.
            shape = grad_state.AddBackpropAccumulatedValue(
                history_zeros_shape, zeros_shape, dead_branch)
            result = array_ops.zeros(shape, val.dtype)
        return result