예제 #1
0
def load_graph(fname):
    model_proto = onnx.ModelProto()
    with open(fname, "rb") as f:
        data = f.read()
        model_proto.ParseFromString(data)
    g = GraphUtil.create_graph_from_onnx_model(model_proto)
    return g, model_proto.producer_name
예제 #2
0
    def test_match_flipped(self):
        n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
        n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
        n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")

        graph_proto = helper.make_graph(
            nodes=[n1, n2, n3],
            name="test",
            inputs=[
                helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
                helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])
            ],
            outputs=[
                helper.make_tensor_value_info("n2:0", TensorProto.FLOAT,
                                              [2, 2])
            ],
            initializer=[])
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        pattern = OpTypePattern(
            'Mul', inputs=[OpTypePattern('Add'),
                           OpTypePattern('Sub')])
        ops = g.get_nodes()
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        self.assertEqual(1, len(match_results))
예제 #3
0
 def test_rewrite_subgraph(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = utils.make_name("ReplacedOp")
         out_name = utils.port_name(op_name)
         new_node = g.make_node("Sub",
                                inputs=input_node.input,
                                outputs=[out_name],
                                name=op_name)
         g.replace_all_inputs(output_node.output[0],
                              new_node.output[0])  # ops=ops
         for n in set(match.get_nodes()):
             g.remove_node(n.name)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] ' \
                'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__6 [op_type=Sub] ' \
                'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__6 n3:0 -> ReplacedOp__6 ' \
                'ReplacedOp__6:0 -> n6 ReplacedOp__6:0 -> n5_graph_outputs_Identity__4 }'
     self.assertEqual(expected, result)
예제 #4
0
    def test_node_attr_onnx(self):
        n1 = helper.make_node("Conv", ["X", "W"], ["Y"],
                              name="n1",
                              my_attr="my_attr")
        graph_proto = helper.make_graph(
            nodes=[n1],
            name="test",
            inputs=[
                helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
                helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])
            ],
            outputs=[
                helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])
            ],
            initializer=[])
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        n1 = g.get_node_by_name("n1")
        self.assertTrue("my_attr" in n1.attr)
        self.assertTrue("my_attr" not in n1.get_onnx_attrs())

        n1 = helper.make_node("Conv", ["X", "W"], ["Y"],
                              name="n1",
                              domain="my_domain",
                              my_attr="my_attr")
        graph_proto = helper.make_graph(
            nodes=[n1],
            name="test",
            inputs=[
                helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
                helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])
            ],
            outputs=[
                helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])
            ],
            initializer=[])
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        n1 = g.get_node_by_name("n1")
        self.assertTrue("my_attr" in n1.attr)
        self.assertTrue("my_attr" in n1.get_onnx_attrs())
예제 #5
0
 def test_insert_node2(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     g.insert_new_node_on_output("Abs", "n1:0", name="n7")
     ops = g.get_nodes()
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
                'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
                'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 ' \
                'n5_raw_output___3:0 -> n6 n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
     self.assertEqual(expected, result)
예제 #6
0
 def test_remove_input(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     n4 = g.get_node_by_name("n4")
     g.remove_input(n4, n4.input[1])
     ops = g.get_nodes()
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \
                'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
                'n5_graph_outputs_Identity__4 [op_type=Identity] input -> n1 n1:0 -> n3 ' \
                'n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 n5_raw_output___3:0 -> n6 ' \
                'n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
     self.assertEqual(expected, result)
예제 #7
0
 def test_make_const_string(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     arr1 = np.array("test", np.object)
     arr2 = np.array([["A", "B"], ["C", "D"]], np.object)
     arr3 = np.array(b"test", np.object)
     arr4 = np.array([[b"A", b"B"], [b"C", b"D"]], np.object)
     const1 = g.make_const("const1", arr1)
     const2 = g.make_const("const2", arr2)
     const3 = g.make_const("const3", arr3)
     const4 = g.make_const("const4", arr4)
     np.testing.assert_equal(const1.get_tensor_value(False), arr1)
     np.testing.assert_equal(const2.get_tensor_value(False), arr2)
     np.testing.assert_equal(const3.get_tensor_value(False), arr1)
     np.testing.assert_equal(const4.get_tensor_value(False), arr2)
예제 #8
0
 def test_data_format(self):
     n1 = helper.make_node("Conv", ["X", "W"], ["Y"],
                           name="n1",
                           data_format="NHWC")
     graph_proto = helper.make_graph(
         nodes=[n1],
         name="test",
         inputs=[
             helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
             helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])
         ],
         outputs=[
             helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])
         ],
         initializer=[])
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     n = g.get_node_by_name("n1")
     self.assertEqual(n.data_format, "NHWC")
     self.assertTrue(n.is_nhwc())