def test_extract_onnx_gs_graph(self, extract_model): model, input_meta, output_meta = extract_model graph = gs.import_onnx(model) graph = extract_subgraph(graph, input_meta, output_meta) assert isinstance(graph, gs.Graph) assert len(graph.nodes) == 1 assert len(graph.inputs) == 1 assert graph.inputs[0].name == "X" assert len(graph.outputs) == 1 assert graph.outputs[0].name == "identity_out_0"
def test_extract_onnx_gs_graph(self, extract_model): model, input_meta, output_meta = extract_model graph = gs_from_onnx(model) subgraph = extract_subgraph(graph, input_meta, output_meta) # Make sure original graph isn't modified. assert len(graph.nodes) == 2 assert isinstance(subgraph, gs.Graph) assert len(subgraph.nodes) == 1 assert len(subgraph.inputs) == 1 assert subgraph.inputs[0].name == "X" assert len(subgraph.outputs) == 1 assert subgraph.outputs[0].name == "identity_out_0"
def test_extract_passes_no_output_shape(self, extract_model): model, input_meta, output_meta = extract_model output_meta["identity_out_0"].shape = None model = extract_subgraph(model, input_meta, output_meta) self.check_model(model)
def test_extract_passes_no_input_dtype(self, extract_model): model, input_meta, output_meta = extract_model input_meta["X"].dtype = None model = extract_subgraph(model, input_meta, output_meta) self.check_model(model)
def test_extract_onnx_model_no_output_meta(self, extract_model): model, input_meta, _ = extract_model model = extract_subgraph(model, input_metadata=input_meta) assert model.graph.output[0].name == "identity_out_2"
def test_extract_onnx_model_no_input_meta(self, extract_model): model, _, output_meta = extract_model model = extract_subgraph(model, output_metadata=output_meta) self.check_model(model)
def test_extract_onnx_model(self, extract_model): original_model, input_meta, output_meta = extract_model model = extract_subgraph(original_model, input_meta, output_meta) assert original_model.graph.output[0].name == "identity_out_2" self.check_model(model)