Esempio n. 1
0
 def test_op_is_in_context(self):
     """Tests that XLACompileContext is recognized as an XLA context."""
     op1 = constant_op.constant(1)
     context = self.create_test_xla_compile_context()
     context.Enter()
     op2 = constant_op.constant(2)
     context.Exit()
     self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
     self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
Esempio n. 2
0
 def testIsInContext(self):
     """Test that control_flow_util can check that we're in a TPU context."""
     z1 = array_ops.identity(1)
     context = tpu.TPUReplicateContext(b"context")
     context.Enter()
     z2 = array_ops.identity(1)
     context.Exit()
     self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
     self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
Esempio n. 3
0
 def testIsInContext(self):
     """Test that control_flow_util can check that we're in a TPU context."""
     with ops.Graph().as_default():
         z1 = array_ops.identity(1)
         pivot = control_flow_ops.no_op()
         context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
         context.Enter()
         z2 = array_ops.identity(1)
         context.Exit()
         self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
         self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
Esempio n. 4
0
def _maybe_set_maximum_iterations_attr(op, maximum_iterations):
    if maximum_iterations is not None and control_flow_util.IsInXLAContext(op):
        # Store the maximum_iterations to use in the gradient pass.
        op._set_attr(  # pylint: disable=protected-access
            "_maximum_iterations",
            attr_value_pb2.AttrValue(
                i=tensor_util.constant_value(maximum_iterations)))
def maybe_set_lowering_attr(op):
    """Sets the flag to enable lowering on `op` if necessary.

  Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
  Functions, allowing users to specify devices & colocation inside of cond_v2
  and while_v2 input functions, and enabling non-strict evaluation & partial
  pruning. This brings v2 control flow closer to feature parity with v1 control
  flow.

  However, we do not lower in the following cases:
    - When the `If` or `While` ops are in the XLA context. Because it is easier
      for XLA to apply its own optimizations when dealing with un-lowered
      control flow operators than with low-level control flow primitives.
    - When the eager execution context specifies the executor of functions to
      be the single threaded executor (see context.function_executor_type()).
      Because the single threaded executor does not support v1 control flow ops.

  Args:
    op: An `If` or `While` Operation.
  """
    if (not control_flow_util.IsInXLAContext(op)
            and context.context().get_function_call_options().executor_type !=
            "SINGLE_THREADED_EXECUTOR"):
        # pylint: disable=protected-access
        op._set_attr("_lower_using_switch_merge",
                     attr_value_pb2.AttrValue(b=True))
Esempio n. 6
0
 def _is_xla_tensor(tensor):
     try:
         op = tensor.op
     except AttributeError:
         return False
     if control_flow_util.IsInXLAContext(op):
         return True
     return False
Esempio n. 7
0
def _maybe_set_lowering_attr(op):
  """Sets the flag to enable lowering on the `While` op if necessary.

  Lowering allows while_v2 to avoid some of the limitations of Functions,
  allowing users to specify devices & colocation inside of while_v2
  branches, and enabling non-strict evaluation & partial pruning of while_v2
  branches. This brings while_v2 closer to feature parity with
  tf.while_loop.

  However, we do not lower `While` in the XLA context because it is easier
  for XLA to apply its own optimizations when dealing with un-lowered
  `While` operators than with low-level control flow primitives.

  Args:
    op: The While op.
  """
  if not control_flow_util.IsInXLAContext(op):
    # pylint: disable=protected-access
    op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
Esempio n. 8
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      # Find the outer most graph for uniquing function names.
      # TODO(jpienaar): Make this work in eager mode.
      graph = ops.get_default_graph()
      while isinstance(graph, _function.FuncGraph):
        graph = graph.outer_graph

      true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
      false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))

    true_graph = _function.func_graph_from_py_func(
        true_name, true_fn, [], {})
    false_graph = _function.func_graph_from_py_func(
        false_name, false_fn, [], {})
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred, cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      # pylint: disable=protected-access
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))
      # pylint: enable=protected-access

    return tuple(tensors[:num_cond_outputs])
    def AddForwardAccumulator(self, value, dead_branch=False):
        """Add an accumulator for each forward tensor that is needed in backprop.

    This is added to the forward loop at the first time when a tensor
    in the forward loop is used by backprop gradient computation loop.
    We create an accumulator that accumulates the value of tensor at each
    iteration. Called in the control flow context where gradients() is called.

    The pseudocode is:
    ```
      acc = stack();
      while (_pivot) {
        acc = stack_push(acc, value);
      }
    ```

    We make sure that the stack push op in one iteration is executed before
    next iteration. This is achieved by adding a control edge from
    `forward_index.op.inputs[0].op` to the push op, and another control
    edge from the push op to either `forward_index.op` or `forward_sync`.

    Args:
      value: The source tensor in forward that is to be accumulated.
      dead_branch: True iff the tensor is on a dead branch of a cond.

    Returns:
      The stack that contains the accumulated history of the tensor.

    Raises:
      TypeError: For internal errors involving the value condition context.
      ValueError: If `value` is inside a XLA scope and a valid max size
        for the stack can't be found.
    """
        # curr_ctxt is the context that tf.gradients was called in.
        with self._forward_index.graph.as_default():
            curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
            with ops.control_dependencies(None):
                if curr_ctxt:
                    curr_ctxt.Enter()
                with ops.colocate_with(value):
                    # We only need to pass maximum_iterations to the stack if
                    # we're inside an XLA context.
                    if not util.IsInXLAContext(value.op):
                        max_size = constant_op.constant(-1, dtypes.int32)
                    else:
                        max_size = _GetMaxSizeFromNestedMaximumIterations(
                            value, self.forward_context)
                    acc = gen_data_flow_ops.stack_v2(
                        max_size=max_size,
                        elem_type=value.dtype.base_dtype,
                        name="f_acc")
                if curr_ctxt:
                    curr_ctxt.Exit()

                # Make acc available in the forward context.
                enter_acc = self.forward_context.AddValue(acc)

                # Add the stack_push op in the context of value.op.
                swap_enabled = self.forward_context.swap_memory
                value_ctxt = util.GetOutputContext(value.op)
                if value_ctxt == self.forward_context:
                    # value is not nested in the forward context.
                    self.forward_context.Enter()
                    push = gen_data_flow_ops.stack_push_v2(
                        enter_acc, value, swap_memory=swap_enabled)
                    self.forward_context.Exit()
                    # Protect stack push and order it before forward_index.
                    self.forward_index.op._add_control_input(push.op)
                else:
                    # value is in a cond context within the forward context.
                    if not isinstance(value_ctxt,
                                      control_flow_ops.CondContext):
                        raise TypeError("value_ctxt is not a CondContext: %s" %
                                        value_ctxt)
                    if dead_branch:
                        # The special case for creating a zero tensor for a dead
                        # branch of a switch. See _ControlFlowState.ZerosLike().
                        value_ctxt.outer_context.Enter()
                        push = gen_data_flow_ops.stack_push_v2(
                            enter_acc, value, swap_memory=swap_enabled)
                        value_ctxt.outer_context.Exit()
                        push.op._set_control_flow_context(value_ctxt)
                    else:
                        value_ctxt.Enter()
                        push = gen_data_flow_ops.stack_push_v2(
                            enter_acc, value, swap_memory=swap_enabled)
                        value_ctxt.Exit()
                    # Protect stack push and order it before forward_sync.
                    self.forward_sync._add_control_input(push.op)
                # Order stack push after the successor of forward_index
                add_op = self.forward_index.op.inputs[0].op
                push.op._add_control_input(add_op)
                return acc
Esempio n. 10
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            # Find the outer most graph for uniquing function names.
            # TODO(jpienaar): Make this work in eager mode.
            graph = ops.get_default_graph()
            while isinstance(graph, function.FuncGraph):
                graph = graph.outer_graph

            true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
            false_name = graph.unique_name(
                ("%sfalse" % scope).replace("/", "_"))

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = util.in_defun()

        true_graph = function.func_graph_from_py_func(
            true_name,
            true_fn, [], {},
            func_graph=util.CondBranchFuncGraph(true_name),
            add_control_dependencies=add_control_dependencies)
        false_graph = function.func_graph_from_py_func(
            false_name,
            false_fn, [], {},
            func_graph=util.CondBranchFuncGraph(false_name),
            add_control_dependencies=add_control_dependencies)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.external_captures,
                                         false_graph.external_captures)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(  # pylint: disable=protected-access
            pred,
            cond_inputs, [t.dtype for t in true_graph.outputs],
            util.create_new_tf_function(true_graph),
            util.create_new_tf_function(false_graph),
            output_shapes=_get_output_shapes(true_graph.outputs,
                                             false_graph.outputs),
            name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            # pylint: disable=protected-access
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))
            # pylint: enable=protected-access

        # Return identities for each output of the If op, rather than the output of
        # the If op directly. This makes pruning work if the output of cond() is
        # fetched: the lowering pass converts the If outputs into IdentityN outputs,
        # which if fetched will cause all ops in the taken branch to be run (since
        # it takes all merge ops as input). After lowering, each output identity op
        # will end up with only the appropriate merge op as input.
        # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
        # correct output structure
        tensors = tuple(array_ops.identity(t) for t in tensors)

        result = tuple(tensors[:num_cond_outputs])
        if len(result) == 1:
            return result[0]
        else:
            return result
Esempio n. 11
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        # Identify if there is a caller device, & get the innermost if possible.
        device_stack = ops.get_default_graph()._device_function_stack
        caller_device = device_stack[-1] if device_stack else None

        caller_colocation_stack = ops.get_default_graph()._colocation_stack
        caller_container = ops.get_default_graph()._container
        caller_collection_ref = ops.get_default_graph()._collections

        func_name_prefix = scope.replace("/", "_")

        true_graph = _function.func_graph_from_py_func(
            true_fn, [], [],
            name="%strue" % func_name_prefix,
            device=caller_device,
            colocation_stack=caller_colocation_stack,
            collections_ref=caller_collection_ref,
            container=caller_container)
        false_graph = _function.func_graph_from_py_func(
            false_fn, [], [],
            name="%sfalse" % func_name_prefix,
            device=caller_device,
            colocation_stack=caller_colocation_stack,
            collections_ref=caller_collection_ref,
            container=caller_container)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.extra_inputs,
                                         false_graph.extra_inputs)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(pred,
                                         cond_inputs,
                                         [t.dtype for t in true_graph.outputs],
                                         _create_new_tf_function(true_graph),
                                         _create_new_tf_function(false_graph),
                                         name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))

        return tensors[:num_cond_outputs]