Beispiel #1
0
def check_input_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("onnx input size != tnn input size")
    for name, onnx_info in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_info = tnn_input_info[tnn_name]
        if check_shape_info(onnx_info, tnn_info) == True:
            logging.info(name + ": input shape of onnx and tnn is aligned!\n")
        else:
            logging.error("input is not align 194\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info),
                                                                                      str(tnn_info)))
Beispiel #2
0
def check_input_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("onnx input size != tnn input size")
    for name, onnx_shape in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_shape = tnn_input_info[tnn_name]
        if type(onnx_shape[0]) is not int:
            onnx_shape[0] = 1
        if tnn_shape != onnx_shape:
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".
                format(name, str(onnx_shape), str(tnn_shape)))

    logging.info("Check onnx input shape and tnn input shape align!\n")
Beispiel #3
0
def check_input_lite_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("tflite input size != tnn input size")
    for name, onnx_info in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_info = tnn_input_info[tnn_name]
        if check_shape_info(onnx_info, tnn_info):
            logging.info("Check tflite input shape and tnn input shape align!\n")
        else:
            logging.info("input is not align\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info),
                                                                                      str(tnn_info)))
    logging.info("Check tflite input shape and tnn input shape align!\n")
Beispiel #4
0
def check_input_lite_info(onnx_input_info: dict, tnn_input_info: dict):
    if len(onnx_input_info) != len(tnn_input_info):
        print_not_align_message("tflite input size != tnn input size")
    for name, onnx_shape in onnx_input_info.items():
        tnn_name = convert_name.onnx_name2tnn_name(name)
        tnn_shape = tnn_input_info[tnn_name]
        if type(onnx_shape[0]) is not int:
            onnx_shape[0] = 1
        nchw = [1, 1, 1, 1]
        nchw[0] = onnx_shape[0]
        nchw[1] = onnx_shape[3]
        nchw[2] = onnx_shape[1]
        nchw[3] = onnx_shape[2]
        if tnn_shape != nchw:
            logging.info("input is not algin 216\n")
            print_not_align_message(
                "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".
                format(name, str(onnx_shape), str(tnn_shape)))
    logging.info("Check tflite input shape and tnn input shape align!\n")
Beispiel #5
0
def parse_input_names(input_names: str) -> dict:
    input_info = {}
    for x in input_names.split(" "):
        if ':' not in x:
            shape = list(map(int, x.split(',')))
            input_info[None] = {'shape': shape, 'data_type': 0}
        else:
            pieces = x.split(':')
            # for the input name like input:0
            name, shape = ':'.join(pieces[:-1]), list(
                map(int, pieces[-1].split(',')))
            input_shape_info = {'shape': shape, 'data_type': 0}
            input_info[name] = input_shape_info

    for name, input_shape_info in input_info.items():
        if ":" not in name:
            continue
        tnn_name = convert_name.onnx_name2tnn_name(name)
        input_info[tnn_name] = input_shape_info
        del input_info[name]

    return input_info