def If(cond, inputs, then_branch, else_branch, name=None): r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs). Args: cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is converted to a boolean according to the following rule: if the scalar is a numerical value, non-zero means True and zero means False; if the scalar is a string, non-empty means True and empty means False. inputs: A list of input tensors. then_branch: A function takes 'inputs' and returns a list of tensors, whose types are the same as what else_branch returns. else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. name: A name for the operation (optional). Returns: A list of tensors returned by either then_branch(inputs) or else_branch(inputs). """ # pylint: disable=protected-access return gen_functional_ops._if( cond, inputs, [_.type for _ in then_branch.definition.signature.output_arg], then_branch, else_branch, name=name)
def If(cond, inputs, then_branch, else_branch, name=None): r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs). Args: cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is converted to a boolean according to the following rule: if the scalar is a numerical value, non-zero means True and zero means False; if the scalar is a string, non-empty means True and empty means False. inputs: A list of input tensors. then_branch: A function takes 'inputs' and returns a list of tensors, whose types are the same as what else_branch returns. else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. name: A name for the operation (optional). Returns: A list of tensors returned by either then_branch(inputs) or else_branch(inputs). """ # pylint: disable=protected-access return gen_functional_ops._if( cond, inputs, [_.type for _ in then_branch.definition.signature.output_arg], then_branch, else_branch, name=name)
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors)
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" true_graph, false_graph = _get_func_graphs(op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. assert true_graph.outer_graph == op.graph assert false_graph.outer_graph == op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func(true_graph, grads, _get_grad_fn_name(true_graph)) false_grad_graph = _create_grad_func(false_graph, grads, _get_grad_fn_name(false_graph)) assert ([t.dtype for t in true_grad_graph.outputs ] == [t.dtype for t in false_grad_graph.outputs]) # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. true_grad_intermediates = _get_intermediates(true_grad_graph) false_grad_intermediates = _get_intermediates(false_grad_graph) # Save the original number of gradient outputs to return. num_grad_outputs = len(true_grad_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( true_grad_graph, false_grad_graph, true_grad_intermediates, false_grad_intermediates) true_grad_graph.outputs.extend(extra_true_grad_outputs) false_grad_graph.outputs.extend(extra_false_grad_outputs) # Create the gradient If op. tensors = gen_functional_ops._if( op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], _create_new_tf_function(true_grad_graph), _create_new_tf_function(false_grad_graph), output_shapes=_get_output_shapes(true_grad_graph.outputs, false_grad_graph.outputs)) # The predicate has no gradient. return [None] + tensors[:num_grad_outputs]
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" true_graph, false_graph = _get_func_graphs(op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. assert true_graph.outer_graph == op.graph assert false_graph.outer_graph == op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func( true_graph, grads, _get_grad_fn_name(true_graph)) false_grad_graph = _create_grad_func( false_graph, grads, _get_grad_fn_name(false_graph)) assert ([t.dtype for t in true_grad_graph.outputs] == [t.dtype for t in false_grad_graph.outputs]) # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. true_grad_intermediates = _get_intermediates(true_grad_graph) false_grad_intermediates = _get_intermediates(false_grad_graph) # Save the original number of gradient outputs to return. num_grad_outputs = len(true_grad_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( true_grad_graph, false_grad_graph, true_grad_intermediates, false_grad_intermediates) true_grad_graph.outputs.extend(extra_true_grad_outputs) false_grad_graph.outputs.extend(extra_false_grad_outputs) # Create the gradient If op. tensors = gen_functional_ops._if( op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], _create_new_tf_function(true_grad_graph), _create_new_tf_function(false_grad_graph), output_shapes=_get_output_shapes(true_grad_graph.outputs, false_grad_graph.outputs)) # The predicate has no gradient. return [None] + tensors[:num_grad_outputs]
def If(cond, inputs, then_branch, else_branch, name=None): r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs). Args: cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is converted to a boolean according to the following rule: if the scalar is a numerical value, non-zero means True and zero means False; if the scalar is a string, non-empty means True and empty means False. inputs: A list of input tensors. then_branch: A function takes 'inputs' and returns a list of tensors, whose types are the same as what else_branch returns. else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. name: A name for the operation (optional). Returns: A list of tensors returned by either then_branch(inputs) or else_branch(inputs). """ # pylint: disable=protected-access # Handle the Defun case until users have transitioned to tf.function. Note # that composites may need to be re-packed by the caller. if isinstance(then_branch, function._DefinedFunction): tlist = [_.type for _ in then_branch.definition.signature.output_arg] return gen_functional_ops._if( cond, inputs, tlist, then_branch, else_branch, name=name) # We assume that `then_branch` is a ConcreteFunction here. then_out = then_branch.structured_outputs else_out = else_branch.structured_outputs # Ensure then/else are the same type of composites to avoid an invalid call # to pack_sequence_as later on. nest.assert_same_structure(then_out, else_out, expand_composites=True) tlist = nest.flatten(then_branch.output_dtypes) ret = gen_functional_ops._if( cond, inputs, tlist, then_branch, else_branch, name=name) # Re-pack the outputs to restore any CompositeTensors return nest.pack_sequence_as(then_out, ret, expand_composites=True)
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" true_graph = op._true_graph false_graph = op._false_graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func(true_graph, grads, "%sgrad" % true_graph.name) false_grad_graph = _create_grad_func(false_graph, grads, "%sgrad" % false_graph.name) assert ([t.dtype for t in true_grad_graph.outputs ] == [t.dtype for t in false_grad_graph.outputs]) # Match up the captured grad function inputs with outputs of 'op' and other # external tensors. true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. true_grad_intermediates = _get_intermediates(true_grad_graph) false_grad_intermediates = _get_intermediates(false_grad_graph) # Save the original number of gradient outputs to return. num_grad_outputs = len(true_grad_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( true_grad_graph, false_grad_graph, true_grad_intermediates, false_grad_intermediates) true_grad_graph.outputs.extend(extra_true_grad_outputs) false_grad_graph.outputs.extend(extra_false_grad_outputs) # Create the gradient If op. tensors = gen_functional_ops._if( op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], _create_new_tf_function(true_grad_graph), _create_new_tf_function(false_grad_graph)) tensors[0].op._true_graph = true_grad_graph tensors[0].op._false_graph = false_grad_graph # The predicate has no gradient. return [None] + tensors[:num_grad_outputs]
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" true_graph, false_graph = _get_func_graphs(op) # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func( true_graph, grads, _get_grad_fn_name(true_graph)) false_grad_graph = _create_grad_func( false_graph, grads, _get_grad_fn_name(false_graph)) assert ([t.dtype for t in true_grad_graph.outputs] == [t.dtype for t in false_grad_graph.outputs]) # Match up the captured grad function inputs with outputs of 'op' and other # external tensors. true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. true_grad_intermediates = _get_intermediates(true_grad_graph) false_grad_intermediates = _get_intermediates(false_grad_graph) # Save the original number of gradient outputs to return. num_grad_outputs = len(true_grad_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( true_grad_graph, false_grad_graph, true_grad_intermediates, false_grad_intermediates) true_grad_graph.outputs.extend(extra_true_grad_outputs) false_grad_graph.outputs.extend(extra_false_grad_outputs) # Create the gradient If op. tensors = gen_functional_ops._if( op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], _create_new_tf_function(true_grad_graph), _create_new_tf_function(false_grad_graph)) # The predicate has no gradient. return [None] + tensors[:num_grad_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" with ops.name_scope(name) as scope: true_graph = function.func_graph_from_py_func(true_fn, [], [], name="%s_true" % scope) false_graph = function.func_graph_from_py_func(false_fn, [], [], name="%s_false" % scope) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if(pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) # TODO(b/79883549): if we could make Graphs from FunctionDefs, we wouldn't # need this extra state. Requiring extra state also prevents the ability to # take the gradient of deserialized If ops. tensors[0].op._true_graph = true_graph tensors[0].op._false_graph = false_graph return tensors[:num_cond_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" with ops.name_scope(name) as scope: true_graph = function.func_graph_from_py_func(true_fn, [], [], name="%s_true" % scope) false_graph = function.func_graph_from_py_func(false_fn, [], [], name="%s_false" % scope) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if(pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) return tensors[:num_cond_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" with ops.name_scope(name) as scope: true_graph = function.func_graph_from_py_func(true_fn, [], [], name="%s_true" % scope) false_graph = function.func_graph_from_py_func(false_fn, [], [], name="%s_false" % scope) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) return tensors[:num_cond_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) true_graph = function.func_graph_from_py_func( true_name, true_fn, [], {}) false_graph = function.func_graph_from_py_func( false_name, false_fn, [], {}) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output util.maybe_set_lowering_attr(tensors[0].op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors[:num_cond_outputs])
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if not name: name = "cond" with ops.name_scope(name) as scope: # Identify if there is a caller device, & get the innermost if possible. device_stack = ops.get_default_graph()._device_function_stack caller_device = device_stack[-1] if device_stack else None caller_colocation_stack = ops.get_default_graph()._colocation_stack caller_container = ops.get_default_graph()._container caller_collection_ref = ops.get_default_graph()._collections func_name_prefix = scope.replace("/", "_") true_graph = _function.func_graph_from_py_func( true_fn, [], [], name="%strue" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) false_graph = _function.func_graph_from_py_func( false_fn, [], [], name="%sfalse" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) return tensors[:num_cond_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, _function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) true_graph = _function.func_graph_from_py_func( true_name, true_fn, [], {}) false_graph = _function.func_graph_from_py_func( false_name, false_fn, [], {}) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access return tuple(tensors[:num_cond_outputs])
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" # Get the if operator (this logic handles the case where op is a MockOp) if_op = op.outputs[0].op true_graph, false_graph = _get_func_graphs(if_op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. assert true_graph.outer_graph == if_op.graph assert false_graph.outer_graph == if_op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func( true_graph, grads, _get_grad_fn_name(true_graph)) false_grad_graph = _create_grad_func( false_graph, grads, _get_grad_fn_name(false_graph)) if (true_grad_graph.if_op_needs_rewrite or false_grad_graph.if_op_needs_rewrite): # Modify 'op' to output the intermediates needed by the grad functions. Note # that all needed intermediates are wrapped in optionals. Each optional # intermediate output will have a value iff its corresponding branch is # taken. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? true_intermediates = true_grad_graph.xla_intermediates false_intermediates = false_grad_graph.xla_intermediates extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: true_intermediates = true_grad_graph.wrapped_intermediates false_intermediates = false_grad_graph.wrapped_intermediates # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): indicate it's an internal bug if this fails. _check_same_outputs(true_graph, false_graph) true_graph.name += "_rewritten" false_graph.name += "_rewritten" if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph)) if_op._set_func_attr("else_branch", util.create_new_tf_function(false_graph)) if_op._set_type_list_attr("Tout", true_graph.output_types) if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes) if_op._add_outputs( [t.dtype for t in extra_true_outputs], [t.shape for t in extra_true_outputs]) # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # This modifies true_grad_graph and false_grad_graph. _make_output_composite_tensors_match(true_grad_graph, false_grad_graph) outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. true_grad_intermediates = _get_intermediates(true_grad_graph) false_grad_intermediates = _get_intermediates(false_grad_graph) # Save the original number of gradient outputs to return. num_grad_outputs = len(true_grad_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( true_grad_graph, false_grad_graph, true_grad_intermediates, false_grad_intermediates) true_grad_graph.outputs.extend(extra_true_grad_outputs) false_grad_graph.outputs.extend(extra_false_grad_outputs) # Create the gradient If op. tensors = gen_functional_ops._if( op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], util.create_new_tf_function(true_grad_graph), util.create_new_tf_function(false_grad_graph), output_shapes=_get_output_shapes(true_grad_graph.outputs, false_grad_graph.outputs)) Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs.
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if not name: name = "cond" with ops.name_scope(name) as scope: # Identify if there is a caller device, & get the innermost if possible. device_stack = ops.get_default_graph()._device_function_stack caller_device = device_stack[-1] if device_stack else None caller_colocation_stack = ops.get_default_graph()._colocation_stack caller_container = ops.get_default_graph()._container caller_collection_ref = ops.get_default_graph()._collections func_name_prefix = scope.replace("/", "_") true_graph = function.func_graph_from_py_func( true_fn, [], [], name="%strue" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) false_graph = function.func_graph_from_py_func( false_fn, [], [], name="%sfalse" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) return tensors[:num_cond_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if not name: name = "cond" with ops.name_scope(name) as scope: # Identify if there is a caller device, & get the innermost if possible. device_stack = ops.get_default_graph()._device_function_stack caller_device = device_stack[-1] if device_stack else None caller_colocation_stack = ops.get_default_graph()._colocation_stack caller_container = ops.get_default_graph()._container caller_collection_ref = ops.get_default_graph()._collections func_name_prefix = scope.replace("/", "_") true_graph = _function.func_graph_from_py_func( true_fn, [], [], name="%strue" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) false_graph = _function.func_graph_from_py_func( false_fn, [], [], name="%sfalse" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.extra_inputs, false_graph.extra_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if(pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) return tensors[:num_cond_outputs]
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name( ("%sfalse" % scope).replace("/", "_")) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() true_graph = function.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph(true_name), add_control_dependencies=add_control_dependencies) false_graph = function.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph(false_name), add_control_dependencies=add_control_dependencies) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must match, # we wrap all the intermediates in optionals. Each intermediate output will # have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates(true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates(false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, wrapped_true_intermediates, wrapped_false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): somehow indicate it's a bug if this fails. _check_same_outputs(true_graph, false_graph) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors[:num_cond_outputs]
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must match, # we wrap all the intermediates in optionals. Each intermediate output will # have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates( true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates( false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, wrapped_true_intermediates, wrapped_false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): somehow indicate it's a bug if this fails. _check_same_outputs(true_graph, false_graph) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors[:num_cond_outputs]