Example #1
0
def get_input_shape_from_onnx(onnx_path) -> dict:
    # onnx_input_info: list = {  "input name":{
    #                                           {"shape": [n, c,...]},
    #                                           {"data_type": 0}
    #                                       }
    #                           ,
    #                           "input name": {
    #                                           {"shape": [n, c,...]},
    #                                           {"data_type": 0}
    #                                       }
    #                         }
    onnxruntime.set_default_logger_severity(3)
    session = onnxruntime.InferenceSession(onnx_path)
    input_info: dict = {}
    for ip in session.get_inputs():
        name = ip.name
        name = name.replace(":", "_")
        shape = ip.shape
        data_type = 0
        if ip.type == 'tensor(float)':
            data_type = 0
        elif ip.type == 'tensor(int64)':
            data_type = 3
        else:
            logging.error("Do not support input date type")
        if type(shape[0]) is not int:
            shape[0] = 1
        shape_information = {'shape': shape, 'data_type': data_type}
        input_info.update({name: shape_information})
    return input_info
Example #2
0
def convert(tf_path, input_names, output_names, output_dir, version, optimize, half, align=False, not_fold_const=False,
            input_path=None, refer_path=None):
    logging.info("Converter Tensorflow to TNN model\n")
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".pb")]
    onnx_path = os.path.join(output_dir, model_name + ".onnx")
    if tf2onnx(tf_path, input_names, output_names, onnx_path, not_fold_const) is False:
        logging.error("Oh No, tf2onnx failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Convert TensorFlow to ONNX model succeed!\n")
    if version is None:
        version = "v1.0"
    checker.check_file_exist(onnx_path)
    onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path, input_path, refer_path)
Example #3
0
def run(cmd_string, work_dir=None, timeout=None, is_shell=True):
    """
         执行一个SHELL命令 封装了subprocess的Popen方法, 支持超时判断,支持读取stdout和stderr
        :parameter:
              cwd: 运行命令时更改路径,如果被设定,子进程会直接先更改当前路径到cwd
              timeout: 超时时间,秒,支持小数,精度0.1秒
              shell: 是否通过shell运行
        :return return_code
        :exception 执行超时
    """

    if is_shell:
        cmd_string_list = cmd_string
    else:
        cmd_string_list = shlex.split(cmd_string)
    if timeout:
        end_time = datetime.datetime.now() + \
                   datetime.timedelta(seconds=timeout)

    sub = subprocess.Popen(cmd_string_list,
                           stdout=subprocess.PIPE,
                           stderr=subprocess.PIPE,
                           shell=True,
                           bufsize=4096,
                           cwd=work_dir,
                           close_fds=True)
    (stdout, stderr) = sub.communicate()
    logging.debug(str(stdout.decode('utf-8')))
    rc = sub.poll()
    if rc != 0:
        logging.error(str(stderr.decode('utf-8')))
    return rc
Example #4
0
def convert(tf_path,  output_dir, version, half, align=False,
            input_path=None, refer_path=None, debug_mode: bool = False):
    checker.check_file_exist(tf_path)
    model_name = os.path.basename(tf_path)
    if output_dir is None or not os.path.isdir(output_dir):
        output_dir = os.path.dirname(tf_path)
    checker.check_file_exist(output_dir)
    model_name = model_name[:-len(".tflite")]
    if tflite2tnn(tf_path, output_dir, half) is False:
        logging.error("Oh No, tflite2tnn failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Convert TensorFlowLite to TNN model succeed!\n")

    if version is None:
        version = "v1.0"
    if align == 'output':
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        tnn_proto_name = model_name + proto_suffix
        tnn_model_name = model_name + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(tf_path, tnn_proto_path, tnn_model_path, input_path, refer_path, None, True,
                                debug_mode=debug_mode)
Example #5
0
def get_input_shape_from_tnn(tnn_proto_path):
    input_info: dict = {}
    magic_number = linecache.getline(
        tnn_proto_path,
        1).strip('\n').strip('\"').strip(',').strip(' ').split(" ")[-1]
    magic_number = int(magic_number)
    if magic_number == TNN_MAGIC_NUMBER:
        line = linecache.getline(tnn_proto_path,
                                 2).strip('\n').strip('\"').strip(',')
        input_list = line.split(':')
        for tnn_input in input_list:
            name, n, c, h, w = tnn_input.strip(' ').split(' ')
            size = 4
            shape = [int(n), int(c), int(h), int(w)]
            data_type = 0
            input_shape_info = {'shape': shape, 'data_type': data_type}
            input_info.update({name: input_shape_info})
    elif magic_number == TNN_MAGIC_NUMBER_V2:
        line = linecache.getline(tnn_proto_path,
                                 2).strip('\n').strip('\"').strip(',')
        input_list = line.split(':')
        for tnn_input in input_list:
            # information: name size shape1 shape2... data_type
            information = tnn_input.strip(' ').split(' ')
            name = information[0]
            size = int(information[1])
            data_type = int(information[-1])
            shape: list = list(map(int, information[2:-1]))
            input_shape_info = {'shape': shape, 'data_type': data_type}
            input_info.update({name: input_shape_info})
    else:
        logging.error("Unspport TNN model version\n")
        sys.exit(-1)
    return input_info
Example #6
0
def convert(proto_path,
            model_path,
            output_dir,
            version,
            optimize,
            half,
            align=False,
            input_path=None,
            refer_path=None):
    logging.info("Converter Caffe to ONNX Model\n")
    checker.check_file_exist(proto_path)
    checker.check_file_exist(model_path)
    if output_dir is None:
        output_dir = os.path.dirname(proto_path)
    checker.check_file_exist(output_dir)

    proto_name = os.path.basename(proto_path)
    proto_name = proto_name[:-len(".prototxt")]
    onnx_path = os.path.join(output_dir, proto_name + ".onnx")

    if caffe2onnx(proto_path, model_path, onnx_path) is False:
        logging.error("Oh No, caff2onnx failed :(\n")
        sys.exit(return_code.CONVERT_FAILED)
    else:
        logging.info("Congratulations! caffe2onnx succeed!\n")
    if version is None:
        version = "v1.0"

    is_ssd = checker.is_ssd_model(proto_path)
    if is_ssd:
        onnx2tnn.convert(onnx_path,
                         output_dir,
                         version,
                         False,
                         half,
                         is_ssd=True)
    else:
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)

    if is_ssd and ((input_path is None) or (refer_path is None)):
        align = False
        optimize = False

    if align is True:
        proto_suffix = '.tnnproto'
        model_suffix = '.tnnmodel'
        onnx_base_name = os.path.basename(onnx_path)
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)
        align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                input_path, refer_path)
Example #7
0
def align_model(onnx_path: str, tnn_proto_path: str, tnn_model_path: str, input_file_path: str=None,
                refer_path: str = None, input_names: str = None, is_tflite: bool=False ) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param onnx_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    logging.info("{}  align model (tflite or ONNX vs TNN),please wait a moment {}\n" .format("-" * 10, "-" * 10))

    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)

    if input_names is not None:
        input_info = parse_input_names(input_names)
    # check input
    if input_names is not None:
        tnn_input_info = input_info
        onnx_input_info = input_info
    else:
        tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
        if is_tflite == True:
            onnx_input_info = get_input_shape_from_tflite(onnx_path)
        else:
            onnx_input_info = get_input_shape_from_onnx(onnx_path)
    if is_tflite == True:
        check_input_lite_info(onnx_input_info, tnn_input_info)
    else:
       check_input_info(onnx_input_info, tnn_input_info)
    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(onnx_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            logging.error("Invalid input_file_path")
            sys.exit(return_code.ALIGN_FAILED)
    if refer_path is None:
        if is_tflite == True:
            reference_output_path = run_tflite(onnx_path, input_path, onnx_input_info)
        else:
            reference_output_path = run_onnx(onnx_path, input_path, onnx_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            logging.error("Invalid refer_path")
            sys.exit(return_code.ALIGN_FAILED)
    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path, reference_output_path, is_tflite)
    if input_file_path is None and os.path.exists(input_path):
        data.clean_temp_data(os.path.dirname(input_path))
    if refer_path is None and os.path.exists(reference_output_path):
        data.clean_temp_data(reference_output_path)

    return True
def throw_exception(current_shape):
    message = "Current shape: "
    for name, shape in current_shape.items():
        message += str(name) + ": " + str(shape) + "   "

    logging.error("You should use -in to specify input's name and shape. e.g.: -in name:1,3,32,32")
    logging.error(message)

    sys.exit(return_code.CONVERT_FAILED)
Example #9
0
def check_input_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("onnx input size != tnn input size")
    for name, onnx_info in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_info = tnn_input_info[tnn_name]
        if check_shape_info(onnx_info, tnn_info) == True:
            logging.info(name + ": input shape of onnx and tnn is aligned!\n")
        else:
            logging.error("input is not align 194\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info),
                                                                                      str(tnn_info)))
Example #10
0
def parse_path(path: str):
    if path is None:
        return None
    if " " in path or " " in os.getcwd():
        logging.error("The path can not contain spaces!")
        sys.exit(return_code.SPACE_IN_PATH)
    if path.startswith("/"):
        return path
    elif path.startswith("./"):
        return os.path.join(os.getcwd(), path[2:])
    elif path.startswith("../"):
        abs_path = os.getcwd() + "/" + path
        return abs_path
    elif path.startswith("~"):
        abs_path = os.path.expanduser('~') + path[1:]
        return abs_path
    else:
        return os.path.join(os.getcwd(), path)
Example #11
0
def get_input_shape_from_tflite(tflite_path) -> dict:
    import tensorflow as tf
    input_info: dict = {}
    interpreter = tf.lite.Interpreter(tflite_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    for item in input_details:
        name = item["name"]
        shape = list(item["shape"])
        shape = nhwc_shape_to_nchw(shape)

        if item["dtype"] == np.float32:
            data_type = 0
        elif item["dtype"] == np.int64:
            data_type = 3
        else:
            logging.error("Do not support input date type")
        shape_infomation = {"shape": shape, "data_type": data_type}
        input_info.update({name: shape_infomation})
    return input_info
Example #12
0
def print_not_align_message(reason=None, is_tflite=False):
    logging.error("{}   Unfortunately   {}".format("-" * 10, "-" * 10))
    if is_tflite == True:
        logging.error("The tflite model not aligned with tnn model\n")
    else:
        logging.error("The onnx model not aligned with tnn model\n")
    sys.exit(return_code.ALIGN_FAILED)
Example #13
0
def run(cmd_string,
        work_dir=None,
        timeout=None,
        is_shell=True,
        log_level="debug"):
    """
         执行一个SHELL命令 封装了subprocess的Popen方法, 支持超时判断,支持读取stdout和stderr
        :parameter:
              cwd: 运行命令时更改路径,如果被设定,子进程会直接先更改当前路径到cwd
              timeout: 超时时间,秒,支持小数,精度0.1秒
              shell: 是否通过shell运行
        :return return_code
        :exception 执行超时
    """

    if is_shell:
        cmd_string_list = cmd_string
    else:
        cmd_string_list = shlex.split(cmd_string)
    if timeout:
        end_time = datetime.datetime.now() + \
                   datetime.timedelta(seconds=timeout)

    sub = subprocess.Popen(cmd_string_list,
                           stdout=subprocess.PIPE,
                           stderr=subprocess.STDOUT,
                           shell=True,
                           bufsize=0,
                           cwd=work_dir,
                           close_fds=True)
    while True:
        line = sub.stdout.readline().decode('utf-8')
        if log_level == "error":
            logging.error(str(line))
        elif log_level == "debug":
            logging.debug(str(line))
        if line == '' and sub.poll() is not None:
            break
    rc = sub.poll()
    return rc
def check_input_names(input_names: str, onnx_input_info: dict):
    input_names = input_names.strip()

    input_shapes_ = {}
    for x in input_names.split(" "):
        if ':' not in x:
            input_shapes_[None] = list(map(int, x.split(',')))
        else:
            pieces = x.split(':')
            # for the input name like input:0
            name, shape = ':'.join(
                pieces[:-1]), list(map(int, pieces[-1].split(',')))
            input_shapes_[name] = shape

    if len(input_shapes_) != len(onnx_input_info):
        logging.error("The specified input does not match the input of the ONNX model.")
        logging.error("Specified: ", list(input_shapes_.keys()))
        logging.error("ONNX: ", list(onnx_input_info.keys()))
        sys.exit(return_code.CONVERT_FAILED)
Example #15
0
def print_not_align_message(reason=None):
    logging.error("{}   Unfortunately   {}".format("-" * 10, "-" * 10))
    logging.error("The onnx model not aligned with tnn model\n")
    sys.exit(return_code.ALIGN_FAILED)
Example #16
0
def check_file_exist(file_path):
    if os.path.exists(file_path) is False:
        logging.error("The " + file_path +
                      " does not exist! please make sure the file exist!\n")
        sys.exit(return_code.CONVERT_FAILED)
Example #17
0
def align_model(original_model_path: str, tnn_proto_path: str, tnn_model_path: str, input_file_path: str = None,
                refer_path: str = None, specify_input_args: str = None, is_tflite: bool = False, debug_mode: bool = False, align_batch: bool = False) -> bool:
    """
    对 onnx 模型和 tnn 模型进行对齐.
    当前支持模型: 单输入,单输出;单输入,多输出;
    :param original_model_path:
    :param tnn_proto_path:
    :param tnn_model_path:
    :return:
    """
    logging.info("{}  align model (tflite or ONNX vs TNN),please wait a moment {}\n" .format("-" * 10, "-" * 10))

    checker.check_file_exist(tnn_proto_path)
    checker.check_file_exist(tnn_model_path)
    # list = {  "input name1":{
    #                           {"shape": [n, c,...]},
    #                           {"data_type": 0}
    #                        },
    #           "input name22": {
    #                            {"shape": [n, c,...]},
    #                            {"data_type": 0}
    #                         }
    # get original input info
    if is_tflite:
        original_input_info = get_input_shape_from_tflite(original_model_path)
    else:
        original_input_info = get_input_shape_from_onnx(original_model_path)
    # get tnn input info
    tnn_input_info = get_input_shape_from_tnn(tnn_proto_path)
    # check input
    if specify_input_args is not None:
        specify_input_info = parse_specify_input_args(specify_input_args)
        update_original_input_shape(original_input_info, specify_input_info)

    if is_tflite:
        check_input_lite_info(original_input_info, tnn_input_info)
    else:
       check_input_info(original_input_info, tnn_input_info)
    if input_file_path is None:
        # generate data
        input_path = data.gene_random_data(original_input_info)
    else:
        if os.path.exists(input_file_path):
            input_path = input_file_path
        else:
            logging.error("Invalid input_file_path")
            sys.exit(return_code.ALIGN_FAILED)
    if refer_path is None:
        if is_tflite == True:
            reference_output_path = run_tflite(original_model_path, input_path, original_input_info)
        else:
            reference_output_path = run_onnx(original_model_path, input_path, original_input_info)
    else:
        if os.path.exists(refer_path):
            reference_output_path = refer_path
        else:
            logging.error("Invalid refer_path")
            sys.exit(return_code.ALIGN_FAILED)

    logging.info("Run tnn model_check...")
    run_tnn_model_check(tnn_proto_path, tnn_model_path, input_path, reference_output_path, is_tflite, align_batch)
    if debug_mode is False:
        if input_file_path is None and os.path.exists(input_path):
            data.clean_temp_data(os.path.dirname(input_path))
        if refer_path is None and os.path.exists(reference_output_path):
            data.clean_temp_data(reference_output_path)
    return True
Example #18
0
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False,
            align=False,
            input_path=None,
            refer_path=None,
            input_names: str = None):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          half:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    logging.info("Converter ONNX to TNN Model\n")

    checker.check_file_exist(onnx_path)

    ret, current_shape = checker.check_onnx_dim(onnx_path)

    if ret is False and current_shape is not None:
        if input_names is None:
            throw_exception(current_shape)
        if input_names is not None and not ("[" in input_names
                                            and "]" in input_names):
            throw_exception(current_shape)

    proto_suffix = '.tnnproto'
    model_suffix = '.tnnmodel'
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=v1.0"
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir
    logging.debug("The onnx2tnn command:" + command + "\n")

    if input_names is not None:
        new_input_names = ""
        for char in input_names:
            if char == "[":
                char = ":"
            if char == "]":
                continue
            new_input_names += char
        command = command + " -input_shape " + new_input_names

    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)

    if result == 0:
        logging.info("Converter ONNX to TNN model succeed!\n")
    else:
        logging.error("Converter ONNX to TNN model failed!\n")
        sys.exit(return_code.CONVERT_FAILED)
    onnx_base_name = os.path.basename(onnx_path)

    if align is True:
        if optimize is True:
            tnn_proto_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx'
                                                  )] + '.opt' + model_suffix
        else:
            tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
            tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
        tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
        tnn_model_path = os.path.join(output_dir, tnn_model_name)

        if input_names is None:
            align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                    input_path, refer_path)
        else:
            align_model.align_model(onnx_path, tnn_proto_path, tnn_model_path,
                                    input_path, refer_path, new_input_names)
Example #19
0
def convert(onnx_path,
            output_dir=None,
            version="v1.0",
            optimize=True,
            half=False,
            align='',
            align_batch=False,
            input_path=None,
            refer_path=None,
            input_names: str = None,
            is_ssd=False,
            debug_mode: bool = False):
    """
    执行 onnx 转换为 tnn 的转换指令
    :parameter:
          onnx_path:    需要转换的 onnx 文件的路径
          output_path:  生成的 tnn 文件的路径
          version:      转换模型的版本号
          optimize:     是否需要对模型进行优化,默认是需要进行优化
          half:         是否需要转为 FP16 的模型,减小模型的大小
    :return return_code
    :exception 执行超时
    """
    logging.info("Converter ONNX to TNN Model...\n")

    checker.check_file_exist(onnx_path)

    try:
        if not is_ssd:
            logging.info("Converter ONNX to TNN check_onnx_dim...\n")
            ret, current_shape = checker.check_onnx_dim(onnx_path)
            logging.info("Converter ONNX to TNN check_onnx_dim...\n")
            if ret is False and current_shape is not None:
                if input_names is None:
                    logging.info("Converter ONNX to TNN current_shape...\n")
                    throw_exception(current_shape)
            if input_names is not None:
                input_names = input_names.strip()
                if ":" not in input_names and " " not in input_names:
                    input_names = list(
                        current_shape.keys())[0] + ":" + input_names
                check_input_names(input_names, current_shape)
    except Exception as e:
        print(e)
        logging.error(
            "check_onnx_dim failed, next stage of convertion may failed too\n")

    proto_suffix = '.tnnproto'
    model_suffix = '.tnnmodel'
    command = "python3 onnx2tnn.py " + onnx_path
    command = command + " -version=" + version
    checker.check_file_exist(onnx_path)
    if optimize is True:
        command = command + " -optimize=1"
    else:
        command = command + " -optimize=0"
    if half is True:
        command = command + " -half=1"
    else:
        command = command + " -half=0"

    if output_dir is None:
        output_dir = os.path.dirname(onnx_path)
    checker.check_file_exist(output_dir)
    command = command + " -o " + output_dir

    if input_names is not None:
        command = command + " -input_shape " + input_names
    logging.debug("The onnx2tnn command:" + command + "\n")

    work_dir = "../onnx2tnn/onnx-converter/"
    result = cmd.run(command, work_dir=work_dir)

    if result == 0:
        logging.info("Converter ONNX to TNN model succeed!\n")
    else:
        logging.error("Converter ONNX to TNN model failed!\n")
        sys.exit(return_code.CONVERT_FAILED)
    onnx_base_name = os.path.basename(onnx_path)

    if optimize is True:
        tnn_proto_name = onnx_base_name[:-len('.onnx')] + '.opt' + proto_suffix
        tnn_model_name = onnx_base_name[:-len('.onnx')] + '.opt' + model_suffix
    else:
        tnn_proto_name = onnx_base_name[:-len('.onnx')] + proto_suffix
        tnn_model_name = onnx_base_name[:-len('.onnx')] + model_suffix
    tnn_proto_path = os.path.join(output_dir, tnn_proto_name)
    tnn_model_path = os.path.join(output_dir, tnn_model_name)

    if align == 'output' or align_batch is True:
        if input_names is None:
            align_model.align_model(onnx_path,
                                    tnn_proto_path,
                                    tnn_model_path,
                                    input_path,
                                    refer_path,
                                    debug_mode=debug_mode,
                                    align_batch=align_batch)
        else:
            align_model.align_model(onnx_path,
                                    tnn_proto_path,
                                    tnn_model_path,
                                    input_path,
                                    refer_path,
                                    input_names,
                                    debug_mode=debug_mode,
                                    align_batch=align_batch)
    elif align == 'all':
        is_align_all = (align == 'all')
        align_model.align_all(onnx_path, tnn_proto_path, is_align_all,
                              input_names, input_path, refer_path)