Exemplo n.º 1
0
def pytorch2onnx(model,
                 input,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input (dict): We need to use this input to execute the model.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
    """
    model.cpu().eval()
    merged = input['merged'].unsqueeze(0)
    trimap = input['trimap'].unsqueeze(0)
    input = torch.cat((merged, trimap), 1)
    model.forward = model.forward_dummy
    # pytorch has some bug in pytorch1.3, we have to fix it
    # by replacing these existing op
    register_extra_symbolics(opset_version)
    with torch.no_grad():
        torch.onnx.export(model,
                          input,
                          output_file,
                          input_names=['cat_input'],
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          verbose=show,
                          opset_version=opset_version)
    print(f'Successfully exported ONNX model: {output_file}')
    if verify:
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        # get pytorch output, only concern pred_alpha
        pytorch_result = model(input)
        if isinstance(pytorch_result, (tuple, list)):
            pytorch_result = pytorch_result[0]
        pytorch_result = pytorch_result.detach().numpy()
        # get onnx output
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(None, {
            'cat_input': input.detach().numpy(),
        })
        # only concern pred_alpha value
        if isinstance(onnx_result, (tuple, list)):
            onnx_result = onnx_result[0]
        # check the numerical value
        assert np.allclose(
            pytorch_result,
            onnx_result), 'The outputs are different between Pytorch and ONNX'
        print('The numerical values are same between Pytorch and ONNX')
Exemplo n.º 2
0
def pytorch2onnx(model: nn.Module,
                 model_type: str,
                 img_path: str,
                 verbose: bool = False,
                 show: bool = False,
                 opset_version: int = 11,
                 output_file: str = 'tmp.onnx',
                 verify: bool = False,
                 dynamic_export: bool = False,
                 device_id: int = 0):
    """Export PyTorch model to ONNX model and verify the outputs are same
    between PyTorch and ONNX.

    Args:
        model (nn.Module): PyTorch model we want to export.
        model_type (str): Model type, detection or recognition model.
        img_path (str): We need to use this input to execute the model.
        opset_version (int): The onnx op version. Default: 11.
        verbose (bool): Whether print the computation graph. Default: False.
        show (bool): Whether visialize final results. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between PyTorch and ONNX.
            Default: False.
        dynamic_export (bool): Whether apply dynamic export.
            Default: False.
        device_id (id): Device id to place model and data.
            Default: 0
    """
    device = torch.device(type='cuda', index=device_id)
    model.to(device).eval()
    _convert_batchnorm(model)

    # prepare inputs
    mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
    imgs = mm_inputs.pop('img')
    img_metas = mm_inputs.pop('img_metas')

    if isinstance(imgs, list):
        imgs = imgs[0]

    img_list = [img[None, :].to(device) for img in imgs]

    origin_forward = model.forward
    if (model_type == 'det'):
        model.forward = partial(model.simple_test,
                                img_metas=img_metas,
                                rescale=True)
    else:
        model.forward = partial(model.forward,
                                img_metas=img_metas,
                                return_loss=False,
                                rescale=True)

    # pytorch has some bug in pytorch1.3, we have to fix it
    # by replacing these existing op
    register_extra_symbolics(opset_version)
    dynamic_axes = None
    if dynamic_export and model_type == 'det':
        dynamic_axes = {
            'input': {
                0: 'batch',
                2: 'height',
                3: 'width'
            },
            'output': {
                0: 'batch',
                2: 'height',
                3: 'width'
            }
        }
    elif dynamic_export and model_type == 'recog':
        dynamic_axes = {
            'input': {
                0: 'batch',
                3: 'width'
            },
            'output': {
                0: 'batch',
                1: 'seq_len',
                2: 'num_classes'
            }
        }
    with torch.no_grad():
        torch.onnx.export(model, (img_list[0], ),
                          output_file,
                          input_names=['input'],
                          output_names=['output'],
                          export_params=True,
                          keep_initializers_as_inputs=False,
                          verbose=verbose,
                          opset_version=opset_version,
                          dynamic_axes=dynamic_axes)
    print(f'Successfully exported ONNX model: {output_file}')
    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
        if dynamic_export:
            # scale image for dynamic shape test
            img_list = [
                nn.functional.interpolate(_, scale_factor=scale_factor)
                for _ in img_list
            ]
            if model_type == 'det':
                img_metas[0][0]['scale_factor'] = img_metas[0][0][
                    'scale_factor'] * (scale_factor * 2)

        # check the numerical value
        # get pytorch output
        with torch.no_grad():
            model.forward = origin_forward
            pytorch_out = model.simple_test(img_list[0],
                                            img_metas[0],
                                            rescale=True)

        # get onnx output
        if model_type == 'det':
            onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
        else:
            onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
                                               device_id)
        onnx_out = onnx_model.simple_test(img_list[0],
                                          img_metas[0],
                                          rescale=True)

        # compare results
        same_diff = 'same'
        if model_type == 'recog':
            for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
                if onnx_result['text'] != pytorch_result[
                        'text'] or not np.allclose(
                            np.array(onnx_result['score']),
                            np.array(pytorch_result['score']),
                            rtol=1e-4,
                            atol=1e-4):
                    same_diff = 'different'
                    break
        else:
            for onnx_result, pytorch_result in zip(
                    onnx_out[0]['boundary_result'],
                    pytorch_out[0]['boundary_result']):
                if not np.allclose(np.array(onnx_result),
                                   np.array(pytorch_result),
                                   rtol=1e-4,
                                   atol=1e-4):
                    same_diff = 'different'
                    break
        print('The outputs are {} between PyTorch and ONNX'.format(same_diff))

        if show:
            onnx_img = onnx_model.show_result(img_path,
                                              onnx_out[0],
                                              out_file='onnx.jpg',
                                              show=False)
            pytorch_img = model.show_result(img_path,
                                            pytorch_out[0],
                                            out_file='pytorch.jpg',
                                            show=False)
            if onnx_img is None:
                onnx_img = cv2.imread(img_path)
            if pytorch_img is None:
                pytorch_img = cv2.imread(img_path)

            cv2.imshow('PyTorch', pytorch_img)
            cv2.imshow('ONNXRuntime', onnx_img)
            cv2.waitKey()
    return
Exemplo n.º 3
0
def pytorch2onnx(model,
                 input_shape,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input_shape (tuple): Use this input shape to construct
            the corresponding dummy input and execute the model.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
    """
    model.cpu().eval()

    if isinstance(model.decode_head, nn.ModuleList):
        num_classes = model.decode_head[-1].num_classes
    else:
        num_classes = model.decode_head.num_classes

    mm_inputs = _demo_mm_inputs(input_shape, num_classes)

    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')

    img_list = [img[None, :] for img in imgs]
    img_meta_list = [[img_meta] for img_meta in img_metas]

    # replace original forward function
    origin_forward = model.forward
    model.forward = partial(model.forward,
                            img_metas=img_meta_list,
                            return_loss=False)

    register_extra_symbolics(opset_version)
    with torch.no_grad():
        torch.onnx.export(model, (img_list, ),
                          output_file,
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          verbose=show,
                          opset_version=opset_version)
        print(f'Successfully exported ONNX model: {output_file}')
    model.forward = origin_forward

    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        # check the numerical value
        # get pytorch output
        pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(
            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
        if not np.allclose(pytorch_result, onnx_result):
            raise ValueError(
                'The outputs are different between Pytorch and ONNX')
        print('The outputs are same between Pytorch and ONNX')
Exemplo n.º 4
0
def pytorch2onnx(model,
                 mm_inputs,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 dynamic_export=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        mm_inputs (dict): Contain the input tensors and img_metas information.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
        dynamic_export (bool): Whether to export ONNX with dynamic axis.
            Default: False.
    """
    model.cpu().eval()
    test_mode = model.test_cfg.mode

    if isinstance(model.decode_head, nn.ModuleList):
        num_classes = model.decode_head[-1].num_classes
    else:
        num_classes = model.decode_head.num_classes

    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')

    img_list = [img[None, :] for img in imgs]
    img_meta_list = [[img_meta] for img_meta in img_metas]
    # update img_meta
    img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

    # replace original forward function
    origin_forward = model.forward
    model.forward = partial(
        model.forward,
        img_metas=img_meta_list,
        return_loss=False,
        rescale=True)
    dynamic_axes = None
    if dynamic_export:
        if test_mode == 'slide':
            dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
        else:
            dynamic_axes = {
                'input': {
                    0: 'batch',
                    2: 'height',
                    3: 'width'
                },
                'output': {
                    1: 'batch',
                    2: 'height',
                    3: 'width'
                }
            }

    register_extra_symbolics(opset_version)
    with torch.no_grad():
        torch.onnx.export(
            model, (img_list, ),
            output_file,
            input_names=['input'],
            output_names=['output'],
            export_params=True,
            keep_initializers_as_inputs=False,
            verbose=show,
            opset_version=opset_version,
            dynamic_axes=dynamic_axes)
        print(f'Successfully exported ONNX model: {output_file}')
    model.forward = origin_forward

    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        if dynamic_export and test_mode == 'whole':
            # scale image for dynamic shape test
            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
            # concate flip image for batch test
            flip_img_list = [_.flip(-1) for _ in img_list]
            img_list = [
                torch.cat((ori_img, flip_img), 0)
                for ori_img, flip_img in zip(img_list, flip_img_list)
            ]

            # update img_meta
            img_list, img_meta_list = _update_input_img(
                img_list, img_meta_list, test_mode == 'whole')

        # check the numerical value
        # get pytorch output
        with torch.no_grad():
            pytorch_result = model(img_list, img_meta_list, return_loss=False)
            pytorch_result = np.stack(pytorch_result, 0)

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(
            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
        # show segmentation results
        if show:
            import cv2
            import os.path as osp
            img = img_meta_list[0][0]['filename']
            if not osp.exists(img):
                img = imgs[0][:3, ...].permute(1, 2, 0) * 255
                img = img.detach().numpy().astype(np.uint8)
                ori_shape = img.shape[:2]
            else:
                ori_shape = LoadImage()({'img': img})['ori_shape']

            # resize onnx_result to ori_shape
            onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
                                      (ori_shape[1], ori_shape[0]))
            show_result_pyplot(
                model,
                img, (onnx_result_, ),
                palette=model.PALETTE,
                block=False,
                title='ONNXRuntime',
                opacity=0.5)

            # resize pytorch_result to ori_shape
            pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
                                         (ori_shape[1], ori_shape[0]))
            show_result_pyplot(
                model,
                img, (pytorch_result_, ),
                title='PyTorch',
                palette=model.PALETTE,
                opacity=0.5)
        # compare results
        np.testing.assert_allclose(
            pytorch_result.astype(np.float32) / num_classes,
            onnx_result.astype(np.float32) / num_classes,
            rtol=1e-5,
            atol=1e-5,
            err_msg='The outputs are different between Pytorch and ONNX')
        print('The outputs are same between Pytorch and ONNX')
Exemplo n.º 5
0
def pytorch2onnx(model,
                 input,
                 model_type,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 dynamic_export=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input (dict): We need to use this input to execute the model.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
    """
    model.cpu().eval()

    if model_type == 'mattor':
        merged = input['merged'].unsqueeze(0)
        trimap = input['trimap'].unsqueeze(0)
        data = torch.cat((merged, trimap), 1)
    elif model_type == 'restorer':
        data = input['lq'].unsqueeze(0)
    model.forward = model.forward_dummy
    # pytorch has some bug in pytorch1.3, we have to fix it
    # by replacing these existing op
    register_extra_symbolics(opset_version)
    dynamic_axes = None
    if dynamic_export:
        dynamic_axes = {
            'input': {
                0: 'batch',
                2: 'height',
                3: 'width'
            },
            'output': {
                0: 'batch',
                2: 'height',
                3: 'width'
            }
        }
    with torch.no_grad():
        torch.onnx.export(model,
                          data,
                          output_file,
                          input_names=['input'],
                          output_names=['output'],
                          export_params=True,
                          keep_initializers_as_inputs=False,
                          verbose=show,
                          opset_version=opset_version,
                          dynamic_axes=dynamic_axes)
    print(f'Successfully exported ONNX model: {output_file}')
    if verify:
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        if dynamic_export:
            # scale image for dynamic shape test
            data = torch.nn.functional.interpolate(data, scale_factor=1.1)

            # concate flip image for batch test
            flip_data = data.flip(-1)
            data = torch.cat((data, flip_data), 0)

        # get pytorch output, only concern pred_alpha
        with torch.no_grad():
            pytorch_result = model(data)
        if isinstance(pytorch_result, (tuple, list)):
            pytorch_result = pytorch_result[0]
        pytorch_result = pytorch_result.detach().numpy()
        # get onnx output
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(None, {
            'input': data.detach().numpy(),
        })
        # only concern pred_alpha value
        if isinstance(onnx_result, (tuple, list)):
            onnx_result = onnx_result[0]

        if show:
            pytorch_visualize = pytorch_result[0].transpose(1, 2, 0)
            pytorch_visualize = np.clip(pytorch_visualize, 0, 1)[:, :, ::-1]
            onnx_visualize = onnx_result[0].transpose(1, 2, 0)
            onnx_visualize = np.clip(onnx_visualize, 0, 1)[:, :, ::-1]

            cv2.imshow('PyTorch', pytorch_visualize)
            cv2.imshow('ONNXRuntime', onnx_visualize)
            cv2.waitKey()

        # check the numerical value
        assert np.allclose(
            pytorch_result, onnx_result, rtol=1e-5,
            atol=1e-5), 'The outputs are different between Pytorch and ONNX'
        print('The numerical values are same between Pytorch and ONNX')
Exemplo n.º 6
0
def pytorch2onnx(model,
                 input_shape,
                 opset_version=11,
                 dynamic_export=False,
                 show=False,
                 output_file='tmp.onnx',
                 do_simplify=False,
                 verify=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input_shape (tuple): Use this input shape to construct
            the corresponding dummy input and execute the model.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
    """
    model.cpu().eval()

    if hasattr(model.head, 'num_classes'):
        num_classes = model.head.num_classes
    # Some backbones use `num_classes=-1` to disable top classifier.
    elif getattr(model.backbone, 'num_classes', -1) > 0:
        num_classes = model.backbone.num_classes
    else:
        raise AttributeError('Cannot find "num_classes" in both head and '
                             'backbone, please check the config file.')

    mm_inputs = _demo_mm_inputs(input_shape, num_classes)

    imgs = mm_inputs.pop('imgs')
    img_list = [img[None, :] for img in imgs]

    # replace original forward function
    origin_forward = model.forward
    model.forward = partial(model.forward, img_metas={}, return_loss=False)
    register_extra_symbolics(opset_version)

    # support dynamic shape export
    if dynamic_export:
        dynamic_axes = {
            'input': {
                0: 'batch',
                2: 'width',
                3: 'height'
            },
            'probs': {
                0: 'batch'
            }
        }
    else:
        dynamic_axes = {}

    with torch.no_grad():
        torch.onnx.export(model, (img_list, ),
                          output_file,
                          input_names=['input'],
                          output_names=['probs'],
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          dynamic_axes=dynamic_axes,
                          verbose=show,
                          opset_version=opset_version)
        print(f'Successfully exported ONNX model: {output_file}')
    model.forward = origin_forward

    if do_simplify:
        import onnx
        import onnxsim
        from mmcv import digit_version

        min_required_version = '0.3.0'
        assert digit_version(mmcv.__version__) >= digit_version(
            min_required_version
        ), f'Requires to install onnx-simplify>={min_required_version}'

        if dynamic_axes:
            input_shape = (input_shape[0], input_shape[1], input_shape[2] * 2,
                           input_shape[3] * 2)
        else:
            input_shape = (input_shape[0], input_shape[1], input_shape[2],
                           input_shape[3])
        imgs = _demo_mm_inputs(input_shape, model.head.num_classes).pop('imgs')
        input_dic = {'input': imgs.detach().cpu().numpy()}
        input_shape_dic = {'input': list(input_shape)}

        model_opt, check_ok = onnxsim.simplify(
            output_file,
            input_shapes=input_shape_dic,
            input_data=input_dic,
            dynamic_input_shape=dynamic_export)
        if check_ok:
            onnx.save(model_opt, output_file)
            print(f'Successfully simplified ONNX model: {output_file}')
        else:
            print('Failed to simplify ONNX model.')
    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        # test the dynamic model
        if dynamic_export:
            dynamic_test_inputs = _demo_mm_inputs(
                (input_shape[0], input_shape[1], input_shape[2] * 2,
                 input_shape[3] * 2), model.head.num_classes)
            imgs = dynamic_test_inputs.pop('imgs')
            img_list = [img[None, :] for img in imgs]

        # check the numerical value
        # get pytorch output
        pytorch_result = model(img_list, img_metas={}, return_loss=False)[0]

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(
            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
        if not np.allclose(pytorch_result, onnx_result):
            raise ValueError(
                'The outputs are different between Pytorch and ONNX')
        print('The outputs are same between Pytorch and ONNX')
Exemplo n.º 7
0
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mmcv
import torch

from mmcv.onnx import register_extra_symbolics
from mmseg.models.segmentors import EncoderDecoder

register_extra_symbolics(opset=11)


def _convert_batchnorm(module):
    for name, child in module.named_children():
        if isinstance(child, torch.nn.SyncBatchNorm):
            new_child = torch.nn.BatchNorm2d(child.num_features, child.eps,
                                             child.momentum, child.affine,
                                             child.track_running_stats)
            if child.affine:
                new_child.weight.data = child.weight.data.clone().detach()
                new_child.bias.data = child.bias.data.clone().detach()
                # keep requires_grad unchanged
                new_child.weight.requires_grad = child.weight.requires_grad
                new_child.bias.requires_grad = child.bias.requires_grad
            new_child.running_mean = child.running_mean