def gen_random_input(model_output_dir, input_nodes, input_shapes, input_files, input_ranges, input_data_types, input_file_name="model_input"): for input_name in input_nodes: formatted_name = common.formatted_file_name( input_file_name, input_name) if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) input_ranges_str = ":".join(input_ranges) input_data_types_str = ",".join(input_data_types) generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, input_shapes_str, input_ranges_str, input_data_types_str) input_file_list = [] if isinstance(input_files, list): input_file_list.extend(input_files) else: input_file_list.append(input_files) if len(input_file_list) != 0: input_name_list = [] if isinstance(input_nodes, list): input_name_list.extend(input_nodes) else: input_name_list.append(input_nodes) if len(input_file_list) != len(input_name_list): raise Exception('If input_files set, the input files should ' 'match the input names.') for i in range(len(input_file_list)): if input_file_list[i] is not None: dst_input_file = model_output_dir + '/' + \ common.formatted_file_name(input_file_name, input_name_list[i]) if input_file_list[i].startswith("http://") or \ input_file_list[i].startswith("https://"): six.moves.urllib.request.urlretrieve(input_file_list[i], dst_input_file) else: sh.cp("-f", input_file_list[i], dst_input_file)
def gen_random_input(model_output_dir, input_nodes, input_shapes, input_files, input_ranges, input_data_types, input_file_name="model_input"): for input_name in input_nodes: formatted_name = common.formatted_file_name( input_file_name, input_name) if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) input_ranges_str = ":".join(input_ranges) input_data_types_str = ",".join(input_data_types) generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, input_shapes_str, input_ranges_str, input_data_types_str) input_file_list = [] if isinstance(input_files, list): input_file_list.extend(input_files) else: input_file_list.append(input_files) if len(input_file_list) != 0: input_name_list = [] if isinstance(input_nodes, list): input_name_list.extend(input_nodes) else: input_name_list.append(input_nodes) if len(input_file_list) != len(input_name_list): raise Exception('If input_files set, the input files should ' 'match the input names.') for i in range(len(input_file_list)): if input_file_list[i] is not None: dst_input_file = model_output_dir + '/' + \ common.formatted_file_name(input_file_name, input_name_list[i]) if input_file_list[i].startswith("http://") or \ input_file_list[i].startswith("https://"): urllib.urlretrieve(input_file_list[i], dst_input_file) else: sh.cp("-f", input_file_list[i], dst_input_file)
def gen_input(model_output_dir, input_nodes, input_shapes, input_files=None, input_ranges=None, input_data_types=None, input_data_map=None, input_file_name="model_input"): for input_name in input_nodes: formatted_name = common.formatted_file_name(input_file_name, input_name) if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_file_list = [] if isinstance(input_files, list): input_file_list.extend(input_files) else: input_file_list.append(input_files) if input_data_map: for i in range(len(input_nodes)): dst_input_file = model_output_dir + '/' + \ common.formatted_file_name(input_file_name, input_nodes[i]) input_name = input_nodes[i] common.mace_check( input_name in input_data_map, common.ModuleName.RUN, "The preprocessor API in PrecisionValidator" " script should return all inputs of model") if input_data_types[i] == 'float32': input_data = np.array(input_data_map[input_name], dtype=np.float32) elif input_data_types[i] == 'int32': input_data = np.array(input_data_map[input_name], dtype=np.int32) else: common.mace_check( False, common.ModuleName.RUN, 'Do not support input data type %s' % input_data_types[i]) common.mace_check( list(map(int, common.split_shape(input_shapes[i]))) == list( input_data.shape), common.ModuleName.RUN, "The shape return from preprocessor API of" " PrecisionValidator script is not same with" " model deployment file. %s vs %s" % (str(input_shapes[i]), str(input_data.shape))) input_data.tofile(dst_input_file) elif len(input_file_list) != 0: input_name_list = [] if isinstance(input_nodes, list): input_name_list.extend(input_nodes) else: input_name_list.append(input_nodes) common.mace_check( len(input_file_list) == len(input_name_list), common.ModuleName.RUN, 'If input_files set, the input files should ' 'match the input names.') for i in range(len(input_file_list)): if input_file_list[i] is not None: dst_input_file = model_output_dir + '/' + \ common.formatted_file_name(input_file_name, input_name_list[i]) if input_file_list[i].startswith("http://") or \ input_file_list[i].startswith("https://"): six.moves.urllib.request.urlretrieve( input_file_list[i], dst_input_file) else: sh.cp("-f", input_file_list[i], dst_input_file) else: # generate random input files input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) input_ranges_str = ":".join(input_ranges) input_data_types_str = ",".join(input_data_types) generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, input_shapes_str, input_ranges_str, input_data_types_str)