Exemple #1
0
    def test_extract_onnx_gs_graph(self, extract_model):
        model, input_meta, output_meta = extract_model
        graph = gs.import_onnx(model)
        graph = extract_subgraph(graph, input_meta, output_meta)
        assert isinstance(graph, gs.Graph)
        assert len(graph.nodes) == 1

        assert len(graph.inputs) == 1
        assert graph.inputs[0].name == "X"

        assert len(graph.outputs) == 1
        assert graph.outputs[0].name == "identity_out_0"
    def test_extract_onnx_gs_graph(self, extract_model):
        model, input_meta, output_meta = extract_model
        graph = gs_from_onnx(model)
        subgraph = extract_subgraph(graph, input_meta, output_meta)
        # Make sure original graph isn't modified.
        assert len(graph.nodes) == 2

        assert isinstance(subgraph, gs.Graph)
        assert len(subgraph.nodes) == 1

        assert len(subgraph.inputs) == 1
        assert subgraph.inputs[0].name == "X"

        assert len(subgraph.outputs) == 1
        assert subgraph.outputs[0].name == "identity_out_0"
 def test_extract_passes_no_output_shape(self, extract_model):
     model, input_meta, output_meta = extract_model
     output_meta["identity_out_0"].shape = None
     model = extract_subgraph(model, input_meta, output_meta)
     self.check_model(model)
 def test_extract_passes_no_input_dtype(self, extract_model):
     model, input_meta, output_meta = extract_model
     input_meta["X"].dtype = None
     model = extract_subgraph(model, input_meta, output_meta)
     self.check_model(model)
 def test_extract_onnx_model_no_output_meta(self, extract_model):
     model, input_meta, _ = extract_model
     model = extract_subgraph(model, input_metadata=input_meta)
     assert model.graph.output[0].name == "identity_out_2"
 def test_extract_onnx_model_no_input_meta(self, extract_model):
     model, _, output_meta = extract_model
     model = extract_subgraph(model, output_metadata=output_meta)
     self.check_model(model)
    def test_extract_onnx_model(self, extract_model):
        original_model, input_meta, output_meta = extract_model
        model = extract_subgraph(original_model, input_meta, output_meta)

        assert original_model.graph.output[0].name == "identity_out_2"
        self.check_model(model)