def __init__(self, ctpn_weight_path, densenet_weight_path, dict_path, ctpn_config_path=None, densenet_config_path=None): """ :param ctpn_weight_path: CTPN 模型权重文件路径 :param densenet_weight_path: Densenet 模型权重文件路径 :param dict_path: 字典文件路径 :param ctpn_config_path: CTPN 模型配置文件路径 :param densenet_config_path: Densenet 模型配置文件路径 """ self.id_to_char = load_dict(dict_path, encoding="utf-8") # 初始化CTPN模型 if ctpn_config_path is not None: ctpn_config = CTPN.load_config(ctpn_config_path) ctpn_config["weight_path"] = ctpn_weight_path self.ctpn = CTPN(**ctpn_config) else: self.ctpn = CTPN() # 初始化Densenet 模型 if densenet_config_path is not None: densenet_config = DenseNetOCR.load_config(densenet_config_path) densenet_config["weight_path"] = densenet_weight_path self.ocr = DenseNetOCR(**densenet_config) else: self.ocr = DenseNetOCR(num_classes=len(self.id_to_char))
class TextDetectionApp: __lock = Lock() __ocr = None def __init__(self, ctpn_weight_path, densenet_weight_path, dict_path, ctpn_config_path=None, densenet_config_path=None): """ :param ctpn_weight_path: CTPN 模型权重文件路径 :param densenet_weight_path: Densenet 模型权重文件路径 :param dict_path: 字典文件路径 :param ctpn_config_path: CTPN 模型配置文件路径 :param densenet_config_path: Densenet 模型配置文件路径 """ self.id_to_char = load_dict(dict_path, encoding="utf-8") # 初始化CTPN模型 if ctpn_config_path is not None: ctpn_config = CTPN.load_config(ctpn_config_path) ctpn_config["weight_path"] = ctpn_weight_path self.ctpn = CTPN(**ctpn_config) else: self.ctpn = CTPN() # 初始化Densenet 模型 if densenet_config_path is not None: densenet_config = DenseNetOCR.load_config(densenet_config_path) densenet_config["weight_path"] = densenet_weight_path self.ocr = DenseNetOCR(**densenet_config) else: self.ocr = DenseNetOCR(num_classes=len(self.id_to_char)) def detect(self, image, adjust=True, parallel=True): """ :param parallel: 是否并行处理 :param image: numpy数组形状为(h, w, c)或图像路径 :param adjust: 是否调整检测框 :return: """ if type(image) == str: if not os.path.exists(image): raise ValueError("The images path: " + image + " not exists!") text_recs, img = self.ctpn.predict(image, mode=2) # 得到所有的检测框 if len(text_recs) == 0: return [], [] text_recs = sort_box(text_recs) if parallel: imgs = clip_imgs_with_bboxes(text_recs, img, adjust) texts = self.ocr.predict_multi(imgs, id_to_char=self.id_to_char) else: texts = [] for index, rec in enumerate(text_recs): image, text = single_text_detect(rec, self.ocr, self.id_to_char, img, adjust) # 识别文字 # plt.subplot(len(text_recs), 1, index + 1) # plt.imshow(images) if text is not None and len(text) > 0: texts.append(text) return text_recs, texts @staticmethod def get_or_create(ctpn_weight_path=default_ctpn_weight_path, ctpn_config_path=default_ctpn_config_path, densenet_weight_path=default_densenet_weight_path, densenet_config_path=default_densenet_config_path, dict_path=default_dict_path): TextDetectionApp.__lock.acquire() try: if TextDetectionApp.__ocr is None: TextDetectionApp.__ocr = TextDetectionApp( ctpn_weight_path=ctpn_weight_path, ctpn_config_path=ctpn_config_path, densenet_weight_path=densenet_weight_path, densenet_config_path=densenet_config_path, dict_path=dict_path) except Exception as e: print(e) finally: TextDetectionApp.__lock.release() return TextDetectionApp.__ocr
parser.add_argument("--weights_file_path", help="模型初始权重文件位置", default=None) parser.add_argument("--save_weights_file_path", help="保存模型训练权重文件位置", default=r'model/weights-densent-{epoch:02d}.hdf5') args = parser.parse_args() K.set_session(utils.get_session(0.8)) batch_size = args.batch_size encoding = "UTF-8" initial_epoch = args.initial_epoch # 载入模型配置文件 config = DenseNetOCR.load_config(args.config_file_path) weights_file_path = args.weights_file_path gpus = args.gpus config['num_gpu'] = gpus # 载入初始权重 if weights_file_path is not None: config["weight_path"] = weights_file_path # 载入训练数据 images_dir = args.images_dir dict_file_path = args.dict_file_path train_labeled_file_path = args.train_file_path test_labeled_file_path = args.test_file_path save_weights_file_path = args.save_weights_file_path