def load_node_tests(data_dir=os.path.join(DATA_DIR, 'node')): testcases = [] for test_name in os.listdir(data_dir): case_dir = os.path.join(data_dir, test_name) node = onnx.NodeProto() with open(os.path.join(case_dir, 'node.pb'), 'rb') as f: node.ParseFromString(f.read()) inputs = [] inputs_num = len(glob.glob(os.path.join(case_dir, 'input_*.pb'))) for i in range(inputs_num): input_file = os.path.join(case_dir, 'input_{}.pb'.format(i)) tensor = onnx.TensorProto() with open(input_file, 'rb') as f: tensor.ParseFromString(f.read()) inputs.append(tensor) outputs = [] outputs_num = len(glob.glob(os.path.join(case_dir, 'output_*.pb'))) for i in range(outputs_num): output_file = os.path.join(case_dir, 'output_{}.pb'.format(i)) tensor = onnx.TensorProto() with open(output_file, 'rb') as f: tensor.ParseFromString(f.read()) outputs.append(tensor) testcases.append( NodeTestCase(node, inputs, outputs, test_name)) return testcases
def load_node_tests(data_dir=os.path.join(DATA_DIR, 'node')): '''Load node test cases from on-disk data files. ''' testcases = [] for test_name in os.listdir(data_dir): case_dir = os.path.join(data_dir, test_name) # skip the non-dir files, such as generated __init__.py. if not os.path.isdir(case_dir): continue node = onnx.NodeProto() with open(os.path.join(case_dir, 'node.pb'), 'rb') as f: node.ParseFromString(f.read()) inputs = [] inputs_num = len(glob.glob(os.path.join(case_dir, 'input_*.pb'))) for i in range(inputs_num): input_file = os.path.join(case_dir, 'input_{}.pb'.format(i)) tensor = onnx.TensorProto() with open(input_file, 'rb') as f: tensor.ParseFromString(f.read()) inputs.append(tensor) outputs = [] outputs_num = len(glob.glob(os.path.join(case_dir, 'output_*.pb'))) for i in range(outputs_num): output_file = os.path.join(case_dir, 'output_{}.pb'.format(i)) tensor = onnx.TensorProto() with open(output_file, 'rb') as f: tensor.ParseFromString(f.read()) outputs.append(tensor) testcases.append(NodeTestCase(node, inputs, outputs, test_name)) return testcases
def trim_unused_outputs(node, graph): trimmed = onnx.NodeProto() trimmed.CopyFrom(node) graph_outputs = [o.name for o in graph.output] for o_idx in range(len(node.output)): o = node.output[o_idx] use = [n for n in graph.node if o in list(n.input) + graph_outputs] if not use: trimmed.output[o_idx] = '' return trimmed
def main(args): """ Replace custom layer atomic operations with single Hardshrink operation for hardcoded ONNX model acquired from TODO """ # Load model onnx_model = onnx.load(args.model_path) if args.verbose > 0: print('1. Before removal: ') print_graph(onnx_model.graph, args.verbose) # Remove atomic operations node_indices_to_remove = [ *list(range(1, 11)), *list(range(12, 22)), *list(range(25, 35)), ] for index in node_indices_to_remove[::-1]: node = onnx_model.graph.node[index] onnx_model.graph.node.remove(node) if args.verbose > 0: print('2. After removal: ') print_graph(onnx_model.graph, args.verbose) # Insert Hardshrink nodes for i in [5, 2, 1]: node_hs = onnx.NodeProto() node_hs.op_type = 'Hardshrink' node_hs.name = f'hs_{i}' node_hs.output.insert(0, f'hs_output_{i}') node_hs.input.insert(0, onnx_model.graph.node[i - 1].output[0]) onnx_model.graph.node[i].input[0] = f'hs_output_{i}' onnx_model.graph.node.insert(i, node_hs) if args.verbose > 0: print('3. After insertion: ') print_graph(onnx_model.graph, args.verbose) # Save model onnx.save(onnx_model, args.save_path)
def main(): # Configurable parameters from command line parser = argparse.ArgumentParser(description='ONNX Modifying Example') parser.add_argument('--onnx', help='onnx file to modify') parser.add_argument( '--output', default="output.onnx", help='input batch size for testing (default: output.onnx)') args = parser.parse_args() # Load ONNX file model = onnx.load(args.onnx) # Retrieve graph_def graph = model.graph node_input_new = False counter_conv_nodes_updated = 0 nodes_to_delete = [] # Iterate through all the nodes for i, node in enumerate(graph.node): if not node_input_new: node_input_new = graph.node[0].input[0] if counter_conv_nodes_updated == 2: break if node.op_type == 'Conv': # Update inputs of any Conv node and converting Conv->CoordConv graph.node[i].input.remove(graph.node[i].input[0]) graph.node[i].input.insert(0, node_input_new) graph.node[i].op_type = COORD_CONV_OP_TYPE counter_conv_nodes_updated += 1 elif node.op_type == 'Relu': # Saving output of previous node node_input_new = graph.node[i].output[0] else: # Add node to list of removable nodes nodes_to_delete.append(i) for i in nodes_to_delete[::-1]: # Remove unnecessary nodes n = graph.node[i] graph.node.remove(n) # insert AC nodes i = 0 while i < len(graph.node): if graph.node[i].op_type == COORD_CONV_OP_TYPE: print('here') # Create an ac node node_ac = onnx.NodeProto() node_ac.op_type = "CoordConvAC" node_ac.output.insert(0, f"ac_output_{i}") node_ac.input.insert(0, graph.node[i].input[0]) graph.node[i].input[0] = f"ac_output_{i}" graph.node.insert(i, node_ac) i += 1 i += 1 # Generate model_cropped from modified graph model_cropped = onnx.helper.make_model(graph) print(onnx.helper.printable_graph(model_cropped.graph)) print("Inputs:", model_cropped.graph.node[0].input, "Outputs:", model_cropped.graph.node[-1].output) # Save the serialized model onnx.save(model_cropped, args.output)
def generate_proto_nodes( self, g: torch._C.Graph, onnx_vars: Dict[TorchValueID, onnx.TensorProto], val_tab: Dict[TorchValueID, ONNXValueID], ) -> Tuple[List[onnx.NodeProto], Dict[TorchValueID, onnx.TensorProto], Dict[TorchValueID, ONNXValueID],]: node_name_counter: int = 0 def node_name(n: torch._C.Node) -> str: nonlocal node_name_counter op = n.kind().split("::")[-1] node_name_counter += 1 return f"{op}_{node_name_counter - 1}" val_tab_rev: Dict[ONNXValueID, TorchValueID] = {v: k for k, v in val_tab.items()} def register_val_name(id: TorchValueID, name: ONNXValueID, shadow: bool = False) -> ONNXValueID: assert id not in val_tab, f"{id} already registered in {g}" if shadow: new_name = name c = 1 while new_name in val_tab_rev: new_name = ONNXValueID(f"{name}_{c}") c += 1 name = new_name else: assert name not in val_tab_rev, f"{name} already registered in {g}" val_tab_rev[name] = id val_tab[id] = name assert len(val_tab_rev) == len(val_tab) return name def value_name(v: torch._C.Value) -> ONNXValueID: if _unique_id(v) in self.attrs: return self.attrs[_unique_id(v)] n: torch._C.Node = v.node() or v.uses()[0].user scope: str = self.node_scope.get(n, n.scopeName()) if len(scope) > 0: scope += "." scope = _remove_prefix(scope.split("/")[-1], "__module.") scope = _remove_prefix(scope, f"{_ppe_ignore_scope}.") return ONNXValueID(f"{scope}{v.debugName()}") def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphProto: branch_nodes, _, _ = self.generate_proto_nodes(cast(torch._C.Graph, b), onnx_vars, val_tab) branch_inputs: List[onnx.ValueInfoProto] = [] for i in b.inputs(): branch_inputs.append(onnx.ValueInfoProto()) branch_inputs[-1].name = val_tab[_unique_id(i)] if not self.strip_doc_string: branch_inputs[-1].doc_string = repr(i) branch_outputs: List[onnx.ValueInfoProto] = [] for i in b.outputs(): branch_outputs.append(onnx.ValueInfoProto()) branch_outputs[-1].name = val_tab[_unique_id(i)] if not self.strip_doc_string: branch_outputs[-1].doc_string = repr(i) branch_graph: onnx.GraphProto = onnx.helper.make_graph( name=name, nodes=branch_nodes, # TODO(twata): Support initializers if needed inputs=branch_inputs, outputs=branch_outputs, doc_string=doc_string, ) return branch_graph # Nodes and initializers onnx_nodes: List[onnx.NodeProto] = [] self_count: int = 0 # Run only in root graph if self.g == g: if self.input_names is not None: for idx, v in enumerate(g.inputs()): if self.is_self(v): # Skip module's self input self_count += 1 continue register_val_name(_unique_id(v), ONNXValueID(self.input_names[idx - self_count])) assert (len(list(g.inputs())) - self_count) == len(self.input_names) if self.output_names is not None: if len(self.output_names) != len(list(g.outputs())): warnings.warn(f"Specified output_names ({self.output_names}) count and graph outputs ({list(g.outputs())}) count differ") for idx, v in enumerate(g.outputs()): if idx >= len(self.output_names): break register_val_name(_unique_id(v), ONNXValueID(self.output_names[idx])) none_nodes: List[torch._C.Node] = [] for n in g.nodes(): # Skip None value node if n.mustBeNone(): none_nodes.append(n) continue if n.kind() == "prim::GetAttr": continue if n.kind() == "onnx::Constant" : if len(n.output().uses()) == 0: warnings.warn(f"Unused constant left: {n}") continue # Skip constant folded initialzers if _unique_id(n.output()) in self.attrs: continue for i in n.inputs(): if self.is_self(i): continue if i.node() is not None and i.node() in none_nodes: continue if _unique_id(i) in self.attrs and _unique_id(i) not in onnx_vars: k: ONNXValueID = self.attrs[_unique_id(i)] assert isinstance(self.vars[k], torch.Tensor) t: torch.Tensor = cast(torch.Tensor, self.vars[k]) onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k) register_val_name(_unique_id(i), value_name(i), shadow=True) continue if _unique_id(i) not in val_tab: register_val_name(_unique_id(v), value_name(i)) for o in n.outputs(): if _unique_id(o) not in val_tab: register_val_name(_unique_id(o), value_name(o), shadow=True) def assign_onnx_values( onnx_values: List[str], prefix: str, torch_values: Iterator[torch._C.Value], ) -> None: assert len(onnx_values) == 0 for v in torch_values: if v.node() is not None and v.node() in none_nodes: onnx_values.append("") continue k: ONNXValueID = val_tab.get(_unique_id(v), value_name(v)) if _unique_id(v) not in val_tab: register_val_name(_unique_id(v), k) onnx_values.append(k) new_nd = onnx.NodeProto() new_nd.name = node_name(n) new_nd.op_type = n.kind().split("::")[-1] if n.kind() == "prim::If": if n in self.node_doc_string: new_nd.doc_string = f"""## Symbolic node {n} {self.node_doc_string[n]}""" blocks: List[torch._C.Block] = list(n.blocks()) assert len(blocks) == 2 for attr_name, block in zip(["then_branch", "else_branch"], blocks): sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string) new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g)) else: assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}" if n in self.node_doc_string: new_nd.doc_string = self.node_doc_string[n] for attr_name in n.attributeNames(): if n.kindOf(attr_name) == "t": attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name))) else: attr = onnx.helper.make_attribute(attr_name, n[attr_name]) new_nd.attribute.append(attr) assign_onnx_values(new_nd.input, new_nd.name, n.inputs()) assign_onnx_values(new_nd.output, new_nd.name, n.outputs()) onnx_nodes.append(new_nd) return onnx_nodes, onnx_vars, val_tab
def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs): updated_graphs = kwargs["updated_graphs"] node_to_consumers = kwargs["node_to_consumers"] validate_updates = kwargs["validate_updates"] nodes_to_update = [] for node in filter(lambda node: node.op_type == "DequantizeLinear", graph.node): # node providing graph output won't have consumer nodes consumers = node_to_consumers[node] if node in node_to_consumers else [] if len(consumers) > 1: if not all(consumer in graph.node for consumer in consumers): # TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that # value as is (no need to handle recursing into the subgraph) and update the consumers in this # graph only raise IndexError( "DequantizeLinear node output is consumed by a subgraph. " "This is not currently supported." ) nodes_to_update.append(node) if validate_updates: if nodes_to_update: # internal error. we somehow missed an update in the first pass when validate_upates was false raise ValueError("Graph still has DequantizeLinear nodes with multiple consumers.") return if nodes_to_update: dup_idx = 0 new_graph = onnx.GraphProto() graph_outputs = set([output.name for output in graph.output]) for node in graph.node: new_graph.node.append(node) if node in nodes_to_update: is_graph_output = node.output[0] in graph_outputs # create duplicate DQ nodes as needed so that there is one consumer per node. # this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups. # if the node produces a graph output we need a duplicate DQ node for every consumer node. # if not, we can leave the first consumer as is and create duplicate nodes for the other consumers. start_idx = 0 if is_graph_output else 1 consumers = list(node_to_consumers[node])[start_idx:] for idx, consumer in enumerate(consumers): # create duplicate DQ node duplicate = onnx.NodeProto() duplicate.CopyFrom(node) # update node name for debugging. use the global dup idx for node duplication duplicate.name += f"/qdq_utils_dup_{dup_idx}" # update output. use the local idx for value duplication orig_output = node.output[0] new_output = f"{orig_output}/qdq_utils_dup_{idx}" duplicate.output[0] = new_output # update input on the consumer node. for input_idx, input_name in enumerate(consumer.input): if input_name == orig_output: consumer.input[input_idx] = new_output new_graph.node.append(duplicate) dup_idx += 1 # replace nodes del graph.node[:] graph.node.extend(new_graph.node) updated_graphs.append(graph)