def _find_children_hints(call, graph_def): """Find all children hints. For a given OpHint, we find all children hints inside it, we also copy all the nodes inside function defs (if applicable) to the original graph_def, they are returned in a list as well. Args: call: Parent OpHint that contains children ophints. graph_def: Original graph def. Returns: Ordered children hints inside the parent ophint; new graph def that contains nodes inside function defs (if applicable); nodes inside function defs. """ name_to_input_name, _, _ = _extract_graph_summary(graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) output_nodes_set = set(output_names) children_hints = [] out = _graph_pb2.GraphDef() out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) function_def_nodes = set() for node in graph_def.node: out.node.extend([_copy.deepcopy(node)]) n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # special handle for while loop function def. if node.op == "While": body_name = node.attr["body"].func.name inputs_outside_loop = node.input for function_def in graph_def.library.function: if function_def.signature.name == body_name: function_inputs = function_def.signature.input_arg assert len(inputs_outside_loop) == len(function_inputs) nodes_mapping = {} for i in range(len(function_inputs)): nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i] # TODO(b/123050804): Consider use grappler. (children_hints_in_loop, new_nodes) = _find_children_hints_in_while_loop( function_def, nodes_mapping) function_def_nodes.update([x.name for x in new_nodes]) children_hints.extend(children_hints_in_loop) out.node.extend(new_nodes) return children_hints, out, function_def_nodes
def _getGraphOpTypes(self, graphdef, output_nodes): """Returns used op types in `graphdef` reachable from `output_nodes`. This is used to check that after the stub transformation the expected nodes are there. NOTE: this is not a exact test that the graph is the correct output, but it balances compact expressibility of test with sanity checking. Args: graphdef: TensorFlow proto graphdef. output_nodes: A list of output node names that we need to reach. Returns: A set of node types reachable from `output_nodes`. """ name_to_input_name, name_to_node, _ = ( _extract_graph_summary(graphdef)) # Find all nodes that are needed by the outputs used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) return set([name_to_node[node_name].op for node_name in used_node_names])
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError( "Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] else: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def _convert_single_op_hint_to_stub(call, graph_def): """Given a graph_def, converts `call` into a stub and returns a new graph_def. Args: call: A single function call to be converted. graph_def: A graph_def to use as input (that hass call obviously). Returns: A new transformed graph-def that has call as a stub (single op). Note: after this process, the graph_def can no longer be loaded into the tensorflow runtime, so all future manipulations are done in graph_def level. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) input_nodes_set = set(input_names) output_nodes_set = set(output_names) nodes_after_fuse = [] nodes_deleted_by_fuse = set() # Classify each node. We want to keep everything reachable by input, but # we don't know if things that are not reachable by output or input (things # after fusing). for node in graph_def.node: n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is an internal node. Check to make sure it is really internal. # TODO(aselle): this could be done more efficiently by flooding # the graph first. _check_subgraph_closed(n, reachable_by_input, input_nodes_set, name_to_input_name) nodes_deleted_by_fuse.add(n) elif n not in reachable_by_input: # n is a node that after all the fusings, so keep it. nodes_after_fuse.append(n) else: # n is a node that is randomly in the graph but not connected to # the chain of dependencies. pass # Make a new graphdef with all the pre-input and input nodes out = _graph_pb2.GraphDef() reachable_by_input_sorted = sorted( list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([_copy.deepcopy(name_to_node[node])]) # Create any stacks to aggregate arguments into to a single input # i.e. for static_rnn's. # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 sorted_input_indices = list(call.inputs.keys()) sorted_input_indices.sort() sorted_output_indices = list(call.outputs.keys()) sorted_output_indices.sort() new_node = _node_def_pb2.NodeDef() # Delegate to each operand to produce the proper new input for this stub node. # In particular, an aggregate input will now be a Pack of some previously # non-fused things. for input_index in sorted_input_indices: inputs = call.inputs[input_index] new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) # Ceate the function new_node.op = call.function_name new_node.name = call.uuid out.node.extend([new_node]) # Now call each output argument to give them a chance to make the proper # output type and add it to our new_node. output_dtypes = [] for output_index in sorted_output_indices: output = call.outputs[output_index] output_dtype = ( output.aggregate_and_return_name_for_output(new_node.name, output_index, out)) output_dtypes.append(output_dtype) new_node.attr["_output_types"].list.type[:] = output_dtypes # TODO(aselle): what is right here? new_node.attr["_output_quantized"].b = False # Add post output nodes that do not depend on the outputs for n in nodes_after_fuse: should_keep = True for input_name in name_to_input_name[n]: if input_name in nodes_deleted_by_fuse: should_keep = False if should_keep: out.node.extend([_copy.deepcopy(name_to_node[n])]) # Misc. graph_def data that needs copying. out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def _convert_single_op_hint_to_stub(call, graph_def): """Given a graph_def, converts `call` into a stub and returns a new graph_def. Args: call: A single function call to be converted. graph_def: A graph_def to use as input (that hass call obviously). Returns: A new transformed graph-def that has call as a stub (single op). Note: after this process, the graph_def can no longer be loaded into the tensorflow runtime, so all future manipulations are done in graph_def level. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) input_nodes_set = set(input_names) output_nodes_set = set(output_names) nodes_after_fuse = [] nodes_deleted_by_fuse = set() # Classify each node. We want to keep everything reachable by input, but # we don't know if things that are not reachable by output or input (things # after fusing). for node in graph_def.node: n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is an internal node. Check to make sure it is really internal. # TODO(aselle): this could be done more efficiently by flooding # the graph first. _check_subgraph_closed(n, reachable_by_input, input_nodes_set, name_to_input_name) nodes_deleted_by_fuse.add(n) elif n not in reachable_by_input: # n is a node that after all the fusings, so keep it. nodes_after_fuse.append(n) else: # n is a node that is randomly in the graph but not connected to # the chain of dependencies. pass # Make a new graphdef with all the pre-input and input nodes out = _graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([_copy.deepcopy(name_to_node[node])]) # Create any stacks to aggregate arguments into to a single input # i.e. for static_rnn's. # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 sorted_input_indices = list(call.inputs.keys()) sorted_input_indices.sort() sorted_output_indices = list(call.outputs.keys()) sorted_output_indices.sort() new_node = _node_def_pb2.NodeDef() # Delegate to each operand to produce the proper new input for this stub node. # In particular, an aggregate input will now be a Pack of some previously # non-fused things. for input_index in sorted_input_indices: inputs = call.inputs[input_index] new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend( sorted_input_indices) # Ceate the function new_node.op = call.function_name new_node.name = call.uuid out.node.extend([new_node]) # Now call each output argument to give them a chance to make the proper # output type and add it to our new_node. output_dtypes = [] for output_index in sorted_output_indices: output = call.outputs[output_index] output_dtype = (output.aggregate_and_return_name_for_output( new_node.name, output_index, out)) output_dtypes.append(output_dtype) new_node.attr["_output_types"].list.type[:] = output_dtypes # TODO(aselle): what is right here? new_node.attr["_output_quantized"].b = False # Add post output nodes that do not depend on the outputs for n in nodes_after_fuse: should_keep = True for input_name in name_to_input_name[n]: if input_name in nodes_deleted_by_fuse: should_keep = False if should_keep: out.node.extend([_copy.deepcopy(name_to_node[n])]) # Misc. graph_def data that needs copying. out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError("Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] else: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted( list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out