Esempio n. 1
0
def test_scatternd():

    def func(data):
        data[:, :-2] += 1
        data[:2, :] -= 1
        return data

    data = torch.zeros(4, 4).cuda()
    wrapped_model = WrapFunction(func).eval().cuda()

    input_names = ['input']
    output_names = ['output']

    with torch.no_grad():
        torch.onnx.export(
            wrapped_model, (data.clone(), ),
            onnx_file,
            export_params=True,
            keep_initializers_as_inputs=True,
            input_names=input_names,
            output_names=output_names,
            opset_version=11)

    onnx_model = onnx.load(onnx_file)

    # create trt engine and wrapper
    opt_shape_dict = {
        'input': [list(data.shape),
                  list(data.shape),
                  list(data.shape)],
    }
    # trt config
    fp16_mode = False
    max_workspace_size = 1 << 30

    trt_engine = onnx2trt(
        onnx_model,
        opt_shape_dict,
        fp16_mode=fp16_mode,
        max_workspace_size=max_workspace_size)

    save_trt_engine(trt_engine, trt_file)
    trt_model = TRTWrapper(trt_file, input_names, output_names)

    with torch.no_grad():
        trt_outputs = trt_model({'input': data.clone()})
        trt_results = trt_outputs['output']

    # compute pytorch_output
    with torch.no_grad():
        pytorch_results = wrapped_model(data.clone())

    # allclose
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    if os.path.exists(trt_file):
        os.remove(trt_file)
    assert torch.allclose(pytorch_results, trt_results)
Esempio n. 2
0
def test_corner_pool(mode):
    try:
        from mmcv.ops import CornerPool
    except (ImportError, ModuleNotFoundError):
        pytest.skip('test requires compilation')

    opset = 11
    # register custom op `mmcv::MMCVCornerPool`
    from mmcv.onnx.symbolic import register_extra_symbolics
    register_extra_symbolics(opset)

    # trt config
    fp16_mode = False
    max_workspace_size = 1 << 30

    inputs = [
        # (n, c, h, w)
        torch.rand((2, 3, 5, 5)),
        torch.rand((1, 2, 4, 6)),
        torch.rand((2, 1, 3, 2)),
    ]

    class CornerPoolWrapper(CornerPool):
        def __init__(self, mode):
            super(CornerPoolWrapper, self).__init__(mode)

        def forward(self, x):
            # no use `torch.cummax`, instead `corner_pool` is used
            # for various torch version
            return self.corner_pool.apply(x)

    wrapped_model = CornerPoolWrapper(mode).cuda()
    for input in inputs:
        input = input.cuda()

        with torch.no_grad():
            torch.onnx.export(wrapped_model, (input, ),
                              onnx_file,
                              export_params=True,
                              keep_initializers_as_inputs=True,
                              input_names=['input'],
                              output_names=['output'],
                              opset_version=opset)
        onnx_model = onnx.load(onnx_file)

        # create trt engine and wrapper
        opt_shape_dict = {
            'input': [list(input.shape),
                      list(input.shape),
                      list(input.shape)],
        }
        trt_engine = onnx2trt(onnx_model,
                              opt_shape_dict,
                              fp16_mode=fp16_mode,
                              max_workspace_size=max_workspace_size)
        save_trt_engine(trt_engine, trt_file)
        trt_model = TRTWrapper(trt_file, ['input'], ['output'])

        with torch.no_grad():
            trt_outputs = trt_model({'input': input})
            trt_pool_feat = trt_outputs['output']

        # compute pytorch_output
        with torch.no_grad():
            pytorch_pool_feat = wrapped_model(input)

        # allclose
        if os.path.exists(onnx_file):
            os.remove(onnx_file)
        if os.path.exists(trt_file):
            os.remove(trt_file)
        assert torch.allclose(pytorch_pool_feat, trt_pool_feat, atol=1e-5)
Esempio n. 3
0
def test_instance_norm(dynamic_export, fp16_mode):

    n, c, h, w = 2, 3, 10, 10
    data = torch.randn(n, c, h, w).cuda()
    norm = nn.InstanceNorm2d(c, affine=True)

    wrapped_model = WrapFunction(norm).eval().cuda()

    input_names = ['input']
    output_names = ['output']
    dynamic_axes = None
    if dynamic_export:
        dynamic_axes = {
            'input': {
                0: 'n',
                2: 'h',
                3: 'w',
            },
            'output': {
                0: 'n',
                2: 'h',
                3: 'w',
            },
        }
    with torch.no_grad():
        torch.onnx.export(wrapped_model, (data.clone(), ),
                          onnx_file,
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          opset_version=11)

    onnx_model = onnx.load(onnx_file)

    # create trt engine and wrapper
    if dynamic_export:
        opt_shape_dict = {
            'input':
            [list(data.shape),
             list(data.shape), [2 * n, c, 2 * h, 2 * w]],
        }
    else:
        opt_shape_dict = {
            'input': [list(data.shape),
                      list(data.shape),
                      list(data.shape)],
        }
    # trt config
    max_workspace_size = 1 << 30

    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=fp16_mode,
                          max_workspace_size=max_workspace_size)

    save_trt_engine(trt_engine, trt_file)
    trt_model = TRTWrapper(trt_file, input_names, output_names)

    with torch.no_grad():
        trt_outputs = trt_model({'input': data.clone()})
        trt_results = trt_outputs['output']

    # compute pytorch_output
    with torch.no_grad():
        pytorch_results = wrapped_model(data.clone())

    # allclose
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    if os.path.exists(trt_file):
        os.remove(trt_file)
    assert torch.allclose(pytorch_results, trt_results)
Esempio n. 4
0
def test_cummin_cummax(func: Callable):
    # Note generally `cummax` or `cummin` is exportable to ONNX
    # as long as the pytorch version >= 1.5.0, since `torch.cummax`
    # is only supported with torch >= 1.5.0.
    # But when `cummax` or `cummin` serves as an intermediate component
    # whose outputs is used as inputs for another modules, it's expected
    # that pytorch version must be >= 1.7.0. Otherwise error appears like:
    # `RuntimeError: tuple  appears in op that does not forward tuples,
    # unsupported 'kind: prim::PythonOp`.
    from packaging import version
    if version.parse(torch.__version__) < version.parse('1.7.0'):
        pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0')

    opset = 11
    # register custom op `mmcv::cummax` and `mmcv::cummin`
    from mmcv.onnx.symbolic import register_extra_symbolics
    register_extra_symbolics(opset)

    input_list = [
        # arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
        torch.rand((2, 3, 4, 1, 5)).cuda(),
        torch.rand((1)).cuda()
    ]

    input_names = ['input']
    output_names = ['output', 'indices']

    for input in input_list:
        ndims = input.dim()
        # valid dim range is [-ndims, ndims-1]
        # test for all `dim` value which is valid
        for dim in range(-ndims, ndims):
            cummax_func = partial(func, dim=dim)
            wrapped_model = WrapFunction(cummax_func).eval().cuda()

            with torch.no_grad():
                torch.onnx.export(wrapped_model,
                                  input,
                                  onnx_file,
                                  export_params=True,
                                  keep_initializers_as_inputs=False,
                                  input_names=input_names,
                                  output_names=output_names,
                                  opset_version=opset)

            onnx_model = onnx.load(onnx_file)

            # create trt engine and wrapper
            opt_shape_dict = {
                'input':
                [list(input.shape),
                 list(input.shape),
                 list(input.shape)]
            }
            # trt config
            fp16_mode = False
            max_workspace_size = 1 << 30

            trt_engine = onnx2trt(onnx_model,
                                  opt_shape_dict,
                                  fp16_mode=fp16_mode,
                                  max_workspace_size=max_workspace_size)

            # remove ONNX model after conversion
            if os.path.exists(onnx_file):
                os.remove(onnx_file)

            # save TensorRT model
            save_trt_engine(trt_engine, trt_file)

            # load and wrap TensorRT model
            trt_model = TRTWrapper(trt_file)

            # remove trt model after loading
            if os.path.exists(trt_file):
                os.remove(trt_file)

            # compute trt output
            with torch.no_grad():
                trt_results = trt_model({'input': input.contiguous().clone()})
                trt_output = trt_results['output']
                trt_indices = trt_results['indices']

            # compute pytorch output
            with torch.no_grad():
                pytorch_results = wrapped_model(input.clone())
                pytorch_output = pytorch_results[0]
                pytorch_indices = pytorch_results[1]

            torch.testing.assert_allclose(trt_output, pytorch_output)
            torch.testing.assert_allclose(trt_indices, pytorch_indices)
Esempio n. 5
0
def test_grid_sample(mode, padding_mode, align_corners):
    from mmcv.onnx.symbolic import register_extra_symbolics

    register_extra_symbolics(11)

    input = torch.rand(1, 1, 10, 10).cuda()
    grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
    grid = F.affine_grid(grid, (1, 1, 15, 15)).type_as(input).cuda()

    def func(input, grid):
        return F.grid_sample(input,
                             grid,
                             mode=mode,
                             padding_mode=padding_mode,
                             align_corners=align_corners)

    wrapped_model = WrapFunction(func).eval().cuda()

    input_names = ['input', 'grid']
    output_names = ['output']

    with torch.no_grad():
        torch.onnx.export(wrapped_model, (input.clone(), grid.clone()),
                          onnx_file,
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          input_names=input_names,
                          output_names=output_names,
                          opset_version=11)

    onnx_model = onnx.load(onnx_file)

    # create trt engine and wrapper
    opt_shape_dict = {
        'input': [list(input.shape),
                  list(input.shape),
                  list(input.shape)],
        'grid': [list(grid.shape),
                 list(grid.shape),
                 list(grid.shape)],
    }
    # trt config
    fp16_mode = False
    max_workspace_size = 1 << 30

    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=fp16_mode,
                          max_workspace_size=max_workspace_size)

    save_trt_engine(trt_engine, trt_file)
    trt_model = TRTWrapper(trt_file, input_names, output_names)

    with torch.no_grad():
        trt_outputs = trt_model({'input': input.clone(), 'grid': grid.clone()})
        trt_results = trt_outputs['output']

    # compute pytorch_output
    with torch.no_grad():
        pytorch_results = wrapped_model(input.clone(), grid.clone())

    # allclose
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    if os.path.exists(trt_file):
        os.remove(trt_file)
    assert torch.allclose(pytorch_results, trt_results)
Esempio n. 6
0
def test_roialign():
    try:
        from mmcv.ops import RoIAlign
    except (ImportError, ModuleNotFoundError):
        pytest.skip('test requires compilation')

    # trt config
    fp16_mode = False
    max_workspace_size = 1 << 30

    # roi align config
    pool_h = 2
    pool_w = 2
    spatial_scale = 1.0
    sampling_ratio = 2

    inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
              ([[[[1., 2.], [3., 4.]], [[4., 3.],
                                        [2., 1.]]]], [[0., 0., 0., 1., 1.]]),
              ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
                  [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]

    wrapped_model = RoIAlign((pool_w, pool_h), spatial_scale, sampling_ratio,
                             'avg', True).cuda()
    for case in inputs:
        np_input = np.array(case[0], dtype=np.float32)
        np_rois = np.array(case[1], dtype=np.float32)
        input = torch.from_numpy(np_input).cuda()
        rois = torch.from_numpy(np_rois).cuda()

        with torch.no_grad():
            torch.onnx.export(wrapped_model, (input, rois),
                              onnx_file,
                              export_params=True,
                              keep_initializers_as_inputs=True,
                              input_names=['input', 'rois'],
                              output_names=['roi_feat'],
                              opset_version=11)
        onnx_model = onnx.load(onnx_file)

        # create trt engine and wrapper
        opt_shape_dict = {
            'input': [list(input.shape),
                      list(input.shape),
                      list(input.shape)],
            'rois': [list(rois.shape),
                     list(rois.shape),
                     list(rois.shape)]
        }
        trt_engine = onnx2trt(onnx_model,
                              opt_shape_dict,
                              fp16_mode=fp16_mode,
                              max_workspace_size=max_workspace_size)
        save_trt_engine(trt_engine, trt_file)
        trt_model = TRTWrapper(trt_file, ['input', 'rois'], ['roi_feat'])

        with torch.no_grad():
            trt_outputs = trt_model({'input': input, 'rois': rois})
            trt_roi_feat = trt_outputs['roi_feat']

        # compute pytorch_output
        with torch.no_grad():
            pytorch_roi_feat = wrapped_model(input, rois)

        # allclose
        if os.path.exists(onnx_file):
            os.remove(onnx_file)
        if os.path.exists(trt_file):
            os.remove(trt_file)
        assert torch.allclose(pytorch_roi_feat, trt_roi_feat)
Esempio n. 7
0
def test_modulated_deform_conv(with_bias):
    try:
        from mmcv.ops import ModulatedDeformConv2dPack
    except (ImportError, ModuleNotFoundError):
        pytest.skip('test requires compilation')

    input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]

    x = torch.Tensor(input).cuda()
    model = ModulatedDeformConv2dPack(1,
                                      1,
                                      kernel_size=(2, 2),
                                      stride=1,
                                      padding=1,
                                      deform_groups=1,
                                      bias=with_bias)
    model.weight.data.fill_(1.)
    model.type(torch.float32)
    model = model.cuda().eval()

    input_names = ['input']
    output_names = ['output']

    with torch.no_grad():
        torch.onnx.export(model, (x.clone(), ),
                          onnx_file,
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          input_names=input_names,
                          output_names=output_names,
                          opset_version=11)

    onnx_model = onnx.load(onnx_file)

    # create trt engine and wrapper
    opt_shape_dict = {
        'input': [list(x.shape), list(x.shape),
                  list(x.shape)],
    }
    # trt config
    fp16_mode = False
    max_workspace_size = 1 << 30

    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=fp16_mode,
                          max_workspace_size=max_workspace_size)

    save_trt_engine(trt_engine, trt_file)
    trt_model = TRTWrapper(trt_file, input_names, output_names)

    with torch.no_grad():
        trt_outputs = trt_model({'input': x.clone()})
        trt_results = trt_outputs['output']

    # compute pytorch_output
    with torch.no_grad():
        pytorch_results = model(x.clone())

    # allclose
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    if os.path.exists(trt_file):
        os.remove(trt_file)
    torch.testing.assert_allclose(pytorch_results, trt_results)
Esempio n. 8
0
def onnx2tensorrt(onnx_file: str,
                  model_type: str,
                  trt_file: str,
                  config: dict,
                  input_config: dict,
                  fp16: bool = False,
                  verify: bool = False,
                  show: bool = False,
                  workspace_size: int = 1,
                  verbose: bool = False):
    import tensorrt as trt
    min_shape = input_config['min_shape']
    max_shape = input_config['max_shape']
    # create trt engine and wrapper
    opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_file,
        opt_shape_dict,
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
        fp16_mode=fp16,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        mm_inputs = _prepare_input_img(input_config['input_path'],
                                       config.data.test.pipeline)

        imgs = mm_inputs.pop('img')
        img_metas = mm_inputs.pop('img_metas')

        if isinstance(imgs, list):
            imgs = imgs[0]

        img_list = [img[None, :] for img in imgs]
        # update img_meta
        img_list, img_metas = _update_input_img(img_list, img_metas)

        # Get results from ONNXRuntime
        if model_type == 'det':
            onnx_model = ONNXRuntimeDetector(onnx_file, config, 0)
        else:
            onnx_model = ONNXRuntimeRecognizer(onnx_file, config, 0)
        onnx_out = onnx_model.simple_test(img_list[0],
                                          img_metas[0],
                                          rescale=True)

        # Get results from TensorRT
        if model_type == 'det':
            trt_model = TensorRTDetector(trt_file, config, 0)
        else:
            trt_model = TensorRTRecognizer(trt_file, config, 0)
        img_list[0] = img_list[0].to(torch.device('cuda:0'))
        trt_out = trt_model.simple_test(img_list[0],
                                        img_metas[0],
                                        rescale=True)

        # compare results
        same_diff = 'same'
        if model_type == 'recog':
            for onnx_result, trt_result in zip(onnx_out, trt_out):
                if onnx_result['text'] != trt_result['text'] or \
                     not np.allclose(
                            np.array(onnx_result['score']),
                            np.array(trt_result['score']),
                            rtol=1e-4,
                            atol=1e-4):
                    same_diff = 'different'
                    break
        else:
            for onnx_result, trt_result in zip(onnx_out[0]['boundary_result'],
                                               trt_out[0]['boundary_result']):
                if not np.allclose(np.array(onnx_result),
                                   np.array(trt_result),
                                   rtol=1e-4,
                                   atol=1e-4):
                    same_diff = 'different'
                    break
        print('The outputs are {} between TensorRT and ONNX'.format(same_diff))

        if show:
            onnx_img = onnx_model.show_result(input_config['input_path'],
                                              onnx_out[0],
                                              out_file='onnx.jpg',
                                              show=False)
            trt_img = trt_model.show_result(input_config['input_path'],
                                            trt_out[0],
                                            out_file='tensorrt.jpg',
                                            show=False)
            if onnx_img is None:
                onnx_img = cv2.imread(input_config['input_path'])
            if trt_img is None:
                trt_img = cv2.imread(input_config['input_path'])

            cv2.imshow('TensorRT', trt_img)
            cv2.imshow('ONNXRuntime', onnx_img)
            cv2.waitKey()
    return
Esempio n. 9
0
def test_batched_nms():
    try:
        import mmcv
        from mmcv.ops import batched_nms
    except (ImportError, ModuleNotFoundError):
        pytest.skip('test requires compilation')

    # trt config
    os.environ['ONNX_BACKEND'] = 'MMCVTensorRT'
    fp16_mode = False
    max_workspace_size = 1 << 30
    data = mmcv.load('./tests/data/batched_nms_data.pkl')
    nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1)
    boxes = torch.from_numpy(data['boxes']).cuda()
    scores = torch.from_numpy(data['scores']).cuda()
    idxs = torch.from_numpy(data['idxs']).cuda()
    class_agnostic = False

    nms = partial(batched_nms, nms_cfg=nms_cfg, class_agnostic=class_agnostic)
    wrapped_model = WrapFunction(nms)
    wrapped_model.cpu().eval()
    input_data = (boxes.detach().cpu(), scores.detach().cpu(),
                  idxs.detach().cpu())
    input_names = ['boxes', 'scores', 'idxs']
    output_names = ['dets', 'inds']
    with torch.no_grad():
        torch.onnx.export(wrapped_model,
                          input_data,
                          onnx_file,
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          input_names=input_names,
                          output_names=output_names,
                          opset_version=11)
    onnx_model = onnx.load(onnx_file)
    # create trt engine and wrapper
    opt_shape_dict = {
        'boxes': [list(boxes.shape),
                  list(boxes.shape),
                  list(boxes.shape)],
        'scores': [list(scores.shape),
                   list(scores.shape),
                   list(scores.shape)],
        'idxs': [list(idxs.shape),
                 list(idxs.shape),
                 list(idxs.shape)]
    }
    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=fp16_mode,
                          max_workspace_size=max_workspace_size)
    save_trt_engine(trt_engine, trt_file)
    trt_model = TRTWrapper(trt_file, input_names, output_names)

    with torch.no_grad():
        trt_outputs = trt_model({
            'boxes': boxes,
            'scores': scores,
            'idxs': idxs
        })
        trt_dets = trt_outputs['dets']
        trt_inds = trt_outputs['inds']
        trt_inds = trt_inds.long()

    # compute pytorch_output
    with torch.no_grad():
        pytorch_outputs = wrapped_model(boxes, scores, idxs)
        pytorch_dets, pytorch_inds = pytorch_outputs
    # allclose
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    if os.path.exists(trt_file):
        os.remove(trt_file)
    num_boxes = pytorch_dets.shape[0]
    trt_dets = trt_dets[:num_boxes, ...]
    trt_inds = trt_inds[:num_boxes]
    trt_scores = trt_dets[:, 4]
    pytorch_scores = pytorch_dets[:, 4]

    os.environ.pop('ONNX_BACKEND')
    assert torch.allclose(pytorch_scores, trt_scores)
    assert torch.equal(pytorch_inds, trt_inds)
Esempio n. 10
0
def onnx2tensorrt(onnx_file: str,
                  trt_file: str,
                  config: dict,
                  input_config: dict,
                  model_type: str,
                  img_path: str,
                  fp16: bool = False,
                  verify: bool = False,
                  show: bool = False,
                  workspace_size: int = 1,
                  verbose: bool = False):
    """Convert ONNX model to TensorRT model

    Args:
        onnx_file (str): the path of the input ONNX file.
        trt_file (str): the path to output the TensorRT file.
        config (dict): MMCV configuration.
        input_config (dict): contains min_shape, max_shape and \
            input image path.
        fp16 (bool): whether to enable fp16 mode.
        verify (bool): whether to verify the ouputs of TensorRT \
            and ONNX are same.
        show (bool): whether to show the outputs of TensorRT and ONNX.
        verbose (bool): whether to print the log when generating \
            TensorRT model.
    """
    import tensorrt as trt
    min_shape = input_config['min_shape']
    max_shape = input_config['max_shape']
    # create trt engine and wraper
    opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_file,
        opt_shape_dict,
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
        fp16_mode=fp16,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        inputs = _prepare_input_img(model_type=model_type,
                                    img_path=img_path,
                                    config=config)

        imgs = inputs['imgs']
        img_list = [imgs.unsqueeze(0)]

        if max_shape[0] > 1:
            # concate flip image for batch test
            flip_img_list = [_.flip(-1) for _ in img_list]
            img_list = [
                torch.cat((ori_img, flip_img), 0)
                for ori_img, flip_img in zip(img_list, flip_img_list)
            ]

        # Get results from ONNXRuntime
        ort_custom_op_path = get_onnxruntime_op_path()
        session_options = ort.SessionOptions()
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        sess = ort.InferenceSession(onnx_file, session_options)
        sess.set_providers(['CPUExecutionProvider'], [{}])  # use cpu mode
        onnx_output = sess.run(['output'],
                               {'input': img_list[0].detach().numpy()})[0][0]

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, ['input'], ['output'])
        with torch.no_grad():
            trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
        trt_output = trt_outputs['output'][0].cpu().detach().numpy()

        if show:
            onnx_visualize = onnx_output.transpose(1, 2, 0)
            onnx_visualize = np.clip(onnx_visualize, 0, 1)[:, :, ::-1]
            trt_visualize = trt_output.transpose(1, 2, 0)
            trt_visualize = np.clip(trt_visualize, 0, 1)[:, :, ::-1]

            cv2.imshow('ONNXRuntime', onnx_visualize)
            cv2.imshow('TensorRT', trt_visualize)
            cv2.waitKey()

        np.testing.assert_allclose(onnx_output,
                                   trt_output,
                                   rtol=1e-03,
                                   atol=1e-05)
        print('TensorRT and ONNXRuntime output all close.')
Esempio n. 11
0
def test_detector_wrapper():
    try:
        import onnxruntime as ort  # noqa: F401
        import tensorrt as trt
        from mmcv.tensorrt import (onnx2trt, save_trt_engine)
    except ImportError:
        pytest.skip('ONNXRuntime or TensorRT is not available.')

    cfg = dict(
        model=dict(
            type='DBNet',
            backbone=dict(
                type='ResNet',
                depth=18,
                num_stages=4,
                out_indices=(0, 1, 2, 3),
                frozen_stages=-1,
                norm_cfg=dict(type='BN', requires_grad=True),
                init_cfg=dict(
                    type='Pretrained', checkpoint='torchvision://resnet18'),
                norm_eval=False,
                style='caffe'),
            neck=dict(
                type='FPNC',
                in_channels=[64, 128, 256, 512],
                lateral_channels=256),
            bbox_head=dict(
                type='DBHead',
                text_repr_type='quad',
                in_channels=256,
                loss=dict(type='DBLoss', alpha=5.0, beta=10.0,
                          bbce_loss=True)),
            train_cfg=None,
            test_cfg=None))

    cfg = mmcv.Config(cfg)

    pytorch_model = build_detector(cfg.model, None, None)

    # prepare data
    inputs = torch.rand(1, 3, 224, 224)
    img_metas = [{
        'img_shape': [1, 3, 224, 224],
        'ori_shape': [1, 3, 224, 224],
        'pad_shape': [1, 3, 224, 224],
        'filename': None,
        'scale_factor': np.array([1, 1, 1, 1])
    }]

    pytorch_model.forward = pytorch_model.forward_dummy
    with tempfile.TemporaryDirectory() as tmpdirname:
        onnx_path = f'{tmpdirname}/tmp.onnx'
        with torch.no_grad():
            torch.onnx.export(
                pytorch_model,
                inputs,
                onnx_path,
                input_names=['input'],
                output_names=['output'],
                export_params=True,
                keep_initializers_as_inputs=False,
                verbose=False,
                opset_version=11)

        # TensorRT part
        def get_GiB(x: int):
            """return x GiB."""
            return x * (1 << 30)

        trt_path = onnx_path.replace('.onnx', '.trt')
        min_shape = [1, 3, 224, 224]
        max_shape = [1, 3, 224, 224]
        # create trt engine and wrapper
        opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
        max_workspace_size = get_GiB(1)
        trt_engine = onnx2trt(
            onnx_path,
            opt_shape_dict,
            log_level=trt.Logger.ERROR,
            fp16_mode=False,
            max_workspace_size=max_workspace_size)
        save_trt_engine(trt_engine, trt_path)
        print(f'Successfully created TensorRT engine: {trt_path}')

        wrap_onnx = ONNXRuntimeDetector(onnx_path, cfg, 0)
        wrap_trt = TensorRTDetector(trt_path, cfg, 0)

    assert isinstance(wrap_onnx, ONNXRuntimeDetector)
    assert isinstance(wrap_trt, TensorRTDetector)

    with torch.no_grad():
        onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
        trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)

    assert isinstance(onnx_outputs[0], dict)
    assert isinstance(trt_outputs[0], dict)
    assert 'boundary_result' in onnx_outputs[0]
    assert 'boundary_result' in trt_outputs[0]
Esempio n. 12
0
def test_recognizer_wrapper():
    try:
        import onnxruntime as ort  # noqa: F401
        import tensorrt as trt
        from mmcv.tensorrt import (onnx2trt, save_trt_engine)
    except ImportError:
        pytest.skip('ONNXRuntime or TensorRT is not available.')

    cfg = dict(
        label_convertor=dict(
            type='CTCConvertor',
            dict_type='DICT36',
            with_unknown=False,
            lower=True),
        model=dict(
            type='CRNNNet',
            preprocessor=None,
            backbone=dict(
                type='VeryDeepVgg', leaky_relu=False, input_channels=1),
            encoder=None,
            decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
            loss=dict(type='CTCLoss'),
            label_convertor=dict(
                type='CTCConvertor',
                dict_type='DICT36',
                with_unknown=False,
                lower=True),
            pretrained=None),
        train_cfg=None,
        test_cfg=None)

    cfg = mmcv.Config(cfg)

    pytorch_model = build_detector(cfg.model, None, None)

    # prepare data
    inputs = torch.rand(1, 1, 32, 32)
    img_metas = [{
        'img_shape': [1, 1, 32, 32],
        'ori_shape': [1, 1, 32, 32],
        'pad_shape': [1, 1, 32, 32],
        'filename': None,
        'scale_factor': np.array([1, 1, 1, 1])
    }]

    pytorch_model.forward = partial(
        pytorch_model.forward,
        img_metas=img_metas,
        return_loss=False,
        rescale=True)
    with tempfile.TemporaryDirectory() as tmpdirname:
        onnx_path = f'{tmpdirname}/tmp.onnx'
        with torch.no_grad():
            torch.onnx.export(
                pytorch_model,
                inputs,
                onnx_path,
                input_names=['input'],
                output_names=['output'],
                export_params=True,
                keep_initializers_as_inputs=False,
                verbose=False,
                opset_version=11)

        # TensorRT part
        def get_GiB(x: int):
            """return x GiB."""
            return x * (1 << 30)

        trt_path = onnx_path.replace('.onnx', '.trt')
        min_shape = [1, 1, 32, 32]
        max_shape = [1, 1, 32, 32]
        # create trt engine and wrapper
        opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
        max_workspace_size = get_GiB(1)
        trt_engine = onnx2trt(
            onnx_path,
            opt_shape_dict,
            log_level=trt.Logger.ERROR,
            fp16_mode=False,
            max_workspace_size=max_workspace_size)
        save_trt_engine(trt_engine, trt_path)
        print(f'Successfully created TensorRT engine: {trt_path}')

        wrap_onnx = ONNXRuntimeRecognizer(onnx_path, cfg, 0)
        wrap_trt = TensorRTRecognizer(trt_path, cfg, 0)

    assert isinstance(wrap_onnx, ONNXRuntimeRecognizer)
    assert isinstance(wrap_trt, TensorRTRecognizer)

    with torch.no_grad():
        onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
        trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)

    assert isinstance(onnx_outputs[0], dict)
    assert isinstance(trt_outputs[0], dict)
    assert 'text' in onnx_outputs[0]
    assert 'text' in trt_outputs[0]
Esempio n. 13
0
def onnx2tensorrt(onnx_file,
                  trt_file,
                  input_config,
                  verify=False,
                  show=False,
                  dataset='coco',
                  workspace_size=1,
                  verbose=False):
    import tensorrt as trt
    onnx_model = onnx.load(onnx_file)
    input_shape = input_config['input_shape']
    # create trt engine and wraper
    opt_shape_dict = {'input': [input_shape, input_shape, input_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_model,
        opt_shape_dict,
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
        fp16_mode=False,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        one_img, one_meta = preprocess_example_input(input_config)
        input_img_cpu = one_img.detach().cpu().numpy()
        input_img_cuda = one_img.cuda()
        img = one_meta['show_img']

        # Get results from ONNXRuntime
        ort_custom_op_path = get_onnxruntime_op_path()
        session_options = ort.SessionOptions()
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        sess = ort.InferenceSession(onnx_file, session_options)
        output_names = [_.name for _ in sess.get_outputs()]
        ort_outputs = sess.run(None, {
            'input': input_img_cpu,
        })
        with_mask = len(output_names) == 3
        ort_outputs = [_.squeeze(0) for _ in ort_outputs]
        ort_dets, ort_labels = ort_outputs[:2]
        ort_masks = ort_outputs[2] if with_mask else None
        ort_shapes = [_.shape for _ in ort_outputs]
        print(f'ONNX Runtime output names: {output_names}, \
            output shapes: {ort_shapes}')

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, ['input'], output_names)
        with torch.no_grad():
            trt_outputs = trt_model({'input': input_img_cuda})
        trt_outputs = [
            trt_outputs[_].detach().cpu().numpy().squeeze(0)
            for _ in output_names
        ]
        trt_dets, trt_labels = trt_outputs[:2]
        trt_shapes = [_.shape for _ in trt_outputs]
        print(f'TensorRT output names: {output_names}, \
            output shapes: {trt_shapes}')
        trt_masks = trt_outputs[2] if with_mask else None

        # Show detection outputs
        if show:
            CLASSES = get_classes(dataset)
            score_thr = 0.35
            imshow_det_bboxes(img.copy(),
                              trt_dets,
                              trt_labels,
                              segms=trt_masks,
                              class_names=CLASSES,
                              score_thr=score_thr,
                              win_name='TensorRT')
            imshow_det_bboxes(img.copy(),
                              ort_dets,
                              ort_labels,
                              segms=ort_masks,
                              class_names=CLASSES,
                              score_thr=score_thr,
                              win_name='ONNXRuntime')
        # Compare results
        np.testing.assert_allclose(ort_dets, trt_dets, rtol=1e-03, atol=1e-05)
        np.testing.assert_allclose(ort_labels, trt_labels)
        if with_mask:
            np.testing.assert_allclose(ort_masks,
                                       trt_masks,
                                       rtol=1e-03,
                                       atol=1e-05)
        print('The numerical values are the same ' +
              'between ONNXRuntime and TensorRT')
def onnx2tensorrt(onnx_file,
                  trt_file,
                  input_config,
                  verify=False,
                  show=False,
                  workspace_size=1,
                  verbose=False):
    import tensorrt as trt
    onnx_model = onnx.load(onnx_file)
    max_shape = input_config['max_shape']
    min_shape = input_config['min_shape']
    opt_shape = input_config['opt_shape']
    fp16_mode = False
    # create trt engine and wraper
    opt_shape_dict = {'input': [min_shape, opt_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_model,
        opt_shape_dict,
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
        fp16_mode=fp16_mode,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        # prepare input
        one_img, one_meta = preprocess_example_input(input_config)
        img_list, img_meta_list = [one_img], [[one_meta]]
        img_list = [_.cuda().contiguous() for _ in img_list]

        # wrap ONNX and TensorRT model
        onnx_model = ONNXRuntimeDetector(onnx_file, CLASSES, device_id=0)
        trt_model = TensorRTDetector(trt_file, CLASSES, device_id=0)

        # inference with wrapped model
        with torch.no_grad():
            onnx_results = onnx_model(img_list,
                                      img_metas=img_meta_list,
                                      return_loss=False)[0]
            trt_results = trt_model(img_list,
                                    img_metas=img_meta_list,
                                    return_loss=False)[0]

        if show:
            out_file_ort, out_file_trt = None, None
        else:
            out_file_ort, out_file_trt = 'show-ort.png', 'show-trt.png'
        show_img = one_meta['show_img']
        score_thr = 0.3
        onnx_model.show_result(show_img,
                               onnx_results,
                               score_thr=score_thr,
                               show=True,
                               win_name='ONNXRuntime',
                               out_file=out_file_ort)
        trt_model.show_result(show_img,
                              trt_results,
                              score_thr=score_thr,
                              show=True,
                              win_name='TensorRT',
                              out_file=out_file_trt)
        with_mask = trt_model.with_masks
        # compare a part of result
        if with_mask:
            compare_pairs = list(zip(onnx_results, trt_results))
        else:
            compare_pairs = [(onnx_results, trt_results)]
        err_msg = 'The numerical values are different between Pytorch' + \
                  ' and ONNX, but it does not necessarily mean the' + \
                  ' exported ONNX model is problematic.'
        # check the numerical value
        for onnx_res, pytorch_res in compare_pairs:
            for o_res, p_res in zip(onnx_res, pytorch_res):
                np.testing.assert_allclose(o_res,
                                           p_res,
                                           rtol=1e-03,
                                           atol=1e-05,
                                           err_msg=err_msg)
        print('The numerical values are the same between Pytorch and ONNX')
def onnx2tensorrt(onnx_file,
                  trt_file,
                  input_config,
                  verify=False,
                  show=False,
                  dataset='coco',
                  workspace_size=1):
    onnx_model = onnx.load(onnx_file)
    input_shape = input_config['input_shape']
    # create trt engine and wraper
    opt_shape_dict = {'input': [input_shape, input_shape, input_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=False,
                          max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        one_img, one_meta = preprocess_example_input(input_config)
        input_img_cpu = one_img.detach().cpu().numpy()
        input_img_cuda = one_img.cuda()

        img = one_meta['show_img']

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, ['input'], ['boxes', 'labels'])
        with torch.no_grad():
            trt_outputs = trt_model({'input': input_img_cuda})
        trt_boxes = trt_outputs['boxes'].detach().cpu().numpy()
        trt_labels = trt_outputs['labels'].detach().cpu().numpy()

        # Get results from ONNXRuntime
        ort_custom_op_path = get_onnxruntime_op_path()
        session_options = ort.SessionOptions()
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        sess = ort.InferenceSession(onnx_file, session_options)
        onnx_outputs = sess.run(None, {
            'input': input_img_cpu,
        })
        ort_boxes, ort_labels = onnx_outputs

        # Show detection outputs
        if show:
            CLASSES = get_classes(dataset)
            score_thr = 0.35
            imshow_det_bboxes(img.copy(),
                              trt_boxes,
                              trt_labels,
                              CLASSES,
                              score_thr=score_thr,
                              win_name='TensorRT')
            imshow_det_bboxes(img.copy(),
                              ort_boxes,
                              ort_labels,
                              CLASSES,
                              score_thr=score_thr,
                              win_name='ONNXRuntime')
        # Compare results
        np.testing.assert_allclose(ort_boxes,
                                   trt_boxes,
                                   rtol=1e-03,
                                   atol=1e-05)
        np.testing.assert_allclose(ort_labels, trt_labels)
        print('The numerical values are the same ' +
              'between ONNXRuntime and TensorRT')
Esempio n. 16
0
def onnx2tensorrt(onnx_file: str,
                  trt_file: str,
                  config: dict,
                  input_config: dict,
                  fp16: bool = False,
                  verify: bool = False,
                  show: bool = False,
                  dataset: str = 'CityscapesDataset',
                  workspace_size: int = 1,
                  verbose: bool = False):
    import tensorrt as trt
    min_shape = input_config['min_shape']
    max_shape = input_config['max_shape']
    # create trt engine and wrapper
    opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_file,
        opt_shape_dict,
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
        fp16_mode=fp16,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        inputs = _prepare_input_img(input_config['input_path'],
                                    config.data.test.pipeline,
                                    shape=min_shape[2:])

        imgs = inputs['imgs']
        img_metas = inputs['img_metas']
        img_list = [img[None, :] for img in imgs]
        img_meta_list = [[img_meta] for img_meta in img_metas]
        # update img_meta
        img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

        if max_shape[0] > 1:
            # concate flip image for batch test
            flip_img_list = [_.flip(-1) for _ in img_list]
            img_list = [
                torch.cat((ori_img, flip_img), 0)
                for ori_img, flip_img in zip(img_list, flip_img_list)
            ]

        # Get results from ONNXRuntime
        ort_custom_op_path = get_onnxruntime_op_path()
        session_options = ort.SessionOptions()
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        sess = ort.InferenceSession(onnx_file, session_options)
        sess.set_providers(['CPUExecutionProvider'], [{}])  # use cpu mode
        onnx_output = sess.run(['output'],
                               {'input': img_list[0].detach().numpy()})[0][0]

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, ['input'], ['output'])
        with torch.no_grad():
            trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
        trt_output = trt_outputs['output'][0].cpu().detach().numpy()

        if show:
            dataset = DATASETS.get(dataset)
            assert dataset is not None
            palette = dataset.PALETTE

            show_result_pyplot(input_config['input_path'],
                               (onnx_output[0].astype(np.uint8), ),
                               palette=palette,
                               title='ONNXRuntime',
                               block=False)
            show_result_pyplot(input_config['input_path'],
                               (trt_output[0].astype(np.uint8), ),
                               palette=palette,
                               title='TensorRT')

        np.testing.assert_allclose(onnx_output,
                                   trt_output,
                                   rtol=1e-03,
                                   atol=1e-05)
        print('TensorRT and ONNXRuntime output all close.')
Esempio n. 17
0
def onnx2tensorrt(onnx_file,
                  trt_file,
                  input_shape,
                  max_batch_size,
                  fp16_mode=False,
                  verify=False,
                  workspace_size=1):
    """Create tensorrt engine from onnx model.

    Args:
        onnx_file (str): Filename of the input ONNX model file.
        trt_file (str): Filename of the output TensorRT engine file.
        input_shape (list[int]): Input shape of the model.
            eg [1, 3, 224, 224].
        max_batch_size (int): Max batch size of the model.
        verify (bool, optional): Whether to verify the converted model.
            Defaults to False.
        workspace_size (int, optional): Maximium workspace of GPU.
            Defaults to 1.
    """
    import onnx
    from mmcv.tensorrt import TRTWraper, onnx2trt, save_trt_engine

    onnx_model = onnx.load(onnx_file)
    # create trt engine and wraper
    assert max_batch_size >= 1
    max_shape = [max_batch_size] + list(input_shape[1:])
    opt_shape_dict = {'input': [input_shape, input_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=fp16_mode,
                          max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        import torch
        import onnxruntime as ort

        input_img = torch.randn(*input_shape)
        input_img_cpu = input_img.detach().cpu().numpy()
        input_img_cuda = input_img.cuda()

        # Get results from ONNXRuntime
        session_options = ort.SessionOptions()
        sess = ort.InferenceSession(onnx_file, session_options)

        # get input and output names
        input_names = [_.name for _ in sess.get_inputs()]
        output_names = [_.name for _ in sess.get_outputs()]

        onnx_outputs = sess.run(None, {
            input_names[0]: input_img_cpu,
        })

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, input_names, output_names)
        with torch.no_grad():
            trt_outputs = trt_model({input_names[0]: input_img_cuda})
        trt_outputs = [
            trt_outputs[_].detach().cpu().numpy() for _ in output_names
        ]

        # Compare results
        np.testing.assert_allclose(onnx_outputs[0],
                                   trt_outputs[0],
                                   rtol=1e-05,
                                   atol=1e-05)
        print('The numerical values are the same ' +
              'between ONNXRuntime and TensorRT')
Esempio n. 18
0
def test_deform_conv():
    try:
        from mmcv.ops import DeformConv2dPack
    except (ImportError, ModuleNotFoundError):
        pytest.skip('test requires compilation')

    input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
    offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
                     [[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]],
                     [[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]],
                     [[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]]
    offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7]
    deform_weight = [[[0.4, 0.2, 0.1, 0.9]]]

    c_in = 1
    c_out = 1
    x = torch.Tensor(input).cuda()
    x.requires_grad = True
    model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
    model.conv_offset.weight.data = torch.nn.Parameter(
        torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
    model.conv_offset.bias.data = torch.nn.Parameter(
        torch.Tensor(offset_bias).reshape(8))
    model.weight.data = torch.nn.Parameter(
        torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
    model.cuda().eval()

    input_names = ['input']
    output_names = ['output']

    with torch.no_grad():
        torch.onnx.export(model, (x.clone(), ),
                          onnx_file,
                          export_params=True,
                          keep_initializers_as_inputs=True,
                          input_names=input_names,
                          output_names=output_names,
                          opset_version=11)

    onnx_model = onnx.load(onnx_file)

    # create trt engine and wrapper
    opt_shape_dict = {
        'input': [list(x.shape), list(x.shape),
                  list(x.shape)],
    }
    # trt config
    fp16_mode = False
    max_workspace_size = 1 << 30

    trt_engine = onnx2trt(onnx_model,
                          opt_shape_dict,
                          fp16_mode=fp16_mode,
                          max_workspace_size=max_workspace_size)

    save_trt_engine(trt_engine, trt_file)
    trt_model = TRTWrapper(trt_file, input_names, output_names)

    with torch.no_grad():
        trt_outputs = trt_model({'input': x.clone()})
        trt_results = trt_outputs['output']

    # compute pytorch_output
    with torch.no_grad():
        pytorch_results = model(x.clone())

    # allclose
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    if os.path.exists(trt_file):
        os.remove(trt_file)
    assert torch.allclose(pytorch_results, trt_results)