def test_custom_classes_override_default(dataset, classes):

    dataset_class = DATASETS.get(dataset)

    original_classes = dataset_class.CLASSES

    # Test setting classes as a tuple
    custom_dataset = dataset_class(
        pipeline=[],
        img_dir=MagicMock(),
        split=MagicMock(),
        classes=classes,
        test_mode=True)

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == classes

    # Test setting classes as a list
    custom_dataset = dataset_class(
        pipeline=[],
        img_dir=MagicMock(),
        split=MagicMock(),
        classes=list(classes),
        test_mode=True)

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == list(classes)

    # Test overriding not a subset
    custom_dataset = dataset_class(
        pipeline=[],
        img_dir=MagicMock(),
        split=MagicMock(),
        classes=[classes[0]],
        test_mode=True)

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == [classes[0]]

    # Test default behavior
    if dataset_class is CustomDataset:
        with pytest.raises(AssertionError):
            custom_dataset = dataset_class(
                pipeline=[],
                img_dir=MagicMock(),
                split=MagicMock(),
                classes=None,
                test_mode=True)
    else:
        custom_dataset = dataset_class(
            pipeline=[],
            img_dir=MagicMock(),
            split=MagicMock(),
            classes=None,
            test_mode=True)

        assert custom_dataset.CLASSES == original_classes
コード例 #2
0
def onnx2tensorrt(onnx_file: str,
                  trt_file: str,
                  config: dict,
                  input_config: dict,
                  fp16: bool = False,
                  verify: bool = False,
                  show: bool = False,
                  dataset: str = 'CityscapesDataset',
                  workspace_size: int = 1,
                  verbose: bool = False):
    import tensorrt as trt
    min_shape = input_config['min_shape']
    max_shape = input_config['max_shape']
    # create trt engine and wrapper
    opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_file,
        opt_shape_dict,
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
        fp16_mode=fp16,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        inputs = _prepare_input_img(input_config['input_path'],
                                    config.data.test.pipeline,
                                    shape=min_shape[2:])

        imgs = inputs['imgs']
        img_metas = inputs['img_metas']
        img_list = [img[None, :] for img in imgs]
        img_meta_list = [[img_meta] for img_meta in img_metas]
        # update img_meta
        img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

        if max_shape[0] > 1:
            # concate flip image for batch test
            flip_img_list = [_.flip(-1) for _ in img_list]
            img_list = [
                torch.cat((ori_img, flip_img), 0)
                for ori_img, flip_img in zip(img_list, flip_img_list)
            ]

        # Get results from ONNXRuntime
        ort_custom_op_path = get_onnxruntime_op_path()
        session_options = ort.SessionOptions()
        if osp.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)
        sess = ort.InferenceSession(onnx_file, session_options)
        sess.set_providers(['CPUExecutionProvider'], [{}])  # use cpu mode
        onnx_output = sess.run(['output'],
                               {'input': img_list[0].detach().numpy()})[0][0]

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, ['input'], ['output'])
        with torch.no_grad():
            trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
        trt_output = trt_outputs['output'][0].cpu().detach().numpy()

        if show:
            dataset = DATASETS.get(dataset)
            assert dataset is not None
            palette = dataset.PALETTE

            show_result_pyplot(input_config['input_path'],
                               (onnx_output[0].astype(np.uint8), ),
                               palette=palette,
                               title='ONNXRuntime',
                               block=False)
            show_result_pyplot(input_config['input_path'],
                               (trt_output[0].astype(np.uint8), ),
                               palette=palette,
                               title='TensorRT')

        np.testing.assert_allclose(onnx_output,
                                   trt_output,
                                   rtol=1e-03,
                                   atol=1e-05)
        print('TensorRT and ONNXRuntime output all close.')
コード例 #3
0

if __name__ == '__main__':

    assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
    args = parse_args()

    if not args.input_img:
        args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')

    # check arguments
    assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
    assert osp.exists(args.model), \
        'ONNX model {} not found.'.format(args.model)
    assert args.workspace_size >= 0, 'Workspace size less than 0.'
    assert DATASETS.get(args.dataset) is not None, \
        'Dataset {} does not found.'.format(args.dataset)
    for max_value, min_value in zip(args.max_shape, args.min_shape):
        assert max_value >= min_value, \
            'max_shape should be larger than min shape'

    input_config = {
        'min_shape': args.min_shape,
        'max_shape': args.max_shape,
        'input_path': args.input_img
    }

    cfg = mmcv.Config.fromfile(args.config)
    onnx2tensorrt(args.model,
                  args.trt_file,
                  cfg,