예제 #1
0
def convert(tf_path,  output_dir, version, half, align=False,
            input_path=None, refer_path=None, debug_mode: bool = False):
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".tflite")]
    if tflite2tnn(tf_path, output_dir, half) is False:
        logging.error("Oh No, tflite2tnn failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Convert TensorFlowLite to TNN model succeed!\n")

    if version is None:
        version = "v1.0"
    if align == 'output':
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        tnn_proto_name = model_name + proto_suffix
        tnn_model_name = model_name + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(tf_path, tnn_proto_path, tnn_model_path, input_path, refer_path, None, True,
                                debug_mode=debug_mode)
예제 #2
0
def convert(tf_path, input_names, output_names, output_dir, version, optimize, half, align=False, not_fold_const=False,
            input_path=None, refer_path=None):
    logging.info("Converter Tensorflow to TNN model\n")
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".pb")]
    onnx_path = os.path.join(output_dir, model_name + ".onnx")
    if tf2onnx(tf_path, input_names, output_names, onnx_path, not_fold_const) is False:
        logging.error("Oh No, tf2onnx failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Convert TensorFlow to ONNX model succeed!\n")
    if version is None:
        version = "v1.0"
    checker.check_file_exist(onnx_path)
    onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path, input_path, refer_path)
예제 #3
0
def convert(proto_path,
            model_path,
            output_dir,
            version,
            optimize,
            half,
            align=False,
            input_path=None,
            refer_path=None):
    logging.info("Converter Caffe to ONNX Model\n")
    checker.check_file_exist(proto_path)
    checker.check_file_exist(model_path)
    if output_dir is None:
        output_dir = os.path.dirname(proto_path)
    checker.check_file_exist(output_dir)

    proto_name = os.path.basename(proto_path)
    proto_name = proto_name[:-len(".prototxt")]
    onnx_path = os.path.join(output_dir, proto_name + ".onnx")

    if caffe2onnx(proto_path, model_path, onnx_path) is False:
        logging.error("Oh No, caff2onnx failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Congratulations! caffe2onnx succeed!\n")
    if version is None:
        version = "v1.0"

    is_ssd = checker.is_ssd_model(proto_path)
    if is_ssd:
        onnx2tnn.convert(onnx_path,
                         output_dir,
                         version,
                         False,
                         half,
                         is_ssd=True)
    else:
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if is_ssd and ((input_path is None) or (refer_path is None)):
        align = False
        optimize = False

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                input_path, refer_path)
예제 #4
0
def align_model(onnx_path: str, tnn_proto_path: str, tnn_model_path: str, input_file_path: str=None,
                refer_path: str = None, input_names: str = None, is_tflite: bool=False ) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param onnx_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    logging.info("{}  align model (tflite or ONNX vs TNN),please wait a moment {}\n" .format("-" * 10, "-" * 10))

    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)

    if input_names is not None:
        input_info = parse_input_names(input_names)
    # check input
    if input_names is not None:
        tnn_input_info = input_info
        onnx_input_info = input_info
    else:
        tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
        if is_tflite == True:
            onnx_input_info = get_input_shape_from_tflite(onnx_path)
        else:
            onnx_input_info = get_input_shape_from_onnx(onnx_path)
    if is_tflite == True:
        check_input_lite_info(onnx_input_info, tnn_input_info)
    else:
       check_input_info(onnx_input_info, tnn_input_info)
    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(onnx_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            logging.error("Invalid input_file_path")
            sys.exit(return_code.ALIGN_FAILED)
    if refer_path is None:
        if is_tflite == True:
            reference_output_path = run_tflite(onnx_path, input_path, onnx_input_info)
        else:
            reference_output_path = run_onnx(onnx_path, input_path, onnx_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            logging.error("Invalid refer_path")
            sys.exit(return_code.ALIGN_FAILED)
    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path, reference_output_path, is_tflite)
    if input_file_path is None and os.path.exists(input_path):
        data.clean_temp_data(os.path.dirname(input_path))
    if refer_path is None and os.path.exists(reference_output_path):
        data.clean_temp_data(reference_output_path)

    return True
예제 #5
0
파일: onnx2tnn.py 프로젝트: yytdfc/TNN
def throw_exception(current_shape):
    message = "Current shape: "
    for item in current_shape:
        name, shape = item
        message += str(name) + ": " + str(shape) + "   "

    logging.info(
        "You should use -in to specify input's name and shape. e.g.: -in name[1,3,32,32]"
    )
    logging.info(message)

    exit(-1)
예제 #6
0
def check_input_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("onnx input size != tnn input size")
    for name, onnx_info in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_info = tnn_input_info[tnn_name]
        if check_shape_info(onnx_info, tnn_info) == True:
            logging.info(name + ": input shape of onnx and tnn is aligned!\n")
        else:
            logging.error("input is not align 194\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info),
                                                                                      str(tnn_info)))
예제 #7
0
파일: align_model.py 프로젝트: zwh1024/TNN
def check_input_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("onnx input size != tnn input size")
    for name, onnx_shape in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_shape = tnn_input_info[tnn_name]
        if type(onnx_shape[0]) is not int:
            onnx_shape[0] = 1
        if tnn_shape != onnx_shape:
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".
                format(name, str(onnx_shape), str(tnn_shape)))

    logging.info("Check onnx input shape and tnn input shape align!\n")
예제 #8
0
def check_input_lite_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("tflite input size != tnn input size")
    for name, onnx_shape in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_shape = tnn_input_info[tnn_name]
        if type(onnx_shape[0]) is not int:
            onnx_shape[0] = 1
        nchw = [1, 1, 1, 1]
        nchw[0] = onnx_shape[0]
        nchw[1] = onnx_shape[3]
        nchw[2] = onnx_shape[1]
        nchw[3] = onnx_shape[2]
        if tnn_shape != nchw:
            logging.info("input is not algin 216\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".
                format(name, str(onnx_shape), str(tnn_shape)))
    logging.info("Check tflite input shape and tnn input shape align!\n")
예제 #9
0
def check_input_lite_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("tflite input size != tnn input size")
    for name, onnx_info in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_info = tnn_input_info[tnn_name]
        if check_shape_info(onnx_info, tnn_info):
            logging.info("Check tflite input shape and tnn input shape align!\n")
        else:
            logging.info("input is not align\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info),
                                                                                      str(tnn_info)))
    logging.info("Check tflite input shape and tnn input shape align!\n")
예제 #10
0
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False,
            align=False,
            input_path=None,
            refer_path=None,
            input_names: str = None):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          half:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    logging.info("Converter ONNX to TNN Model\n")

    checker.check_file_exist(onnx_path)

    ret, current_shape = checker.check_onnx_dim(onnx_path)

    if ret is False and current_shape is not None:
        if input_names is None:
            throw_exception(current_shape)
        if input_names is not None and not ("[" in input_names
                                            and "]" in input_names):
            throw_exception(current_shape)

    proto_suffix = '.tnnproto'
    model_suffix = '.tnnmodel'
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=v1.0"
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir
    logging.debug("The onnx2tnn command:" + command + "\n")

    if input_names is not None:
        new_input_names = ""
        for char in input_names:
            if char == "[":
                char = ":"
            if char == "]":
                continue
            new_input_names += char
        command = command + " -input_shape " + new_input_names

    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)

    if result == 0:
        logging.info("Converter ONNX to TNN model succeed!\n")
    else:
        logging.error("Converter ONNX to TNN model failed!\n")
        sys.exit(return_code.CONVERT_FAILED)
    onnx_base_name = os.path.basename(onnx_path)

    if align is True:
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)

        if input_names is None:
            align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                    input_path, refer_path)
        else:
            align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                    input_path, refer_path, new_input_names)
예제 #11
0
파일: align_model.py 프로젝트: zwh1024/TNN
def print_align_message():
    logging.info("{}  Congratulations!   {}".format("-" * 10, "-" * 10))
    logging.info("The onnx model aligned with tnn model\n")
예제 #12
0
파일: align_model.py 프로젝트: yytdfc/TNN
def print_not_align_message(reason):
    logging.info("{}   Unfortunately   {}" .format("-" * 10, "-" * 10))
    logging.info("The onnx model not aligned with tnn model\n")
    logging.info("the reason " + reason + "\n")
    exit(-1)
예제 #13
0
파일: align_model.py 프로젝트: yfz912/TNN
def print_align_message(is_tflite=False):
    logging.info("{}  Congratulations!   {}".format("-" * 10, "-" * 10))
    if is_tflite == True:
        logging.info("The tflite model aligned with tnn model\n")
    else:
        logging.info("The onnx model aligned with tnn model\n")
예제 #14
0
def align_model(original_model_path: str, tnn_proto_path: str, tnn_model_path: str, input_file_path: str = None,
                refer_path: str = None, specify_input_args: str = None, is_tflite: bool = False, debug_mode: bool = False, align_batch: bool = False) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param original_model_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    logging.info("{}  align model (tflite or ONNX vs TNN),please wait a moment {}\n" .format("-" * 10, "-" * 10))

    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)
    # list = {  "input name1":{
    #                           {"shape": [n, c,...]},
    #                           {"data_type": 0}
    #                        },
    #           "input name22": {
    #                            {"shape": [n, c,...]},
    #                            {"data_type": 0}
    #                         }
    # get original input info
    if is_tflite:
        original_input_info = get_input_shape_from_tflite(original_model_path)
    else:
        original_input_info = get_input_shape_from_onnx(original_model_path)
    # get tnn input info
    tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
    # check input
    if specify_input_args is not None:
        specify_input_info = parse_specify_input_args(specify_input_args)
        update_original_input_shape(original_input_info, specify_input_info)

    if is_tflite:
        check_input_lite_info(original_input_info, tnn_input_info)
    else:
       check_input_info(original_input_info, tnn_input_info)
    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(original_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            logging.error("Invalid input_file_path")
            sys.exit(return_code.ALIGN_FAILED)
    if refer_path is None:
        if is_tflite == True:
            reference_output_path = run_tflite(original_model_path, input_path, original_input_info)
        else:
            reference_output_path = run_onnx(original_model_path, input_path, original_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            logging.error("Invalid refer_path")
            sys.exit(return_code.ALIGN_FAILED)

    logging.info("Run tnn model_check...")
    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path, reference_output_path, is_tflite, align_batch)
    if debug_mode is False:
        if input_file_path is None and os.path.exists(input_path):
            data.clean_temp_data(os.path.dirname(input_path))
        if refer_path is None and os.path.exists(reference_output_path):
            data.clean_temp_data(reference_output_path)
    return True
예제 #15
0
파일: checker.py 프로젝트: yytdfc/TNN
def check_file_exist(file_path):
    if os.path.exists(file_path) is False:
        logging.info("the " + file_path +
                     " does not exist! please make sure the file exist!\n")
        exit(-1)
예제 #16
0
파일: onnx2tnn.py 프로젝트: zhiyuyan/TNN
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False,
            align='',
            align_batch=False,
            input_path=None,
            refer_path=None,
            input_names: str = None,
            is_ssd=False,
            debug_mode: bool = False):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          half:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    logging.info("Converter ONNX to TNN Model...\n")

    checker.check_file_exist(onnx_path)

    try:
        if not is_ssd:
            logging.info("Converter ONNX to TNN check_onnx_dim...\n")
            ret, current_shape = checker.check_onnx_dim(onnx_path)
            logging.info("Converter ONNX to TNN check_onnx_dim...\n")
            if ret is False and current_shape is not None:
                if input_names is None:
                    logging.info("Converter ONNX to TNN current_shape...\n")
                    throw_exception(current_shape)
            if input_names is not None:
                input_names = input_names.strip()
                if ":" not in input_names and " " not in input_names:
                    input_names = list(
                        current_shape.keys())[0] + ":" + input_names
                check_input_names(input_names, current_shape)
    except Exception as e:
        print(e)
        logging.error(
            "check_onnx_dim failed, next stage of convertion may failed too\n")

    proto_suffix = '.tnnproto'
    model_suffix = '.tnnmodel'
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=" + version
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir

    if input_names is not None:
        command = command + " -input_shape " + input_names
    logging.debug("The onnx2tnn command:" + command + "\n")

    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)

    if result == 0:
        logging.info("Converter ONNX to TNN model succeed!\n")
    else:
        logging.error("Converter ONNX to TNN model failed!\n")
        sys.exit(return_code.CONVERT_FAILED)
    onnx_base_name = os.path.basename(onnx_path)

    if optimize is True:
        tnn_proto_name = onnx_base_name[:-len('.onnx')] + '.opt' + proto_suffix
        tnn_model_name = onnx_base_name[:-len('.onnx')] + '.opt' + model_suffix
    else:
        tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
        tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
    tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
    tnn_model_path = os.path.join(output_dir, tnn_model_name)

    if align == 'output' or align_batch is True:
        if input_names is None:
            align_model.align_model(onnx_path,
                                    tnn_proto_path,
                                    tnn_model_path,
                                    input_path,
                                    refer_path,
                                    debug_mode=debug_mode,
                                    align_batch=align_batch)
        else:
            align_model.align_model(onnx_path,
                                    tnn_proto_path,
                                    tnn_model_path,
                                    input_path,
                                    refer_path,
                                    input_names,
                                    debug_mode=debug_mode,
                                    align_batch=align_batch)
    elif align == 'all':
        is_align_all = (align == 'all')
        align_model.align_all(onnx_path, tnn_proto_path, is_align_all,
                              input_names, input_path, refer_path)