Esempio n. 1
0
def run_tnn_model_check(proto_path,
                        model_path,
                        input_path,
                        reference_output_path,
                        is_tflite=False,
                        align_batch=False):
    cmd.run("pwd")
    relative_path = "bin/model_check"
    model_check_path = parse_path.parse_path(relative_path)
    checker.check_file_exist(model_check_path)
    command = model_check_path + " -e -p  " + proto_path + " -m " + \
        model_path + " -i " + input_path + " -f " + reference_output_path + " -d NAIVE"

    if align_batch:
        command += " -b "

    logging.debug(command)
    ret = cmd.run(command)

    if ret == 0:
        print_align_message(is_tflite)
    else:
        print_not_align_message(None, is_tflite)

    return
Esempio n. 2
0
def convert(tf_path,  output_dir, version, half, align=False,
            input_path=None, refer_path=None, debug_mode: bool = False):
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".tflite")]
    if tflite2tnn(tf_path, output_dir, half) is False:
        logging.error("Oh No, tflite2tnn failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Convert TensorFlowLite to TNN model succeed!\n")

    if version is None:
        version = "v1.0"
    if align == 'output':
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        tnn_proto_name = model_name + proto_suffix
        tnn_model_name = model_name + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(tf_path, tnn_proto_path, tnn_model_path, input_path, refer_path, None, True,
                                debug_mode=debug_mode)
Esempio n. 3
0
File: data.py Progetto: yfz912/TNN
def gene_random_data(input_info: dict) -> str:
    data = {}
    current_dir = pathlib.Path(__file__).parent.parent
    data_dir = os.path.join(current_dir, "temp_data")
    command = "mkdir -p " + data_dir

    logging.debug(command)

    cmd.run(command)
    checker.check_file_exist(data_dir)
    data_path = os.path.join(data_dir, "input.txt")
    data_file = open(data_path, "w")
    data_file.write(str(len(input_info)) + '\n')
    for name, info in input_info.items():
        shape = info['shape']
        data_type = info['data_type']
        data_file.write(name + ' ' + str(len(shape)) + ' ' +
                        ' '.join([str(dim) for dim in shape]) + ' ' +
                        str(data_type) + '\n')
        if data_type == 0:
            data[name] = np.random.rand(*shape)
            np.savetxt(data_file, data[name].reshape(-1), fmt="%0.6f")
        elif data_type == 3:
            data[name] = np.random.randint(low=0, high=1, size=shape)
            np.savetxt(data_file, data[name].reshape(-1), fmt="%i")
    data_file.close()
    return data_path
Esempio n. 4
0
def align_model(onnx_path: str, tnn_proto_path: str, tnn_model_path: str, input_file_path: str=None,
                refer_path: str = None, input_names: str = None, is_tflite: bool=False ) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param onnx_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    logging.info("{}  align model (tflite or ONNX vs TNN),please wait a moment {}\n" .format("-" * 10, "-" * 10))

    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)

    if input_names is not None:
        input_info = parse_input_names(input_names)
    # check input
    if input_names is not None:
        tnn_input_info = input_info
        onnx_input_info = input_info
    else:
        tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
        if is_tflite == True:
            onnx_input_info = get_input_shape_from_tflite(onnx_path)
        else:
            onnx_input_info = get_input_shape_from_onnx(onnx_path)
    if is_tflite == True:
        check_input_lite_info(onnx_input_info, tnn_input_info)
    else:
       check_input_info(onnx_input_info, tnn_input_info)
    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(onnx_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            logging.error("Invalid input_file_path")
            sys.exit(return_code.ALIGN_FAILED)
    if refer_path is None:
        if is_tflite == True:
            reference_output_path = run_tflite(onnx_path, input_path, onnx_input_info)
        else:
            reference_output_path = run_onnx(onnx_path, input_path, onnx_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            logging.error("Invalid refer_path")
            sys.exit(return_code.ALIGN_FAILED)
    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path, reference_output_path, is_tflite)
    if input_file_path is None and os.path.exists(input_path):
        data.clean_temp_data(os.path.dirname(input_path))
    if refer_path is None and os.path.exists(reference_output_path):
        data.clean_temp_data(reference_output_path)

    return True
Esempio n. 5
0
def align_model(onnx_path: str,
                tnn_proto_path: str,
                tnn_model_path: str,
                input_file_path: str = None,
                refer_path: str = None,
                input_names: str = None) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param onnx_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)

    if input_names is not None:
        input_info = parse_input_names(input_names)

    # check input
    if input_names is not None:
        tnn_input_info = input_info
        onnx_input_info = input_info
    else:
        tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
        onnx_input_info = get_input_shape_from_onnx(onnx_path)
    check_input_info(onnx_input_info, tnn_input_info)

    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(onnx_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            print("invalid input_file_path")
            exit(-1)

    if refer_path is None:
        reference_output_path = run_onnx(onnx_path, input_path,
                                         onnx_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            print("invalid refer_path")
            exit(-1)

    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path,
                        reference_output_path)

    if input_file_path is None and os.path.exists(input_path):
        data.clean_temp_data(os.path.dirname(input_path))
    if refer_path is None and os.path.exists(reference_output_path):
        data.clean_temp_data(reference_output_path)

    return True
Esempio n. 6
0
def run_tnn_model_check(proto_path, model_path, input_path,
                        reference_output_path):
    cmd.run("pwd")
    relative_path = "bin/model_check"
    model_check_path = parse_path.parse_path(relative_path)
    checker.check_file_exist(model_check_path)
    command = model_check_path + " -p  " + proto_path + " -m " + \
        model_path + " -i " + input_path + " -f " + reference_output_path + " -d NAIVE"

    print(command)
    cmd.run(command)
    return
Esempio n. 7
0
def gene_random_data(input_info: dict) -> str:
    data = {}
    current_dir = pathlib.Path(__file__).parent.parent
    data_dir = os.path.join(current_dir, "temp_data")
    command = "mkdir -p " + data_dir
    print(command)
    cmd.run("pwd")
    cmd.run(command)
    checker.check_file_exist(data_dir)
    data_path = os.path.join(data_dir, "input.txt")
    data_file = open(data_path, "w")
    for name, shape in input_info.items():
        data[name] = np.random.rand(*shape)
        np.savetxt(data_file, data[name].reshape(-1), fmt="%0.18f")
    data_file.close()
    return data_path
Esempio n. 8
0
def run_tnn_model_check(proto_path, model_path, input_path, reference_output_path):
    cmd.run("pwd")
    relative_path = "bin/model_check"
    model_check_path = parse_path.parse_path(relative_path)
    checker.check_file_exist(model_check_path)
    command = model_check_path + " -p  " + proto_path + " -m " + \
        model_path + " -i " + input_path + " -f " + reference_output_path + " -d NAIVE"

    logging.debug(command)
    ret = cmd.run(command)

    if ret == 0:
        print_align_message()
    else:
        print_not_align_message()

    return
Esempio n. 9
0
def convert(tf_path, input_names, output_names, output_dir, version, optimize, half, align=False, not_fold_const=False,
            input_path=None, refer_path=None):
    logging.info("Converter Tensorflow to TNN model\n")
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".pb")]
    onnx_path = os.path.join(output_dir, model_name + ".onnx")
    if tf2onnx(tf_path, input_names, output_names, onnx_path, not_fold_const) is False:
        logging.error("Oh No, tf2onnx failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Convert TensorFlow to ONNX model succeed!\n")
    if version is None:
        version = "v1.0"
    checker.check_file_exist(onnx_path)
    onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path, input_path, refer_path)
Esempio n. 10
0
def convert(proto_path,
            model_path,
            output_dir,
            version,
            optimize,
            half,
            align=False,
            input_path=None,
            refer_path=None):
    logging.info("Converter Caffe to ONNX Model\n")
    checker.check_file_exist(proto_path)
    checker.check_file_exist(model_path)
    if output_dir is None:
        output_dir = os.path.dirname(proto_path)
    checker.check_file_exist(output_dir)

    proto_name = os.path.basename(proto_path)
    proto_name = proto_name[:-len(".prototxt")]
    onnx_path = os.path.join(output_dir, proto_name + ".onnx")

    if caffe2onnx(proto_path, model_path, onnx_path) is False:
        logging.error("Oh No, caff2onnx failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Congratulations! caffe2onnx succeed!\n")
    if version is None:
        version = "v1.0"

    is_ssd = checker.is_ssd_model(proto_path)
    if is_ssd:
        onnx2tnn.convert(onnx_path,
                         output_dir,
                         version,
                         False,
                         half,
                         is_ssd=True)
    else:
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if is_ssd and ((input_path is None) or (refer_path is None)):
        align = False
        optimize = False

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                input_path, refer_path)
Esempio n. 11
0
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          halt:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=v1.0"
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir
    print("the onnx2tnn command:" + command)
    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)
    if result == 0:
        print("onnx2tnn succeed!")
    else:
        print("onnx2tnn failed!")
def tflite2tnn(tf_path, tnn_path, not_fold_const=False):
    cmd.run("pwd")
    relative_path = "bin/TnnConverter"
    TnnConverter_path = parse_path.parse_path(relative_path)
    checker.check_file_exist(TnnConverter_path)
    command = TnnConverter_path + " -mt TFLITE  -mp " + tf_path
    checker.check_file_exist(TnnConverter_path)
    checker.check_file_exist(tf_path)
    if tnn_path is None:
        tnn_path = os.path.dirname(tf_path)
    checker.check_file_exist(tnn_path)
    command = command + " -od " + tnn_path + "/"
    logging.debug(command)
    result = cmd.run(command)
    if result == 0:
        return True
    else:
        return False
Esempio n. 13
0
def convert(tf_path, input_names, output_names, output_dir, version, optimize, half):
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".pb")]
    onnx_path = os.path.join(output_dir, model_name + ".onnx")
    if tf2onnx(tf_path, input_names, output_names, onnx_path) is False:
        print("Oh No, tf2onnx failed")
    else:
        print("congratulations! tf2onnx succeed!")
    if version is None:
        version = "v1.0"
    checker.check_file_exist(onnx_path)
    onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)
Esempio n. 14
0
def convert(proto_path,
            model_path,
            output_dir,
            version,
            optimize,
            half,
            align=False,
            input_path=None,
            refer_path=None):
    checker.check_file_exist(proto_path)
    checker.check_file_exist(model_path)
    if output_dir is None:
        output_dir = os.path.dirname(proto_path)
    checker.check_file_exist(output_dir)
    proto_name = os.path.basename(proto_path)
    proto_name = proto_name[:-len(".prototxt")]
    onnx_path = os.path.join(output_dir, proto_name + ".onnx")
    if caffe2onnx(proto_path, model_path, onnx_path) is False:
        print("Oh No, caff2onnx failed")
    else:
        print("congratulations! caffe2onnx succeed!")
    if version is None:
        version = "v1.0"

    is_ssd = checker.is_ssd_model(proto_path)
    if is_ssd:
        onnx2tnn.convert(onnx_path, output_dir, version, False, half)
    else:
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                input_path, refer_path)
Esempio n. 15
0
def convert(tf_path,
            input_names,
            output_names,
            output_dir,
            version,
            optimize,
            half,
            align=False,
            not_fold_const=False,
            input_path=None,
            refer_path=None):
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".pb")]
    onnx_path = os.path.join(output_dir, model_name + ".onnx")
    if tf2onnx(tf_path, input_names, output_names, onnx_path,
               not_fold_const) is False:
        print("Oh No, tf2onnx failed")
    else:
        print("congratulations! tf2onnx succeed!")
    if version is None:
        version = "v1.0"
    checker.check_file_exist(onnx_path)
    onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                input_path, refer_path)
Esempio n. 16
0
def convert(proto_path, model_path, output_dir, version, optimize, half):
    checker.check_file_exist(proto_path)
    checker.check_file_exist(model_path)
    if output_dir is None:
        output_dir = os.path.dirname(proto_path)
    checker.check_file_exist(output_dir)
    proto_name = os.path.basename(proto_path)
    proto_name = proto_name[:-len(".prototxt")]
    onnx_path = os.path.join(output_dir, proto_name + ".onnx")
    if caffe2onnx(proto_path, model_path, onnx_path) is False:
        print("Oh No, caff2onnx failed")
    else:
        print("congratulations! caffe2onnx succeed!")
    if version is None:
        version = "v1.0"
    is_ssd = checker.is_ssd_model(proto_path)
    if is_ssd:
        onnx2tnn.convert(onnx_path, output_dir, version, False, half)
    else:
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)
Esempio n. 17
0
def align_model(original_model_path: str, tnn_proto_path: str, tnn_model_path: str, input_file_path: str = None,
                refer_path: str = None, specify_input_args: str = None, is_tflite: bool = False, debug_mode: bool = False, align_batch: bool = False) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param original_model_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    logging.info("{}  align model (tflite or ONNX vs TNN),please wait a moment {}\n" .format("-" * 10, "-" * 10))

    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)
    # list = {  "input name1":{
    #                           {"shape": [n, c,...]},
    #                           {"data_type": 0}
    #                        },
    #           "input name22": {
    #                            {"shape": [n, c,...]},
    #                            {"data_type": 0}
    #                         }
    # get original input info
    if is_tflite:
        original_input_info = get_input_shape_from_tflite(original_model_path)
    else:
        original_input_info = get_input_shape_from_onnx(original_model_path)
    # get tnn input info
    tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
    # check input
    if specify_input_args is not None:
        specify_input_info = parse_specify_input_args(specify_input_args)
        update_original_input_shape(original_input_info, specify_input_info)

    if is_tflite:
        check_input_lite_info(original_input_info, tnn_input_info)
    else:
       check_input_info(original_input_info, tnn_input_info)
    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(original_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            logging.error("Invalid input_file_path")
            sys.exit(return_code.ALIGN_FAILED)
    if refer_path is None:
        if is_tflite == True:
            reference_output_path = run_tflite(original_model_path, input_path, original_input_info)
        else:
            reference_output_path = run_onnx(original_model_path, input_path, original_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            logging.error("Invalid refer_path")
            sys.exit(return_code.ALIGN_FAILED)

    logging.info("Run tnn model_check...")
    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path, reference_output_path, is_tflite, align_batch)
    if debug_mode is False:
        if input_file_path is None and os.path.exists(input_path):
            data.clean_temp_data(os.path.dirname(input_path))
        if refer_path is None and os.path.exists(reference_output_path):
            data.clean_temp_data(reference_output_path)
    return True
Esempio n. 18
0
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False,
            align=False,
            input_path=None,
            refer_path=None,
            input_names: str = None):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          half:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    logging.info("Converter ONNX to TNN Model\n")

    checker.check_file_exist(onnx_path)

    ret, current_shape = checker.check_onnx_dim(onnx_path)

    if ret is False and current_shape is not None:
        if input_names is None:
            throw_exception(current_shape)
        if input_names is not None and not ("[" in input_names
                                            and "]" in input_names):
            throw_exception(current_shape)

    proto_suffix = '.tnnproto'
    model_suffix = '.tnnmodel'
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=v1.0"
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir
    logging.debug("The onnx2tnn command:" + command + "\n")

    if input_names is not None:
        new_input_names = ""
        for char in input_names:
            if char == "[":
                char = ":"
            if char == "]":
                continue
            new_input_names += char
        command = command + " -input_shape " + new_input_names

    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)

    if result == 0:
        logging.info("Converter ONNX to TNN model succeed!\n")
    else:
        logging.error("Converter ONNX to TNN model failed!\n")
        sys.exit(return_code.CONVERT_FAILED)
    onnx_base_name = os.path.basename(onnx_path)

    if align is True:
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)

        if input_names is None:
            align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                    input_path, refer_path)
        else:
            align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                    input_path, refer_path, new_input_names)
Esempio n. 19
0
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False,
            align='',
            align_batch=False,
            input_path=None,
            refer_path=None,
            input_names: str = None,
            is_ssd=False,
            debug_mode: bool = False):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          half:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    logging.info("Converter ONNX to TNN Model...\n")

    checker.check_file_exist(onnx_path)

    try:
        if not is_ssd:
            logging.info("Converter ONNX to TNN check_onnx_dim...\n")
            ret, current_shape = checker.check_onnx_dim(onnx_path)
            logging.info("Converter ONNX to TNN check_onnx_dim...\n")
            if ret is False and current_shape is not None:
                if input_names is None:
                    logging.info("Converter ONNX to TNN current_shape...\n")
                    throw_exception(current_shape)
            if input_names is not None:
                input_names = input_names.strip()
                if ":" not in input_names and " " not in input_names:
                    input_names = list(
                        current_shape.keys())[0] + ":" + input_names
                check_input_names(input_names, current_shape)
    except Exception as e:
        print(e)
        logging.error(
            "check_onnx_dim failed, next stage of convertion may failed too\n")

    proto_suffix = '.tnnproto'
    model_suffix = '.tnnmodel'
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=" + version
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir

    if input_names is not None:
        command = command + " -input_shape " + input_names
    logging.debug("The onnx2tnn command:" + command + "\n")

    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)

    if result == 0:
        logging.info("Converter ONNX to TNN model succeed!\n")
    else:
        logging.error("Converter ONNX to TNN model failed!\n")
        sys.exit(return_code.CONVERT_FAILED)
    onnx_base_name = os.path.basename(onnx_path)

    if optimize is True:
        tnn_proto_name = onnx_base_name[:-len('.onnx')] + '.opt' + proto_suffix
        tnn_model_name = onnx_base_name[:-len('.onnx')] + '.opt' + model_suffix
    else:
        tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
        tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
    tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
    tnn_model_path = os.path.join(output_dir, tnn_model_name)

    if align == 'output' or align_batch is True:
        if input_names is None:
            align_model.align_model(onnx_path,
                                    tnn_proto_path,
                                    tnn_model_path,
                                    input_path,
                                    refer_path,
                                    debug_mode=debug_mode,
                                    align_batch=align_batch)
        else:
            align_model.align_model(onnx_path,
                                    tnn_proto_path,
                                    tnn_model_path,
                                    input_path,
                                    refer_path,
                                    input_names,
                                    debug_mode=debug_mode,
                                    align_batch=align_batch)
    elif align == 'all':
        is_align_all = (align == 'all')
        align_model.align_all(onnx_path, tnn_proto_path, is_align_all,
                              input_names, input_path, refer_path)