def import_onnx_model(model: onnx.ModelProto) -> Function: onnx.checker.check_model(model) model_byte_string = model.SerializeToString() ie = IECore() ie_network = ie.read_network(model=model_byte_string, weights=b"", init_from_buffer=True) capsule = ie_network._get_function_capsule() ng_function = Function.from_capsule(capsule) return ng_function
def test_GetIENetworkFromNGraph(): element_type = Type.f32 param = Parameter(element_type, Shape([1, 3, 22, 22])) relu = Relu(param) func = Function([relu], [param], 'test') caps = Function.to_capsule(func) cnnNetwork = IENetwork(caps) assert cnnNetwork != None assert cnnNetwork.get_function() != None caps2 = cnnNetwork.get_function() func2 = Function.from_capsule(caps2) assert func2 != None
def test_import_onnx_function(): model_path = os.path.join(os.path.dirname(__file__), "models/add_abc.onnx") ie = IECore() ie_network = ie.read_network(model=model_path) capsule = ie_network._get_function_capsule() ng_function = Function.from_capsule(capsule) dtype = np.float32 value_a = np.array([1.0], dtype=dtype) value_b = np.array([2.0], dtype=dtype) value_c = np.array([3.0], dtype=dtype) runtime = get_runtime() computation = runtime.computation(ng_function) result = computation(value_a, value_b, value_c) assert np.allclose(result, np.array([6], dtype=dtype))
def function_from_cnn(cnn_network: IENetwork) -> Function: """Get nGraph function from Inference Engine CNN network.""" capsule = cnn_network._get_function_capsule() ng_function = Function.from_capsule(capsule) return ng_function