def check_input_info(onnx_input_info: dict, tnn_input_info: dict): if len(onnx_input_info) != len(tnn_input_info): print_not_align_message("onnx input size != tnn input size") for name, onnx_info in onnx_input_info.items(): tnn_name = convert_name.onnx_name2tnn_name(name) tnn_info = tnn_input_info[tnn_name] if check_shape_info(onnx_info, tnn_info) == True: logging.info(name + ": input shape of onnx and tnn is aligned!\n") else: logging.error("input is not align 194\n") print_not_align_message( "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info), str(tnn_info)))
def check_input_info(onnx_input_info: dict, tnn_input_info: dict): if len(onnx_input_info) != len(tnn_input_info): print_not_align_message("onnx input size != tnn input size") for name, onnx_shape in onnx_input_info.items(): tnn_name = convert_name.onnx_name2tnn_name(name) tnn_shape = tnn_input_info[tnn_name] if type(onnx_shape[0]) is not int: onnx_shape[0] = 1 if tnn_shape != onnx_shape: print_not_align_message( "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n". format(name, str(onnx_shape), str(tnn_shape))) logging.info("Check onnx input shape and tnn input shape align!\n")
def check_input_lite_info(onnx_input_info: dict, tnn_input_info: dict): if len(onnx_input_info) != len(tnn_input_info): print_not_align_message("tflite input size != tnn input size") for name, onnx_info in onnx_input_info.items(): tnn_name = convert_name.onnx_name2tnn_name(name) tnn_info = tnn_input_info[tnn_name] if check_shape_info(onnx_info, tnn_info): logging.info("Check tflite input shape and tnn input shape align!\n") else: logging.info("input is not align\n") print_not_align_message( "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n".format(name, str(onnx_info), str(tnn_info))) logging.info("Check tflite input shape and tnn input shape align!\n")
def check_input_lite_info(onnx_input_info: dict, tnn_input_info: dict): if len(onnx_input_info) != len(tnn_input_info): print_not_align_message("tflite input size != tnn input size") for name, onnx_shape in onnx_input_info.items(): tnn_name = convert_name.onnx_name2tnn_name(name) tnn_shape = tnn_input_info[tnn_name] if type(onnx_shape[0]) is not int: onnx_shape[0] = 1 nchw = [1, 1, 1, 1] nchw[0] = onnx_shape[0] nchw[1] = onnx_shape[3] nchw[2] = onnx_shape[1] nchw[3] = onnx_shape[2] if tnn_shape != nchw: logging.info("input is not algin 216\n") print_not_align_message( "The {}'s shape not equal! the onnx shape:{}, tnn shape: {}\n". format(name, str(onnx_shape), str(tnn_shape))) logging.info("Check tflite input shape and tnn input shape align!\n")
def parse_input_names(input_names: str) -> dict: input_info = {} for x in input_names.split(" "): if ':' not in x: shape = list(map(int, x.split(','))) input_info[None] = {'shape': shape, 'data_type': 0} else: pieces = x.split(':') # for the input name like input:0 name, shape = ':'.join(pieces[:-1]), list( map(int, pieces[-1].split(','))) input_shape_info = {'shape': shape, 'data_type': 0} input_info[name] = input_shape_info for name, input_shape_info in input_info.items(): if ":" not in name: continue tnn_name = convert_name.onnx_name2tnn_name(name) input_info[tnn_name] = input_shape_info del input_info[name] return input_info