예제 #1
0
    def __init__(self,
                 ctpn_weight_path,
                 densenet_weight_path,
                 dict_path,
                 ctpn_config_path=None,
                 densenet_config_path=None):
        """

        :param ctpn_weight_path:    CTPN 模型权重文件路径
        :param densenet_weight_path: Densenet 模型权重文件路径
        :param dict_path:           字典文件路径
        :param ctpn_config_path:    CTPN 模型配置文件路径
        :param densenet_config_path: Densenet 模型配置文件路径
        """

        self.id_to_char = load_dict(dict_path, encoding="utf-8")

        # 初始化CTPN模型
        if ctpn_config_path is not None:
            ctpn_config = CTPN.load_config(ctpn_config_path)
            ctpn_config["weight_path"] = ctpn_weight_path
            self.ctpn = CTPN(**ctpn_config)
        else:
            self.ctpn = CTPN()

        # 初始化Densenet 模型
        if densenet_config_path is not None:
            densenet_config = DenseNetOCR.load_config(densenet_config_path)
            densenet_config["weight_path"] = densenet_weight_path
            self.ocr = DenseNetOCR(**densenet_config)
        else:
            self.ocr = DenseNetOCR(num_classes=len(self.id_to_char))
예제 #2
0
class TextDetectionApp:
    __lock = Lock()
    __ocr = None

    def __init__(self,
                 ctpn_weight_path,
                 densenet_weight_path,
                 dict_path,
                 ctpn_config_path=None,
                 densenet_config_path=None):
        """

        :param ctpn_weight_path:    CTPN 模型权重文件路径
        :param densenet_weight_path: Densenet 模型权重文件路径
        :param dict_path:           字典文件路径
        :param ctpn_config_path:    CTPN 模型配置文件路径
        :param densenet_config_path: Densenet 模型配置文件路径
        """

        self.id_to_char = load_dict(dict_path, encoding="utf-8")

        # 初始化CTPN模型
        if ctpn_config_path is not None:
            ctpn_config = CTPN.load_config(ctpn_config_path)
            ctpn_config["weight_path"] = ctpn_weight_path
            self.ctpn = CTPN(**ctpn_config)
        else:
            self.ctpn = CTPN()

        # 初始化Densenet 模型
        if densenet_config_path is not None:
            densenet_config = DenseNetOCR.load_config(densenet_config_path)
            densenet_config["weight_path"] = densenet_weight_path
            self.ocr = DenseNetOCR(**densenet_config)
        else:
            self.ocr = DenseNetOCR(num_classes=len(self.id_to_char))

    def detect(self, image, adjust=True, parallel=True):
        """

        :param parallel: 是否并行处理
        :param image: numpy数组形状为(h, w, c)或图像路径
        :param adjust: 是否调整检测框
        :return:
        """

        if type(image) == str:
            if not os.path.exists(image):
                raise ValueError("The images path: " + image + " not exists!")
        text_recs, img = self.ctpn.predict(image, mode=2)  # 得到所有的检测框

        if len(text_recs) == 0:
            return [], []

        text_recs = sort_box(text_recs)

        if parallel:
            imgs = clip_imgs_with_bboxes(text_recs, img, adjust)

            texts = self.ocr.predict_multi(imgs, id_to_char=self.id_to_char)
        else:
            texts = []
            for index, rec in enumerate(text_recs):
                image, text = single_text_detect(rec, self.ocr,
                                                 self.id_to_char, img,
                                                 adjust)  # 识别文字
                # plt.subplot(len(text_recs), 1, index + 1)
                # plt.imshow(images)
                if text is not None and len(text) > 0:
                    texts.append(text)

        return text_recs, texts

    @staticmethod
    def get_or_create(ctpn_weight_path=default_ctpn_weight_path,
                      ctpn_config_path=default_ctpn_config_path,
                      densenet_weight_path=default_densenet_weight_path,
                      densenet_config_path=default_densenet_config_path,
                      dict_path=default_dict_path):

        TextDetectionApp.__lock.acquire()
        try:
            if TextDetectionApp.__ocr is None:
                TextDetectionApp.__ocr = TextDetectionApp(
                    ctpn_weight_path=ctpn_weight_path,
                    ctpn_config_path=ctpn_config_path,
                    densenet_weight_path=densenet_weight_path,
                    densenet_config_path=densenet_config_path,
                    dict_path=dict_path)
        except Exception as e:
            print(e)
        finally:
            TextDetectionApp.__lock.release()
        return TextDetectionApp.__ocr
예제 #3
0
    parser.add_argument("--weights_file_path", help="模型初始权重文件位置", default=None)
    parser.add_argument("--save_weights_file_path",
                        help="保存模型训练权重文件位置",
                        default=r'model/weights-densent-{epoch:02d}.hdf5')

    args = parser.parse_args()

    K.set_session(utils.get_session(0.8))

    batch_size = args.batch_size

    encoding = "UTF-8"
    initial_epoch = args.initial_epoch

    # 载入模型配置文件
    config = DenseNetOCR.load_config(args.config_file_path)
    weights_file_path = args.weights_file_path
    gpus = args.gpus
    config['num_gpu'] = gpus

    # 载入初始权重
    if weights_file_path is not None:
        config["weight_path"] = weights_file_path

    # 载入训练数据
    images_dir = args.images_dir
    dict_file_path = args.dict_file_path
    train_labeled_file_path = args.train_file_path
    test_labeled_file_path = args.test_file_path
    save_weights_file_path = args.save_weights_file_path