Esempio n. 1
0
def get_engine(
    onnx_file_path,
    engine_file_path,
    convert_mode,
    dynamic_shapes=False,
    max_batch_size=1,
    calibrator=None,
):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        console.print(f"Reading engine from file {engine_file_path}",
                      style='info')
        with open(engine_file_path,
                  "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine(
            onnx_file_path,
            engine_file_path,
            convert_mode,
            dynamic_shapes,
            max_batch_size,
            calibrator,
        )
Esempio n. 2
0
def main_convert(opts):
    assert os.path.exists(opts['onnx_file']), 'ONNX File Not Found!.'
    console.print(json.dumps(opts, indent=4, ensure_ascii=False))
    calibrator = None
    if opts['quant_mode'] == 'int8':
        calibrator = build_calibrator(
            opts['calibrator_type'],
            build_batches_generator(
                file_list=opts['file_list'],
                bchw_shape=opts['bchw_shape'],
                mean_value=opts['mean_value'],
                std_value=opts['std_value'],
                resize_type=opts['resize_type'],
                border_value=opts['border_value'],
                channel_order=opts['channel_order'],
                channel_first=opts['channel_first'],
            ), opts['bchw_shape'])
    build_engine(
        opts['onnx_file'],
        opts['engine_file'],
        opts['quant_mode'],
        opts['shape_dynamic'],
        opts['bchw_shape'][0],
        calibrator,
    )
Esempio n. 3
0
def download_from_json(jsfile: str, branch: str):
    import re
    jsfile = Path(jsfile)
    assert jsfile.exists() and branch

    with jsfile.open('r') as jf:
        json_str = ''
        for line in jf.readlines():
            if not re.match(r'\s*//', line) and not re.match(r'\s*\n', line):
                xline = XSTR(line)
                json_str += xline.rmCmt()
        cfg = json.loads(json_str)
        console.print(json.dumps(cfg, indent=4, ensure_ascii=False))

        for k, v in cfg.items():
            try:
                if branch not in v['custom_params']:
                    console.print('{} not in {}'.format(branch, k))
                    continue
                url = v['custom_params'][branch]
                dst = jsfile.parent / v['custom_params']['model_path']

                if not dst.parent.exists():
                    os.makedirs(str(dst.parent.absolute()))

                download_from_url(
                    url, dst, not bool(v['custom_params']['model_encrypted']))
            except Exception:
                console.print_exception()
Esempio n. 4
0
def build_batches_generator(
        file_list,
        bchw_shape,
        mean_value,
        std_value,
        resize_type='centerpad_resize',
        border_value=(0, 0, 0),
        channel_order="bgr",
        channel_first=True,
):
    '''
    mean: [0.5, 0.5, 0.5] bgr format
    variance: [0.5, 0.5, 0.5] bgr format
    '''
    channel_order = channel_order.lower()
    assert channel_order in ['bgr', 'rgb']
    with open(file_list, 'r') as fin:
        lines = fin.readlines()

    batches = []
    for line in lines:
        try:
            line = line.strip()
            image = cv2.imread(line)
            image = eval(resize_type)(image, bchw_shape, border_value)
            image = image.astype(np.float32) / 255.

            # sub mean div variance
            image -= list(mean_value)
            image /= list(std_value)

            if channel_order == 'rgb':
                image = image[..., ::-1]

            if channel_first:
                image = image.transpose(2, 0, 1)

            batches.append(image)
            if len(batches) == bchw_shape[0]:
                yield np.ascontiguousarray(image, dtype=np.float32)

        except Exception as e:
            console.print(e, type='danger')
Esempio n. 5
0
def _download(url: str, dst: Path):
    try:
        assert not dst.is_dir()
        file_size = int(urlopen(url).info().get('Content-Length', -1))
        first_byte = os.path.getsize(str(dst)) if (dst.exists()) else 0
        headers = {"Range": "bytes={}-{}".format(first_byte, file_size)}
        req = requests.get(url, headers=headers, stream=True)
        chunk_size = 1024
        with Progress() as progress:
            console.print('[green]GET [yellow]{}'.format(url))
            task = progress.add_task("[bold blue][Downloading...]",
                                     total=file_size)
            progress.update(task, advance=first_byte)
            with dst.open('ab') as fw:
                for i, chunk in enumerate(req.iter_content(chunk_size)):
                    if chunk:
                        fw.write(chunk)
                    progress.update(task, advance=chunk_size)
    except Exception:
        console.print_exception()
        return False
    return True
Esempio n. 6
0
def upload(src_path, encrypt=False, scene='', tag=''):
    src_path = Path(src_path)
    url = '{}:{}/{}/upload'.format(CONFIG['fs_ip'], CONFIG['fs_port'],
                                   CONFIG['fs_group'])
    data = {'output': 'json', 'path': '', 'scene': scene}
    if src_path.is_dir():
        for path in src_path.rglob('*'):
            upload(path)
    else:
        if encrypt:
            with do_encrypt(src_path, CONFIG['ed_secret']) as ef:
                r = requests.post(url=url,
                                  data=data,
                                  files={'file': open(ef, 'rb')})
        else:
            r = requests.post(url=url,
                              data=data,
                              files={'file': open(src_path, 'rb')})
        r = r.json()
        assert r
        r.update({'tag': tag, 'name': src_path.name, 'utime': time.time()})
        insert_one(copy.deepcopy(r))
        console.print(json.dumps(r, indent=4, ensure_ascii=False))
Esempio n. 7
0
def build_engine(
    onnx_file_path,
    engine_file_path,
    convert_mode,
    dynamic_shapes=False,
    max_batch_size=1,
    calibrator=None,
):
    """Takes an ONNX file and creates a TensorRT engine to run inference with"""
    convert_mode = convert_mode.lower()
    assert convert_mode in ['fp32', 'fp16', 'int8'
                            ], 'mode should be in ["fp32", "fp16", "int8"]'

    explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.
                                EXPLICIT_BATCH) if dynamic_shapes else 0
    with trt.Builder(TRT_LOGGER) as builder, \
            builder.create_network(explicit_batch) as network, \
            trt.OnnxParser(network, TRT_LOGGER) as parser:

        builder.max_workspace_size = 1 << 31
        builder.max_batch_size = max_batch_size
        builder.strict_type_constraints = True

        if convert_mode == 'int8':
            assert (builder.platform_has_fast_int8 == True
                    ), 'platform not support int8'
            builder.int8_mode = True
            assert (calibrator is not None)
            builder.int8_calibrator = calibrator
        elif convert_mode == 'fp16':
            assert (builder.platform_has_fast_fp16 == True
                    ), 'platform not support fp16'
            builder.fp16_mode = True

        # Parse model file
        if not os.path.exists(onnx_file_path):
            console.print(
                f'ONNX file {onnx_file_path} not found, please run yolov3_to_onnx.py first to generate it.',
                style='danger')
            exit(0)

        console.print(f'Loading ONNX file from path {onnx_file_path}...',
                      style='info')
        with open(onnx_file_path, 'rb') as model:
            console.print('Beginning ONNX file parsing', style='info')
            ret = parser.parse(model.read())
            if not ret:
                console.print("Parser ONNX model failed.", style='danger')
                console.print(parser.get_error(0), style='danger')
                exit(0)

        console.print("Completed parsing of ONNX file", style='info')
        console.print(
            f"Building an engine from file {onnx_file_path}; this may take a while...",
            style='info')

        engine = builder.build_cuda_engine(network)
        if engine is None:
            console.print("Creating engine failed.", style='danger')
            exit(0)

        console.print("Completed creating Engine", style='info')
        with open(engine_file_path, "wb") as f:
            f.write(engine.serialize())
        return engine
Esempio n. 8
0
def show():
    console.print(json.dumps(CONFIG, indent=4, ensure_ascii=False))