def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Replace None gradients with zeros. This is needed because `grads` could have # None incoming gradients for the TensorLists. If we pass None's through, the # custom gradient of TensorListPopBack will create an EmptyTensorList inside # the FuncGraph which is undesirable. # TODO(b/80444525): There might be an issue with treating no gradient as zero # gradient in certain cases. Consider replacing None gradients with Zeros # for accumulators only. grads = [ g if g is not None else array_ops.zeros_like(output) for g, output in zip(grads, op.outputs) ] body_grad_graph, args = _create_grad_func( body_graph, grads, _get_unique_name("%s_grad" % body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures cond_grad_graph = function.func_graph_from_py_func( _get_unique_name("%s_grad_cond" % op.name), grad_cond, loop_vars, {}) assert len(loop_vars) == len(body_grad_graph.inputs) assert len(loop_vars) == len(body_grad_graph.outputs) assert len(loop_vars) == len(cond_grad_graph.inputs) outputs = gen_functional_ops._while( loop_vars, cond_v2._create_new_tf_function(cond_grad_graph), cond_v2._create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name=_get_unique_name("%s_grad" % op.name)) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. return outputs[2:2 + len(op.inputs)]
def while_loop(cond, body, loop_vars, name=None): """Like tf.while_loop, except emits a single While op.""" if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = _get_unique_name(("%scond" % scope).replace("/", "_")) body_name = _get_unique_name(("%sbody" % scope).replace("/", "_")) flattened_loop_vars = nest.flatten(loop_vars) num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond, flattened_loop_vars, {}) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:num_outputs] - Args for the original loop body. args[num_outputs:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ outputs = body(*args[:num_outputs]) if not isinstance(outputs, collections.Sequence): outputs = [outputs] # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[num_outputs:]) body_graph = function.func_graph_from_py_func(body_name, wrapped_body, flattened_loop_vars, {}) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(srbs): Dedup tensors that are captured in both the cond and body. # This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: # TODO(srbs): Cache and re-use empty tensor lists. tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) flattened_loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) outputs = gen_functional_ops._while( flattened_loop_vars, cond_v2._create_new_tf_function(cond_graph), cond_v2._create_new_tf_function(body_graph), name=scope) _maybe_set_lowering_attr(outputs[0].op) # First var is loop counter. if num_outputs == 1: return outputs[1] else: return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
def while_loop(cond, body, loop_vars, name=None): """Like tf.while_loop, except emits a single While op.""" if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = _get_unique_name(("%scond" % scope).replace("/", "_")) body_name = _get_unique_name(("%sbody" % scope).replace("/", "_")) flattened_loop_vars = nest.flatten(loop_vars) num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond, flattened_loop_vars, {}) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:num_outputs] - Args for the original loop body. args[num_outputs:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ outputs = body(*args[:num_outputs]) if not isinstance(outputs, collections.Sequence): outputs = [outputs] # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list(args[num_outputs:]) body_graph = function.func_graph_from_py_func(body_name, wrapped_body, flattened_loop_vars, {}) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(srbs): Dedup tensors that are captured in both the cond and body. # This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: # TODO(srbs): Cache and re-use empty tensor lists. tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) flattened_loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) outputs = gen_functional_ops._while( flattened_loop_vars, cond_v2._create_new_tf_function(cond_graph), cond_v2._create_new_tf_function(body_graph), name=scope) _copy_handle_data(body_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # First var is loop counter. if num_outputs == 1: return outputs[1] else: return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])