def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Replace None gradients with zeros. This is needed because `grads` could have # None incoming gradients for the TensorLists. If we pass None's through, the # custom gradient of TensorListPopBack will create an EmptyTensorList inside # the FuncGraph which is undesirable. # TODO(b/80444525): There might be an issue with treating no gradient as zero # gradient in certain cases. Consider replacing None gradients with Zeros # for accumulators only. grads = [ g if g is not None else array_ops.zeros_like(output) for g, output in zip(grads, op.outputs) ] body_grad_graph, args = _create_grad_func( body_graph, grads, _get_unique_name("%s_grad" % body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = _get_unique_name("%s_grad_cond" % op.name) cond_grad_graph = function.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) assert len(loop_vars) == len(body_grad_graph.inputs) assert len(loop_vars) == len(body_grad_graph.outputs) assert len(loop_vars) == len(cond_grad_graph.inputs) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name=_get_unique_name("%s_grad" % op.name)) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. return outputs[2:2 + len(op.inputs)]
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Replace None gradients with zeros. This is needed because `grads` could have # None incoming gradients for the TensorLists. If we pass None's through, the # custom gradient of TensorListPopBack will create an EmptyTensorList inside # the FuncGraph which is undesirable. # TODO(b/80444525): There might be an issue with treating no gradient as zero # gradient in certain cases. Consider replacing None gradients with Zeros # for accumulators only. grads = [ g if g is not None else array_ops.zeros_like(output) for g, output in zip(grads, op.outputs) ] body_grad_graph, args = _create_grad_func( body_graph, grads, _get_unique_name("%s_grad" % body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures cond_grad_graph = function.func_graph_from_py_func( _get_unique_name("%s_grad_cond" % op.name), grad_cond, loop_vars, {}) assert len(loop_vars) == len(body_grad_graph.inputs) assert len(loop_vars) == len(body_grad_graph.outputs) assert len(loop_vars) == len(cond_grad_graph.inputs) outputs = gen_functional_ops._while( loop_vars, cond_v2._create_new_tf_function(cond_grad_graph), cond_v2._create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name=_get_unique_name("%s_grad" % op.name)) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. return outputs[2:2 + len(op.inputs)]
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) fn = function.Function(function.func_graph_from_py_func( name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False), signature=signature) fn._variable_holder = holder return fn
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) fn = function.Function( function.func_graph_from_py_func( name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False), signature=signature) fn._variable_holder = holder return fn
def _create_grad_func(func_graph, grads, name, while_op): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: func_graph: FuncGraph for the forward body function. grads: The incoming grads for `func_graph`'s outputs. name: Name of the returned gradient function. while_op: The forward While op. Returns: 2-tuple of (grad_func_graph, args). """ assert len(func_graph.outputs) == len(grads) loop_counter = constant_op.constant(0.) # TODO(srbs): For nested while loops will need to lookup this value from # the accumulator of the enclosing while loop. For now use as is assuming # there is no nesting. num_iters_t = while_op.outputs[0] args = [loop_counter, num_iters_t] + grads # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = function.func_graph_from_py_func( name, lambda *args: _grad_fn(func_graph, args), args, {}, func_graph=_WhileBodyGradFuncGraph(name, func_graph)) # Add the popped accumulators to the list of outputs. for internal_capture in grad_func_graph.internal_captures: grad_func_graph.outputs.append( grad_func_graph.popped_tensor_lists[internal_capture]) return grad_func_graph, args
def _create_grad_func(func_graph, grads, name, while_op): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: func_graph: FuncGraph for the forward body function. grads: The incoming grads for `func_graph`'s outputs. name: Name of the returned gradient function. while_op: The forward While op. Returns: 2-tuple of (grad_func_graph, args). """ assert len(func_graph.outputs) == len(grads) loop_counter = constant_op.constant(0.) # TODO(srbs): For nested while loops will need to lookup this value from # the accumulator of the enclosing while loop. For now use as is assuming # there is no nesting. num_iters_t = while_op.outputs[0] args = [loop_counter, num_iters_t] + grads # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = function.func_graph_from_py_func( name, lambda *args: _grad_fn(func_graph, args), args, {}, func_graph=_WhileBodyGradFuncGraph(name, func_graph)) # Add the popped accumulators to the list of outputs. for internal_capture in grad_func_graph.internal_captures: grad_func_graph.outputs.append( grad_func_graph.popped_tensor_lists[internal_capture]) return grad_func_graph, args
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) true_graph = function.func_graph_from_py_func( true_name, true_fn, [], {}) false_graph = function.func_graph_from_py_func( false_name, false_fn, [], {}) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def _create_grad_func(func_graph, grads, name): """Returns the FuncGraph representation of _grad_fn.""" return function.func_graph_from_py_func( name, lambda: _grad_fn(func_graph, grads), [], {})
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): """Like tf.while_loop, except emits a single While op.""" flattened_loop_vars = nest.flatten(loop_vars) if shape_invariants is not None: nest.assert_same_structure(loop_vars, shape_invariants) flattened_shapes = nest.flatten(shape_invariants) else: flattened_shapes = [t.shape for t in flattened_loop_vars] del shape_invariants if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars flattened_shapes = [tensor_shape.scalar()] + flattened_shapes # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] cond_graph = function.func_graph_from_py_func( cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature, func_graph=util.WhileCondFuncGraph(cond_name)) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures flattened_shapes = flattened_shapes + [ t.shape for t in cond_graph.external_captures ] def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:num_outputs] - Args for the original loop body. args[num_outputs:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ outputs = body(*args[:num_outputs]) if not isinstance(outputs, collections.Sequence): outputs = [outputs] # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[num_outputs:]) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] body_graph = function.func_graph_from_py_func( body_name, wrapped_body, flattened_loop_vars, {}, signature=signature, func_graph=util.WhileBodyFuncGraph(body_name)) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(srbs): Dedup tensors that are captured in both the cond and body. # This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: # TODO(srbs): Cache and re-use empty tensor lists. tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) flattened_loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. _check_shapes_compat(body_graph.outputs[1:1 + num_outputs], flattened_shapes[1:1 + num_outputs], flattened_loop_vars[1:1 + num_outputs]) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # First var is loop counter. if num_outputs == 1: return outputs[1] else: return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() while isinstance(graph, function.FuncGraph): graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name( ("%sfalse" % scope).replace("/", "_")) true_graph = function.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph(true_name)) false_graph = function.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph(false_name)) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def _create_grad_func(func_graph, grads, name): """Returns the FuncGraph representation of _grad_fn.""" return function.func_graph_from_py_func( name, lambda: _grad_fn(func_graph, grads), [], {}, func_graph=util.CondBranchFuncGraph(name))
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): """Like tf.while_loop, except emits a single While op.""" flattened_loop_vars = nest.flatten(loop_vars) if shape_invariants is not None: nest.assert_same_structure(loop_vars, shape_invariants) flattened_shapes = nest.flatten(shape_invariants) else: flattened_shapes = [t.shape for t in flattened_loop_vars] del shape_invariants if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = _get_unique_name(("%scond" % scope).replace("/", "_")) body_name = _get_unique_name(("%sbody" % scope).replace("/", "_")) num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars flattened_shapes = [tensor_shape.scalar()] + flattened_shapes # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] cond_graph = function.func_graph_from_py_func( cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures flattened_shapes = flattened_shapes + [ t.shape for t in cond_graph.external_captures ] def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:num_outputs] - Args for the original loop body. args[num_outputs:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ outputs = body(*args[:num_outputs]) if not isinstance(outputs, collections.Sequence): outputs = [outputs] # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list(args[num_outputs:]) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] body_graph = function.func_graph_from_py_func( body_name, wrapped_body, flattened_loop_vars, {}, signature=signature) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(srbs): Dedup tensors that are captured in both the cond and body. # This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: # TODO(srbs): Cache and re-use empty tensor lists. tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) flattened_loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. _check_shapes_compat(body_graph.outputs[1:1 + num_outputs], flattened_shapes[1:1 + num_outputs], flattened_loop_vars[1:1 + num_outputs]) outputs = gen_functional_ops._while( flattened_loop_vars, cond_v2._create_new_tf_function(cond_graph), cond_v2._create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # First var is loop counter. if num_outputs == 1: return outputs[1] else: return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() true_graph = function.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph(true_name), add_control_dependencies=add_control_dependencies) false_graph = function.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph(false_name), add_control_dependencies=add_control_dependencies) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result