Ejemplo n.º 1
0
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
def text_detection(args, net):

    image_list, _, _ = craft.file_utils.get_files(args.input_dir)

    output_dir = args.output_dir
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' +
              args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(
                copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(
                    torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()

    print('-' * 30)
    print('文本区域检测  图片数量= {}'.format(len(image_list)))
    print('-' * 30)

    # load data
    for k, image_path in enumerate(image_list):

        try:

            print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                      image_path),
                  end='\r')
            image = craft.imgproc.loadImage(image_path)

            bboxes, polys, score_text = test_net(args, net, image,
                                                 args.text_threshold,
                                                 args.link_threshold,
                                                 args.low_text, args.cuda,
                                                 args.poly, refine_net)

            # save score text
            filename, file_ext = os.path.splitext(os.path.basename(image_path))
            #mask_file = os.path.join(output_dir , filename + '_mask.jpg')
            #cv2.imwrite(mask_file, score_text)

            craft.file_utils.saveResult(image_path,
                                        image[:, :, ::-1],
                                        polys,
                                        dirname=output_dir,
                                        write_image=True)
        except Exception:
            print("【Error】 图片[{}]  文本区域检测识别失败".format(image_path))

    print("elapsed time : {}s".format(time.time() - t))
Ejemplo n.º 3
0
class CraftOne:
    def __init__(self, cuda=True):
        self.cuda = cuda
        for k, v in config_craft.items():
            setattr(self, k, v)

        self.net = CRAFT()
        print(f'Loading weights from checkpoint ({self.trained_model})')
        if self.cuda:
            self.net.load_state_dict(
                copyStateDict(torch.load(self.trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(
                    torch.load(self.trained_model, map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if self.refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            print(
                f'Loading weights of refiner from checkpoint ({self.refiner_model})'
            )
            if self.cuda:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    copyStateDict(
                        torch.load(self.refiner_model, map_location='cpu')))

            self.refine_net.eval()
            self.poly = True

    def detect_text(self,
                    net,
                    image,
                    text_threshold,
                    link_threshold,
                    low_text,
                    poly,
                    refine_net=None):
        t0 = time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if self.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time() - t0
        t1 = time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def main(self, image):
        bboxes, polys, score_text = self.detect_text(self.net, image,
                                                     self.text_threshold,
                                                     self.link_threshold,
                                                     self.low_text, self.poly,
                                                     self.refine_net)

        return bboxes, polys, score_text
Ejemplo n.º 4
0
class DETECTION:
    def __init__(self):
        # model settings #
        self.trained_model = 'model/craft_mlt_25k.pth'
        self.text_threshold = 0.7
        self.low_text = 0.4
        self.link_threshold = 0.4
        self.cuda = True
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.poly = True
        self.show_time = False
        self.video_folder = 'input/'
        self.refine = False
        self.refiner_model = 'model/craft_refiner_CTW1500.pth'
        self.interpolation = cv2.INTER_LINEAR

        #import model
        self.net = CRAFT()  # initialize

    def copyStateDict(self, state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    def load_model(self):
        print('Loading weights from checkpoint (' + self.trained_model + ')')
        if self.cuda:
            self.net.load_state_dict(
                self.copyStateDict(torch.load(self.trained_model)))
        else:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load(self.trained_model, map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        # # LinkRefiner
        self.refine_net = None
        if self.refine:
            self.refine_net = RefineNet()
            print('Loading weights of refiner from checkpoint (' +
                  self.refiner_model + ')')
            if self.cuda:
                self.refine_net.load_state_dict(
                    self.copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    self.copyStateDict(
                        torch.load(self.refiner_model, map_location='cpu')))

            self.refine_net.eval()
            self.poly = True
            t = time.time()

    def test_net(self, image_opencv):

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image_opencv,
            self.canvas_size,
            interpolation=self.interpolation,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]

        if self.cuda:
            x = x.cuda()

        # forward pass
        y, feature = self.net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        t0 = time.time()
        if self.refine_net is not None:
            y_refiner = self.refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               self.text_threshold,
                                               self.link_threshold,
                                               self.low_text, self.poly)
        #print(boxes)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]
        t1 = time.time() - t1

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
        return boxes, polys
Ejemplo n.º 5
0
class Detection:
    def __init__(
        self,
        trained_model="weights/craft_mlt_25k.pth",
        text_threshold=0.7,
        low_text=0.4,
        link_threshold=0.4,
        cuda=True,
        canvas_size=1280,
        mag_ratio=1.5,
        poly=False,
        show_time=False,
        test_folder="/data/",
        refine=False,
        refiner_model="weights/craft_refiner_CTW1500.pth",
        result_folder="./results",
    ):
        self.trained_model = trained_model
        self.text_threshold = text_threshold
        self.low_text = low_text
        self.link_threshold = link_threshold
        self.cuda = cuda
        self.canvas_size = canvas_size
        self.mag_ratio = mag_ratio
        self.poly = poly
        self.show_time = show_time
        self.test_folder = test_folder
        self.refine = refine
        self.refiner_model = refiner_model
        self.result_folder = result_folder

        self.__checkFolder()
        self.__loadNet()

    def __checkFolder(self):
        if not os.path.isdir(self.result_folder):
            os.mkdir(self.result_folder)

    def __loadNet(self):
        # load net
        self.net = CRAFT()  # initialize

        logger.info("Loading weights from checkpoint (" + self.trained_model + ")")
        if self.cuda:
            self.net.load_state_dict(copyStateDict(torch.load(self.trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(torch.load(self.trained_model, map_location="cpu"))
            )

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if self.refine:
            from refinenet import RefineNet

            self.refine_net = RefineNet()
            logger.info("Loading weights of refiner from checkpoint (" + self.refiner_model + ")")
            if self.cuda:
                self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(self.refiner_model, map_location="cpu"))
                )

            self.refine_net.eval()
            self.poly = True

    def TextDetect(self, image_path=None):
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = self.test_net(
            self.net,
            image,
            self.text_threshold,
            self.link_threshold,
            self.low_text,
            self.cuda,
            self.poly,
            self.refine_net,
        )

        num, location = file_utils.saveResult(
            image_path, image[:, :, ::-1], polys, dirname=self.result_folder
        )
        if not num:
            logger.warning("No image box found")
        else:
            logger.info(f"Saved to {self.result_folder}, {image_path} done")
        return num, location, image.shape

    def test_net(
        self, net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None
    ):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image, self.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=self.mag_ratio
        )
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(
            score_text, score_link, text_threshold, link_threshold, low_text, poly
        )

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None:
                polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            logger.info("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text
Ejemplo n.º 6
0
def text_detection_service(image_list, cuda=False):
    # load net
    net = CRAFT()

    if cuda:
        net.load_state_dict(copyStateDict(torch.load(TRAINED_MODEL)))
    else:
        net.load_state_dict(copyStateDict(torch.load(TRAINED_MODEL, map_location='cpu')))

    if cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    POLY = False
    if REFINE:
        from refinenet import RefineNet
        refine_net = RefineNet()

        if cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(REFINER_MODEL)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(REFINER_MODEL, map_location='cpu')))

        refine_net.eval()
        POLY = True

    for k, image in enumerate(image_list):
        bboxes, polys, score_text = test_net(net, image, TEXT_THRESHOLD, LINK_THRESHOLD, LOW_TEXT, cuda, POLY, refine_net)
        max_area_index = [0, 0]
        max_area = [0, 0]

        for i in range(len(bboxes)):
            area = calculate_area(bboxes[i])

            if area >= max_area[0]:
                go_between = max_area_index[0]
                max_area_index[0] = i
                max_area_index[1] = go_between
                go_between = max_area[0]
                max_area[0] = area
                max_area[1] = go_between
            elif area >= max_area[1]:
                max_area_index[1] = i
                max_area[1] = area
            else:
                continue
        
        area_crops = []
        for i in range(len(max_area_index)):
            bbox_crop_ind = bboxes[max_area_index[i]]
            rect = cv2.boundingRect(bbox_crop_ind)
            x, y, w, h = rect
            y1 = y - 5 if y >=5 else 0
            y2 = y + h + 5
            x1 = x - 5 if x >=5 else 0
            x2 = x + w + 5
            area_crops.append(image[int(y1):int(y2), int(x1):int(x2)])
        return area_crops
Ejemplo n.º 7
0
def craftnet():
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + CONFIG['trained_model'] + ')')
    if CONFIG['cuda']:
        net.load_state_dict(copyStateDict(torch.load(CONFIG['trained_model'])))
    else:
        net.load_state_dict(
            copyStateDict(
                torch.load(CONFIG['trained_model'], map_location='cpu')))

    if CONFIG['cuda']:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if CONFIG['refine']:
        from refinenet import RefineNet
        refine_net = RefineNet()
        #print('Loading weights of refiner from checkpoint (' + CONFIG['refiner_model'] + ')')
        if CONFIG['cuda']:
            refine_net.load_state_dict(
                copyStateDict(torch.load(CONFIG['refiner_model'])))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(
                    torch.load(CONFIG['refiner_model'], map_location='cpu')))

        refine_net.eval()
        CONFIG['poly'] = True

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        #print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        orig, image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(
            net, image, CONFIG['text_threshold'], CONFIG['link_threshold'],
            CONFIG['low_text'], CONFIG['cuda'], CONFIG['poly'], refine_net)

        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path,
                              image[:, :, ::-1],
                              polys,
                              dirname=result_folder)

    information = []
    for file in os.listdir('result/temp_result'):
        filename = os.path.splitext(file)[0]
        extension = os.path.splitext(file)[1]
        if extension == '.tif':
            #!tesseract oem 13 --tessdata-dir ./result/ ./result/temp_result{filename}.png ./test/{filename+'-eng'} -l eng+vie
            image = Image.open('result/temp_result/' + file)

            config = '--psm 10 --oem 3 -l vie+eng'
            raw_text = pytesseract.image_to_string(image,
                                                   lang='eng+vie',
                                                   config=config)
            information.append(raw_text)

    X = {
        "name": [],
        "phone": [],
        "email": [],
        "company": [],
        "website": [],
        "address": [],
        "extra_information": []
    }
    for i in range(len(information)):
        info = information[i]
        if parse_info(info):

            email_parse = parse_email(info)
            if email_parse != None:
                X["email"].append(email_parse)
                continue

            phone_parse = parse_phone(info)
            if phone_parse != None:
                X["phone"].append(phone_parse)
                continue

            website_parse = parse_website(info)
            if website_parse != None:
                X["website"].append(website_parse)
                continue

            company_parse = parse_company(info)
            if company_parse != None:
                X["company"].append(company_parse)
                continue

            address_parse = parse_address(info)
            if address_parse != None:
                X["address"].append(address_parse)
                continue

            name_parse = parse_name(info)
            if name_parse != None:
                X["name"].append(info)
                continue

            X["extra_information"].append(info)
    return X
Ejemplo n.º 8
0
def runCRAFT(img):
    global trained_model
    global text_threshold
    global low_text
    global link_threshold
    global cuda
    global canvas_size
    global mag_ratio
    global poly
    global show_time
    global refine
    global refiner_model
    """ For test images in a folder """
    # image_list, _, _ = file_utils.get_files(args.test_folder)

    # result_folder = './result/'
    # if not os.path.isdir(result_folder):
    #     os.mkdir(result_folder)

    # load net
    net = CRAFT()  # initialize

    if cuda:
        net.load_state_dict(copyStateDict(torch.load(trained_model)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(trained_model, map_location='cpu')))

    if cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + refiner_model +
              ')')
        if cuda:
            refine_net.load_state_dict(copyStateDict(
                torch.load(refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(torch.load(refiner_model, map_location='cpu')))

        refine_net.eval()
        poly = True
    t = time.time()

    # load data
    # for k, image_path in enumerate(image_list):
    #    print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
    #    img = imgproc.loadImage('C:/Users/LG/Desktop/testImages/cheese.png')

    img = np.array(img)
    if img.shape[0] == 2:
        img = img[0]
    if len(img.shape) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    if img.shape[2] == 4:
        img = img[:, :, :3]
    boxes = test_net(net, img, text_threshold, link_threshold, low_text, cuda,
                     refine_net)
    # save score text
    # filename, file_ext = os.path.splitext(os.path.basename('./CRAFT_image/test6.jpg'))
    # mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    # cv2.imwrite(mask_file, score_text)

    # file_utils.saveResult(image_path, image[:, :, ::-1], polys, dirname=result_folder)
    strResult = file_utils.saveResult(img[:, :, ::-1], boxes)

    return strResult
Ejemplo n.º 9
0
class NpPointsCraft(object):
    """
    NpPointsCraft Class
    git clone https://github.com/clovaai/CRAFT-pytorch.git
    """
    def __init__(self, **args):
        pass

    @classmethod
    def get_classname(cls):
        return cls.__name__

    def load(self, mtl_model_path="latest", refiner_model_path="latest"):
        """
        TODO: describe method
        """
        if mtl_model_path == "latest":
            model_info = download_latest_model(self.get_classname(),
                                               "mtl",
                                               ext="pth",
                                               mode=get_mode_torch())
            mtl_model_path = model_info["path"]
        if refiner_model_path == "latest":
            model_info = download_latest_model(self.get_classname(),
                                               "refiner",
                                               ext="pth",
                                               mode=get_mode_torch())
            refiner_model_path = model_info["path"]
        device = "cpu"
        if get_mode_torch() == "gpu":
            device = "cuda"
        self.loadModel(device, True, mtl_model_path, refiner_model_path)

    def loadModel(self,
                  device="cuda",
                  is_refine=True,
                  trained_model=os.path.join(CRAFT_DIR,
                                             'weights/craft_mlt_25k.pth'),
                  refiner_model=os.path.join(
                      CRAFT_DIR, 'weights/craft_refiner_CTW1500.pth')):
        """
        TODO: describe method
        """
        is_cuda = device == "cuda"
        self.is_cuda = is_cuda

        # load net
        self.net = CRAFT()  # initialize

        print('Loading weights from checkpoint (' + trained_model + ')')
        if is_cuda:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(torch.load(trained_model, map_location='cpu')))

        if is_cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if is_refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            print('Loading weights of refiner from checkpoint (' +
                  refiner_model + ')')
            if is_cuda:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(refiner_model,
                                             map_location='cpu')))

            self.refine_net.eval()
            self.is_poly = True

    def detectByImagePath(self,
                          image_path,
                          targetBoxes,
                          qualityProfile=[1, 0, 0],
                          debug=False):
        """
        TODO: describe method
        """
        image = imgproc.loadImage(image_path)
        for targetBox in targetBoxes:
            x = min(targetBox['x1'], targetBox['x2'])
            w = abs(targetBox['x2'] - targetBox['x1'])
            y = min(targetBox['y1'], targetBox['y2'])
            h = abs(targetBox['y2'] - targetBox['y1'])
            #print('x: {} w: {} y: {} h: {}'.format(x,w,y,h))
            image_part = image[y:y + h, x:x + w]
            points = self.detectInBbox(image_part)
            propablyPoints = addCoordinatesOffset(points, x, y)
            targetBox['points'] = []
            targetBox['imgParts'] = []
            if (len(propablyPoints)):
                targetPointsVariants = makeRectVariants2(
                    propablyPoints, h, w, qualityProfile)
                # targetBox['points'] = addCoordinatesOffset(points, x, y)
                # targetPointsVariants = [targetPoints, fixSideFacets(targetPoints)]
                if len(targetPointsVariants) > 1:
                    imgParts = [
                        getCvZoneRGB(image, reshapePoints(rect, 1))
                        for rect in targetPointsVariants
                    ]
                    idx = detectBestPerspective(
                        normalizePerspectiveImages(imgParts))
                    print('--------------------------------------------------')
                    print('idx={}'.format(idx))
                    #targetBox['points'] = addoptRectToBbox2(targetPointsVariants[idx], image.shape,x,y)
                    targetBox['points'] = targetPointsVariants[idx]
                    targetBox['imgParts'] = imgParts
                else:
                    targetBox['points'] = targetPointsVariants[0]
        return targetBoxes, image

    def detect(self,
               image,
               targetBoxes,
               qualityProfile=[1, 0, 0],
               debug=False):
        """
        TODO: describe method
        """
        all_points = []
        for targetBox in targetBoxes:
            x = int(min(targetBox[0], targetBox[2]))
            w = int(abs(targetBox[2] - targetBox[0]))
            y = int(min(targetBox[1], targetBox[3]))
            h = int(abs(targetBox[3] - targetBox[1]))

            image_part = image[y:y + h, x:x + w]
            propablyPoints = addCoordinatesOffset(
                self.detectInBbox(image_part), x, y)
            points = []
            if (len(propablyPoints)):
                targetPointsVariants = makeRectVariants2(
                    propablyPoints, h, w, qualityProfile)
                if len(targetPointsVariants) > 1:
                    imgParts = [
                        getCvZoneRGB(image, reshapePoints(rect, 1))
                        for rect in targetPointsVariants
                    ]
                    idx = detectBestPerspective(
                        normalizePerspectiveImages(imgParts))
                    points = targetPointsVariants[idx]
                else:
                    points = targetPointsVariants[0]
                all_points.append(points)
            else:
                all_points.append([[x, y + h], [x, y], [x + w, y],
                                   [x + w, y + h]])
        return all_points

    def detectInBbox(self, image, debug=False):
        """
        TODO: describe method
        """
        low_text = 0.4
        link_threshold = 0.7  # 0.4
        text_threshold = 0.6
        canvas_size = 1280
        mag_ratio = 1.5

        t = time.time()
        bboxes, polys, score_text = test_net(self.net, image, text_threshold,
                                             link_threshold, low_text,
                                             self.is_cuda, self.is_poly,
                                             canvas_size, self.refine_net,
                                             mag_ratio)
        if debug:
            print("elapsed time : {}s".format(time.time() - t))
        dimensions = []
        for poly in bboxes:
            dimensions.append({
                'dx': distance(poly[0], poly[1]),
                'dy': distance(poly[1], poly[2])
            })

        if (debug):
            print(score_text.shape)
            # print(polys)
            print(dimensions)
            print(bboxes)

        np_bboxes_idx, garbage_bboxes_idx = split_boxes(bboxes, dimensions)

        targetPoints = []
        if (debug):
            print('np_bboxes_idx')
            print(np_bboxes_idx)
            print('garbage_bboxes_idx')
            print(garbage_bboxes_idx)
            print('raw_boxes')
            print(raw_boxes)
            print('raw_polys')
            print(raw_polys)

        if len(np_bboxes_idx) == 1:
            targetPoints = bboxes[np_bboxes_idx[0]]

        if len(np_bboxes_idx) > 1:
            targetPoints = minimum_bounding_rectangle(
                np.concatenate([bboxes[i] for i in np_bboxes_idx], axis=0))

        imgParts = []
        if len(np_bboxes_idx) > 0:
            targetPoints = normalizeRect(targetPoints)
            if (debug):
                print('###################################')
                print(targetPoints)

            if (debug):
                print('image.shape')
                print(image.shape)
            #targetPoints = fixSideFacets(targetPoints, image.shape)
            targetPoints = addoptRectToBbox(targetPoints, image.shape, 7, 12,
                                            0, 12)
        return targetPoints
Ejemplo n.º 10
0
def analysis(image_path, result_path):
    """ For test images in a folder """
    net = CRAFT()     # initialize

    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()

    image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net,
                                         image,
                                         args.text_threshold,
                                         args.link_threshold,
                                         args.low_text,
                                         args.cuda,
                                         args.poly,
                                         refine_net)
                                         
    opencv_image = cv2.imread(image_path)
    
    for index, box in enumerate(polys):
        xmin, xmax, ymin, ymax = box[0][0], box[1][0], box[0][1], box[2][1]
        multiplier_area = image[int(ymin):int(ymax), int(xmin):int(xmax)]
        
        try:
            im_pil = Image.fromarray(multiplier_area)
            #if you want to detect the text on the image
            if args.ocr_on:
                if args.ocr_method == 'pytesseract':
                    configuration = ("-l eng --oem 1 --psm 7")
                    multiplier = (pytesseract.image_to_string(multiplier_area, config=configuration).lower())
                    multiplier = multiplier.split("\n")[0]
                    
                elif args.ocr_method == 'TPS-ResNet-BiLSTM':
                    multiplier = text_recognition.recognition(im_pil)
                    
                cv2.putText(opencv_image, multiplier, (int(xmin),int(ymin-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 1)
                
            cv2.rectangle(opencv_image,(int(xmin),int(ymin)), (int(xmax),int(ymax)),(0,0,255), 1)
            cv2.imwrite(result_path, opencv_image)
                
        except:
            print("====ERROR====", traceback.format_exc())
            pass