def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.inputs[1] parallel_iterations = op.get_attr("parallel_iterations") try: num_original_outputs = while_op.get_attr("_num_original_outputs") except: # pylint: disable=bare-except num_original_outputs = len(while_op.outputs) num_intermediates = len(while_op.outputs) - num_original_outputs grads = [ _preprocess_grad(grad, body_out, while_out) # pylint: disable=g-complex-comprehension for grad, body_out, while_out in zip( grads[:num_original_outputs], body_graph.outputs[:num_original_outputs], while_op.outputs[:num_original_outputs]) ] + [None] * num_intermediates # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) # Do not ingore grads wrt extra outputs when computing higher order # derivatives. while_op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=len(while_op.outputs))) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, *unused_args): return counter < forward_loop_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) util.maybe_propagate_compile_time_consts_in_xla(grad_op) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None parallel_iterations = op.get_attr("parallel_iterations") assert not _is_in_xla_context() or maximum_iterations is not None maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) grads = [_preprocess_grad(grad, body_out, while_out) for grad, body_out, while_out in zip(grads, body_graph.outputs, while_op.outputs)] # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) _maybe_set_maximum_iterations_attr(grad_op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)