예제 #1
0
def extract_text(img, craft_net, clt):
    prediction_result = get_prediction(image=img,
                                       craft_net=craft_net,
                                       text_threshold=0.01,
                                       link_threshold=0.4,
                                       low_text=0.4,
                                       cuda=True,
                                       long_size=1280)

    color = get_bg_color(img, clt)

    for box in prediction_result['boxes']:
        x1, y1 = np.min(box, axis=0)
        x2, y2 = np.max(box, axis=0)
        h, w = y2 - y1, x2 - x1

        roi = [x1, y1 - 100, x2, y2 + 100]
        roi = filter(roi, img.shape)
        a = filter([x1, y1, x2, y2], img.shape)

        color = get_bg_color(img[roi[1]:roi[3], :], clt)

        if (a[2] - a[0]) * (a[3] - a[1]) < img.shape[0] * img.shape[1] * 0.03:
            img[a[1] - 1:a[3] + 2, a[0] - 1:a[2] + 2] = color
    return img
예제 #2
0
    def test_get_prediction(self):
        # load image
        image = read_image(self.image_path)

        # load models
        craft_net = load_craftnet_model()
        refine_net = None

        # perform prediction
        text_threshold = 0.9
        link_threshold = 0.2
        low_text = 0.2
        cuda = False
        prediction_result = get_prediction(
            image=image,
            craft_net=craft_net,
            refine_net=refine_net,
            text_threshold=text_threshold,
            link_threshold=link_threshold,
            low_text=low_text,
            cuda=cuda,
            long_size=720,
        )

        self.assertEqual(len(prediction_result["boxes"]), 35)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111)
        self.assertEqual(len(prediction_result["polys"]), 35)
        self.assertEqual(
            prediction_result["heatmaps"]["text_score_heatmap"].shape,
            (240, 368, 3))
예제 #3
0
    def detect_text(self):
        if not self.refine_net:
            self.refine_net = craft.load_refinenet_model(cuda=self.cuda)
            self.craft_net = craft.load_craftnet_model(cuda=self.cuda)

        # perform prediction
        self.text_areas = craft.get_prediction(
            image=self.rgb,
            craft_net=self.craft_net,
            refine_net=self.refine_net,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=self.cuda,
            long_size=1280
        )

        if self.debug:
            # export heatmap, detection points, box visualization
            craft.export_extra_results(
                image_path=self.image_path,
                image=self.rgb,
                regions=self.text_areas["boxes"],
                heatmaps=self.text_areas["heatmaps"],
                output_dir=self.temp_dir
            )

            file = os.path.splitext(self.image_path)[0]
            notebook_image(self.temp_dir, "%s_text_detection.png" % (file))
            notebook_image(self.temp_dir, "%s_text_score_heatmap.png" % (file))
            notebook_image(self.temp_dir, "%s_link_score_heatmap.png" % (file))
예제 #4
0
 def detect(self, img):
     prediction_result = get_prediction(image=img,
                                        craft_net=self.craft_net,
                                        refine_net=self.refine_net,
                                        **self.args)
     return [{
         'type': 'text',
         'points': points
     } for points in prediction_result['boxes'].tolist()]