def attrs_to_dict(attrs): attr_dict = OrderedDict() for attr in attrs: def process_attr(attr_str: str): processed = getattr(attr, ONNX_PYTHON_ATTR_MAPPING[attr_str]) if attr_str == "STRING": processed = processed.decode() elif attr_str == "TENSOR": processed = OnnxImporter.import_tensor(processed) elif attr_str == "GRAPH": processed = OnnxImporter.import_graph( processed, misc.combine_dicts(tensor_map, subgraph_tensor_map)) elif attr_str == "FLOATS" or attr_str == "INTS": processed = list(processed) elif attr_str == "STRINGS": processed = [p.decode() for p in processed] return processed if attr.type in ATTR_TYPE_MAPPING: attr_str = ATTR_TYPE_MAPPING[attr.type] if attr_str in ONNX_PYTHON_ATTR_MAPPING: attr_dict[attr.name] = process_attr(attr_str) else: G_LOGGER.warning( "Attribute of type {:} is currently unsupported. Skipping attribute." .format(attr_str)) else: G_LOGGER.warning( "Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute." .format(attr.type)) return attr_dict
def register_func(func): if hasattr(Graph, func.__name__): G_LOGGER.warning("Registered function: {:} is hidden by a Graph attribute or function with the same name. This function will never be called!".format(func.__name__)) for opset in opsets: Graph.OPSET_FUNC_MAP[opset][func.__name__] = func return func
def get_opset(model: onnx.ModelProto): try: return model.opset_import[0].version except: G_LOGGER.warning( "Model does not contain opset information! Using default opset." ) return None
def fold_constants(self): """ Folds constants in-place in the graph. The graph must be topologically sorted prior to calling this function (see `toposort()`). This function will not remove constants after folding them. In order to get rid of these hanging nodes, you can run the `cleanup()` function. *Note: Due to how this function is implemented, the graph must be exportable to ONNX, and evaluable in ONNX-Runtime. Additionally, ONNX-Runtime must be installed.* Returns: self """ import onnxruntime from onnx_graphsurgeon.exporters.onnx_exporter import export_onnx temp_graph = copy.deepcopy(self) # Since the graph is topologically sorted, this should find all constant nodes in the graph. graph_constants = { tensor.name: tensor for tensor in temp_graph.tensors().values() if isinstance(tensor, Constant) } for node in temp_graph.nodes: if all([inp.name in graph_constants for inp in node.inputs]): graph_constants.update({out.name: out for out in node.outputs}) # Next build a graph with just the constants, and evaluate - no need to evaluate constants outputs_to_evaluate = [ tensor for tensor in graph_constants.values() if isinstance(tensor, Variable) ] if not outputs_to_evaluate: G_LOGGER.warning( "Could not find any operations in this graph that can be folded. This could mean that constant folding has already been run on this graph. Skipping." ) return self output_names = [out.name for out in outputs_to_evaluate] temp_graph.outputs = outputs_to_evaluate temp_graph.cleanup() # Determining types is not trivial, and ONNX-RT does its own type inference. sess = onnxruntime.InferenceSession( export_onnx(temp_graph, do_type_check=False).SerializeToString()) constant_values = sess.run(output_names, {}) # Finally, replace the Variables in the original graph with constants. graph_tensors = self.tensors() for name, values in zip(output_names, constant_values): graph_tensors[name].to_constant(values) graph_tensors[name].inputs.clear() # Constants do not need inputs return self
def register_func(func): if hasattr(Graph, func.__name__): G_LOGGER.warning("Registered function: {:} is hidden by a Graph attribute or function with the same name. This function will never be called!".format(func.__name__)) # Default behavior is to register functions for all opsets. if opsets is None: Graph.GLOBAL_FUNC_MAP[func.__name__] = func else: for opset in opsets: Graph.OPSET_FUNC_MAP[opset][func.__name__] = func return func
def get_opset(model: onnx.ModelProto): try: for importer in OnnxImporter.get_import_domains(model): if importer.domain == "" or importer.domain == "ai.onnx": return importer.version G_LOGGER.warning( "Model does not contain ONNX domain opset information! Using default opset." ) return None except: G_LOGGER.warning( "Model does not contain opset information! Using default opset." ) return None