def test_ocr_for_single_line(img_fp, expected): ocr = CNOCR root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) img_fp = os.path.join(root_dir, 'examples', img_fp) pred = ocr.ocr_for_single_line(img_fp) print('\n') print_preds([pred]) assert cal_score([pred], expected) >= 0.8 img = read_img(img_fp) pred = ocr.ocr_for_single_line(img) print_preds([pred]) assert cal_score([pred], expected) >= 0.8 img = read_img(img_fp, gray=False) pred = ocr.ocr_for_single_line(img) print_preds([pred]) assert cal_score([pred], expected) >= 0.8 img = np.array(Image.fromarray(img).convert('L')) assert len(img.shape) == 2 pred = ocr.ocr_for_single_line(img) print_preds([pred]) assert cal_score([pred], expected) >= 0.8 img = np.expand_dims(img, axis=2) assert len(img.shape) == 3 and img.shape[2] == 1 pred = ocr.ocr_for_single_line(img) print_preds([pred]) assert cal_score([pred], expected) >= 0.8
def test_ocr(img_fp, expected): ocr = CNOCR root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) img_fp = os.path.join(root_dir, 'examples', img_fp) pred = ocr.ocr(img_fp) print('\n') print_preds(pred) assert cal_score(pred, expected) >= 0.8 img = read_img(img_fp) pred = ocr.ocr(img) print_preds(pred) assert cal_score(pred, expected) >= 0.8 img = read_img(img_fp, gray=False) pred = ocr.ocr(img) print_preds(pred) assert cal_score(pred, expected) >= 0.8
def test_ocr_for_single_lines(img_fp, expected): ocr = CNOCR root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) img_fp = os.path.join(root_dir, 'examples', img_fp) img = read_img(img_fp) if img.mean() < 145: # 把黑底白字的图片对调为白底黑字 img = 255 - img line_imgs = line_split(np.squeeze(img, -1), blank=True) line_img_list = [line_img for line_img, _ in line_imgs] pred = ocr.ocr_for_single_lines(line_img_list) print('\n') print_preds(pred) assert cal_score(pred, expected) >= 0.8 line_img_list = [np.array(line_img) for line_img in line_img_list] pred = ocr.ocr_for_single_lines(line_img_list) print_preds(pred) assert cal_score(pred, expected) >= 0.8
def _prepare_img( self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]) -> np.ndarray: """ Args: img_fp (Union[str, Path, torch.Tensor, np.ndarray]): image array with type torch.Tensor or np.ndarray, with shape [height, width] or [height, width, channel]. channel should be 1 (gray image) or 3 (color image). Returns: np.ndarray: with shape (height, width, 1), dtype uint8, scale [0, 255] """ img = img_fp if isinstance(img_fp, (str, Path)): if not os.path.isfile(img_fp): raise FileNotFoundError(img_fp) img = read_img(img_fp) if isinstance(img, torch.Tensor): img = img.numpy() if len(img.shape) == 2: img = np.expand_dims(img, -1) elif len(img.shape) == 3: if img.shape[2] == 3: # color to gray img = np.expand_dims( np.array(Image.fromarray(img).convert('L')), -1) elif img.shape[2] != 1: raise ValueError( 'only images with shape [height, width, 1] (gray images), ' 'or [height, width, 3] (RGB-formated color images) are supported' ) if img.dtype != np.dtype('uint8'): img = img.astype('uint8') return img
def evaluate( model_name, pretrained_model_fp, context, eval_index_fp, img_folder, batch_size, output_dir, verbose, ): ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context) fn_labels_list = read_input_file(eval_index_fp) miss_cnt, redundant_cnt = Counter(), Counter() total_time_cost = 0.0 bad_cnt = 0 badcases = [] start_idx = 0 while start_idx < len(fn_labels_list): logger.info('start_idx: %d', start_idx) batch = fn_labels_list[start_idx:start_idx + batch_size] img_fps = [os.path.join(img_folder, fn) for fn, _ in batch] reals = [labels for _, labels in batch] imgs = [read_img(img) for img in img_fps] start_time = time.time() outs = ocr.ocr_for_single_lines(imgs, batch_size=1) total_time_cost += time.time() - start_time preds = [out[0] for out in outs] for bad_info in compare_preds_to_reals(preds, reals, img_fps): if verbose: logger.info('\t'.join(bad_info)) distance = Levenshtein.distance(bad_info[1], bad_info[2]) bad_info.insert(0, distance) badcases.append(bad_info) miss_cnt.update(list(bad_info[-2])) redundant_cnt.update(list(bad_info[-1])) bad_cnt += 1 start_idx += batch_size badcases.sort(key=itemgetter(0), reverse=True) output_dir = Path(output_dir) if not output_dir.exists(): os.makedirs(output_dir) with open(output_dir / 'badcases.txt', 'w') as f: f.write('\t'.join([ 'distance', 'image_fp', 'real_words', 'pred_words', 'miss_words', 'redundant_words', ]) + '\n') for bad_info in badcases: f.write('\t'.join(map(str, bad_info)) + '\n') with open(output_dir / 'miss_words_stat.txt', 'w') as f: for word, num in miss_cnt.most_common(): f.write('\t'.join([word, str(num)]) + '\n') with open(output_dir / 'redundant_words_stat.txt', 'w') as f: for word, num in redundant_cnt.most_common(): f.write('\t'.join([word, str(num)]) + '\n') logger.info( "number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f" % ( len(fn_labels_list), bad_cnt, 1.0 - bad_cnt / len(fn_labels_list), total_time_cost / len(fn_labels_list), ))
def test_read_img(): img_fp = EXAMPLE_DIR / '00010991.jpg' img = read_img(img_fp) print(img.shape, img)