def load_graph(fname): model_proto = onnx.ModelProto() with open(fname, "rb") as f: data = f.read() model_proto.ParseFromString(data) g = GraphUtil.create_graph_from_onnx_model(model_proto) return g, model_proto.producer_name
def test_match_flipped(self): n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1") n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2") n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3") graph_proto = helper.make_graph( nodes=[n1, n2, n3], name="test", inputs=[ helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]), helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2]) ], outputs=[ helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2]) ], initializer=[]) g = GraphUtil.create_graph_from_onnx_graph(graph_proto) pattern = OpTypePattern( 'Mul', inputs=[OpTypePattern('Add'), OpTypePattern('Sub')]) ops = g.get_nodes() matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) self.assertEqual(1, len(match_results))
def test_rewrite_subgraph(self): graph_proto = self.sample_net() g = GraphUtil.create_graph_from_onnx_graph(graph_proto) pattern = \ OpTypePattern('Abs', name='output', inputs=[ OpTypePattern('Add', name='input') ]) ops = g.get_nodes() matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input_node = match.get_op('input') output_node = match.get_op('output') op_name = utils.make_name("ReplacedOp") out_name = utils.port_name(op_name) new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name) g.replace_all_inputs(output_node.output[0], new_node.output[0]) # ops=ops for n in set(match.get_nodes()): g.remove_node(n.name) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] ' \ 'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__6 [op_type=Sub] ' \ 'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__6 n3:0 -> ReplacedOp__6 ' \ 'ReplacedOp__6:0 -> n6 ReplacedOp__6:0 -> n5_graph_outputs_Identity__4 }' self.assertEqual(expected, result)
def test_node_attr_onnx(self): n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", my_attr="my_attr") graph_proto = helper.make_graph( nodes=[n1], name="test", inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]), helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2]) ], outputs=[ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2]) ], initializer=[]) g = GraphUtil.create_graph_from_onnx_graph(graph_proto) n1 = g.get_node_by_name("n1") self.assertTrue("my_attr" in n1.attr) self.assertTrue("my_attr" not in n1.get_onnx_attrs()) n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr") graph_proto = helper.make_graph( nodes=[n1], name="test", inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]), helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2]) ], outputs=[ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2]) ], initializer=[]) g = GraphUtil.create_graph_from_onnx_graph(graph_proto) n1 = g.get_node_by_name("n1") self.assertTrue("my_attr" in n1.attr) self.assertTrue("my_attr" in n1.get_onnx_attrs())
def test_insert_node2(self): graph_proto = self.sample_net() g = GraphUtil.create_graph_from_onnx_graph(graph_proto) g.insert_new_node_on_output("Abs", "n1:0", name="n7") ops = g.get_nodes() g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \ 'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \ 'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 ' \ 'n5_raw_output___3:0 -> n6 n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }' self.assertEqual(expected, result)
def test_remove_input(self): graph_proto = self.sample_net() g = GraphUtil.create_graph_from_onnx_graph(graph_proto) n4 = g.get_node_by_name("n4") g.remove_input(n4, n4.input[1]) ops = g.get_nodes() g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \ 'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \ 'n5_graph_outputs_Identity__4 [op_type=Identity] input -> n1 n1:0 -> n3 ' \ 'n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 n5_raw_output___3:0 -> n6 ' \ 'n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }' self.assertEqual(expected, result)
def test_make_const_string(self): graph_proto = self.sample_net() g = GraphUtil.create_graph_from_onnx_graph(graph_proto) arr1 = np.array("test", np.object) arr2 = np.array([["A", "B"], ["C", "D"]], np.object) arr3 = np.array(b"test", np.object) arr4 = np.array([[b"A", b"B"], [b"C", b"D"]], np.object) const1 = g.make_const("const1", arr1) const2 = g.make_const("const2", arr2) const3 = g.make_const("const3", arr3) const4 = g.make_const("const4", arr4) np.testing.assert_equal(const1.get_tensor_value(False), arr1) np.testing.assert_equal(const2.get_tensor_value(False), arr2) np.testing.assert_equal(const3.get_tensor_value(False), arr1) np.testing.assert_equal(const4.get_tensor_value(False), arr2)
def test_data_format(self): n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", data_format="NHWC") graph_proto = helper.make_graph( nodes=[n1], name="test", inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]), helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2]) ], outputs=[ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2]) ], initializer=[]) g = GraphUtil.create_graph_from_onnx_graph(graph_proto) n = g.get_node_by_name("n1") self.assertEqual(n.data_format, "NHWC") self.assertTrue(n.is_nhwc())