def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) self.if_op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. # If it is not a resource, we wrap it in an optional in the forward graph # and capture the optional normally. We then unwrap the captured optional # value in the gradient graph to get the raw intermediate value. # If it is a resource, we trace the resource upto the input in the forward # graph and capture that. if tensor.dtype == dtypes.resource: # Index of the forward graph input corresponding to the resource tensor. index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) # This gets mapped to the corresponding If op input in # `_resolve_grad_inputs`. captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( self._forward_graph.inputs[index], name) else: if tensor not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.if_op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) self.if_op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in # an optional in the forward graph and capture the optional normally. We # then unwrap the captured optional value in the gradient graph to get the # raw intermediate value. if tensor not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.if_op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor
def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" # Get the if operator (this logic handles the case where op is a MockOp) if_op = op.outputs[0].op true_graph, false_graph = _get_func_graphs(if_op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. assert true_graph.outer_graph == if_op.graph assert false_graph.outer_graph == if_op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass # functions. true_grad_graph = _create_grad_func( true_graph, grads, util.unique_grad_fn_name(true_graph.name)) false_grad_graph = _create_grad_func( false_graph, grads, util.unique_grad_fn_name(false_graph.name)) assert ([t.dtype for t in true_grad_graph.outputs ] == [t.dtype for t in false_grad_graph.outputs]) if (true_grad_graph.if_op_needs_rewrite or false_grad_graph.if_op_needs_rewrite): # Modify 'op' to output the intermediates needed by the grad functions. Note # that all needed intermediates are wrapped in optionals. Each optional # intermediate output will have a value iff its corresponding branch is # taken. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? true_intermediates = true_grad_graph.xla_intermediates false_intermediates = false_grad_graph.xla_intermediates extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: true_intermediates = true_grad_graph.wrapped_intermediates false_intermediates = false_grad_graph.wrapped_intermediates # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): indicate it's an internal bug if this fails. _check_same_outputs(true_graph, false_graph) true_graph.name += "_rewritten" false_graph.name += "_rewritten" if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph)) if_op._set_func_attr("else_branch", util.create_new_tf_function(false_graph)) if_op._set_type_list_attr("Tout", true_graph.output_types) if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes) if_op._add_outputs([t.dtype for t in extra_true_outputs], [t.shape for t in extra_true_outputs]) # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) # The predicate has no gradient. return [None] + outputs
def _CaseGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a Case op produced by tf.switch_case.""" # Get the Case operator (this logic handles the case where op is a MockOp) case_op = op.outputs[0].op branch_graphs = get_func_graphs(case_op) assert branch_graphs # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. for branch_graph in branch_graphs: assert branch_graph.outer_graph == case_op.graph # Create grad functions that compute the gradient of the branch forward # graphs. These functions will capture tensors from the forward pass # functions. branch_grad_graphs = [] for branch_graph in branch_graphs: branch_grad_graphs.append( _create_grad_func(branch_graph, grads, util.unique_grad_fn_name(branch_graph.name))) if any(g.op_needs_rewrite for g in branch_grad_graphs): # Modify 'op' to output the intermediates needed by the grad functions. Note # that all needed intermediates are wrapped in optionals. Each optional # intermediate output will have a value iff its corresponding branch is # taken. # NOTE(bjp): if there are any active sessions, this modification to `op` # may make them unrunnable! if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(bjp,jpienaar): can XLA support optionals? branches_intermediates = [ branch_grad_graph.xla_intermediates for branch_grad_graph in branch_grad_graphs ] extra_branch_outputs = _make_intermediates_match_xla( branch_graphs, branches_intermediates) else: branch_intermediates = [ g.wrapped_intermediates for g in branch_grad_graphs ] # Make outputs match by adding none optionals. extra_branch_outputs = _make_intermediates_match( branch_graphs, branch_intermediates) for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs): branch_graph.outputs.extend(extra_outputs) # TODO(bjp): indicate it's an internal bug if this fails. _check_same_outputs(_CASE, branch_graphs) for branch_graph in branch_graphs: branch_graph.name += "_rewritten" case_op._set_func_list_attr("branches", [ util.create_new_tf_function(branch_graph) for branch_graph in branch_graphs ]) case_op._set_type_list_attr("Tout", branch_graphs[0].output_types) case_op._set_shape_list_attr("output_shapes", branch_graphs[0].output_shapes) case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]], [t.shape for t in extra_branch_outputs[0]]) # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. branches_grad_inputs = [ _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph, branch_grad_graph in zip(branch_graphs, branch_grad_graphs) ] # This modifies the graphs in branch_grad_graphs. _make_output_composite_tensors_match(_CASE, branch_grad_graphs) outputs = _build_case(case_op.inputs[0], branch_grad_graphs, branches_grad_inputs, name="gradient") # The predicate has no gradient. return [None] + outputs
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must match, # we wrap all the intermediates in optionals. Each intermediate output will # have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so output intermediates directly and # make them match via FakeParams, which can be converted to zeros in XLA. # TODO(skyewm,jpienaar): can XLA support optionals? extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( true_graph, false_graph, true_intermediates, false_intermediates) else: # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates( true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates( false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( true_graph, false_graph, wrapped_true_intermediates, wrapped_false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): somehow indicate it's a bug if this fails. _check_same_outputs(true_graph, false_graph) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return tensors[:num_cond_outputs]