def test_resnet50_v1b(): model_name = 'resnet50_v1b' img_fp = os.path.join(example_dir, 'beauty2.jpg') std = CnStd(model_name) box_info_list = std.detect(img_fp) img = mx.image.imread(img_fp, 1) box_info_list2 = std.detect(img) assert len(box_info_list) == len(box_info_list2)
def test_multiple_instances(model_name): global INSTANCE_ID print('test multiple instances for model_name: %s' % model_name) img_fp = os.path.join(example_dir, 'beauty2.jpg') INSTANCE_ID += 1 print('instance id: %d' % INSTANCE_ID) std1 = CnStd(model_name, name='instance-%d' % INSTANCE_ID) box_info_list = std1.detect(img_fp) INSTANCE_ID += 1 std2 = CnStd(model_name, name='instance-%d' % INSTANCE_ID) box_info_list2 = std2.detect(img_fp) assert len(box_info_list) == len(box_info_list2)
def get_list(self): std=CnStd() ocr=CnOcr() ocr._model_name='conv-lite-fc' print(ocr._model_name) ocr_res2=ocr.ocr(img_fp=self.filepath) box_info_list=std.detect(self.filepath,pse_threshold=0.7,pse_min_area=150,context='gpu',height_border=0.10) image=Image.open(self.filepath) fontstyle=ImageFont.truetype('./simhei.ttf',13,encoding='utf-8') draw=ImageDraw.Draw(image) for box_info in box_info_list: print('a') print('a') print(box_info) info_box=box_info['box'] crp_img=box_info['cropped_img'] ocr_res1=ocr.ocr_for_single_line(crp_img) print('result: %s' % ''.join(str(ocr_res1))) x1,y1=info_box[0,0],info_box[0,1] x2,y2=info_box[1,0],info_box[1,1] x3,y3=info_box[2,0],info_box[2,1] x4,y4=info_box[3,0],info_box[3,1] draw.polygon([(x1,y1),(x4,y4),(x3,y3),(x2,y2)],outline=(255,0,0)) draw.text((x4,y4),str(ocr_res1),(200,0,0),font=fontstyle) image.show() print(ocr_res2) return box_info_list
def identify(image): std = CnStd(model_name='resnet50_v1b', root='C:/Users/HP/AppData/Roaming/Python/Python36/site-packages/cnstd') global cn_ocr box_info_list = std.detect(image) ocr_res=[] for box_info in box_info_list: cropped_img = box_info['cropped_img'] ocr_res.extend(cn_ocr.ocr_for_single_line(cropped_img)) print('ocr result1: %s' % ''.join(ocr_res)) return ocr_res
def identify(image): std = CnStd(model_name='resnet50_v1b', root='C:/Users/dell/AppData/Roaming/cnstd/0.1.0/mobilenetv3') global cn_ocr ocr_res = [] box_info_list = std.detect(image) ocr_res = [] for box_info in box_info_list: cropped_img = box_info['cropped_img'] ocr_res.extend(cn_ocr.ocr_for_single_line(cropped_img)) print('ocr result1: %s' % ''.join(ocr_res)) return ocr_res
async def Img2Text(img: Image, ocr_model: CnOcr, std: CnStd) -> str: url = img.url async with aiohttp.ClientSession() as session: async with session.get(url=url) as resp: img_content = await resp.read() img = IMG.open(BytesIO(img_content)).convert("RGB") img = numpy.array(img) box_info_list = std.detect(img) res = [] for box_info in box_info_list: cropped_img = box_info['cropped_img'] # 检测出的文本框 ocr_res = ocr_model.ocr_for_single_line(cropped_img) res.append([ocr_res]) print(res) return "".join(await flat(res))
def extract_content(save_dir_path): std = CnStd() cn_ocr = CnOcr() base_path = os.path.abspath(os.path.dirname(__file__)) pic_base_dir = os.path.join(base_path, "received") pic_path_list = (glob.glob(os.path.join(pic_base_dir, "*.png")) + glob.glob(os.path.join(pic_base_dir, "*.jpg")) + glob.glob(os.path.join(pic_base_dir, "*.jpeg"))) workbook = xlwt.Workbook() for index, pic_path in enumerate(pic_path_list): sheet = workbook.add_sheet('sheet{}'.format(index), cell_overwrite_ok=True) box_info_list = std.detect(pic_path) for box_info in box_info_list: x_list, y_list = [], [] for (x, y) in box_info['box']: x_list.append(x) y_list.append(y) top, bottom, left, right = min(y_list), max(y_list), min( x_list), max(x_list) top_row, bottom_row, left_column, right_column = int( top // 80), int(bottom // 80), int(left // 60), int(right // 60) cropped_img = box_info['cropped_img'] # 检测出的文本框 ocr_res = ''.join(cn_ocr.ocr_for_single_line(cropped_img)) try: logger.info( "top_row:{}, bottom_row:{}, left_column:{}, right_column:{}, ocr_res:{}", top_row, bottom_row, left_column, right_column, ocr_res, feature="f-strings") sheet.write_merge(top_row, bottom_row, left_column, right_column, ocr_res) except Exception as e: print(e) xls_base_dir = os.path.join(base_path, save_dir_path) xls_path = os.path.join(xls_base_dir, "res.xls") workbook.save(xls_path)
def ocr(): if 'file' not in request.files: return jsonify(code=-1, message='no file error'), 400 file = request.files['file'] _uuid = str(uuid.uuid1()) file_name = '/tmp/ocr/' + _uuid file.save(file_name) ocr = CnOcr(name=_uuid) std = CnStd(name=_uuid) box_info_list = std.detect(file_name) lines = [] for box_info in box_info_list: cropped_img = box_info['cropped_img'] # 检测出的文本框 ocr_res = ocr.ocr_for_single_line(cropped_img) lines.append(''.join(ocr_res)) return jsonify(code=0, message='ok', data=lines)
def target_function(idx, ls, prefix): # 使得每一张卡都能被充分用到 # 提交的时候由于只有一张卡,需要全部都改成0 os.environ['CUDA_VISIBLE_DEVICES'] = '0' std = CnStd(context='gpu') cn_ocr = CnOcr(context='gpu') result = dict() for file_name in tqdm(ls): file_path = os.path.join(prefix, file_name) box_info_list = std.detect(file_path) output = '' for box_info in box_info_list: cropped_img = box_info['cropped_img'] # 检测出的文本框 if type(cropped_img) != type(None): ocr_res = cn_ocr.ocr_for_single_line(cropped_img) output += ''.join(ocr_res) # print('ocr result: %s' % ''.join(ocr_res)) output = output.replace(' ', '') output = re.sub("[^\u4e00-\u9fa5]", "", output) result[file_name] = output with open('./output_%d.json' % (idx), 'w', encoding='utf-8') as w: w.write(json.dumps(result, ensure_ascii=False, indent=2))
class PickStuNumber: def __init__(self, path: str, show_img: bool = False): self.__ext = {'jpg', 'jpeg'} self.__ocr = CnOcr(model_name='densenet-lite-gru', cand_alphabet=string.digits, name=path) self.__std = CnStd(name=path) self.__info_dict = {} self.__dup_name_dict = {} # 先对路径进行替换 path = self.__format_path(path) # 根据传入的路径判断操作 if os.path.isdir(path) or os.path.isfile(path): files = [self.__format_path(os.path.join(path, f)) for f in os.listdir(path) if (os.path.isfile(os.path.join(path, f)) and self.__is_image(f))] \ if os.path.isdir(path) \ else [path] for file in tqdm(files): self.__handle_info( file, self.__ocr_number( self.__std_number(self.__cutter(file, show_img)))) else: print(f'获取数据错误,“{path}”既不是文件也不是文件夹') @staticmethod def __format_path(path: str): return os.path.abspath(path).replace('\\', '/') @staticmethod def __get_suffix(path: str) -> str: """ 获取后缀 :param path: 图片路径 :return: 是否为图片 """ return path.split('.')[-1] def __is_image(self, path: str) -> bool: return self.__get_suffix(path) in self.__ext @staticmethod def __cutter(path: str, show_img: bool = False) -> numpy.ndarray: """ 切割图片 :param path: 图片路径 :param show_img: 是否需要展示图片 :return: 图片对应的 ndarray """ print(path) # 以灰度模式读取图片 origin_img = cv2.imread(path, 0) if show_img: # 自由拉伸窗口 # cv2.namedWindow('bin img', 0) cv2.imshow('origin img', origin_img) # 切出一部分,取值是经验值 origin_img = origin_img[:origin_img.shape[0] // 2] # 二值化 _, origin_img = cv2.threshold(origin_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) if show_img: # 自由拉伸窗口 # cv2.namedWindow('bin img', 0) cv2.imshow('bin img', origin_img) # 形态学转换,主要为了检测出那个红色的 banner kernel = numpy.ones((15, 15), dtype=numpy.uint8) # img = cv2.erode(img, kernel=kernel, iterations=1) img = cv2.dilate(origin_img, kernel=kernel, iterations=2) # 边缘检测 contours, _ = cv2.findContours(img, 1, 2) # 找出第二大的,即红色的 banner contours = sorted(contours, key=cv2.contourArea, reverse=True) if len(contours) > 1: # 获取包围 banner 的矩形数据 x, y, w, h = cv2.boundingRect(contours[1]) # 目前所有的数值设定使用的是经验值 if w * h > 250000: # 需要识别的学号部分 # 左上角坐标 left_top_x = x left_top_y = y + h + 20 # 右下角坐标 right_down_x = x + w right_down_y = y + h + 190 img = origin_img[left_top_y:right_down_y, left_top_x:right_down_x] else: img = origin_img[120:] else: img = origin_img[120:] # 对切出的图片进行再次处理,以便图像识别 kernel = numpy.ones((2, 2), dtype=numpy.uint8) # 腐蚀以加粗 img = cv2.erode(img, kernel=kernel, iterations=1) # 重新映射回 rgb img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if show_img: # 自由拉伸窗口 # cv2.namedWindow('final img', 0) cv2.imshow('final img', img) cv2.waitKey(0) cv2.destroyAllWindows() return img def __ocr_number(self, img_list: List[numpy.ndarray]): """ 识别数字 :param img_list: :return: """ return self.__ocr.ocr_for_single_lines(img_list) def __std_number(self, img: numpy.ndarray): """ 定位数字 :param img: :return: """ return [i['cropped_img'] for i in self.__std.detect(img)] @staticmethod def __handle_result_list(result_list: List[List[str]]) -> [str, bool]: """ 处理结果列表 :param result_list: 结果列表 :return: 结果,是否有效 """ result = result_list[0] if len(result) < 12 and len(result_list) > 1: for i in result_list: if len(i) >= 12: result = i result = ''.join(result[:12] if len(result) >= 12 else result) print(result, re.match(r'\d{12}', result) is not None) return result, re.match(r'\d{12}', result) is not None def __handle_dup_name(self, name, path): dup_keys = self.__dup_name_dict.get(name) # 如设置过,即表明有重复的 if dup_keys: # 设置重复的为 True,只要第一次重复时设置即可 if 1 == len(dup_keys): self.__info_dict[dup_keys[0]]['dup'] = True # 将本次的 path 也添加进去 self.__dup_name_dict[name].append(path) return True else: self.__dup_name_dict[name] = [path] return False def __handle_info(self, key, value): """ 处理每条信息 :param key: :param value: """ name, is_legal = self.__handle_result_list(value) self.__info_dict[key] = { 'name': name, 'suffix': self.__get_suffix(key), 'legal': is_legal, 'dup': self.__handle_dup_name(name, key) } def print_info(self): """ 打印图片信息 :return: """ beeprint.pp(self.__info_dict) return self def print_dup(self): """ 打印重复图片信息 :return: """ beeprint.pp(self.__dup_name_dict) return self def write_out(self, path: str = '.', out_path_suc: str = 'output_suc', out_path_dup: str = 'output_dup', out_path_fail: str = 'output_fail'): """ 输出重命名后的图片到文件夹 :param path: 文件夹路径 :param out_path_suc: 合规且不重复图片所在的文件夹 :param out_path_dup: 合规但是重复图片所在的文件夹 :param out_path_fail: 其它图片所在文件夹 :return: self """ # 处理路径 path = self.__format_path(path) if os.path.isdir(path): # 拼接文件路径 suc = os.path.join(path, out_path_suc) fail = os.path.join(path, out_path_fail) dup = os.path.join(path, out_path_dup) # 创建结果文件夹 not os.path.exists(suc) and os.makedirs(suc) not os.path.exists(fail) and os.makedirs(fail) not os.path.exists(dup) and os.makedirs(dup) # 将图片输出到相应的文件夹 for key, value in self.__info_dict.items(): # 合规且不重复 if value.get('legal') is True and value.get('dup') is False: copyfile( key, os.path.join( suc, f'{value.get("name")}.{value.get("suffix")}')) # 合规但是重复 elif value.get('legal') is True and value.get('dup') is True: index = self.__dup_name_dict[value.get("name")].index(key) copyfile( key, os.path.join( dup, f'{value.get("name")}.{index}.{value.get("suffix")}' )) else: copyfile( key, os.path.join( fail, f'{value.get("name")}.{value.get("suffix")}' or os.path.split(key)[1])) else: print(f'“{path}” 并非一个合法的路径!') return self
class picture(QWidget): def __init__(self): super(picture, self).__init__() self.resize(350, 350) self.setWindowTitle("图片转文字") self.label = QLabel(self) # self.label.setText("显示图片") self.label.setScaledContents(True) self.label.setFixedSize(300, 200) self.label.move(25, 60) self.label.setStyleSheet( "QLabel{background:white;}" "QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}" ) btn = QPushButton(self) btn.setText("打开图片") btn.move(135, 20) btn.clicked.connect(self.openimage) self.label_text = QLabel(self) self.label_text.setFixedSize(300, 30) self.label_text.move(25, 270) self.label_text.setTextInteractionFlags( Qt.TextSelectableByMouse) ###可复制 self.label_wait = QLabel(self) self.label_wait.setFixedSize(300, 30) self.label_wait.move(25, 300) # 标签1的背景填充更改为True,否则无法显示背景 self.label_wait.setAutoFillBackground(True) # # 实例化背景对象,进行相关背景颜色属性设置 # palette = QPalette() # palette.setColor(QPalette.Window, Qt.green) # # 标签1加载背景 # self.label_wait.setPalette(palette) # 设置文本居中显示 self.label_wait.setAlignment(Qt.AlignCenter) self.label_wait.setText('tips:识别过程可能会卡住,需几秒到几十秒不等') self.std = CnStd() self.cn_ocr = CnOcr() def openimage(self): imgName, imgType = QFileDialog.getOpenFileName( self, "打开图片", "", "*.jpg;;*.png;;All Files(*)") if imgName and imgType: # 实例化背景对象,进行相关背景颜色属性设置 palette = QPalette() palette.setColor(QPalette.Window, Qt.green) # 标签1加载背景 self.label_wait.setPalette(palette) box_info_list = self.std.detect(imgName) result = '' for box_info in box_info_list: cropped_img = box_info['cropped_img'] # 检测出的文本框 # cv2.imshow('1', cropped_img) # cv2.waitKey(0) ocr_res = self.cn_ocr.ocr_for_single_line(cropped_img) result += ''.join(ocr_res) # print('ocr result: %s' % ''.join(ocr_res)) # print(result) self.label_text.setText(result) self.label_wait.setText('↑点击文字,ctrl+a全选、ctrl+c复制、ctrl+v粘贴') jpg = QtGui.QPixmap(imgName).scaled(self.label.width(), self.label.height()) self.label.setPixmap(jpg) with open('history.txt', 'a', encoding='utf8') as f: f.write(result + '\n')
from cnocr import CnOcr import cnocr from cnstd import CnStd std=CnStd() ocr=CnOcr() box_info_list=std.detect('E:\\Work Place\\pocr\\pic\\2.png') res=ocr.ocr('E:\\Work Place\\pocr\\pic\\1.png') for box_info in box_info_list: crp_img=box_info['cropped_img'] ocr_res=ocr.ocr_for_single_line(crp_img) print('result: %s' % ''.join(ocr_res))
print('初始化失败,未找到预加载模型') return None def has_chinese(word): for ch in word: if '\u4e00' <= ch <= '\u9fff': return True return False if __name__ == '__main__': ocr = CnOcr() std = CnStd() path = '/images/life.jpg' img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), -1) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = np.array(img) # res = ocr.ocr(img) box_info_list = std.detect(img) res = '' for box_info in box_info_list: cropped_img = box_info['cropped_img'] ocr_res = ocr.ocr_for_single_line(cropped_img) res += ''.join(ocr_res) + '\n' # print('ocr result: %s' % ''.join(ocr_res)) print("Predicted Chars:", res)
#-*- codeing = utf-8 -*- #@Time : 2020/12/30 下午10:19 #@Author : 江啸栋19262010049 #@File : run.py #@Software : PyCharm # from cnstd import CnStd # std = CnStd() # box_info_list = std.detect('examples/taobao.jpg') from cnstd import CnStd from cnocr import CnOcr std = CnStd() cn_ocr = CnOcr() box_info_list = std.detect('examples/taobao.jpg') for box_info in box_info_list: cropped_img = box_info['cropped_img'] ocr_res = cn_ocr.ocr_for_single_line(cropped_img) print('ocr result: %s' % ''.join(ocr_res))