def load_graph(fname, target): model_proto = onnx.ModelProto() with open(fname, "rb") as f: data = f.read() model_proto.ParseFromString(data) g = GraphUtil.create_graph_from_onnx_model(model_proto, target) return g, model_proto
def test_extra_opset(self): extra_opset = [ utils.make_opsetid(constants.MICROSOFT_DOMAIN, 1), utils.make_opsetid("my.domain", 1024), ] with tf.Session() as sess: x = tf.placeholder(tf.float32, [2, 3], name="input1") x_ = tf.add(x, x) _ = tf.identity(x_, name="output") g = process_tf_graph(sess.graph, opset=self.config.opset, extra_opset=extra_opset) self.assertEqual(g.opset, self.config.opset) self.assertEqual(g.extra_opset, extra_opset) # convert between graph and model proto, make sure extra opset is preserved model_proto = g.make_model("test") model_proto = GraphUtil.optimize_model_proto(model_proto) g = GraphUtil.create_graph_from_onnx_model(model_proto) self.assertEqual(g.opset, self.config.opset) self.assertEqual(g.extra_opset, extra_opset)
def load_graph(fname): model_proto = onnx.ModelProto() g = GraphUtil.create_graph_from_onnx_model(model_proto) return g, model_proto.producer_name