def test_custom_classes_override_default(dataset, classes): dataset_class = DATASETS.get(dataset) original_classes = dataset_class.CLASSES # Test setting classes as a tuple custom_dataset = dataset_class( pipeline=[], img_dir=MagicMock(), split=MagicMock(), classes=classes, test_mode=True) assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == classes # Test setting classes as a list custom_dataset = dataset_class( pipeline=[], img_dir=MagicMock(), split=MagicMock(), classes=list(classes), test_mode=True) assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == list(classes) # Test overriding not a subset custom_dataset = dataset_class( pipeline=[], img_dir=MagicMock(), split=MagicMock(), classes=[classes[0]], test_mode=True) assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == [classes[0]] # Test default behavior if dataset_class is CustomDataset: with pytest.raises(AssertionError): custom_dataset = dataset_class( pipeline=[], img_dir=MagicMock(), split=MagicMock(), classes=None, test_mode=True) else: custom_dataset = dataset_class( pipeline=[], img_dir=MagicMock(), split=MagicMock(), classes=None, test_mode=True) assert custom_dataset.CLASSES == original_classes
def onnx2tensorrt(onnx_file: str, trt_file: str, config: dict, input_config: dict, fp16: bool = False, verify: bool = False, show: bool = False, dataset: str = 'CityscapesDataset', workspace_size: int = 1, verbose: bool = False): import tensorrt as trt min_shape = input_config['min_shape'] max_shape = input_config['max_shape'] # create trt engine and wrapper opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} max_workspace_size = get_GiB(workspace_size) trt_engine = onnx2trt( onnx_file, opt_shape_dict, log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, fp16_mode=fp16, max_workspace_size=max_workspace_size) save_dir, _ = osp.split(trt_file) if save_dir: os.makedirs(save_dir, exist_ok=True) save_trt_engine(trt_engine, trt_file) print(f'Successfully created TensorRT engine: {trt_file}') if verify: inputs = _prepare_input_img(input_config['input_path'], config.data.test.pipeline, shape=min_shape[2:]) imgs = inputs['imgs'] img_metas = inputs['img_metas'] img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta img_list, img_meta_list = _update_input_img(img_list, img_meta_list) if max_shape[0] > 1: # concate flip image for batch test flip_img_list = [_.flip(-1) for _ in img_list] img_list = [ torch.cat((ori_img, flip_img), 0) for ori_img, flip_img in zip(img_list, flip_img_list) ] # Get results from ONNXRuntime ort_custom_op_path = get_onnxruntime_op_path() session_options = ort.SessionOptions() if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) sess = ort.InferenceSession(onnx_file, session_options) sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode onnx_output = sess.run(['output'], {'input': img_list[0].detach().numpy()})[0][0] # Get results from TensorRT trt_model = TRTWraper(trt_file, ['input'], ['output']) with torch.no_grad(): trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()}) trt_output = trt_outputs['output'][0].cpu().detach().numpy() if show: dataset = DATASETS.get(dataset) assert dataset is not None palette = dataset.PALETTE show_result_pyplot(input_config['input_path'], (onnx_output[0].astype(np.uint8), ), palette=palette, title='ONNXRuntime', block=False) show_result_pyplot(input_config['input_path'], (trt_output[0].astype(np.uint8), ), palette=palette, title='TensorRT') np.testing.assert_allclose(onnx_output, trt_output, rtol=1e-03, atol=1e-05) print('TensorRT and ONNXRuntime output all close.')
if __name__ == '__main__': assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' args = parse_args() if not args.input_img: args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png') # check arguments assert osp.exists(args.config), 'Config {} not found.'.format(args.config) assert osp.exists(args.model), \ 'ONNX model {} not found.'.format(args.model) assert args.workspace_size >= 0, 'Workspace size less than 0.' assert DATASETS.get(args.dataset) is not None, \ 'Dataset {} does not found.'.format(args.dataset) for max_value, min_value in zip(args.max_shape, args.min_shape): assert max_value >= min_value, \ 'max_shape should be larger than min shape' input_config = { 'min_shape': args.min_shape, 'max_shape': args.max_shape, 'input_path': args.input_img } cfg = mmcv.Config.fromfile(args.config) onnx2tensorrt(args.model, args.trt_file, cfg,