def test_op_is_in_context(self): """Tests that XLACompileContext is recognized as an XLA context.""" op1 = constant_op.constant(1) context = self.create_test_xla_compile_context() context.Enter() op2 = constant_op.constant(2) context.Exit() self.assertFalse(control_flow_util.IsInXLAContext(op1.op)) self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) context = tpu.TPUReplicateContext(b"context") context.Enter() z2 = array_ops.identity(1) context.Exit() self.assertFalse(control_flow_util.IsInXLAContext(z1.op)) self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" with ops.Graph().as_default(): z1 = array_ops.identity(1) pivot = control_flow_ops.no_op() context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() self.assertFalse(control_flow_util.IsInXLAContext(z1.op)) self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
def _maybe_set_maximum_iterations_attr(op, maximum_iterations): if maximum_iterations is not None and control_flow_util.IsInXLAContext(op): # Store the maximum_iterations to use in the gradient pass. op._set_attr( # pylint: disable=protected-access "_maximum_iterations", attr_value_pb2.AttrValue( i=tensor_util.constant_value(maximum_iterations)))
def maybe_set_lowering_attr(op): """Sets the flag to enable lowering on `op` if necessary. Lowering allows cond_v2 and while_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 and while_v2 input functions, and enabling non-strict evaluation & partial pruning. This brings v2 control flow closer to feature parity with v1 control flow. However, we do not lower in the following cases: - When the `If` or `While` ops are in the XLA context. Because it is easier for XLA to apply its own optimizations when dealing with un-lowered control flow operators than with low-level control flow primitives. - When the eager execution context specifies the executor of functions to be the single threaded executor (see context.function_executor_type()). Because the single threaded executor does not support v1 control flow ops. Args: op: An `If` or `While` Operation. """ if (not control_flow_util.IsInXLAContext(op) and context.context().get_function_call_options().executor_type != "SINGLE_THREADED_EXECUTOR"): # pylint: disable=protected-access op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
def _is_xla_tensor(tensor): try: op = tensor.op except AttributeError: return False if control_flow_util.IsInXLAContext(op): return True return False
def _maybe_set_lowering_attr(op): """Sets the flag to enable lowering on the `While` op if necessary. Lowering allows while_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of while_v2 branches, and enabling non-strict evaluation & partial pruning of while_v2 branches. This brings while_v2 closer to feature parity with tf.while_loop. However, we do not lower `While` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `While` operators than with low-level control flow primitives. Args: op: The While op. """ if not control_flow_util.IsInXLAContext(op): # pylint: disable=protected-access op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, _function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) true_graph = _function.func_graph_from_py_func( true_name, true_fn, [], {}) false_graph = _function.func_graph_from_py_func( false_name, false_fn, [], {}) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access return tuple(tensors[:num_cond_outputs])
def AddForwardAccumulator(self, value, dead_branch=False): """Add an accumulator for each forward tensor that is needed in backprop. This is added to the forward loop at the first time when a tensor in the forward loop is used by backprop gradient computation loop. We create an accumulator that accumulates the value of tensor at each iteration. Called in the control flow context where gradients() is called. The pseudocode is: ``` acc = stack(); while (_pivot) { acc = stack_push(acc, value); } ``` We make sure that the stack push op in one iteration is executed before next iteration. This is achieved by adding a control edge from `forward_index.op.inputs[0].op` to the push op, and another control edge from the push op to either `forward_index.op` or `forward_sync`. Args: value: The source tensor in forward that is to be accumulated. dead_branch: True iff the tensor is on a dead branch of a cond. Returns: The stack that contains the accumulated history of the tensor. Raises: TypeError: For internal errors involving the value condition context. ValueError: If `value` is inside a XLA scope and a valid max size for the stack can't be found. """ # curr_ctxt is the context that tf.gradients was called in. with self._forward_index.graph.as_default(): curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access with ops.control_dependencies(None): if curr_ctxt: curr_ctxt.Enter() with ops.colocate_with(value): # We only need to pass maximum_iterations to the stack if # we're inside an XLA context. if not util.IsInXLAContext(value.op): max_size = constant_op.constant(-1, dtypes.int32) else: max_size = _GetMaxSizeFromNestedMaximumIterations( value, self.forward_context) acc = gen_data_flow_ops.stack_v2( max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") if curr_ctxt: curr_ctxt.Exit() # Make acc available in the forward context. enter_acc = self.forward_context.AddValue(acc) # Add the stack_push op in the context of value.op. swap_enabled = self.forward_context.swap_memory value_ctxt = util.GetOutputContext(value.op) if value_ctxt == self.forward_context: # value is not nested in the forward context. self.forward_context.Enter() push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) self.forward_context.Exit() # Protect stack push and order it before forward_index. self.forward_index.op._add_control_input(push.op) else: # value is in a cond context within the forward context. if not isinstance(value_ctxt, control_flow_ops.CondContext): raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) if dead_branch: # The special case for creating a zero tensor for a dead # branch of a switch. See _ControlFlowState.ZerosLike(). value_ctxt.outer_context.Enter() push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) value_ctxt.outer_context.Exit() push.op._set_control_flow_context(value_ctxt) else: value_ctxt.Enter() push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) value_ctxt.Exit() # Protect stack push and order it before forward_sync. self.forward_sync._add_control_input(push.op) # Order stack push after the successor of forward_index add_op = self.forward_index.op.inputs[0].op push.op._add_control_input(add_op) return acc
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name( ("%sfalse" % scope).replace("/", "_")) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() true_graph = function.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph(true_name), add_control_dependencies=add_control_dependencies) false_graph = function.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph(false_name), add_control_dependencies=add_control_dependencies) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if not name: name = "cond" with ops.name_scope(name) as scope: # Identify if there is a caller device, & get the innermost if possible. device_stack = ops.get_default_graph()._device_function_stack caller_device = device_stack[-1] if device_stack else None caller_colocation_stack = ops.get_default_graph()._colocation_stack caller_container = ops.get_default_graph()._container caller_collection_ref = ops.get_default_graph()._collections func_name_prefix = scope.replace("/", "_") true_graph = _function.func_graph_from_py_func( true_fn, [], [], name="%strue" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) false_graph = _function.func_graph_from_py_func( false_fn, [], [], name="%sfalse" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if(pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) return tensors[:num_cond_outputs]