コード例 #1
0
def main(args):
    # 加载图片
    image, image_meta, _, _ = image_utils.load_image_gt(
        np.random.randint(10), args.image_path, config.IMAGE_SHAPE[0], None)
    # 加载模型
    config.IMAGES_PER_GPU = 1
    m = models.ctpn_net(config, 'test')
    if args.weight_path is not None:
        m.load_weights(args.weight_path)
    else:
        m.load_weights(config.WEIGHT_PATH)
    # m.summary()

    # 模型预测
    text_boxes, text_scores, _ = m.predict(np.array([image]))
    text_boxes = np_utils.remove_pad(text_boxes[0])
    text_scores = np_utils.remove_pad(text_scores[0])[:, 0]

    # 文本行检测器
    image_meta = image_utils.parse_image_meta(image_meta)
    detector = TextDetector(config)
    text_lines = detector.detect(text_boxes, text_scores, config.IMAGE_SHAPE,
                                 image_meta['window'])
    # print("text_lines:{}".format(text_lines))

    boxes_num = 15
    fig = plt.figure(figsize=(16, 16))
    ax = fig.add_subplot(1, 1, 1)

    visualize.display_polygons(image,
                               text_lines[:boxes_num, :8],
                               text_lines[:boxes_num, 8],
                               ax=ax)
    fig.savefig('examples.{}.png'.format(np.random.randint(10)))
コード例 #2
0
ファイル: predict.py プロジェクト: hogwild/keras-ctpn
def main(args):
    # 覆盖参数 rewrite the parameters
    config.USE_SIDE_REFINE = bool(args.use_side_refine)
    if args.weight_path is not None:
        config.WEIGHT_PATH = args.weight_path
    config.IMAGES_PER_GPU = 1
    config.IMAGE_SHAPE = (1024, 1024, 3)
    # 加载图片 load images
    image, image_meta, _, _ = image_utils.load_image_gt(np.random.randint(10),
                                                        args.image_path,
                                                        config.IMAGE_SHAPE[0],
                                                        None)
    # 加载模型 load the model
    m = models.ctpn_net(config, 'test')
    m.load_weights(config.WEIGHT_PATH, by_name=True)
    # m.summary()

    # 模型预测 prediction
    text_boxes, text_scores, _ = m.predict([np.array([image]), np.array([image_meta])])
    text_boxes = np_utils.remove_pad(text_boxes[0])
    text_scores = np_utils.remove_pad(text_scores[0])[:, 0]

    # 文本行检测器 text detector
    image_meta = image_utils.parse_image_meta(image_meta)
    detector = TextDetector(config)
    text_lines = detector.detect(text_boxes, text_scores, config.IMAGE_SHAPE, image_meta['window'])
    # 可视化保存图像 saving the images
    boxes_num = 30
    fig = plt.figure(figsize=(16, 16))
    ax = fig.add_subplot(1, 1, 1)
    visualize.display_polygons(image, text_lines[:boxes_num, :8], text_lines[:boxes_num, 8],
                               ax=ax)
    image_name = os.path.basename(args.image_path)
    fig.savefig('{}.{}.jpg'.format(os.path.splitext(image_name)[0], int(config.USE_SIDE_REFINE)))
コード例 #3
0
ファイル: evaluate.py プロジェクト: ximingr/ctpn-with-keras
def generator(image_path_list, image_shape):
    for i, image_path in enumerate(image_path_list):
        image, image_meta, _, _ = image_utils.load_image_gt(
            np.random.randint(10), image_path, image_shape[0])
        print("开始评估第 {} 张图像".format(i))
        yield {
            "input_image": np.asarray([image]),
            "input_image_meta": np.asarray([image_meta])
        }
コード例 #4
0
ファイル: evaluate.py プロジェクト: hogwild/keras-ctpn
def generator(image_path_list, image_shape):
    for i, image_path in enumerate(image_path_list):
        image, image_meta, _, _ = image_utils.load_image_gt(
            np.random.randint(10), image_path, image_shape[0])
        if i % 200 == 0:
            print("Evaluating No. {} ".format(i))
        yield {
            "input_image": np.asarray([image]),
            "input_image_meta": np.asarray([image_meta])
        }
コード例 #5
0
    def get_text_lines(self,
                       img: np.ndarray,
                       interactive=False) -> List[ctpn_coordinate_pair]:
        # 加载图片
        image, image_meta, _, _ = image_utils.load_image_gt(
            np.random.randint(10), img, config.IMAGE_SHAPE[0], None)

        # 加载模型

        # 模型预测
        text_boxes, text_scores, _ = self.m.predict(
            [np.array([image]), np.array([image_meta])])
        text_boxes = np_utils.remove_pad(text_boxes[0])
        text_scores = np_utils.remove_pad(text_scores[0])[:, 0]

        # 文本行检测器
        image_meta = image_utils.parse_image_meta(image_meta)
        detector = TextDetector(config)
        text_lines = detector.detect(text_boxes, text_scores,
                                     config.IMAGE_SHAPE, image_meta['window'])
        # 可视化保存图像
        boxes_num = 30
        fig = plt.figure(figsize=(16, 16))
        ax = fig.add_subplot(1, 1, 1)
        visualize.display_polygons(image,
                                   text_lines[:boxes_num, :8],
                                   text_lines[:boxes_num, 8],
                                   ax=ax)
        lines = list(map(lambda x: ctpn_coordinate_pair(x, x[4]), text_lines))
        for r in text_boxes:
            # (y1,x1,y2,x2)
            cv2.rectangle(image, (r[1], r[0]), (r[3], r[2]), (0, 255, 0), 2)
        for r in lines:
            # (y1,x1,y2,x2)
            cv2.line(image, (r.x1, r.y1), (r.x2, r.y2), (255, 0, 0), 2)

        #image = rotate_image(image, get_rotation_angle(lines))
        if interactive:
            cv2.imshow('img', image)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        return lines