def evaluate(): parser = argparse.ArgumentParser() parser.add_argument("--model-name", help="model name", type=str, default='densenet-lite-lstm') parser.add_argument("--model-epoch", type=int, default=None, help="model epoch") parser.add_argument( "--gpu", help="Number of GPUs for training [Default 0, means using cpu]" "目前限制gpu <= 1,因为 gpu > 1时预测结果有问题,与 gpu = 1时不同,暂未发现原因。", type=int, default=0, ) parser.add_argument( "-i", "--input-fp", default='test.txt', help="the file path with image names and labels", ) parser.add_argument("--image-prefix-dir", default='.', help="图片所在文件夹,相对于索引文件中记录的图片位置") parser.add_argument("--batch-size", type=int, default=128, help="batch size") parser.add_argument( "-v", "--verbose", action='store_true', help="whether to print details to screen", ) parser.add_argument( "-o", "--output-dir", default=False, help="the output directory which records the analysis results", ) args = parser.parse_args() assert args.gpu <= 1 context = gen_context(args.gpu) ocr = CnOcr(model_name=args.model_name, model_epoch=args.model_epoch, context=context) alphabet = ocr._alphabet fn_labels_list = read_input_file(args.input_fp) miss_cnt, redundant_cnt = Counter(), Counter() model_time_cost = 0.0 start_idx = 0 bad_cnt = 0 badcases = [] while start_idx < len(fn_labels_list): logger.info('start_idx: %d', start_idx) batch = fn_labels_list[start_idx:start_idx + args.batch_size] batch_img_fns = [] batch_labels = [] batch_imgs = [] for fn, labels in batch: batch_labels.append(labels) img_fp = os.path.join(args.image_prefix_dir, fn) batch_img_fns.append(img_fp) img = mx.image.imread(img_fp, 1).asnumpy() batch_imgs.append(img) start_time = time.time() batch_preds = ocr.ocr_for_single_lines(batch_imgs) model_time_cost += time.time() - start_time for bad_info in compare_preds_to_reals(batch_preds, batch_labels, batch_img_fns, alphabet): if args.verbose: logger.info('\t'.join(bad_info)) distance = Levenshtein.distance(bad_info[1], bad_info[2]) bad_info.insert(0, distance) badcases.append(bad_info) miss_cnt.update(list(bad_info[-2])) redundant_cnt.update(list(bad_info[-1])) bad_cnt += 1 start_idx += args.batch_size badcases.sort(key=itemgetter(0), reverse=True) output_dir = Path(args.output_dir) if not output_dir.exists(): os.makedirs(output_dir) with open(output_dir / 'badcases.txt', 'w') as f: f.write('\t'.join([ 'distance', 'image_fp', 'real_words', 'pred_words', 'miss_words', 'redundant_words', ]) + '\n') for bad_info in badcases: f.write('\t'.join(map(str, bad_info)) + '\n') with open(output_dir / 'miss_words_stat.txt', 'w') as f: for word, num in miss_cnt.most_common(): f.write('\t'.join([word, str(num)]) + '\n') with open(output_dir / 'redundant_words_stat.txt', 'w') as f: for word, num in redundant_cnt.most_common(): f.write('\t'.join([word, str(num)]) + '\n') logger.info( "number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f" % ( len(fn_labels_list), bad_cnt, bad_cnt / len(fn_labels_list), model_time_cost / len(fn_labels_list), ))
def evaluate( model_name, pretrained_model_fp, context, eval_index_fp, img_folder, batch_size, output_dir, verbose, ): ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context) fn_labels_list = read_input_file(eval_index_fp) miss_cnt, redundant_cnt = Counter(), Counter() total_time_cost = 0.0 bad_cnt = 0 badcases = [] start_idx = 0 while start_idx < len(fn_labels_list): logger.info('start_idx: %d', start_idx) batch = fn_labels_list[start_idx:start_idx + batch_size] img_fps = [os.path.join(img_folder, fn) for fn, _ in batch] reals = [labels for _, labels in batch] imgs = [read_img(img) for img in img_fps] start_time = time.time() outs = ocr.ocr_for_single_lines(imgs, batch_size=1) total_time_cost += time.time() - start_time preds = [out[0] for out in outs] for bad_info in compare_preds_to_reals(preds, reals, img_fps): if verbose: logger.info('\t'.join(bad_info)) distance = Levenshtein.distance(bad_info[1], bad_info[2]) bad_info.insert(0, distance) badcases.append(bad_info) miss_cnt.update(list(bad_info[-2])) redundant_cnt.update(list(bad_info[-1])) bad_cnt += 1 start_idx += batch_size badcases.sort(key=itemgetter(0), reverse=True) output_dir = Path(output_dir) if not output_dir.exists(): os.makedirs(output_dir) with open(output_dir / 'badcases.txt', 'w') as f: f.write('\t'.join([ 'distance', 'image_fp', 'real_words', 'pred_words', 'miss_words', 'redundant_words', ]) + '\n') for bad_info in badcases: f.write('\t'.join(map(str, bad_info)) + '\n') with open(output_dir / 'miss_words_stat.txt', 'w') as f: for word, num in miss_cnt.most_common(): f.write('\t'.join([word, str(num)]) + '\n') with open(output_dir / 'redundant_words_stat.txt', 'w') as f: for word, num in redundant_cnt.most_common(): f.write('\t'.join([word, str(num)]) + '\n') logger.info( "number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f" % ( len(fn_labels_list), bad_cnt, 1.0 - bad_cnt / len(fn_labels_list), total_time_cost / len(fn_labels_list), ))
def main(): ''' DCMMC: A very critical flaw: results predicted on multi-gpu are misplaced! ''' parser = argparse.ArgumentParser() parser.add_argument( "--model_name", help="model name", type=str, default='conv-lite-fc' ) parser.add_argument("--model_epoch", type=int, default=None, help="model epoch") parser.add_argument( '--dataset', type=str, help='location of hdf5 dataset' ) parser.add_argument('--train_ratio', type=float, default=0.8, help='train ratio of the dataset') parser.add_argument('--batch_size', type=int, default=64, help='batch size') parser.add_argument('--gpus', type=int, default=0, help='number of gpus, 0 indicates cpu only') args = parser.parse_args() ocr = CnOcr(model_name=args.model_name, model_epoch=args.model_epoch, gpus=args.gpus) log_cp = 50 dataset = h5py.File(args.dataset, 'r') batch_size = args.batch_size gold_train, gold_val = [], [] res_train, res_val = [], [] num_train = int(len(dataset) * args.train_ratio) num_train_batch = int(np.ceil(num_train / batch_size)) num_val_batch = int(np.ceil((len(dataset) - num_train) / batch_size)) logger.info(f'total num samples={len(dataset)}') logger.info(f'num_train_batch={num_train_batch}, num_val_batch={num_val_batch}') logger.info('start train dataset') s_t = time() for idx_batch in range(num_train_batch): data = [] for idx in range(idx_batch * args.batch_size, min((idx_batch+1) * args.batch_size, num_train)): data.append(dataset[str(idx)]['img'][...]) gold_train.append(str(dataset[str(idx)]['y'][...])) res = ocr.ocr_for_single_lines(data) res = [''.join(r) for r in res] res_train += res if idx_batch % log_cp == 0: log_str = f'batch [{idx_batch + 1}/{num_train_batch}]: ' log_str += '{:.2f}s/batch'.format((time() - s_t)/(idx_batch + 1)) logger.info(log_str) # if idx_batch >= 4: # break logger.info('start val dataset') s_t = time() for idx_batch in range(num_val_batch): data = [] for idx in range(idx_batch * args.batch_size + num_train, min((idx_batch+1) * args.batch_size + num_train, len(dataset))): data.append(dataset[str(idx)]['img'][...]) gold_val.append(str(dataset[str(idx)]['y'][...])) res = ocr.ocr_for_single_lines(data) res = [''.join(r) for r in res] res_val += res if idx_batch % log_cp == 0: log_str = f'batch [{idx_batch + 1}/{num_val_batch}]: ' log_str += '{:.2f}s/batch'.format((time() - s_t)/(idx_batch + 1)) logger.info(log_str) # if idx_batch >= 4: # break assert len(res_val) == len(gold_val) assert len(res_train) == len(gold_train) acc_train = sum([r == p for r, p in zip(res_train, gold_train)]) / len(res_train) acc_val = sum([r == p for r, p in zip(res_val, gold_val)]) / len(res_val) logger.info(f'acc_train={acc_train}, acc_val={acc_val}') dist_fn = lambda r, p: distance(r, p) / max(len(r), len(p)) dist_train = sum([dist_fn(r, p) for r, p in zip(res_train, gold_train)]) / len(res_train) dist_val = sum([dist_fn(r, p) for r, p in zip(res_val, gold_val)]) / len(res_val) logger.info(f'dist_train={dist_train}, dist_val={dist_val}') logger.info(f'write to file. #train={len(res_train)}, #val={len(res_val)}') with open('train_no_bpe.tok.src', 'w') as f: f.writelines([r + '\n' for r in res_train]) with open('train_no_bpe.tok.trg', 'w') as f: f.writelines([r + '\n' for r in gold_train]) with open('dev_no_bpe.tok.src', 'w') as f: f.writelines([r + '\n' for r in res_val]) with open('dev_no_bpe.tok.trg', 'w') as f: f.writelines([r + '\n' for r in gold_val]) logger.info('write done.')
class PickStuNumber: def __init__(self, path: str, show_img: bool = False): self.__ext = {'jpg', 'jpeg'} self.__ocr = CnOcr(model_name='densenet-lite-gru', cand_alphabet=string.digits, name=path) self.__std = CnStd(name=path) self.__info_dict = {} self.__dup_name_dict = {} # 先对路径进行替换 path = self.__format_path(path) # 根据传入的路径判断操作 if os.path.isdir(path) or os.path.isfile(path): files = [self.__format_path(os.path.join(path, f)) for f in os.listdir(path) if (os.path.isfile(os.path.join(path, f)) and self.__is_image(f))] \ if os.path.isdir(path) \ else [path] for file in tqdm(files): self.__handle_info( file, self.__ocr_number( self.__std_number(self.__cutter(file, show_img)))) else: print(f'获取数据错误,“{path}”既不是文件也不是文件夹') @staticmethod def __format_path(path: str): return os.path.abspath(path).replace('\\', '/') @staticmethod def __get_suffix(path: str) -> str: """ 获取后缀 :param path: 图片路径 :return: 是否为图片 """ return path.split('.')[-1] def __is_image(self, path: str) -> bool: return self.__get_suffix(path) in self.__ext @staticmethod def __cutter(path: str, show_img: bool = False) -> numpy.ndarray: """ 切割图片 :param path: 图片路径 :param show_img: 是否需要展示图片 :return: 图片对应的 ndarray """ print(path) # 以灰度模式读取图片 origin_img = cv2.imread(path, 0) if show_img: # 自由拉伸窗口 # cv2.namedWindow('bin img', 0) cv2.imshow('origin img', origin_img) # 切出一部分,取值是经验值 origin_img = origin_img[:origin_img.shape[0] // 2] # 二值化 _, origin_img = cv2.threshold(origin_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) if show_img: # 自由拉伸窗口 # cv2.namedWindow('bin img', 0) cv2.imshow('bin img', origin_img) # 形态学转换,主要为了检测出那个红色的 banner kernel = numpy.ones((15, 15), dtype=numpy.uint8) # img = cv2.erode(img, kernel=kernel, iterations=1) img = cv2.dilate(origin_img, kernel=kernel, iterations=2) # 边缘检测 contours, _ = cv2.findContours(img, 1, 2) # 找出第二大的,即红色的 banner contours = sorted(contours, key=cv2.contourArea, reverse=True) if len(contours) > 1: # 获取包围 banner 的矩形数据 x, y, w, h = cv2.boundingRect(contours[1]) # 目前所有的数值设定使用的是经验值 if w * h > 250000: # 需要识别的学号部分 # 左上角坐标 left_top_x = x left_top_y = y + h + 20 # 右下角坐标 right_down_x = x + w right_down_y = y + h + 190 img = origin_img[left_top_y:right_down_y, left_top_x:right_down_x] else: img = origin_img[120:] else: img = origin_img[120:] # 对切出的图片进行再次处理,以便图像识别 kernel = numpy.ones((2, 2), dtype=numpy.uint8) # 腐蚀以加粗 img = cv2.erode(img, kernel=kernel, iterations=1) # 重新映射回 rgb img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if show_img: # 自由拉伸窗口 # cv2.namedWindow('final img', 0) cv2.imshow('final img', img) cv2.waitKey(0) cv2.destroyAllWindows() return img def __ocr_number(self, img_list: List[numpy.ndarray]): """ 识别数字 :param img_list: :return: """ return self.__ocr.ocr_for_single_lines(img_list) def __std_number(self, img: numpy.ndarray): """ 定位数字 :param img: :return: """ return [i['cropped_img'] for i in self.__std.detect(img)] @staticmethod def __handle_result_list(result_list: List[List[str]]) -> [str, bool]: """ 处理结果列表 :param result_list: 结果列表 :return: 结果,是否有效 """ result = result_list[0] if len(result) < 12 and len(result_list) > 1: for i in result_list: if len(i) >= 12: result = i result = ''.join(result[:12] if len(result) >= 12 else result) print(result, re.match(r'\d{12}', result) is not None) return result, re.match(r'\d{12}', result) is not None def __handle_dup_name(self, name, path): dup_keys = self.__dup_name_dict.get(name) # 如设置过,即表明有重复的 if dup_keys: # 设置重复的为 True,只要第一次重复时设置即可 if 1 == len(dup_keys): self.__info_dict[dup_keys[0]]['dup'] = True # 将本次的 path 也添加进去 self.__dup_name_dict[name].append(path) return True else: self.__dup_name_dict[name] = [path] return False def __handle_info(self, key, value): """ 处理每条信息 :param key: :param value: """ name, is_legal = self.__handle_result_list(value) self.__info_dict[key] = { 'name': name, 'suffix': self.__get_suffix(key), 'legal': is_legal, 'dup': self.__handle_dup_name(name, key) } def print_info(self): """ 打印图片信息 :return: """ beeprint.pp(self.__info_dict) return self def print_dup(self): """ 打印重复图片信息 :return: """ beeprint.pp(self.__dup_name_dict) return self def write_out(self, path: str = '.', out_path_suc: str = 'output_suc', out_path_dup: str = 'output_dup', out_path_fail: str = 'output_fail'): """ 输出重命名后的图片到文件夹 :param path: 文件夹路径 :param out_path_suc: 合规且不重复图片所在的文件夹 :param out_path_dup: 合规但是重复图片所在的文件夹 :param out_path_fail: 其它图片所在文件夹 :return: self """ # 处理路径 path = self.__format_path(path) if os.path.isdir(path): # 拼接文件路径 suc = os.path.join(path, out_path_suc) fail = os.path.join(path, out_path_fail) dup = os.path.join(path, out_path_dup) # 创建结果文件夹 not os.path.exists(suc) and os.makedirs(suc) not os.path.exists(fail) and os.makedirs(fail) not os.path.exists(dup) and os.makedirs(dup) # 将图片输出到相应的文件夹 for key, value in self.__info_dict.items(): # 合规且不重复 if value.get('legal') is True and value.get('dup') is False: copyfile( key, os.path.join( suc, f'{value.get("name")}.{value.get("suffix")}')) # 合规但是重复 elif value.get('legal') is True and value.get('dup') is True: index = self.__dup_name_dict[value.get("name")].index(key) copyfile( key, os.path.join( dup, f'{value.get("name")}.{index}.{value.get("suffix")}' )) else: copyfile( key, os.path.join( fail, f'{value.get("name")}.{value.get("suffix")}' or os.path.split(key)[1])) else: print(f'“{path}” 并非一个合法的路径!') return self