Exemplo n.º 1
0
 def test_export_graph(self, model):
     onnx_graph = model.load().graph
     graph = OnnxImporter.import_graph(onnx_graph)
     exported_onnx_graph = OnnxExporter.export_graph(graph)
     imported_graph = OnnxImporter.import_graph(exported_onnx_graph)
     assert graph == imported_graph
     assert graph.opset == imported_graph.opset
     # ONNX exports the initializers in this model differently after importing - ONNX GS can't do much about this.
     if model.path != lstm_model().path:
         assert onnx_graph == exported_onnx_graph
Exemplo n.º 2
0
 def test_import_graph_tensor_map_preserved(self):
     model = identity_model()
     tensor_map = OrderedDict()
     graph = OnnxImporter.import_graph(model.load().graph,
                                       tensor_map=tensor_map)
     assert len(tensor_map) == 0
     model.assert_equal(graph)
Exemplo n.º 3
0
 def test_import_graph_value_info(self):
     model = onnx.shape_inference.infer_shapes(identity_model().load())
     graph = OnnxImporter.import_graph(model.graph)
     tensors = graph.tensors()
     assert all(
         [type(tensor) == Variable and tensor.dtype is not None and tensor.shape for tensor in tensors.values()]
     )
Exemplo n.º 4
0
def import_onnx(onnx_model: "onnx.ModelProto") -> Graph:
    """
    Import an onnx-graphsurgeon Graph from the provided ONNX model.

    Args:
        onnx_model (onnx.ModelProto): The ONNX model.

    Returns:
        Graph: A corresponding onnx-graphsurgeon Graph.
    """
    from onnx_graphsurgeon.importers.onnx_importer import OnnxImporter

    return OnnxImporter.import_graph(onnx_model.graph,
                                     opset=OnnxImporter.get_opset(onnx_model))
Exemplo n.º 5
0
 def test_import_graph(self, model):
     graph = OnnxImporter.import_graph(model.load().graph)
     model.assert_equal(graph)
Exemplo n.º 6
0
 def test_import_graph_with_dim_param(self):
     model = dim_param_model()
     graph = OnnxImporter.import_graph(model.load().graph)
     model.assert_equal(graph)
Exemplo n.º 7
0
 def test_import_graph_with_initializer(self):
     model = lstm_model()
     graph = OnnxImporter.import_graph(model.load().graph)
     model.assert_equal(graph)
Exemplo n.º 8
0
 def test_export(self):
     with tempfile.NamedTemporaryFile() as f:
         onnx_model = gs.export_onnx(self.imported_graph)
         assert onnx_model
         assert OnnxImporter.import_graph(
             onnx_model.graph) == self.imported_graph
Exemplo n.º 9
0
 def setup_method(self):
     self.imported_graph = OnnxImporter.import_graph(
         identity_model().load().graph)