def pytorch2onnx(model, input, opset_version=11, show=False, output_file='tmp.onnx', verify=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. input (dict): We need to use this input to execute the model. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. """ model.cpu().eval() merged = input['merged'].unsqueeze(0) trimap = input['trimap'].unsqueeze(0) input = torch.cat((merged, trimap), 1) model.forward = model.forward_dummy # pytorch has some bug in pytorch1.3, we have to fix it # by replacing these existing op register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export(model, input, output_file, input_names=['cat_input'], export_params=True, keep_initializers_as_inputs=True, verbose=show, opset_version=opset_version) print(f'Successfully exported ONNX model: {output_file}') if verify: # check by onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) # get pytorch output, only concern pred_alpha pytorch_result = model(input) if isinstance(pytorch_result, (tuple, list)): pytorch_result = pytorch_result[0] pytorch_result = pytorch_result.detach().numpy() # get onnx output sess = rt.InferenceSession(output_file) onnx_result = sess.run(None, { 'cat_input': input.detach().numpy(), }) # only concern pred_alpha value if isinstance(onnx_result, (tuple, list)): onnx_result = onnx_result[0] # check the numerical value assert np.allclose( pytorch_result, onnx_result), 'The outputs are different between Pytorch and ONNX' print('The numerical values are same between Pytorch and ONNX')
def pytorch2onnx(model: nn.Module, model_type: str, img_path: str, verbose: bool = False, show: bool = False, opset_version: int = 11, output_file: str = 'tmp.onnx', verify: bool = False, dynamic_export: bool = False, device_id: int = 0): """Export PyTorch model to ONNX model and verify the outputs are same between PyTorch and ONNX. Args: model (nn.Module): PyTorch model we want to export. model_type (str): Model type, detection or recognition model. img_path (str): We need to use this input to execute the model. opset_version (int): The onnx op version. Default: 11. verbose (bool): Whether print the computation graph. Default: False. show (bool): Whether visialize final results. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between PyTorch and ONNX. Default: False. dynamic_export (bool): Whether apply dynamic export. Default: False. device_id (id): Device id to place model and data. Default: 0 """ device = torch.device(type='cuda', index=device_id) model.to(device).eval() _convert_batchnorm(model) # prepare inputs mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path) imgs = mm_inputs.pop('img') img_metas = mm_inputs.pop('img_metas') if isinstance(imgs, list): imgs = imgs[0] img_list = [img[None, :].to(device) for img in imgs] origin_forward = model.forward if (model_type == 'det'): model.forward = partial(model.simple_test, img_metas=img_metas, rescale=True) else: model.forward = partial(model.forward, img_metas=img_metas, return_loss=False, rescale=True) # pytorch has some bug in pytorch1.3, we have to fix it # by replacing these existing op register_extra_symbolics(opset_version) dynamic_axes = None if dynamic_export and model_type == 'det': dynamic_axes = { 'input': { 0: 'batch', 2: 'height', 3: 'width' }, 'output': { 0: 'batch', 2: 'height', 3: 'width' } } elif dynamic_export and model_type == 'recog': dynamic_axes = { 'input': { 0: 'batch', 3: 'width' }, 'output': { 0: 'batch', 1: 'seq_len', 2: 'num_classes' } } with torch.no_grad(): torch.onnx.export(model, (img_list[0], ), output_file, input_names=['input'], output_names=['output'], export_params=True, keep_initializers_as_inputs=False, verbose=verbose, opset_version=opset_version, dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') if verify: # check by onnx import onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5) if dynamic_export: # scale image for dynamic shape test img_list = [ nn.functional.interpolate(_, scale_factor=scale_factor) for _ in img_list ] if model_type == 'det': img_metas[0][0]['scale_factor'] = img_metas[0][0][ 'scale_factor'] * (scale_factor * 2) # check the numerical value # get pytorch output with torch.no_grad(): model.forward = origin_forward pytorch_out = model.simple_test(img_list[0], img_metas[0], rescale=True) # get onnx output if model_type == 'det': onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id) else: onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg, device_id) onnx_out = onnx_model.simple_test(img_list[0], img_metas[0], rescale=True) # compare results same_diff = 'same' if model_type == 'recog': for onnx_result, pytorch_result in zip(onnx_out, pytorch_out): if onnx_result['text'] != pytorch_result[ 'text'] or not np.allclose( np.array(onnx_result['score']), np.array(pytorch_result['score']), rtol=1e-4, atol=1e-4): same_diff = 'different' break else: for onnx_result, pytorch_result in zip( onnx_out[0]['boundary_result'], pytorch_out[0]['boundary_result']): if not np.allclose(np.array(onnx_result), np.array(pytorch_result), rtol=1e-4, atol=1e-4): same_diff = 'different' break print('The outputs are {} between PyTorch and ONNX'.format(same_diff)) if show: onnx_img = onnx_model.show_result(img_path, onnx_out[0], out_file='onnx.jpg', show=False) pytorch_img = model.show_result(img_path, pytorch_out[0], out_file='pytorch.jpg', show=False) if onnx_img is None: onnx_img = cv2.imread(img_path) if pytorch_img is None: pytorch_img = cv2.imread(img_path) cv2.imshow('PyTorch', pytorch_img) cv2.imshow('ONNXRuntime', onnx_img) cv2.waitKey() return
def pytorch2onnx(model, input_shape, opset_version=11, show=False, output_file='tmp.onnx', verify=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. input_shape (tuple): Use this input shape to construct the corresponding dummy input and execute the model. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. """ model.cpu().eval() if isinstance(model.decode_head, nn.ModuleList): num_classes = model.decode_head[-1].num_classes else: num_classes = model.decode_head.num_classes mm_inputs = _demo_mm_inputs(input_shape, num_classes) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # replace original forward function origin_forward = model.forward model.forward = partial(model.forward, img_metas=img_meta_list, return_loss=False) register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export(model, (img_list, ), output_file, export_params=True, keep_initializers_as_inputs=True, verbose=show, opset_version=opset_version) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward if verify: # check by onnx import onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) # check the numerical value # get pytorch output pytorch_result = model(img_list, img_meta_list, return_loss=False)[0] # get onnx output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) sess = rt.InferenceSession(output_file) onnx_result = sess.run( None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] if not np.allclose(pytorch_result, onnx_result): raise ValueError( 'The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX')
def pytorch2onnx(model, mm_inputs, opset_version=11, show=False, output_file='tmp.onnx', verify=False, dynamic_export=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. mm_inputs (dict): Contain the input tensors and img_metas information. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. dynamic_export (bool): Whether to export ONNX with dynamic axis. Default: False. """ model.cpu().eval() test_mode = model.test_cfg.mode if isinstance(model.decode_head, nn.ModuleList): num_classes = model.decode_head[-1].num_classes else: num_classes = model.decode_head.num_classes imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta img_list, img_meta_list = _update_input_img(img_list, img_meta_list) # replace original forward function origin_forward = model.forward model.forward = partial( model.forward, img_metas=img_meta_list, return_loss=False, rescale=True) dynamic_axes = None if dynamic_export: if test_mode == 'slide': dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}} else: dynamic_axes = { 'input': { 0: 'batch', 2: 'height', 3: 'width' }, 'output': { 1: 'batch', 2: 'height', 3: 'width' } } register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export( model, (img_list, ), output_file, input_names=['input'], output_names=['output'], export_params=True, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward if verify: # check by onnx import onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) if dynamic_export and test_mode == 'whole': # scale image for dynamic shape test img_list = [resize(_, scale_factor=1.5) for _ in img_list] # concate flip image for batch test flip_img_list = [_.flip(-1) for _ in img_list] img_list = [ torch.cat((ori_img, flip_img), 0) for ori_img, flip_img in zip(img_list, flip_img_list) ] # update img_meta img_list, img_meta_list = _update_input_img( img_list, img_meta_list, test_mode == 'whole') # check the numerical value # get pytorch output with torch.no_grad(): pytorch_result = model(img_list, img_meta_list, return_loss=False) pytorch_result = np.stack(pytorch_result, 0) # get onnx output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) sess = rt.InferenceSession(output_file) onnx_result = sess.run( None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] # show segmentation results if show: import cv2 import os.path as osp img = img_meta_list[0][0]['filename'] if not osp.exists(img): img = imgs[0][:3, ...].permute(1, 2, 0) * 255 img = img.detach().numpy().astype(np.uint8) ori_shape = img.shape[:2] else: ori_shape = LoadImage()({'img': img})['ori_shape'] # resize onnx_result to ori_shape onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot( model, img, (onnx_result_, ), palette=model.PALETTE, block=False, title='ONNXRuntime', opacity=0.5) # resize pytorch_result to ori_shape pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot( model, img, (pytorch_result_, ), title='PyTorch', palette=model.PALETTE, opacity=0.5) # compare results np.testing.assert_allclose( pytorch_result.astype(np.float32) / num_classes, onnx_result.astype(np.float32) / num_classes, rtol=1e-5, atol=1e-5, err_msg='The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX')
def pytorch2onnx(model, input, model_type, opset_version=11, show=False, output_file='tmp.onnx', verify=False, dynamic_export=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. input (dict): We need to use this input to execute the model. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. """ model.cpu().eval() if model_type == 'mattor': merged = input['merged'].unsqueeze(0) trimap = input['trimap'].unsqueeze(0) data = torch.cat((merged, trimap), 1) elif model_type == 'restorer': data = input['lq'].unsqueeze(0) model.forward = model.forward_dummy # pytorch has some bug in pytorch1.3, we have to fix it # by replacing these existing op register_extra_symbolics(opset_version) dynamic_axes = None if dynamic_export: dynamic_axes = { 'input': { 0: 'batch', 2: 'height', 3: 'width' }, 'output': { 0: 'batch', 2: 'height', 3: 'width' } } with torch.no_grad(): torch.onnx.export(model, data, output_file, input_names=['input'], output_names=['output'], export_params=True, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') if verify: # check by onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) if dynamic_export: # scale image for dynamic shape test data = torch.nn.functional.interpolate(data, scale_factor=1.1) # concate flip image for batch test flip_data = data.flip(-1) data = torch.cat((data, flip_data), 0) # get pytorch output, only concern pred_alpha with torch.no_grad(): pytorch_result = model(data) if isinstance(pytorch_result, (tuple, list)): pytorch_result = pytorch_result[0] pytorch_result = pytorch_result.detach().numpy() # get onnx output sess = rt.InferenceSession(output_file) onnx_result = sess.run(None, { 'input': data.detach().numpy(), }) # only concern pred_alpha value if isinstance(onnx_result, (tuple, list)): onnx_result = onnx_result[0] if show: pytorch_visualize = pytorch_result[0].transpose(1, 2, 0) pytorch_visualize = np.clip(pytorch_visualize, 0, 1)[:, :, ::-1] onnx_visualize = onnx_result[0].transpose(1, 2, 0) onnx_visualize = np.clip(onnx_visualize, 0, 1)[:, :, ::-1] cv2.imshow('PyTorch', pytorch_visualize) cv2.imshow('ONNXRuntime', onnx_visualize) cv2.waitKey() # check the numerical value assert np.allclose( pytorch_result, onnx_result, rtol=1e-5, atol=1e-5), 'The outputs are different between Pytorch and ONNX' print('The numerical values are same between Pytorch and ONNX')
def pytorch2onnx(model, input_shape, opset_version=11, dynamic_export=False, show=False, output_file='tmp.onnx', do_simplify=False, verify=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. input_shape (tuple): Use this input shape to construct the corresponding dummy input and execute the model. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. """ model.cpu().eval() if hasattr(model.head, 'num_classes'): num_classes = model.head.num_classes # Some backbones use `num_classes=-1` to disable top classifier. elif getattr(model.backbone, 'num_classes', -1) > 0: num_classes = model.backbone.num_classes else: raise AttributeError('Cannot find "num_classes" in both head and ' 'backbone, please check the config file.') mm_inputs = _demo_mm_inputs(input_shape, num_classes) imgs = mm_inputs.pop('imgs') img_list = [img[None, :] for img in imgs] # replace original forward function origin_forward = model.forward model.forward = partial(model.forward, img_metas={}, return_loss=False) register_extra_symbolics(opset_version) # support dynamic shape export if dynamic_export: dynamic_axes = { 'input': { 0: 'batch', 2: 'width', 3: 'height' }, 'probs': { 0: 'batch' } } else: dynamic_axes = {} with torch.no_grad(): torch.onnx.export(model, (img_list, ), output_file, input_names=['input'], output_names=['probs'], export_params=True, keep_initializers_as_inputs=True, dynamic_axes=dynamic_axes, verbose=show, opset_version=opset_version) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward if do_simplify: import onnx import onnxsim from mmcv import digit_version min_required_version = '0.3.0' assert digit_version(mmcv.__version__) >= digit_version( min_required_version ), f'Requires to install onnx-simplify>={min_required_version}' if dynamic_axes: input_shape = (input_shape[0], input_shape[1], input_shape[2] * 2, input_shape[3] * 2) else: input_shape = (input_shape[0], input_shape[1], input_shape[2], input_shape[3]) imgs = _demo_mm_inputs(input_shape, model.head.num_classes).pop('imgs') input_dic = {'input': imgs.detach().cpu().numpy()} input_shape_dic = {'input': list(input_shape)} model_opt, check_ok = onnxsim.simplify( output_file, input_shapes=input_shape_dic, input_data=input_dic, dynamic_input_shape=dynamic_export) if check_ok: onnx.save(model_opt, output_file) print(f'Successfully simplified ONNX model: {output_file}') else: print('Failed to simplify ONNX model.') if verify: # check by onnx import onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) # test the dynamic model if dynamic_export: dynamic_test_inputs = _demo_mm_inputs( (input_shape[0], input_shape[1], input_shape[2] * 2, input_shape[3] * 2), model.head.num_classes) imgs = dynamic_test_inputs.pop('imgs') img_list = [img[None, :] for img in imgs] # check the numerical value # get pytorch output pytorch_result = model(img_list, img_metas={}, return_loss=False)[0] # get onnx output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) sess = rt.InferenceSession(output_file) onnx_result = sess.run( None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] if not np.allclose(pytorch_result, onnx_result): raise ValueError( 'The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX')
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import mmcv import torch from mmcv.onnx import register_extra_symbolics from mmseg.models.segmentors import EncoderDecoder register_extra_symbolics(opset=11) def _convert_batchnorm(module): for name, child in module.named_children(): if isinstance(child, torch.nn.SyncBatchNorm): new_child = torch.nn.BatchNorm2d(child.num_features, child.eps, child.momentum, child.affine, child.track_running_stats) if child.affine: new_child.weight.data = child.weight.data.clone().detach() new_child.bias.data = child.bias.data.clone().detach() # keep requires_grad unchanged new_child.weight.requires_grad = child.weight.requires_grad new_child.bias.requires_grad = child.bias.requires_grad new_child.running_mean = child.running_mean