def bonsai_parser(model, model_in): """ Full parsing function, handling the route layers Args: model: pytorch model to be processed model_in: model input Returns: A complete BonsaiParsedModel """ # Simple parsing, without the routing layers bonsai_parsed_model = parse_simple_model(model, model_in.size()) # Getting the graph that represents the underlying network connectivity gd = pg.graph(model, args=(model_in,)) name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(gd[0]) # Convert node numbers to their short weight name graph_layers_to_weights = {get_node_name(k):v for k,v in name_to_seq_num.items()} # Route layers route_layers = {k: v.op.split('::')[1] for k,v in name_to_node.items() if v.op in ['onnx::Concat', 'onnx::Add']} route_weight_names = [get_node_name(x) for x in route_layers] # matching node (full name, node shortened name, and previous nodes connected by the graph) raw_predecessors = {(k, get_node_name(k)): [get_node_name(x) for x in in_names_list] for k, in_names_list in name_to_input_name.items()} # removing duplicates and empty strings predecessors = {weight: list(set([x for x in weight_list if len(x) > 0])) for (name, weight), weight_list in raw_predecessors.items()} # removing nodes that are intermediate values, they dont correspond to graph layers real_nodes = list(set(bonsai_parsed_model.get_weight_names())) + route_weight_names real_predecessors = get_real_predecessors(predecessors.copy(), real_nodes) # getting the relevant layers for the routing computation route_real_predecessors = {k:v for k,v in real_predecessors.items() if k in route_weight_names} route_predecessors_layers = {k:bonsai_parsed_model.get_layers_by_weights(v) for k,v in route_real_predecessors.items()} # computing the layer number of the generated route layer # here we set it to be 1 after the node that is previous to it in the GraphDef graph graph_connection = {graph_layers_to_weights[k]: [graph_layers_to_weights[x] for x in v] for k, v in route_real_predecessors.items()} prev_index = {k: v.index(int(k)-1) for k, v in graph_connection.items()} res_layers = {v[prev_index[graph_layers_to_weights[k]]] + 1: v for k, v in route_predecessors_layers.items()} # keeping in mind that layer indices shift when we add new layers shifted_values = {k: [val + len([x for x in res_layers if int(x) < int(val)]) for val in v] for k, v in res_layers.items()} final_layers = {int(k) + len([x for x in shifted_values if int(x) < int(k)]): v for k, v in shifted_values.items()} # adding the layers to the model for (k, v), operation in zip(final_layers.items(), route_layers.values()): if operation == 'Concat': bonsai_parsed_model.insert_module(int(k), 'route') bonsai_parsed_model.insert_param(int(k), 'layers', str(v)) elif operation == 'Add': bonsai_parsed_model.insert_module(int(k), 'residual_add') bonsai_parsed_model.insert_param(int(k), 'layers', str(v)) return bonsai_parsed_model
def _find_children_hints(call, graph_def): """Find all children hints. For a given OpHint, we find all children hints inside it, we also copy all the nodes inside function defs (if applicable) to the original graph_def, they are returned in a list as well. Args: call: Parent OpHint that contains children ophints. graph_def: Original graph def. Returns: Ordered children hints inside the parent ophint; new graph def that contains nodes inside function defs (if applicable); nodes inside function defs. """ name_to_input_name, _, _ = _extract_graph_summary(graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) output_nodes_set = set(output_names) children_hints = [] out = _graph_pb2.GraphDef() out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) function_def_nodes = set() for node in graph_def.node: out.node.extend([_copy.deepcopy(node)]) n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # special handle for while loop function def. if node.op == "While": body_name = node.attr["body"].func.name inputs_outside_loop = node.input for function_def in graph_def.library.function: if function_def.signature.name == body_name: function_inputs = function_def.signature.input_arg assert len(inputs_outside_loop) == len(function_inputs) nodes_mapping = {} for i in range(len(function_inputs)): nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i] # TODO(b/123050804): Consider use grappler. (children_hints_in_loop, new_nodes) = _find_children_hints_in_while_loop( function_def, nodes_mapping) function_def_nodes.update([x.name for x in new_nodes]) children_hints.extend(children_hints_in_loop) out.node.extend(new_nodes) return children_hints, out, function_def_nodes
def _getGraphOpTypes(self, graphdef, output_nodes): """Returns used op types in `graphdef` reachable from `output_nodes`. This is used to check that after the stub transformation the expected nodes are there. NOTE: this is not a exact test that the graph is the correct output, but it balances compact expressibility of test with sanity checking. Args: graphdef: TensorFlow proto graphdef. output_nodes: A list of output node names that we need to reach. Returns: A set of node types reachable from `output_nodes`. """ name_to_input_name, name_to_node, _ = ( _extract_graph_summary(graphdef)) # Find all nodes that are needed by the outputs used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) return set([name_to_node[node_name].op for node_name in used_node_names])
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError( "Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] else: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def _remove_one_redundant_stack_unstack(in_graph_def): """Removes a stack->unstack pattern from in_graph_def in a returned graph. Args: in_graph_def: Graph def to use as input. Returns: Simplified tuple (graph_def, changed_something) where changed_something is true if anything was done. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( in_graph_def) del name_to_seq_num # TODO(aselle): Make this not hardcoded. do_generic_pack_unpack = True out = _graph_pb2.GraphDef() out.library.CopyFrom(in_graph_def.library) out.versions.CopyFrom(in_graph_def.versions) for n in in_graph_def.node: node_name = _tensor_name_base(n.name) if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): continue next_to_visit = [node_name] visited = set() unpack_nodes = set() pack_node = node_name # Find a pattern of unstack connected to a stack (with identities # in between. matches_pattern = True is_hint_created_stack = False while next_to_visit: current_node_name = next_to_visit[0] visited.add(current_node_name) del next_to_visit[0] node = name_to_node[current_node_name] is_op_hint_stack = node.name.startswith("OpHintStack") is_op_hint_unstack = node.name.startswith("OpHintUnstack") if (node.op == "Identity" or is_op_hint_stack or (do_generic_pack_unpack and node.op == "Pack")): is_hint_created_stack |= is_op_hint_stack next_to_visit += [ input_node for input_node in name_to_input_name[current_node_name] if input_node not in visited ] elif (is_op_hint_unstack or (do_generic_pack_unpack and node.op == "Unpack")): unpack_nodes.add(node.name) is_hint_created_stack &= is_op_hint_unstack else: matches_pattern = False break visited.add(node.name) if matches_pattern and len(unpack_nodes) == 1: pack_node = node_name # Check to see if anyone depends on the intermediate identity or the # Unstacked form no_external_dependency = True for other_n in in_graph_def.node: if other_n.name in visited: continue for input_tensor in name_to_input_name[other_n.name]: input_op = _tensor_name_base(input_tensor) if input_op in visited and input_op != pack_node: no_external_dependency = False # Proceed with the substitution if the stack/unstack pair was created # through hints, or that it was not, but nobody is consuming things # between the stack and unstack. if is_hint_created_stack or no_external_dependency: end = unpack_nodes.pop() end_input = name_to_node[end].input[0] # All nodes that depend on the final stack need to be redone to use for other_n in in_graph_def.node: node_name = _tensor_name_base(other_n.name) if node_name not in visited: new_node = _copy.deepcopy(other_n) new_node.input[:] = [ (end_input if stripped == pack_node else non_stripped) for stripped, non_stripped in zip( name_to_input_name[node_name], new_node.input[:]) ] out.node.extend([new_node]) return out, True return in_graph_def, False
def _convert_single_op_hint_to_stub(call, graph_def): """Given a graph_def, converts `call` into a stub and returns a new graph_def. Args: call: A single function call to be converted. graph_def: A graph_def to use as input (that hass call obviously). Returns: A new transformed graph-def that has call as a stub (single op). Note: after this process, the graph_def can no longer be loaded into the tensorflow runtime, so all future manipulations are done in graph_def level. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) input_nodes_set = set(input_names) output_nodes_set = set(output_names) nodes_after_fuse = [] nodes_deleted_by_fuse = set() # Classify each node. We want to keep everything reachable by input, but # we don't know if things that are not reachable by output or input (things # after fusing). for node in graph_def.node: n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is an internal node. Check to make sure it is really internal. # TODO(aselle): this could be done more efficiently by flooding # the graph first. _check_subgraph_closed(n, reachable_by_input, input_nodes_set, name_to_input_name) nodes_deleted_by_fuse.add(n) elif n not in reachable_by_input: # n is a node that after all the fusings, so keep it. nodes_after_fuse.append(n) else: # n is a node that is randomly in the graph but not connected to # the chain of dependencies. pass # Make a new graphdef with all the pre-input and input nodes out = _graph_pb2.GraphDef() reachable_by_input_sorted = sorted( list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([_copy.deepcopy(name_to_node[node])]) # Create any stacks to aggregate arguments into to a single input # i.e. for static_rnn's. # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 sorted_input_indices = list(call.inputs.keys()) sorted_input_indices.sort() sorted_output_indices = list(call.outputs.keys()) sorted_output_indices.sort() new_node = _node_def_pb2.NodeDef() # Delegate to each operand to produce the proper new input for this stub node. # In particular, an aggregate input will now be a Pack of some previously # non-fused things. for input_index in sorted_input_indices: inputs = call.inputs[input_index] new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) # Ceate the function new_node.op = call.function_name new_node.name = call.uuid out.node.extend([new_node]) # Now call each output argument to give them a chance to make the proper # output type and add it to our new_node. output_dtypes = [] for output_index in sorted_output_indices: output = call.outputs[output_index] output_dtype = ( output.aggregate_and_return_name_for_output(new_node.name, output_index, out)) output_dtypes.append(output_dtype) new_node.attr["_output_types"].list.type[:] = output_dtypes # TODO(aselle): what is right here? new_node.attr["_output_quantized"].b = False # Add post output nodes that do not depend on the outputs for n in nodes_after_fuse: should_keep = True for input_name in name_to_input_name[n]: if input_name in nodes_deleted_by_fuse: should_keep = False if should_keep: out.node.extend([_copy.deepcopy(name_to_node[n])]) # Misc. graph_def data that needs copying. out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def _remove_one_redundant_stack_unstack(in_graph_def): """Removes a stack->unstack pattern from in_graph_def in a returned graph. Args: in_graph_def: Graph def to use as input. Returns: Simplified tuple (graph_def, changed_something) where changed_something is true if anything was done. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( in_graph_def) del name_to_seq_num # TODO(aselle): Make this not hardcoded. do_generic_pack_unpack = True out = _graph_pb2.GraphDef() out.library.CopyFrom(in_graph_def.library) out.versions.CopyFrom(in_graph_def.versions) for n in in_graph_def.node: node_name = _tensor_name_base(n.name) if not node_name.startswith("OpHintStack") and not n.op.startswith( "Pack"): continue next_to_visit = [node_name] visited = set() unpack_nodes = set() pack_node = node_name # Find a pattern of unstack connected to a stack (with identities # in between. matches_pattern = True is_hint_created_stack = False while next_to_visit: current_node_name = next_to_visit[0] visited.add(current_node_name) del next_to_visit[0] node = name_to_node[current_node_name] is_op_hint_stack = node.name.startswith("OpHintStack") is_op_hint_unstack = node.name.startswith("OpHintUnstack") if (node.op == "Identity" or is_op_hint_stack or (do_generic_pack_unpack and node.op == "Pack")): is_hint_created_stack |= is_op_hint_stack next_to_visit += [ input_node for input_node in name_to_input_name[current_node_name] if input_node not in visited ] elif (is_op_hint_unstack or (do_generic_pack_unpack and node.op == "Unpack")): unpack_nodes.add(node.name) is_hint_created_stack &= is_op_hint_unstack else: matches_pattern = False break visited.add(node.name) if matches_pattern and len(unpack_nodes) == 1: pack_node = node_name # Check to see if anyone depends on the intermediate identity or the # Unstacked form no_external_dependency = True for other_n in in_graph_def.node: if other_n.name in visited: continue for input_tensor in name_to_input_name[other_n.name]: input_op = _tensor_name_base(input_tensor) if input_op in visited and input_op != pack_node: no_external_dependency = False # Proceed with the substitution if the stack/unstack pair was created # through hints, or that it was not, but nobody is consuming things # between the stack and unstack. if is_hint_created_stack or no_external_dependency: end = unpack_nodes.pop() end_input = name_to_node[end].input[0] # All nodes that depend on the final stack need to be redone to use for other_n in in_graph_def.node: node_name = _tensor_name_base(other_n.name) if node_name not in visited: new_node = _copy.deepcopy(other_n) new_node.input[:] = [(end_input if stripped == pack_node else non_stripped) for stripped, non_stripped in zip( name_to_input_name[node_name], new_node.input[:])] out.node.extend([new_node]) return out, True return in_graph_def, False
def _convert_single_op_hint_to_stub(call, graph_def): """Given a graph_def, converts `call` into a stub and returns a new graph_def. Args: call: A single function call to be converted. graph_def: A graph_def to use as input (that hass call obviously). Returns: A new transformed graph-def that has call as a stub (single op). Note: after this process, the graph_def can no longer be loaded into the tensorflow runtime, so all future manipulations are done in graph_def level. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) input_nodes_set = set(input_names) output_nodes_set = set(output_names) nodes_after_fuse = [] nodes_deleted_by_fuse = set() # Classify each node. We want to keep everything reachable by input, but # we don't know if things that are not reachable by output or input (things # after fusing). for node in graph_def.node: n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is an internal node. Check to make sure it is really internal. # TODO(aselle): this could be done more efficiently by flooding # the graph first. _check_subgraph_closed(n, reachable_by_input, input_nodes_set, name_to_input_name) nodes_deleted_by_fuse.add(n) elif n not in reachable_by_input: # n is a node that after all the fusings, so keep it. nodes_after_fuse.append(n) else: # n is a node that is randomly in the graph but not connected to # the chain of dependencies. pass # Make a new graphdef with all the pre-input and input nodes out = _graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([_copy.deepcopy(name_to_node[node])]) # Create any stacks to aggregate arguments into to a single input # i.e. for static_rnn's. # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 sorted_input_indices = list(call.inputs.keys()) sorted_input_indices.sort() sorted_output_indices = list(call.outputs.keys()) sorted_output_indices.sort() new_node = _node_def_pb2.NodeDef() # Delegate to each operand to produce the proper new input for this stub node. # In particular, an aggregate input will now be a Pack of some previously # non-fused things. for input_index in sorted_input_indices: inputs = call.inputs[input_index] new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend( sorted_input_indices) # Ceate the function new_node.op = call.function_name new_node.name = call.uuid out.node.extend([new_node]) # Now call each output argument to give them a chance to make the proper # output type and add it to our new_node. output_dtypes = [] for output_index in sorted_output_indices: output = call.outputs[output_index] output_dtype = (output.aggregate_and_return_name_for_output( new_node.name, output_index, out)) output_dtypes.append(output_dtype) new_node.attr["_output_types"].list.type[:] = output_dtypes # TODO(aselle): what is right here? new_node.attr["_output_quantized"].b = False # Add post output nodes that do not depend on the outputs for n in nodes_after_fuse: should_keep = True for input_name in name_to_input_name[n]: if input_name in nodes_deleted_by_fuse: should_keep = False if should_keep: out.node.extend([_copy.deepcopy(name_to_node[n])]) # Misc. graph_def data that needs copying. out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError("Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] else: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted( list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out