コード例 #1
0
        def attrs_to_dict(attrs):
            attr_dict = OrderedDict()
            for attr in attrs:

                def process_attr(attr_str: str):
                    processed = getattr(attr,
                                        ONNX_PYTHON_ATTR_MAPPING[attr_str])
                    if attr_str == "STRING":
                        processed = processed.decode()
                    elif attr_str == "TENSOR":
                        processed = OnnxImporter.import_tensor(processed)
                    elif attr_str == "GRAPH":
                        processed = OnnxImporter.import_graph(
                            processed,
                            misc.combine_dicts(tensor_map,
                                               subgraph_tensor_map))
                    elif attr_str == "FLOATS" or attr_str == "INTS":
                        processed = list(processed)
                    elif attr_str == "STRINGS":
                        processed = [p.decode() for p in processed]
                    return processed

                if attr.type in ATTR_TYPE_MAPPING:
                    attr_str = ATTR_TYPE_MAPPING[attr.type]
                    if attr_str in ONNX_PYTHON_ATTR_MAPPING:
                        attr_dict[attr.name] = process_attr(attr_str)
                    else:
                        G_LOGGER.warning(
                            "Attribute of type {:} is currently unsupported. Skipping attribute."
                            .format(attr_str))
                else:
                    G_LOGGER.warning(
                        "Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute."
                        .format(attr.type))
            return attr_dict
コード例 #2
0
ファイル: graph.py プロジェクト: npanpaliya/TensorRT-1
        def register_func(func):
            if hasattr(Graph, func.__name__):
                G_LOGGER.warning("Registered function: {:} is hidden by a Graph attribute or function with the same name. This function will never be called!".format(func.__name__))

            for opset in opsets:
                Graph.OPSET_FUNC_MAP[opset][func.__name__] = func
            return func
コード例 #3
0
 def get_opset(model: onnx.ModelProto):
     try:
         return model.opset_import[0].version
     except:
         G_LOGGER.warning(
             "Model does not contain opset information! Using default opset."
         )
         return None
コード例 #4
0
    def fold_constants(self):
        """
        Folds constants in-place in the graph. The graph must be topologically sorted prior to
        calling this function (see `toposort()`).

        This function will not remove constants after folding them. In order to get rid of
        these hanging nodes, you can run the `cleanup()` function.

        *Note: Due to how this function is implemented, the graph must be exportable to ONNX,
        and evaluable in ONNX-Runtime. Additionally, ONNX-Runtime must be installed.*

        Returns:
            self
        """
        import onnxruntime
        from onnx_graphsurgeon.exporters.onnx_exporter import export_onnx

        temp_graph = copy.deepcopy(self)

        # Since the graph is topologically sorted, this should find all constant nodes in the graph.
        graph_constants = {
            tensor.name: tensor
            for tensor in temp_graph.tensors().values()
            if isinstance(tensor, Constant)
        }
        for node in temp_graph.nodes:
            if all([inp.name in graph_constants for inp in node.inputs]):
                graph_constants.update({out.name: out for out in node.outputs})

        # Next build a graph with just the constants, and evaluate - no need to evaluate constants
        outputs_to_evaluate = [
            tensor for tensor in graph_constants.values()
            if isinstance(tensor, Variable)
        ]

        if not outputs_to_evaluate:
            G_LOGGER.warning(
                "Could not find any operations in this graph that can be folded. This could mean that constant folding has already been run on this graph. Skipping."
            )
            return self

        output_names = [out.name for out in outputs_to_evaluate]

        temp_graph.outputs = outputs_to_evaluate
        temp_graph.cleanup()

        # Determining types is not trivial, and ONNX-RT does its own type inference.
        sess = onnxruntime.InferenceSession(
            export_onnx(temp_graph, do_type_check=False).SerializeToString())
        constant_values = sess.run(output_names, {})

        # Finally, replace the Variables in the original graph with constants.
        graph_tensors = self.tensors()
        for name, values in zip(output_names, constant_values):
            graph_tensors[name].to_constant(values)
            graph_tensors[name].inputs.clear()  # Constants do not need inputs

        return self
コード例 #5
0
        def register_func(func):
            if hasattr(Graph, func.__name__):
                G_LOGGER.warning("Registered function: {:} is hidden by a Graph attribute or function with the same name. This function will never be called!".format(func.__name__))

            # Default behavior is to register functions for all opsets.
            if opsets is None:
                Graph.GLOBAL_FUNC_MAP[func.__name__] = func
            else:
                for opset in opsets:
                    Graph.OPSET_FUNC_MAP[opset][func.__name__] = func
            return func
コード例 #6
0
 def get_opset(model: onnx.ModelProto):
     try:
         for importer in OnnxImporter.get_import_domains(model):
             if importer.domain == "" or importer.domain == "ai.onnx":
                 return importer.version
         G_LOGGER.warning(
             "Model does not contain ONNX domain opset information! Using default opset."
         )
         return None
     except:
         G_LOGGER.warning(
             "Model does not contain opset information! Using default opset."
         )
         return None