コード例 #1
0
ファイル: pvt.py プロジェクト: myownskyW7/mmdetection
    def __init__(self,
                 embed_dims,
                 num_heads,
                 attn_drop=0.,
                 proj_drop=0.,
                 dropout_layer=None,
                 batch_first=True,
                 qkv_bias=True,
                 norm_cfg=dict(type='LN'),
                 sr_ratio=1,
                 init_cfg=None):
        super().__init__(
            embed_dims,
            num_heads,
            attn_drop,
            proj_drop,
            batch_first=batch_first,
            dropout_layer=dropout_layer,
            bias=qkv_bias,
            init_cfg=init_cfg)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = Conv2d(
                in_channels=embed_dims,
                out_channels=embed_dims,
                kernel_size=sr_ratio,
                stride=sr_ratio)
            # The ret[0] of build_norm_layer is norm name.
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]

        # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
        from mmdet import mmcv_version, digit_version
        if mmcv_version < digit_version('1.3.17'):
            warnings.warn('The legacy version of forward function in'
                          'SpatialReductionAttention is deprecated in'
                          'mmcv>=1.3.17 and will no longer support in the'
                          'future. Please upgrade your mmcv.')
            self.forward = self.legacy_forward
コード例 #2
0
import mmcv
import numpy as np
import pytest
import torch
from mmcv.cnn import Scale

from mmdet import digit_version
from mmdet.models import build_detector
from mmdet.models.dense_heads import (FCOSHead, FSAFHead, RetinaHead, SSDHead,
                                      YOLOV3Head)
from .utils import ort_validate

data_path = osp.join(osp.dirname(__file__), 'data')

if digit_version(torch.__version__) <= digit_version('1.5.0'):
    pytest.skip('ort backend does not support version below 1.5.0',
                allow_module_level=True)


def test_cascade_onnx_export():

    config_path = './configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py'
    cfg = mmcv.Config.fromfile(config_path)
    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
    with torch.no_grad():
        model.forward = partial(model.forward, img_metas=[[dict()]])

        dynamic_axes = {
            'input_img': {
                0: 'batch',
コード例 #3
0
ファイル: pytorch2onnx.py プロジェクト: xzjzsa/mmdetection
def pytorch2onnx(model,
                 input_img,
                 input_shape,
                 normalize_cfg,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 test_img=None,
                 do_simplify=False,
                 dynamic_export=None,
                 skip_postprocess=False):

    input_config = {
        'input_shape': input_shape,
        'input_path': input_img,
        'normalize_cfg': normalize_cfg
    }
    # prepare input
    one_img, one_meta = preprocess_example_input(input_config)
    img_list, img_meta_list = [one_img], [[one_meta]]

    if skip_postprocess:
        warnings.warn('Not all models support export onnx without post '
                      'process, especially two stage detectors!')
        model.forward = model.forward_dummy
        torch.onnx.export(model,
                          one_img,
                          output_file,
                          input_names=['input'],
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          do_constant_folding=True,
                          verbose=show,
                          opset_version=opset_version)

        print(f'Successfully exported ONNX model without '
              f'post process: {output_file}')
        return

    # replace original forward function
    origin_forward = model.forward
    model.forward = partial(model.forward,
                            img_metas=img_meta_list,
                            return_loss=False,
                            rescale=False)

    output_names = ['dets', 'labels']
    if model.with_mask:
        output_names.append('masks')
    input_name = 'input'
    dynamic_axes = None
    if dynamic_export:
        dynamic_axes = {
            input_name: {
                0: 'batch',
                2: 'width',
                3: 'height'
            },
            'dets': {
                0: 'batch',
                1: 'num_dets',
            },
            'labels': {
                0: 'batch',
                1: 'num_dets',
            },
        }
        if model.with_mask:
            dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}

    torch.onnx.export(model,
                      img_list,
                      output_file,
                      input_names=[input_name],
                      output_names=output_names,
                      export_params=True,
                      keep_initializers_as_inputs=True,
                      do_constant_folding=True,
                      verbose=show,
                      opset_version=opset_version,
                      dynamic_axes=dynamic_axes)

    model.forward = origin_forward

    # get the custom op path
    ort_custom_op_path = ''
    try:
        from mmcv.ops import get_onnxruntime_op_path
        ort_custom_op_path = get_onnxruntime_op_path()
    except (ImportError, ModuleNotFoundError):
        warnings.warn('If input model has custom op from mmcv, \
            you may have to build mmcv with ONNXRuntime from source.')

    if do_simplify:
        import onnxsim

        from mmdet import digit_version

        min_required_version = '0.3.0'
        assert digit_version(onnxsim.__version__) >= digit_version(
            min_required_version
        ), f'Requires to install onnx-simplify>={min_required_version}'

        input_dic = {'input': img_list[0].detach().cpu().numpy()}
        onnxsim.simplify(output_file,
                         input_data=input_dic,
                         custom_lib=ort_custom_op_path)
    print(f'Successfully exported ONNX model: {output_file}')

    if verify:
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        # wrap onnx model
        onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
        if dynamic_export:
            # scale up to test dynamic shape
            h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
            h, w = min(1344, h), min(1344, w)
            input_config['input_shape'] = (1, 3, h, w)

        if test_img is None:
            input_config['input_path'] = input_img

        # prepare input once again
        one_img, one_meta = preprocess_example_input(input_config)
        img_list, img_meta_list = [one_img], [[one_meta]]

        # get pytorch output
        with torch.no_grad():
            pytorch_results = model(img_list,
                                    img_metas=img_meta_list,
                                    return_loss=False,
                                    rescale=True)[0]

        img_list = [_.cuda().contiguous() for _ in img_list]
        if dynamic_export:
            img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
            img_meta_list = img_meta_list * 2
        # get onnx output
        onnx_results = onnx_model(img_list,
                                  img_metas=img_meta_list,
                                  return_loss=False)[0]
        # visualize predictions
        score_thr = 0.3
        if show:
            out_file_ort, out_file_pt = None, None
        else:
            out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'

        show_img = one_meta['show_img']
        model.show_result(show_img,
                          pytorch_results,
                          score_thr=score_thr,
                          show=True,
                          win_name='PyTorch',
                          out_file=out_file_pt)
        onnx_model.show_result(show_img,
                               onnx_results,
                               score_thr=score_thr,
                               show=True,
                               win_name='ONNXRuntime',
                               out_file=out_file_ort)

        # compare a part of result
        if model.with_mask:
            compare_pairs = list(zip(onnx_results, pytorch_results))
        else:
            compare_pairs = [(onnx_results, pytorch_results)]
        err_msg = 'The numerical values are different between Pytorch' + \
                  ' and ONNX, but it does not necessarily mean the' + \
                  ' exported ONNX model is problematic.'
        # check the numerical value
        for onnx_res, pytorch_res in compare_pairs:
            for o_res, p_res in zip(onnx_res, pytorch_res):
                np.testing.assert_allclose(o_res,
                                           p_res,
                                           rtol=1e-03,
                                           atol=1e-05,
                                           err_msg=err_msg)
        print('The numerical values are the same between Pytorch and ONNX')
コード例 #4
0
def test_general_data():

    # test init
    meta_info = dict(img_size=[256, 256],
                     path='dadfaff',
                     scale_factor=np.array([1.5, 1.5]),
                     img_shape=torch.rand(4))

    data = dict(bboxes=torch.rand(4, 4),
                labels=torch.rand(4),
                masks=np.random.rand(4, 2, 2))

    instance_data = GeneralData(meta_info=meta_info)
    assert 'img_size' in instance_data
    assert instance_data.img_size == [256, 256]
    assert instance_data['img_size'] == [256, 256]
    assert 'path' in instance_data
    assert instance_data.path == 'dadfaff'

    # test nice_repr
    repr_instance_data = instance_data.new(data=data)
    nice_repr = str(repr_instance_data)
    for line in nice_repr.split('\n'):
        if 'masks' in line:
            assert 'shape' in line
            assert '(4, 2, 2)' in line
        if 'bboxes' in line:
            assert 'shape' in line
            assert 'torch.Size([4, 4])' in line
        if 'path' in line:
            assert 'dadfaff' in line
        if 'scale_factor' in line:
            assert '[1.5 1.5]' in line

    instance_data = GeneralData(meta_info=meta_info,
                                data=dict(bboxes=torch.rand(5)))
    assert 'bboxes' in instance_data
    assert len(instance_data.bboxes) == 5

    # data should be a dict
    with pytest.raises(AssertionError):
        GeneralData(data=1)

    # test set data
    instance_data = GeneralData()
    instance_data.set_data(data)
    assert 'bboxes' in instance_data
    assert len(instance_data.bboxes) == 4
    assert 'masks' in instance_data
    assert len(instance_data.masks) == 4
    # data should be a dict
    with pytest.raises(AssertionError):
        instance_data.set_data(data=1)

    # test set_meta
    instance_data = GeneralData()
    instance_data.set_meta_info(meta_info)
    assert 'img_size' in instance_data
    assert instance_data.img_size == [256, 256]
    assert instance_data['img_size'] == [256, 256]
    assert 'path' in instance_data
    assert instance_data.path == 'dadfaff'
    # can skip same value when overwrite
    instance_data.set_meta_info(meta_info)

    # meta should be a dict
    with pytest.raises(AssertionError):
        instance_data.set_meta_info(meta_info='fjhka')

    # attribute in `_meta_info_field` is immutable once initialized
    instance_data.set_meta_info(meta_info)
    # meta should be immutable
    with pytest.raises(KeyError):
        instance_data.set_meta_info(dict(img_size=[254, 251]))
    with pytest.raises(KeyError):
        duplicate_meta_info = copy.deepcopy(meta_info)
        duplicate_meta_info['path'] = 'dada'
        instance_data.set_meta_info(duplicate_meta_info)
    with pytest.raises(KeyError):
        duplicate_meta_info = copy.deepcopy(meta_info)
        duplicate_meta_info['scale_factor'] = np.array([1.5, 1.6])
        instance_data.set_meta_info(duplicate_meta_info)

    # test new_instance_data
    instance_data = GeneralData(meta_info)
    new_instance_data = instance_data.new()
    for k, v in instance_data.meta_info_items():
        assert k in new_instance_data
        _equal(v, new_instance_data[k])

    instance_data = GeneralData(meta_info, data=data)
    temp_meta = copy.deepcopy(meta_info)
    temp_data = copy.deepcopy(data)
    temp_data['time'] = '12212'
    temp_meta['img_norm'] = np.random.random(3)

    new_instance_data = instance_data.new(meta_info=temp_meta, data=temp_data)
    for k, v in new_instance_data.meta_info_items():
        if k in instance_data:
            _equal(v, instance_data[k])
        else:
            assert _equal(v, temp_meta[k])
            assert k == 'img_norm'

    for k, v in new_instance_data.items():
        if k in instance_data:
            _equal(v, instance_data[k])
        else:
            assert k == 'time'
            assert _equal(v, temp_data[k])

    # test keys
    instance_data = GeneralData(meta_info, data=dict(bboxes=10))
    assert 'bboxes' in instance_data.keys()
    instance_data.b = 10
    assert 'b' in instance_data

    # test meta keys
    instance_data = GeneralData(meta_info, data=dict(bboxes=10))
    assert 'path' in instance_data.meta_info_keys()
    assert len(instance_data.meta_info_keys()) == len(meta_info)
    instance_data.set_meta_info(dict(workdir='fafaf'))
    assert 'workdir' in instance_data
    assert len(instance_data.meta_info_keys()) == len(meta_info) + 1

    # test values
    instance_data = GeneralData(meta_info, data=dict(bboxes=10))
    assert 10 in instance_data.values()
    assert len(instance_data.values()) == 1

    # test meta values
    instance_data = GeneralData(meta_info, data=dict(bboxes=10))
    # torch 1.3 eq() can not compare str and tensor
    from mmdet import digit_version
    if digit_version(torch.__version__) >= [1, 4]:
        assert 'dadfaff' in instance_data.meta_info_values()
    assert len(instance_data.meta_info_values()) == len(meta_info)

    # test items
    instance_data = GeneralData(data=data)
    for k, v in instance_data.items():
        assert k in data
        assert _equal(v, data[k])

    # test meta_info_items
    instance_data = GeneralData(meta_info=meta_info)
    for k, v in instance_data.meta_info_items():
        assert k in meta_info
        assert _equal(v, meta_info[k])

    # test __setattr__
    new_instance_data = GeneralData(data=data)
    new_instance_data.mask = torch.rand(3, 4, 5)
    new_instance_data.bboxes = torch.rand(2, 4)
    assert 'mask' in new_instance_data
    assert len(new_instance_data.mask) == 3
    assert len(new_instance_data.bboxes) == 2

    # test instance_data_field has been updated
    assert 'mask' in new_instance_data._data_fields
    assert 'bboxes' in new_instance_data._data_fields

    for k in data:
        assert k in new_instance_data._data_fields

    # '_meta_info_field', '_data_fields' is immutable.
    with pytest.raises(AttributeError):
        new_instance_data._data_fields = None
    with pytest.raises(AttributeError):
        new_instance_data._meta_info_fields = None
    with pytest.raises(AttributeError):
        del new_instance_data._data_fields
    with pytest.raises(AttributeError):
        del new_instance_data._meta_info_fields

    # key in _meta_info_field is immutable
    new_instance_data.set_meta_info(meta_info)
    with pytest.raises(KeyError):
        del new_instance_data.img_size
    with pytest.raises(KeyError):
        del new_instance_data.scale_factor
    for k in new_instance_data.meta_info_keys():
        with pytest.raises(AttributeError):
            new_instance_data[k] = None

    # test __delattr__
    # test key can be removed in instance_data_field
    assert 'mask' in new_instance_data._data_fields
    assert 'mask' in new_instance_data.keys()
    assert 'mask' in new_instance_data
    assert hasattr(new_instance_data, 'mask')
    del new_instance_data.mask
    assert 'mask' not in new_instance_data.keys()
    assert 'mask' not in new_instance_data
    assert 'mask' not in new_instance_data._data_fields
    assert not hasattr(new_instance_data, 'mask')

    # tset __delitem__
    new_instance_data.mask = torch.rand(1, 2, 3)
    assert 'mask' in new_instance_data._data_fields
    assert 'mask' in new_instance_data
    assert hasattr(new_instance_data, 'mask')
    del new_instance_data['mask']
    assert 'mask' not in new_instance_data
    assert 'mask' not in new_instance_data._data_fields
    assert 'mask' not in new_instance_data
    assert not hasattr(new_instance_data, 'mask')

    # test __setitem__
    new_instance_data['mask'] = torch.rand(1, 2, 3)
    assert 'mask' in new_instance_data._data_fields
    assert 'mask' in new_instance_data.keys()
    assert hasattr(new_instance_data, 'mask')

    # test data_fields has been updated
    assert 'mask' in new_instance_data.keys()
    assert 'mask' in new_instance_data._data_fields

    # '_meta_info_field', '_data_fields' is immutable.
    with pytest.raises(AttributeError):
        del new_instance_data['_data_fields']
    with pytest.raises(AttributeError):
        del new_instance_data['_meta_info_field']

    #  test __getitem__
    new_instance_data.mask is new_instance_data['mask']

    # test get
    assert new_instance_data.get('mask') is new_instance_data.mask
    assert new_instance_data.get('none_attribute', None) is None
    assert new_instance_data.get('none_attribute', 1) == 1

    # test pop
    mask = new_instance_data.mask
    assert new_instance_data.pop('mask') is mask
    assert new_instance_data.pop('mask', None) is None
    assert new_instance_data.pop('mask', 1) == 1

    # '_meta_info_field', '_data_fields' is immutable.
    with pytest.raises(KeyError):
        new_instance_data.pop('_data_fields')
    with pytest.raises(KeyError):
        new_instance_data.pop('_meta_info_field')
    # attribute in `_meta_info_field` is immutable
    with pytest.raises(KeyError):
        new_instance_data.pop('img_size')
    # test pop attribute in instance_data_filed
    new_instance_data['mask'] = torch.rand(1, 2, 3)
    new_instance_data.pop('mask')
    # test data_field has been updated
    assert 'mask' not in new_instance_data
    assert 'mask' not in new_instance_data._data_fields
    assert 'mask' not in new_instance_data

    # test_keys
    new_instance_data.mask = torch.ones(1, 2, 3)
    'mask' in new_instance_data.keys()
    has_flag = False
    for key in new_instance_data.keys():
        if key == 'mask':
            has_flag = True
    assert has_flag

    # test values
    assert len(list(new_instance_data.keys())) == len(
        list(new_instance_data.values()))
    mask = new_instance_data.mask
    has_flag = False
    for value in new_instance_data.values():
        if value is mask:
            has_flag = True
    assert has_flag

    # test items
    assert len(list(new_instance_data.keys())) == len(
        list(new_instance_data.items()))
    mask = new_instance_data.mask
    has_flag = False
    for key, value in new_instance_data.items():
        if value is mask:
            assert key == 'mask'
            has_flag = True
    assert has_flag

    # test device
    new_instance_data = GeneralData()
    if torch.cuda.is_available():
        newnew_instance_data = new_instance_data.new()
        devices = ('cpu', 'cuda')
        for i in range(10):
            device = devices[i % 2]
            newnew_instance_data[f'{i}'] = torch.rand(1, 2, 3, device=device)
        newnew_instance_data = newnew_instance_data.cpu()
        for value in newnew_instance_data.values():
            assert not value.is_cuda
        newnew_instance_data = new_instance_data.new()
        devices = ('cuda', 'cpu')
        for i in range(10):
            device = devices[i % 2]
            newnew_instance_data[f'{i}'] = torch.rand(1, 2, 3, device=device)
        newnew_instance_data = newnew_instance_data.cuda()
        for value in newnew_instance_data.values():
            assert value.is_cuda
    # test to
    double_instance_data = instance_data.new()
    double_instance_data.long = torch.LongTensor(1, 2, 3, 4)
    double_instance_data.bool = torch.BoolTensor(1, 2, 3, 4)
    double_instance_data = instance_data.to(torch.double)
    for k, v in double_instance_data.items():
        if isinstance(v, torch.Tensor):
            assert v.dtype is torch.double

    # test .cpu() .cuda()
    if torch.cuda.is_available():
        cpu_instance_data = double_instance_data.new()
        cpu_instance_data.mask = torch.rand(1)
        cuda_tensor = torch.rand(1, 2, 3).cuda()
        cuda_instance_data = cpu_instance_data.to(cuda_tensor.device)
        for value in cuda_instance_data.values():
            assert value.is_cuda
        cpu_instance_data = cuda_instance_data.cpu()
        for value in cpu_instance_data.values():
            assert not value.is_cuda
        cuda_instance_data = cpu_instance_data.cuda()
        for value in cuda_instance_data.values():
            assert value.is_cuda

    # test detach
    grad_instance_data = double_instance_data.new()
    grad_instance_data.mask = torch.rand(2, requires_grad=True)
    grad_instance_data.mask_1 = torch.rand(2, requires_grad=True)
    detach_instance_data = grad_instance_data.detach()
    for value in detach_instance_data.values():
        assert not value.requires_grad

    # test numpy
    tensor_instance_data = double_instance_data.new()
    tensor_instance_data.mask = torch.rand(2, requires_grad=True)
    tensor_instance_data.mask_1 = torch.rand(2, requires_grad=True)
    numpy_instance_data = tensor_instance_data.numpy()
    for value in numpy_instance_data.values():
        assert isinstance(value, np.ndarray)
    if torch.cuda.is_available():
        tensor_instance_data = double_instance_data.new()
        tensor_instance_data.mask = torch.rand(2)
        tensor_instance_data.mask_1 = torch.rand(2)
        tensor_instance_data = tensor_instance_data.cuda()
        numpy_instance_data = tensor_instance_data.numpy()
        for value in numpy_instance_data.values():
            assert isinstance(value, np.ndarray)

    instance_data['_c'] = 10000
    instance_data.get('dad', None) is None
    assert hasattr(instance_data, '_c')
    del instance_data['_c']
    assert not hasattr(instance_data, '_c')
    instance_data.a = 1000
    instance_data['a'] = 2000
    assert instance_data['a'] == 2000
    assert instance_data.a == 2000
    assert instance_data.get('a') == instance_data['a'] == instance_data.a
    instance_data._meta = 1000
    assert '_meta' in instance_data.keys()
    if torch.cuda.is_available():
        instance_data.bbox = torch.ones(2, 3, 4, 5).cuda()
        instance_data.score = torch.ones(2, 3, 4, 4)
    else:
        instance_data.bbox = torch.ones(2, 3, 4, 5)

    assert len(instance_data.new().keys()) == 0
    with pytest.raises(AttributeError):
        instance_data.img_size = 100

    for k, v in instance_data.items():
        if k == 'bbox':
            assert isinstance(v, torch.Tensor)
    assert 'a' in instance_data
    instance_data.pop('a')
    assert 'a' not in instance_data

    cpu_instance_data = instance_data.cpu()
    for k, v in cpu_instance_data.items():
        if isinstance(v, torch.Tensor):
            assert not v.is_cuda

    assert isinstance(cpu_instance_data.numpy().bbox, np.ndarray)

    if torch.cuda.is_available():
        cuda_resutls = instance_data.cuda()
        for k, v in cuda_resutls.items():
            if isinstance(v, torch.Tensor):
                assert v.is_cuda
コード例 #5
0
def pytorch2onnx(config_path,
                 checkpoint_path,
                 input_img,
                 input_shape,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 normalize_cfg=None,
                 dataset='coco',
                 test_img=None,
                 do_simplify=False,
                 cfg_options=None):

    input_config = {
        'input_shape': input_shape,
        'input_path': input_img,
        'normalize_cfg': normalize_cfg
    }

    # prepare original model and meta for verifying the onnx model
    orig_model = build_model_from_cfg(config_path,
                                      checkpoint_path,
                                      cfg_options=cfg_options)
    one_img, one_meta = preprocess_example_input(input_config)
    model, tensor_data = generate_inputs_and_wrap_model(
        config_path, checkpoint_path, input_config, cfg_options=cfg_options)
    output_names = ['boxes']
    if model.with_bbox:
        output_names.append('labels')
    if model.with_mask:
        output_names.append('masks')

    torch.onnx.export(model,
                      tensor_data,
                      output_file,
                      input_names=['input'],
                      output_names=output_names,
                      export_params=True,
                      keep_initializers_as_inputs=True,
                      do_constant_folding=True,
                      verbose=show,
                      opset_version=opset_version)

    model.forward = orig_model.forward

    # simplify onnx model
    if do_simplify:
        from mmdet import digit_version
        import mmcv

        min_required_version = '1.2.5'
        assert digit_version(mmcv.__version__) >= digit_version(
            min_required_version
        ), f'Requires to install mmcv>={min_required_version}'
        from mmcv.onnx.simplify import simplify

        input_dic = {'input': one_img.detach().cpu().numpy()}
        _ = simplify(output_file, [input_dic], output_file)
    print(f'Successfully exported ONNX model: {output_file}')
    if verify:
        from mmdet.core import get_classes, bbox2result
        from mmdet.apis import show_result_pyplot

        ort_custom_op_path = ''
        try:
            from mmcv.ops import get_onnxruntime_op_path
            ort_custom_op_path = get_onnxruntime_op_path()
        except (ImportError, ModuleNotFoundError):
            warnings.warn('If input model has custom op from mmcv, \
                you may have to build mmcv with ONNXRuntime from source.')
        model.CLASSES = get_classes(dataset)
        num_classes = len(model.CLASSES)
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)
        if test_img is not None:
            input_config['input_path'] = test_img
            one_img, one_meta = preprocess_example_input(input_config)
            tensor_data = [one_img]
        # check the numerical value
        # get pytorch output
        pytorch_results = model(tensor_data, [[one_meta]], return_loss=False)
        pytorch_results = pytorch_results[0]
        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        session_options = rt.SessionOptions()
        # register custom op for onnxruntime
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        sess = rt.InferenceSession(output_file, session_options)
        onnx_outputs = sess.run(None,
                                {net_feed_input[0]: one_img.detach().numpy()})
        output_names = [_.name for _ in sess.get_outputs()]
        output_shapes = [_.shape for _ in onnx_outputs]
        print(f'onnxruntime output names: {output_names}, \
            output shapes: {output_shapes}')
        nrof_out = len(onnx_outputs)
        assert nrof_out > 0, 'Must have output'
        with_mask = nrof_out == 3
        if nrof_out == 1:
            onnx_results = onnx_outputs[0]
        else:
            det_bboxes, det_labels = onnx_outputs[:2]
            onnx_results = bbox2result(det_bboxes, det_labels, num_classes)
            if with_mask:
                segm_results = onnx_outputs[2].squeeze(1)
                cls_segms = [[] for _ in range(num_classes)]
                for i in range(det_bboxes.shape[0]):
                    cls_segms[det_labels[i]].append(segm_results[i])
                onnx_results = (onnx_results, cls_segms)
        # visualize predictions

        if show:
            show_result_pyplot(model,
                               one_meta['show_img'],
                               pytorch_results,
                               title='Pytorch')
            show_result_pyplot(model,
                               one_meta['show_img'],
                               onnx_results,
                               title='ONNX')

        # compare a part of result

        if with_mask:
            compare_pairs = list(zip(onnx_results, pytorch_results))
        else:
            compare_pairs = [(onnx_results, pytorch_results)]
        for onnx_res, pytorch_res in compare_pairs:
            for o_res, p_res in zip(onnx_res, pytorch_res):
                np.testing.assert_allclose(
                    o_res,
                    p_res,
                    rtol=1e-03,
                    atol=1e-05,
                )
        print('The numerical values are the same between Pytorch and ONNX')
コード例 #6
0
def test_version_check():
    assert digit_version('1.0.5') > digit_version('1.0.5rc0')
    assert digit_version('1.0.5') > digit_version('1.0.4rc0')
    assert digit_version('1.0.5') > digit_version('1.0rc0')
    assert digit_version('1.0.0') > digit_version('0.6.2')
    assert digit_version('1.0.0') > digit_version('0.2.16')
    assert digit_version('1.0.5rc0') > digit_version('1.0.0rc0')
    assert digit_version('1.0.0rc1') > digit_version('1.0.0rc0')
    assert digit_version('1.0.0rc2') > digit_version('1.0.0rc0')
    assert digit_version('1.0.0rc2') > digit_version('1.0.0rc1')
    assert digit_version('1.0.1rc1') > digit_version('1.0.0rc1')
    assert digit_version('1.0.0') > digit_version('1.0.0rc1')
コード例 #7
0
ファイル: pytorch2onnx.py プロジェクト: zyzlimit/mmdetection
def pytorch2onnx(config_path,
                 checkpoint_path,
                 input_img,
                 input_shape,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 normalize_cfg=None,
                 dataset='coco',
                 test_img=None,
                 do_simplify=False,
                 cfg_options=None,
                 dynamic_export=None):

    input_config = {
        'input_shape': input_shape,
        'input_path': input_img,
        'normalize_cfg': normalize_cfg
    }

    # prepare original model and meta for verifying the onnx model
    orig_model = build_model_from_cfg(config_path,
                                      checkpoint_path,
                                      cfg_options=cfg_options)
    one_img, one_meta = preprocess_example_input(input_config)
    model, tensor_data = generate_inputs_and_wrap_model(
        config_path, checkpoint_path, input_config, cfg_options=cfg_options)
    output_names = ['dets', 'labels']
    if model.with_mask:
        output_names.append('masks')
    dynamic_axes = None
    if dynamic_export:
        dynamic_axes = {
            'input': {
                0: 'batch',
                2: 'width',
                3: 'height'
            },
            'dets': {
                0: 'batch',
                1: 'num_dets',
            },
            'labels': {
                0: 'batch',
                1: 'num_dets',
            },
        }
        if model.with_mask:
            dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}

    torch.onnx.export(model,
                      tensor_data,
                      output_file,
                      input_names=['input'],
                      output_names=output_names,
                      export_params=True,
                      keep_initializers_as_inputs=True,
                      do_constant_folding=True,
                      verbose=show,
                      opset_version=opset_version,
                      dynamic_axes=dynamic_axes)

    model.forward = orig_model.forward

    # get the custom op path
    ort_custom_op_path = ''
    try:
        from mmcv.ops import get_onnxruntime_op_path
        ort_custom_op_path = get_onnxruntime_op_path()
    except (ImportError, ModuleNotFoundError):
        warnings.warn('If input model has custom op from mmcv, \
            you may have to build mmcv with ONNXRuntime from source.')

    if do_simplify:
        from mmdet import digit_version
        import onnxsim

        min_required_version = '0.3.0'
        assert digit_version(onnxsim.__version__) >= digit_version(
            min_required_version
        ), f'Requires to install onnx-simplify>={min_required_version}'

        input_dic = {'input': one_img.detach().cpu().numpy()}
        onnxsim.simplify(output_file,
                         input_data=input_dic,
                         custom_lib=ort_custom_op_path)
    print(f'Successfully exported ONNX model: {output_file}')

    if verify:
        from mmdet.core import get_classes, bbox2result
        from mmdet.apis import show_result_pyplot

        model.CLASSES = get_classes(dataset)
        num_classes = len(model.CLASSES)
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)
        if dynamic_export:
            # scale up to test dynamic shape
            h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
            input_config['input_shape'] = (1, 3, h, w)
        if test_img is not None:
            input_config['input_path'] = test_img
        one_img, one_meta = preprocess_example_input(input_config)
        tensor_data = [one_img]

        # get pytorch output
        pytorch_results = model(tensor_data, [[one_meta]], return_loss=False)
        pytorch_results = pytorch_results[0]
        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        session_options = rt.SessionOptions()
        # register custom op for ONNX Runtime
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        feed_input_img = one_img.detach().numpy()
        if dynamic_export:
            # test batch with two input images
            feed_input_img = np.vstack([feed_input_img, feed_input_img])
        sess = rt.InferenceSession(output_file, session_options)
        onnx_outputs = sess.run(None, {net_feed_input[0]: feed_input_img})
        output_names = [_.name for _ in sess.get_outputs()]
        output_shapes = [_.shape for _ in onnx_outputs]
        print(f'ONNX Runtime output names: {output_names}, \
            output shapes: {output_shapes}')
        # get last image's outputs
        onnx_outputs = [_[-1] for _ in onnx_outputs]
        ort_dets, ort_labels = onnx_outputs[:2]
        onnx_results = bbox2result(ort_dets, ort_labels, num_classes)
        if model.with_mask:
            segm_results = onnx_outputs[2]
            cls_segms = [[] for _ in range(num_classes)]
            for i in range(ort_dets.shape[0]):
                cls_segms[ort_labels[i]].append(segm_results[i])
            onnx_results = (onnx_results, cls_segms)
        # visualize predictions
        if show:
            show_result_pyplot(model,
                               one_meta['show_img'],
                               pytorch_results,
                               title='Pytorch')
            show_result_pyplot(model,
                               one_meta['show_img'],
                               onnx_results,
                               title='ONNXRuntime')

        # compare a part of result
        if model.with_mask:
            compare_pairs = list(zip(onnx_results, pytorch_results))
        else:
            compare_pairs = [(onnx_results, pytorch_results)]
        err_msg = 'The numerical values are different between Pytorch' + \
                  ' and ONNX, but it does not necessarily mean the' + \
                  ' exported ONNX model is problematic.'
        # check the numerical value
        for onnx_res, pytorch_res in compare_pairs:
            for o_res, p_res in zip(onnx_res, pytorch_res):
                np.testing.assert_allclose(o_res,
                                           p_res,
                                           rtol=1e-03,
                                           atol=1e-05,
                                           err_msg=err_msg)
        print('The numerical values are the same between Pytorch and ONNX')
コード例 #8
0
ファイル: pytorch2onnx.py プロジェクト: taofuyu/mmdetection
    # get the custom op path
    ort_custom_op_path = ''
    try:
        from mmcv.ops import get_onnxruntime_op_path
        ort_custom_op_path = get_onnxruntime_op_path()
    except (ImportError, ModuleNotFoundError):
        warnings.warn('If input model has custom op from mmcv, \
            you may have to build mmcv with ONNXRuntime from source.')

    if do_simplify:
        import onnxsim

        from mmdet import digit_version

        min_required_version = '0.3.0'
        assert digit_version(onnxsim.__version__) >= digit_version(
            min_required_version
        ), f'Requires to install onnx-simplify>={min_required_version}'

        input_dic = {'input': img_list[0].detach().cpu().numpy()}
        model_opt, check_ok = onnxsim.simplify(
            output_file,
            input_data=input_dic,
            custom_lib=ort_custom_op_path,
            dynamic_input_shape=dynamic_export)
        if check_ok:
            onnx.save(model_opt, output_file)
            print(f'Successfully simplified ONNX model: {output_file}')
        else:
            warnings.warn('Failed to simplify ONNX model.')
    print(f'Successfully exported ONNX model: {output_file}')