def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=None, batch_first=True, qkv_bias=True, norm_cfg=dict(type='LN'), sr_ratio=1, init_cfg=None): super().__init__( embed_dims, num_heads, attn_drop, proj_drop, batch_first=batch_first, dropout_layer=dropout_layer, bias=qkv_bias, init_cfg=init_cfg) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=sr_ratio, stride=sr_ratio) # The ret[0] of build_norm_layer is norm name. self.norm = build_norm_layer(norm_cfg, embed_dims)[1] # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa from mmdet import mmcv_version, digit_version if mmcv_version < digit_version('1.3.17'): warnings.warn('The legacy version of forward function in' 'SpatialReductionAttention is deprecated in' 'mmcv>=1.3.17 and will no longer support in the' 'future. Please upgrade your mmcv.') self.forward = self.legacy_forward
import mmcv import numpy as np import pytest import torch from mmcv.cnn import Scale from mmdet import digit_version from mmdet.models import build_detector from mmdet.models.dense_heads import (FCOSHead, FSAFHead, RetinaHead, SSDHead, YOLOV3Head) from .utils import ort_validate data_path = osp.join(osp.dirname(__file__), 'data') if digit_version(torch.__version__) <= digit_version('1.5.0'): pytest.skip('ort backend does not support version below 1.5.0', allow_module_level=True) def test_cascade_onnx_export(): config_path = './configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py' cfg = mmcv.Config.fromfile(config_path) model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) with torch.no_grad(): model.forward = partial(model.forward, img_metas=[[dict()]]) dynamic_axes = { 'input_img': { 0: 'batch',
def pytorch2onnx(model, input_img, input_shape, normalize_cfg, opset_version=11, show=False, output_file='tmp.onnx', verify=False, test_img=None, do_simplify=False, dynamic_export=None, skip_postprocess=False): input_config = { 'input_shape': input_shape, 'input_path': input_img, 'normalize_cfg': normalize_cfg } # prepare input one_img, one_meta = preprocess_example_input(input_config) img_list, img_meta_list = [one_img], [[one_meta]] if skip_postprocess: warnings.warn('Not all models support export onnx without post ' 'process, especially two stage detectors!') model.forward = model.forward_dummy torch.onnx.export(model, one_img, output_file, input_names=['input'], export_params=True, keep_initializers_as_inputs=True, do_constant_folding=True, verbose=show, opset_version=opset_version) print(f'Successfully exported ONNX model without ' f'post process: {output_file}') return # replace original forward function origin_forward = model.forward model.forward = partial(model.forward, img_metas=img_meta_list, return_loss=False, rescale=False) output_names = ['dets', 'labels'] if model.with_mask: output_names.append('masks') input_name = 'input' dynamic_axes = None if dynamic_export: dynamic_axes = { input_name: { 0: 'batch', 2: 'width', 3: 'height' }, 'dets': { 0: 'batch', 1: 'num_dets', }, 'labels': { 0: 'batch', 1: 'num_dets', }, } if model.with_mask: dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'} torch.onnx.export(model, img_list, output_file, input_names=[input_name], output_names=output_names, export_params=True, keep_initializers_as_inputs=True, do_constant_folding=True, verbose=show, opset_version=opset_version, dynamic_axes=dynamic_axes) model.forward = origin_forward # get the custom op path ort_custom_op_path = '' try: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): warnings.warn('If input model has custom op from mmcv, \ you may have to build mmcv with ONNXRuntime from source.') if do_simplify: import onnxsim from mmdet import digit_version min_required_version = '0.3.0' assert digit_version(onnxsim.__version__) >= digit_version( min_required_version ), f'Requires to install onnx-simplify>={min_required_version}' input_dic = {'input': img_list[0].detach().cpu().numpy()} onnxsim.simplify(output_file, input_data=input_dic, custom_lib=ort_custom_op_path) print(f'Successfully exported ONNX model: {output_file}') if verify: # check by onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) # wrap onnx model onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0) if dynamic_export: # scale up to test dynamic shape h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]] h, w = min(1344, h), min(1344, w) input_config['input_shape'] = (1, 3, h, w) if test_img is None: input_config['input_path'] = input_img # prepare input once again one_img, one_meta = preprocess_example_input(input_config) img_list, img_meta_list = [one_img], [[one_meta]] # get pytorch output with torch.no_grad(): pytorch_results = model(img_list, img_metas=img_meta_list, return_loss=False, rescale=True)[0] img_list = [_.cuda().contiguous() for _ in img_list] if dynamic_export: img_list = img_list + [_.flip(-1).contiguous() for _ in img_list] img_meta_list = img_meta_list * 2 # get onnx output onnx_results = onnx_model(img_list, img_metas=img_meta_list, return_loss=False)[0] # visualize predictions score_thr = 0.3 if show: out_file_ort, out_file_pt = None, None else: out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png' show_img = one_meta['show_img'] model.show_result(show_img, pytorch_results, score_thr=score_thr, show=True, win_name='PyTorch', out_file=out_file_pt) onnx_model.show_result(show_img, onnx_results, score_thr=score_thr, show=True, win_name='ONNXRuntime', out_file=out_file_ort) # compare a part of result if model.with_mask: compare_pairs = list(zip(onnx_results, pytorch_results)) else: compare_pairs = [(onnx_results, pytorch_results)] err_msg = 'The numerical values are different between Pytorch' + \ ' and ONNX, but it does not necessarily mean the' + \ ' exported ONNX model is problematic.' # check the numerical value for onnx_res, pytorch_res in compare_pairs: for o_res, p_res in zip(onnx_res, pytorch_res): np.testing.assert_allclose(o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg) print('The numerical values are the same between Pytorch and ONNX')
def test_general_data(): # test init meta_info = dict(img_size=[256, 256], path='dadfaff', scale_factor=np.array([1.5, 1.5]), img_shape=torch.rand(4)) data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4), masks=np.random.rand(4, 2, 2)) instance_data = GeneralData(meta_info=meta_info) assert 'img_size' in instance_data assert instance_data.img_size == [256, 256] assert instance_data['img_size'] == [256, 256] assert 'path' in instance_data assert instance_data.path == 'dadfaff' # test nice_repr repr_instance_data = instance_data.new(data=data) nice_repr = str(repr_instance_data) for line in nice_repr.split('\n'): if 'masks' in line: assert 'shape' in line assert '(4, 2, 2)' in line if 'bboxes' in line: assert 'shape' in line assert 'torch.Size([4, 4])' in line if 'path' in line: assert 'dadfaff' in line if 'scale_factor' in line: assert '[1.5 1.5]' in line instance_data = GeneralData(meta_info=meta_info, data=dict(bboxes=torch.rand(5))) assert 'bboxes' in instance_data assert len(instance_data.bboxes) == 5 # data should be a dict with pytest.raises(AssertionError): GeneralData(data=1) # test set data instance_data = GeneralData() instance_data.set_data(data) assert 'bboxes' in instance_data assert len(instance_data.bboxes) == 4 assert 'masks' in instance_data assert len(instance_data.masks) == 4 # data should be a dict with pytest.raises(AssertionError): instance_data.set_data(data=1) # test set_meta instance_data = GeneralData() instance_data.set_meta_info(meta_info) assert 'img_size' in instance_data assert instance_data.img_size == [256, 256] assert instance_data['img_size'] == [256, 256] assert 'path' in instance_data assert instance_data.path == 'dadfaff' # can skip same value when overwrite instance_data.set_meta_info(meta_info) # meta should be a dict with pytest.raises(AssertionError): instance_data.set_meta_info(meta_info='fjhka') # attribute in `_meta_info_field` is immutable once initialized instance_data.set_meta_info(meta_info) # meta should be immutable with pytest.raises(KeyError): instance_data.set_meta_info(dict(img_size=[254, 251])) with pytest.raises(KeyError): duplicate_meta_info = copy.deepcopy(meta_info) duplicate_meta_info['path'] = 'dada' instance_data.set_meta_info(duplicate_meta_info) with pytest.raises(KeyError): duplicate_meta_info = copy.deepcopy(meta_info) duplicate_meta_info['scale_factor'] = np.array([1.5, 1.6]) instance_data.set_meta_info(duplicate_meta_info) # test new_instance_data instance_data = GeneralData(meta_info) new_instance_data = instance_data.new() for k, v in instance_data.meta_info_items(): assert k in new_instance_data _equal(v, new_instance_data[k]) instance_data = GeneralData(meta_info, data=data) temp_meta = copy.deepcopy(meta_info) temp_data = copy.deepcopy(data) temp_data['time'] = '12212' temp_meta['img_norm'] = np.random.random(3) new_instance_data = instance_data.new(meta_info=temp_meta, data=temp_data) for k, v in new_instance_data.meta_info_items(): if k in instance_data: _equal(v, instance_data[k]) else: assert _equal(v, temp_meta[k]) assert k == 'img_norm' for k, v in new_instance_data.items(): if k in instance_data: _equal(v, instance_data[k]) else: assert k == 'time' assert _equal(v, temp_data[k]) # test keys instance_data = GeneralData(meta_info, data=dict(bboxes=10)) assert 'bboxes' in instance_data.keys() instance_data.b = 10 assert 'b' in instance_data # test meta keys instance_data = GeneralData(meta_info, data=dict(bboxes=10)) assert 'path' in instance_data.meta_info_keys() assert len(instance_data.meta_info_keys()) == len(meta_info) instance_data.set_meta_info(dict(workdir='fafaf')) assert 'workdir' in instance_data assert len(instance_data.meta_info_keys()) == len(meta_info) + 1 # test values instance_data = GeneralData(meta_info, data=dict(bboxes=10)) assert 10 in instance_data.values() assert len(instance_data.values()) == 1 # test meta values instance_data = GeneralData(meta_info, data=dict(bboxes=10)) # torch 1.3 eq() can not compare str and tensor from mmdet import digit_version if digit_version(torch.__version__) >= [1, 4]: assert 'dadfaff' in instance_data.meta_info_values() assert len(instance_data.meta_info_values()) == len(meta_info) # test items instance_data = GeneralData(data=data) for k, v in instance_data.items(): assert k in data assert _equal(v, data[k]) # test meta_info_items instance_data = GeneralData(meta_info=meta_info) for k, v in instance_data.meta_info_items(): assert k in meta_info assert _equal(v, meta_info[k]) # test __setattr__ new_instance_data = GeneralData(data=data) new_instance_data.mask = torch.rand(3, 4, 5) new_instance_data.bboxes = torch.rand(2, 4) assert 'mask' in new_instance_data assert len(new_instance_data.mask) == 3 assert len(new_instance_data.bboxes) == 2 # test instance_data_field has been updated assert 'mask' in new_instance_data._data_fields assert 'bboxes' in new_instance_data._data_fields for k in data: assert k in new_instance_data._data_fields # '_meta_info_field', '_data_fields' is immutable. with pytest.raises(AttributeError): new_instance_data._data_fields = None with pytest.raises(AttributeError): new_instance_data._meta_info_fields = None with pytest.raises(AttributeError): del new_instance_data._data_fields with pytest.raises(AttributeError): del new_instance_data._meta_info_fields # key in _meta_info_field is immutable new_instance_data.set_meta_info(meta_info) with pytest.raises(KeyError): del new_instance_data.img_size with pytest.raises(KeyError): del new_instance_data.scale_factor for k in new_instance_data.meta_info_keys(): with pytest.raises(AttributeError): new_instance_data[k] = None # test __delattr__ # test key can be removed in instance_data_field assert 'mask' in new_instance_data._data_fields assert 'mask' in new_instance_data.keys() assert 'mask' in new_instance_data assert hasattr(new_instance_data, 'mask') del new_instance_data.mask assert 'mask' not in new_instance_data.keys() assert 'mask' not in new_instance_data assert 'mask' not in new_instance_data._data_fields assert not hasattr(new_instance_data, 'mask') # tset __delitem__ new_instance_data.mask = torch.rand(1, 2, 3) assert 'mask' in new_instance_data._data_fields assert 'mask' in new_instance_data assert hasattr(new_instance_data, 'mask') del new_instance_data['mask'] assert 'mask' not in new_instance_data assert 'mask' not in new_instance_data._data_fields assert 'mask' not in new_instance_data assert not hasattr(new_instance_data, 'mask') # test __setitem__ new_instance_data['mask'] = torch.rand(1, 2, 3) assert 'mask' in new_instance_data._data_fields assert 'mask' in new_instance_data.keys() assert hasattr(new_instance_data, 'mask') # test data_fields has been updated assert 'mask' in new_instance_data.keys() assert 'mask' in new_instance_data._data_fields # '_meta_info_field', '_data_fields' is immutable. with pytest.raises(AttributeError): del new_instance_data['_data_fields'] with pytest.raises(AttributeError): del new_instance_data['_meta_info_field'] # test __getitem__ new_instance_data.mask is new_instance_data['mask'] # test get assert new_instance_data.get('mask') is new_instance_data.mask assert new_instance_data.get('none_attribute', None) is None assert new_instance_data.get('none_attribute', 1) == 1 # test pop mask = new_instance_data.mask assert new_instance_data.pop('mask') is mask assert new_instance_data.pop('mask', None) is None assert new_instance_data.pop('mask', 1) == 1 # '_meta_info_field', '_data_fields' is immutable. with pytest.raises(KeyError): new_instance_data.pop('_data_fields') with pytest.raises(KeyError): new_instance_data.pop('_meta_info_field') # attribute in `_meta_info_field` is immutable with pytest.raises(KeyError): new_instance_data.pop('img_size') # test pop attribute in instance_data_filed new_instance_data['mask'] = torch.rand(1, 2, 3) new_instance_data.pop('mask') # test data_field has been updated assert 'mask' not in new_instance_data assert 'mask' not in new_instance_data._data_fields assert 'mask' not in new_instance_data # test_keys new_instance_data.mask = torch.ones(1, 2, 3) 'mask' in new_instance_data.keys() has_flag = False for key in new_instance_data.keys(): if key == 'mask': has_flag = True assert has_flag # test values assert len(list(new_instance_data.keys())) == len( list(new_instance_data.values())) mask = new_instance_data.mask has_flag = False for value in new_instance_data.values(): if value is mask: has_flag = True assert has_flag # test items assert len(list(new_instance_data.keys())) == len( list(new_instance_data.items())) mask = new_instance_data.mask has_flag = False for key, value in new_instance_data.items(): if value is mask: assert key == 'mask' has_flag = True assert has_flag # test device new_instance_data = GeneralData() if torch.cuda.is_available(): newnew_instance_data = new_instance_data.new() devices = ('cpu', 'cuda') for i in range(10): device = devices[i % 2] newnew_instance_data[f'{i}'] = torch.rand(1, 2, 3, device=device) newnew_instance_data = newnew_instance_data.cpu() for value in newnew_instance_data.values(): assert not value.is_cuda newnew_instance_data = new_instance_data.new() devices = ('cuda', 'cpu') for i in range(10): device = devices[i % 2] newnew_instance_data[f'{i}'] = torch.rand(1, 2, 3, device=device) newnew_instance_data = newnew_instance_data.cuda() for value in newnew_instance_data.values(): assert value.is_cuda # test to double_instance_data = instance_data.new() double_instance_data.long = torch.LongTensor(1, 2, 3, 4) double_instance_data.bool = torch.BoolTensor(1, 2, 3, 4) double_instance_data = instance_data.to(torch.double) for k, v in double_instance_data.items(): if isinstance(v, torch.Tensor): assert v.dtype is torch.double # test .cpu() .cuda() if torch.cuda.is_available(): cpu_instance_data = double_instance_data.new() cpu_instance_data.mask = torch.rand(1) cuda_tensor = torch.rand(1, 2, 3).cuda() cuda_instance_data = cpu_instance_data.to(cuda_tensor.device) for value in cuda_instance_data.values(): assert value.is_cuda cpu_instance_data = cuda_instance_data.cpu() for value in cpu_instance_data.values(): assert not value.is_cuda cuda_instance_data = cpu_instance_data.cuda() for value in cuda_instance_data.values(): assert value.is_cuda # test detach grad_instance_data = double_instance_data.new() grad_instance_data.mask = torch.rand(2, requires_grad=True) grad_instance_data.mask_1 = torch.rand(2, requires_grad=True) detach_instance_data = grad_instance_data.detach() for value in detach_instance_data.values(): assert not value.requires_grad # test numpy tensor_instance_data = double_instance_data.new() tensor_instance_data.mask = torch.rand(2, requires_grad=True) tensor_instance_data.mask_1 = torch.rand(2, requires_grad=True) numpy_instance_data = tensor_instance_data.numpy() for value in numpy_instance_data.values(): assert isinstance(value, np.ndarray) if torch.cuda.is_available(): tensor_instance_data = double_instance_data.new() tensor_instance_data.mask = torch.rand(2) tensor_instance_data.mask_1 = torch.rand(2) tensor_instance_data = tensor_instance_data.cuda() numpy_instance_data = tensor_instance_data.numpy() for value in numpy_instance_data.values(): assert isinstance(value, np.ndarray) instance_data['_c'] = 10000 instance_data.get('dad', None) is None assert hasattr(instance_data, '_c') del instance_data['_c'] assert not hasattr(instance_data, '_c') instance_data.a = 1000 instance_data['a'] = 2000 assert instance_data['a'] == 2000 assert instance_data.a == 2000 assert instance_data.get('a') == instance_data['a'] == instance_data.a instance_data._meta = 1000 assert '_meta' in instance_data.keys() if torch.cuda.is_available(): instance_data.bbox = torch.ones(2, 3, 4, 5).cuda() instance_data.score = torch.ones(2, 3, 4, 4) else: instance_data.bbox = torch.ones(2, 3, 4, 5) assert len(instance_data.new().keys()) == 0 with pytest.raises(AttributeError): instance_data.img_size = 100 for k, v in instance_data.items(): if k == 'bbox': assert isinstance(v, torch.Tensor) assert 'a' in instance_data instance_data.pop('a') assert 'a' not in instance_data cpu_instance_data = instance_data.cpu() for k, v in cpu_instance_data.items(): if isinstance(v, torch.Tensor): assert not v.is_cuda assert isinstance(cpu_instance_data.numpy().bbox, np.ndarray) if torch.cuda.is_available(): cuda_resutls = instance_data.cuda() for k, v in cuda_resutls.items(): if isinstance(v, torch.Tensor): assert v.is_cuda
def pytorch2onnx(config_path, checkpoint_path, input_img, input_shape, opset_version=11, show=False, output_file='tmp.onnx', verify=False, normalize_cfg=None, dataset='coco', test_img=None, do_simplify=False, cfg_options=None): input_config = { 'input_shape': input_shape, 'input_path': input_img, 'normalize_cfg': normalize_cfg } # prepare original model and meta for verifying the onnx model orig_model = build_model_from_cfg(config_path, checkpoint_path, cfg_options=cfg_options) one_img, one_meta = preprocess_example_input(input_config) model, tensor_data = generate_inputs_and_wrap_model( config_path, checkpoint_path, input_config, cfg_options=cfg_options) output_names = ['boxes'] if model.with_bbox: output_names.append('labels') if model.with_mask: output_names.append('masks') torch.onnx.export(model, tensor_data, output_file, input_names=['input'], output_names=output_names, export_params=True, keep_initializers_as_inputs=True, do_constant_folding=True, verbose=show, opset_version=opset_version) model.forward = orig_model.forward # simplify onnx model if do_simplify: from mmdet import digit_version import mmcv min_required_version = '1.2.5' assert digit_version(mmcv.__version__) >= digit_version( min_required_version ), f'Requires to install mmcv>={min_required_version}' from mmcv.onnx.simplify import simplify input_dic = {'input': one_img.detach().cpu().numpy()} _ = simplify(output_file, [input_dic], output_file) print(f'Successfully exported ONNX model: {output_file}') if verify: from mmdet.core import get_classes, bbox2result from mmdet.apis import show_result_pyplot ort_custom_op_path = '' try: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): warnings.warn('If input model has custom op from mmcv, \ you may have to build mmcv with ONNXRuntime from source.') model.CLASSES = get_classes(dataset) num_classes = len(model.CLASSES) # check by onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) if test_img is not None: input_config['input_path'] = test_img one_img, one_meta = preprocess_example_input(input_config) tensor_data = [one_img] # check the numerical value # get pytorch output pytorch_results = model(tensor_data, [[one_meta]], return_loss=False) pytorch_results = pytorch_results[0] # get onnx output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) session_options = rt.SessionOptions() # register custom op for onnxruntime if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) sess = rt.InferenceSession(output_file, session_options) onnx_outputs = sess.run(None, {net_feed_input[0]: one_img.detach().numpy()}) output_names = [_.name for _ in sess.get_outputs()] output_shapes = [_.shape for _ in onnx_outputs] print(f'onnxruntime output names: {output_names}, \ output shapes: {output_shapes}') nrof_out = len(onnx_outputs) assert nrof_out > 0, 'Must have output' with_mask = nrof_out == 3 if nrof_out == 1: onnx_results = onnx_outputs[0] else: det_bboxes, det_labels = onnx_outputs[:2] onnx_results = bbox2result(det_bboxes, det_labels, num_classes) if with_mask: segm_results = onnx_outputs[2].squeeze(1) cls_segms = [[] for _ in range(num_classes)] for i in range(det_bboxes.shape[0]): cls_segms[det_labels[i]].append(segm_results[i]) onnx_results = (onnx_results, cls_segms) # visualize predictions if show: show_result_pyplot(model, one_meta['show_img'], pytorch_results, title='Pytorch') show_result_pyplot(model, one_meta['show_img'], onnx_results, title='ONNX') # compare a part of result if with_mask: compare_pairs = list(zip(onnx_results, pytorch_results)) else: compare_pairs = [(onnx_results, pytorch_results)] for onnx_res, pytorch_res in compare_pairs: for o_res, p_res in zip(onnx_res, pytorch_res): np.testing.assert_allclose( o_res, p_res, rtol=1e-03, atol=1e-05, ) print('The numerical values are the same between Pytorch and ONNX')
def test_version_check(): assert digit_version('1.0.5') > digit_version('1.0.5rc0') assert digit_version('1.0.5') > digit_version('1.0.4rc0') assert digit_version('1.0.5') > digit_version('1.0rc0') assert digit_version('1.0.0') > digit_version('0.6.2') assert digit_version('1.0.0') > digit_version('0.2.16') assert digit_version('1.0.5rc0') > digit_version('1.0.0rc0') assert digit_version('1.0.0rc1') > digit_version('1.0.0rc0') assert digit_version('1.0.0rc2') > digit_version('1.0.0rc0') assert digit_version('1.0.0rc2') > digit_version('1.0.0rc1') assert digit_version('1.0.1rc1') > digit_version('1.0.0rc1') assert digit_version('1.0.0') > digit_version('1.0.0rc1')
def pytorch2onnx(config_path, checkpoint_path, input_img, input_shape, opset_version=11, show=False, output_file='tmp.onnx', verify=False, normalize_cfg=None, dataset='coco', test_img=None, do_simplify=False, cfg_options=None, dynamic_export=None): input_config = { 'input_shape': input_shape, 'input_path': input_img, 'normalize_cfg': normalize_cfg } # prepare original model and meta for verifying the onnx model orig_model = build_model_from_cfg(config_path, checkpoint_path, cfg_options=cfg_options) one_img, one_meta = preprocess_example_input(input_config) model, tensor_data = generate_inputs_and_wrap_model( config_path, checkpoint_path, input_config, cfg_options=cfg_options) output_names = ['dets', 'labels'] if model.with_mask: output_names.append('masks') dynamic_axes = None if dynamic_export: dynamic_axes = { 'input': { 0: 'batch', 2: 'width', 3: 'height' }, 'dets': { 0: 'batch', 1: 'num_dets', }, 'labels': { 0: 'batch', 1: 'num_dets', }, } if model.with_mask: dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'} torch.onnx.export(model, tensor_data, output_file, input_names=['input'], output_names=output_names, export_params=True, keep_initializers_as_inputs=True, do_constant_folding=True, verbose=show, opset_version=opset_version, dynamic_axes=dynamic_axes) model.forward = orig_model.forward # get the custom op path ort_custom_op_path = '' try: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): warnings.warn('If input model has custom op from mmcv, \ you may have to build mmcv with ONNXRuntime from source.') if do_simplify: from mmdet import digit_version import onnxsim min_required_version = '0.3.0' assert digit_version(onnxsim.__version__) >= digit_version( min_required_version ), f'Requires to install onnx-simplify>={min_required_version}' input_dic = {'input': one_img.detach().cpu().numpy()} onnxsim.simplify(output_file, input_data=input_dic, custom_lib=ort_custom_op_path) print(f'Successfully exported ONNX model: {output_file}') if verify: from mmdet.core import get_classes, bbox2result from mmdet.apis import show_result_pyplot model.CLASSES = get_classes(dataset) num_classes = len(model.CLASSES) # check by onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) if dynamic_export: # scale up to test dynamic shape h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]] input_config['input_shape'] = (1, 3, h, w) if test_img is not None: input_config['input_path'] = test_img one_img, one_meta = preprocess_example_input(input_config) tensor_data = [one_img] # get pytorch output pytorch_results = model(tensor_data, [[one_meta]], return_loss=False) pytorch_results = pytorch_results[0] # get onnx output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) session_options = rt.SessionOptions() # register custom op for ONNX Runtime if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) feed_input_img = one_img.detach().numpy() if dynamic_export: # test batch with two input images feed_input_img = np.vstack([feed_input_img, feed_input_img]) sess = rt.InferenceSession(output_file, session_options) onnx_outputs = sess.run(None, {net_feed_input[0]: feed_input_img}) output_names = [_.name for _ in sess.get_outputs()] output_shapes = [_.shape for _ in onnx_outputs] print(f'ONNX Runtime output names: {output_names}, \ output shapes: {output_shapes}') # get last image's outputs onnx_outputs = [_[-1] for _ in onnx_outputs] ort_dets, ort_labels = onnx_outputs[:2] onnx_results = bbox2result(ort_dets, ort_labels, num_classes) if model.with_mask: segm_results = onnx_outputs[2] cls_segms = [[] for _ in range(num_classes)] for i in range(ort_dets.shape[0]): cls_segms[ort_labels[i]].append(segm_results[i]) onnx_results = (onnx_results, cls_segms) # visualize predictions if show: show_result_pyplot(model, one_meta['show_img'], pytorch_results, title='Pytorch') show_result_pyplot(model, one_meta['show_img'], onnx_results, title='ONNXRuntime') # compare a part of result if model.with_mask: compare_pairs = list(zip(onnx_results, pytorch_results)) else: compare_pairs = [(onnx_results, pytorch_results)] err_msg = 'The numerical values are different between Pytorch' + \ ' and ONNX, but it does not necessarily mean the' + \ ' exported ONNX model is problematic.' # check the numerical value for onnx_res, pytorch_res in compare_pairs: for o_res, p_res in zip(onnx_res, pytorch_res): np.testing.assert_allclose(o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg) print('The numerical values are the same between Pytorch and ONNX')
# get the custom op path ort_custom_op_path = '' try: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): warnings.warn('If input model has custom op from mmcv, \ you may have to build mmcv with ONNXRuntime from source.') if do_simplify: import onnxsim from mmdet import digit_version min_required_version = '0.3.0' assert digit_version(onnxsim.__version__) >= digit_version( min_required_version ), f'Requires to install onnx-simplify>={min_required_version}' input_dic = {'input': img_list[0].detach().cpu().numpy()} model_opt, check_ok = onnxsim.simplify( output_file, input_data=input_dic, custom_lib=ort_custom_op_path, dynamic_input_shape=dynamic_export) if check_ok: onnx.save(model_opt, output_file) print(f'Successfully simplified ONNX model: {output_file}') else: warnings.warn('Failed to simplify ONNX model.') print(f'Successfully exported ONNX model: {output_file}')