def _make_op(inputs): if_op, tensors = util.get_op_and_outputs(op_fn( pred, inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name)) _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) # `if_op` is None if this is a `StatelessIf` op with no outputs. if if_op is not None: # The true and false graphs have already been created, and we need that # to happen before we know which tensors will be captured and so whether # to wrap the cond in a tf.function. Post-hoc mutation of the branch # `outer_graph` properties seems like the only option if we want to # conditionally wrap in a function. true_graph.outer_graph = ops.get_default_graph() false_graph.outer_graph = ops.get_default_graph() if_op._true_graph = true_graph if_op._false_graph = false_graph util.maybe_set_lowering_attr(if_op) util.maybe_propagate_compile_time_consts_in_xla(if_op) _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors
def _build_case(branch_index, branch_graphs, branch_inputs, name=None): """Creates an `Case` op from `branch_index`, branch graphs and inputs. Note that this modifies `branch_graphs` to make the inputs match, and to output all intermediates values so they're available for the gradient computation. `branch_graphs` need not have the same input types, but they must have the same outpute types. Args: branch_index: integer Tensor branch_graphs: List of FuncGraph branch_inputs: List of lists of Tensors to be passed to corresponding branch_graph as input. name: the name for the Case op. Returns: A list of Tensors which are the outputs of the Case op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_CASE, branch_graphs) _check_same_outputs(_CASE, branch_graphs) # Add inputs to branch_graphs to make them match. Note that this modifies the # graphs in `branch_graphs`. case_inputs = _make_inputs_match(branch_graphs, branch_inputs) # Create the Case op. with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): tensors = gen_functional_ops.case( branch_index, case_inputs, [t.dtype for t in branch_graphs[0].outputs], [util.create_new_tf_function(g) for g in branch_graphs], output_shapes=_get_output_shapes( *[g.outputs for g in branch_graphs]), name=name) case_op, tensors = _get_op_and_outputs(tensors) if case_op is not None: util.maybe_set_lowering_attr(case_op) util.maybe_propagate_compile_time_consts_in_xla(case_op) _set_read_only_resource_inputs_attr(case_op, branch_graphs) # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN # outputs, which if fetched will cause all ops in the taken branch to be run # (since it takes all merge ops as input). After lowering, each output # identity op will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors)
def _make_op(inputs): case_op, tensors = util.get_op_and_outputs(op_fn( branch_index, inputs, [t.dtype for t in branch_graphs[0].outputs], [util.create_new_tf_function(g) for g in branch_graphs], output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), name=name)) _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) if case_op is not None: util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) util.maybe_propagate_compile_time_consts_in_xla(case_op) _set_read_only_resource_inputs_attr(case_op, branch_graphs) # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) return tensors
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Set the incoming gradient of non-trainable inputs to None. It is possible # that we receive non-None gradients for non-trainable types in nested while # loops because we accumulate outputs of the inner while as variant tensors # which are trainable and hence receive zeros_like tensors in the gradient # pass. The non-trainable tensors then receive the popped zeros tensor from # this zeros variant. The gradient for the loop vars corresponding to these # tensors is None or zeros (this happens only if the loop var is accumulated # as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if not _is_trainable(output) else grad for grad, output in zip(grads, body_graph.outputs) ] # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs def grad_cond(counter, max_iters, *unused_args): return counter < max_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % while_op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not gradients_impl.IsTrainable(output) else grad for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output util.maybe_set_lowering_attr(tensors[0].op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors[:num_cond_outputs])
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" true_graph, false_graph = _get_func_graphs(op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. assert true_graph.outer_graph == op.graph assert false_graph.outer_graph == op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func( true_graph, grads, util.unique_grad_fn_name(true_graph.name)) false_grad_graph = _create_grad_func( false_graph, grads, util.unique_grad_fn_name(false_graph.name)) assert ([t.dtype for t in true_grad_graph.outputs] == [t.dtype for t in false_grad_graph.outputs]) # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. true_grad_intermediates = _get_intermediates(true_grad_graph) false_grad_intermediates = _get_intermediates(false_grad_graph) # Save the original number of gradient outputs to return. num_grad_outputs = len(true_grad_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( true_grad_graph, false_grad_graph, true_grad_intermediates, false_grad_intermediates) true_grad_graph.outputs.extend(extra_true_grad_outputs) false_grad_graph.outputs.extend(extra_false_grad_outputs) # Create the gradient If op. tensors = gen_functional_ops._if( op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], util.create_new_tf_function(true_grad_graph), util.create_new_tf_function(false_grad_graph), output_shapes=_get_output_shapes(true_grad_graph.outputs, false_grad_graph.outputs)) util.maybe_set_lowering_attr(tensors[0].op) # See comment in cond_v2. tensors = [array_ops.identity(t) for t in tensors] # The predicate has no gradient. return [None] + tensors[:num_grad_outputs]
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" cond_graph = _get_graph(op, "cond") body_graph = _get_graph(op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not _is_trainable(output) else grad for grad, output in zip(grads, body_graph.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all( g is not None or o.dtype == dtypes.resource or not _is_trainable(o) for o, g in zip(body_graph.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) op._set_func_attr("body", util.create_new_tf_function(body_graph)) op._set_type_list_attr("T", body_graph.output_types) op._set_shape_list_attr("output_shapes", body_graph.output_shapes) op._add_while_inputs(new_inputs) op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, op) loop_vars = args + captured_inputs def grad_cond(counter, max_iters, *unused_args): return counter < max_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True, back_prop=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, expand_composites=True) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants, expand_composites=False) signature = nest.map_structure( control_flow_ops._shape_invariant_to_type_spec, loop_vars, list(shape_invariants), expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, list(shape_invariants), expand_composites=False) else: signature = nest.map_structure( type_spec.type_spec_from_value, loop_vars, expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, expand_composites=False) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants signature = ( [tensor_spec.TensorSpec.from_tensor(loop_counter), tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + signature) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies def wrapped_cond(loop_counter, maximum_iterations_arg, *args): """Extra `cond` wrapper that can handle the extra counter loop_var.""" # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. pred = cond(*_pack_sequence_as(orig_loop_vars, args)) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) if maximum_iterations is None: return pred else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, pred) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) cond_graph_captures = object_identity.ObjectIdentitySet( cond_graph.external_captures) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph_captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars, expand_composites=True)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True)) num_original_outputs = len(body_graph.outputs) if back_prop and util.output_all_intermediates(): # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): output_shapes = [t.shape for t in body_graph.outputs] orig_loop_vars_range = slice(first_loop_var_index, first_loop_var_index + num_flattened_outputs) output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] cond_stateful_ops = [ op for op in cond_graph.get_operations() if op._is_stateful ] body_stateful_ops = [ op for op in body_graph.get_operations() if op._is_stateful ] if (cond_stateful_ops or body_stateful_ops): op_fn = gen_functional_ops._while else: op_fn = gen_functional_ops.stateless_while outputs = op_fn( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=scope) # This is needed so we do not compute derivative wrt these extra outputs. outputs[0].op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=num_original_outputs)) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs, expand_composites=True) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph.captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], parallel_iterations=parallel_iterations, name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def while_loop(cond, body, loop_vars, shape_invariants=None, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + cond_graph.external_captures shape_invariants = shape_invariants + type(shape_invariants)( [t.shape for t in cond_graph.external_captures]) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:]) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(b/118457764): Dedup tensors that are captured in both the cond and # body. This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: assert external_capture not in cond_graph.captures, ( "Looks like both cond and body are capturing the same tensor %s. " "This is not supported yet. For now consider passing," " this as a loop variable." % str(external_capture)) cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not gradients_impl.IsTrainable(output) else grad for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must match, # we wrap all the intermediates in optionals. Each intermediate output will # have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates(true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates(false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, wrapped_true_intermediates, wrapped_false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): somehow indicate it's a bug if this fails. _check_same_outputs(true_graph, false_graph) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors[:num_cond_outputs]
def while_loop(cond, body, loop_vars, shape_invariants=None, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + cond_graph.external_captures shape_invariants = shape_invariants + type(shape_invariants)( [t.shape for t in cond_graph.external_captures]) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:]) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(b/118457764): Dedup tensors that are captured in both the cond and # body. This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: assert external_capture not in cond_graph.captures, ( "Looks like both cond and body are capturing the same tensor %s. " "This is not supported yet. For now consider passing," " this as a loop variable." % str(external_capture)) cond_graph.capture(external_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) _check_same_outputs(_COND, [true_graph, false_graph]) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match([true_graph, false_graph], [true_inputs, false_inputs]) # Create the If op. with ops.control_dependencies( list(true_graph.control_captures) + list(false_graph.control_captures)): true_stateful_ops = [ op for op in true_graph.get_operations() if op._is_stateful ] false_stateful_ops = [ op for op in false_graph.get_operations() if op._is_stateful ] # TODO(srbs): Remove this after July 22, 2019. This is required to abide by # 3-week forward compat window of new TF python op generating code with # stale runtime binaries. if (true_stateful_ops or false_stateful_ops or not compat.forward_compatible(2019, 7, 22)): op_fn = gen_functional_ops._if else: op_fn = gen_functional_ops.stateless_if tensors = op_fn(pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes( true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) util.maybe_propagate_compile_time_consts_in_xla(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors)
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None parallel_iterations = op.get_attr("parallel_iterations") assert not _is_in_xla_context() or maximum_iterations is not None maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) grads = [_preprocess_grad(grad, body_out, while_out) for grad, body_out, while_out in zip(grads, body_graph.outputs, while_op.outputs)] # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) _maybe_set_maximum_iterations_attr(grad_op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = type(shape_invariants)( [tensor_shape.scalar(), tensor_shape.scalar()]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, maximum_iterations_arg, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, cond(*_pack_sequence_as(orig_loop_vars, args))) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) for body_capture in body_graph.external_captures[ num_cond_captures:]: assert body_capture not in cond_graph.captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars]), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], parallel_iterations=parallel_iterations, name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.inputs[1] parallel_iterations = op.get_attr("parallel_iterations") try: num_original_outputs = while_op.get_attr("_num_original_outputs") except: # pylint: disable=bare-except num_original_outputs = len(while_op.outputs) num_intermediates = len(while_op.outputs) - num_original_outputs grads = [ _preprocess_grad(grad, body_out, while_out) # pylint: disable=g-complex-comprehension for grad, body_out, while_out in zip( grads[:num_original_outputs], body_graph.outputs[:num_original_outputs], while_op.outputs[:num_original_outputs]) ] + [None] * num_intermediates # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) # Do not ingore grads wrt extra outputs when computing higher order # derivatives. while_op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=len(while_op.outputs))) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, *unused_args): return counter < forward_loop_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) util.maybe_propagate_compile_time_consts_in_xla(grad_op) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, building_gradient, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. building_gradient: Whether this is a gradient If op. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) _check_same_outputs(_COND, [true_graph, false_graph]) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match([true_graph, false_graph], [true_inputs, false_inputs]) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # We do not output intermediates of the gradient If op since this is just # for backwards compatibility with existing code. if not building_gradient and util.output_all_intermediates(): # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must # match, we wrap all the intermediates in optionals. Each intermediate # output will have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates( true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates( false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking [true_graph, false_graph], [wrapped_true_intermediates, wrapped_false_intermediates]) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) _check_same_outputs(_COND, [true_graph, false_graph]) # Create the If op. with ops.control_dependencies( list(true_graph.control_captures) + list(false_graph.control_captures)): true_stateful_ops = [ op for op in true_graph.get_operations() if op._is_stateful ] false_stateful_ops = [ op for op in false_graph.get_operations() if op._is_stateful ] if (true_stateful_ops or false_stateful_ops): op_fn = gen_functional_ops._if else: op_fn = gen_functional_ops.stateless_if tensors = op_fn(pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes( true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if_op._true_graph = true_graph if_op._false_graph = false_graph util.maybe_set_lowering_attr(if_op) util.maybe_propagate_compile_time_consts_in_xla(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors[:num_cond_outputs])
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must match, # we wrap all the intermediates in optionals. Each intermediate output will # have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates( true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates( false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, wrapped_true_intermediates, wrapped_false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): somehow indicate it's a bug if this fails. _check_same_outputs(true_graph, false_graph) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors[:num_cond_outputs]