コード例 #1
0
    return boxes, ret_score_text


if __name__ == '__main__':
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path),
              end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, score_text = test_net(net, image, args.text_threshold,
                                      args.link_threshold, args.low_text,
                                      args.cuda)

        # save score text
コード例 #2
0
ファイル: process.py プロジェクト: manhntm3/CRAFT-pytorch
def craftnet():
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + CONFIG['trained_model'] + ')')
    if CONFIG['cuda']:
        net.load_state_dict(copyStateDict(torch.load(CONFIG['trained_model'])))
    else:
        net.load_state_dict(
            copyStateDict(
                torch.load(CONFIG['trained_model'], map_location='cpu')))

    if CONFIG['cuda']:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if CONFIG['refine']:
        from refinenet import RefineNet
        refine_net = RefineNet()
        #print('Loading weights of refiner from checkpoint (' + CONFIG['refiner_model'] + ')')
        if CONFIG['cuda']:
            refine_net.load_state_dict(
                copyStateDict(torch.load(CONFIG['refiner_model'])))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(
                    torch.load(CONFIG['refiner_model'], map_location='cpu')))

        refine_net.eval()
        CONFIG['poly'] = True

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        #print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        orig, image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(
            net, image, CONFIG['text_threshold'], CONFIG['link_threshold'],
            CONFIG['low_text'], CONFIG['cuda'], CONFIG['poly'], refine_net)

        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path,
                              image[:, :, ::-1],
                              polys,
                              dirname=result_folder)

    information = []
    for file in os.listdir('result/temp_result'):
        filename = os.path.splitext(file)[0]
        extension = os.path.splitext(file)[1]
        if extension == '.tif':
            #!tesseract oem 13 --tessdata-dir ./result/ ./result/temp_result{filename}.png ./test/{filename+'-eng'} -l eng+vie
            image = Image.open('result/temp_result/' + file)

            config = '--psm 10 --oem 3 -l vie+eng'
            raw_text = pytesseract.image_to_string(image,
                                                   lang='eng+vie',
                                                   config=config)
            information.append(raw_text)

    X = {
        "name": [],
        "phone": [],
        "email": [],
        "company": [],
        "website": [],
        "address": [],
        "extra_information": []
    }
    for i in range(len(information)):
        info = information[i]
        if parse_info(info):

            email_parse = parse_email(info)
            if email_parse != None:
                X["email"].append(email_parse)
                continue

            phone_parse = parse_phone(info)
            if phone_parse != None:
                X["phone"].append(phone_parse)
                continue

            website_parse = parse_website(info)
            if website_parse != None:
                X["website"].append(website_parse)
                continue

            company_parse = parse_company(info)
            if company_parse != None:
                X["company"].append(company_parse)
                continue

            address_parse = parse_address(info)
            if address_parse != None:
                X["address"].append(address_parse)
                continue

            name_parse = parse_name(info)
            if name_parse != None:
                X["name"].append(info)
                continue

            X["extra_information"].append(info)
    return X
コード例 #3
0
ファイル: ocr_process.py プロジェクト: anhnhivu/OCR_pipeline
def processing(file_name, crs):
    #path of file pre-trained model of Craft
    trained_model_path = './craft_mlt_25k.pth'
    #trained_model_path = './vgg16.ckpt'

    net = CRAFT()
    net.load_state_dict(
        copyStateDict(torch.load(trained_model_path, map_location='cpu')))
    net.eval()

    # Load image from its path
    image_path = f'./imgtxtenh/pre_{file_name}'
    image = imgproc.loadImage(image_path)

    fig2 = plt.figure(figsize=(10, 10))  # create a 10 x 10 figure
    ax3 = fig2.add_subplot(111)
    ax3.imshow(image, interpolation='none')
    ax3.set_title('larger figure')
    plt.show()

    poly = False
    refine = False
    show_time = False
    refine_net = None
    bboxes, polys, score_text = test_net(net, image, text_threshold,
                                         link_threshold, low_text, cuda, poly,
                                         refine_net)
    file_utils.saveResult(image_path,
                          image[:, :, ::-1],
                          bboxes,
                          dirname='./craft_result/')
    # Compute coordinate of central point in each bounding box returned by CRAFT
    #Purpose: easier for us to make cluster in G-DBScan step
    poly_indexes = {}
    central_poly_indexes = []
    for i in range(len(polys)):
        poly_indexes[i] = polys[i]
        x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] +
                     polys[i][3][0]) / 4
        y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] +
                     polys[i][3][1]) / 4
        central_poly_indexes.append({i: [int(x_central), int(y_central)]})

    # for i in central_poly_indexes:
    #   print(i)

    # For each of these cordinates convert them to new Point instances
    X = []

    for idx, x in enumerate(central_poly_indexes):
        point = Point(x[idx][0], x[idx][1], idx)
        X.append(point)

    # Cluster these central points
    clustered = GDBSCAN(Points(X), n_pred, 1, w_card)

    # Create bounding box for each cluster with 4 points
    #Purpose: Merge words in 1 cluster into 1 bounding box
    cluster_values = []
    for cluster in clustered:
        sort_cluster = sorted(cluster, key=lambda elem: (elem.x, elem.y))
        max_point_id = sort_cluster[len(sort_cluster) - 1].id
        min_point_id = sort_cluster[0].id
        max_rectangle = sorted(poly_indexes[max_point_id],
                               key=lambda elem: (elem[0], elem[1]))
        min_rectangle = sorted(poly_indexes[min_point_id],
                               key=lambda elem: (elem[0], elem[1]))

        right_above_max_vertex = max_rectangle[len(max_rectangle) - 1]
        right_below_max_vertex = max_rectangle[len(max_rectangle) - 2]
        left_above_min_vertex = min_rectangle[0]
        left_below_min_vertex = min_rectangle[1]

        if (int(min_rectangle[0][1]) > int(min_rectangle[1][1])):
            left_above_min_vertex = min_rectangle[1]
            left_below_min_vertex = min_rectangle[0]
        if (int(max_rectangle[len(max_rectangle) - 1][1]) < int(
                max_rectangle[len(max_rectangle) - 2][1])):
            right_above_max_vertex = max_rectangle[len(max_rectangle) - 2]
            right_below_max_vertex = max_rectangle[len(max_rectangle) - 1]

        cluster_values.append([
            left_above_min_vertex, left_below_min_vertex,
            right_above_max_vertex, right_below_max_vertex
        ])

        # for p in cluster_values:
        #   print(p)

    file_utils.saveResult(image_path,
                          image[:, :, ::-1],
                          cluster_values,
                          dirname='./cluster_result/')
    img = np.array(image[:, :, ::-1])
    ocr_res = []
    plain_txt = ""
    for i, box in enumerate(cluster_values):
        poly = np.array(box).astype(np.int32).reshape((-1))
        poly = poly.reshape(-1, 2)

        rect = cv2.boundingRect(poly)
        x, y, w, h = rect
        croped = img[y:y + h, x:x + w].copy()

        # Preprocess croped segment
        croped = cv2.resize(croped,
                            None,
                            fx=5,
                            fy=5,
                            interpolation=cv2.INTER_LINEAR)
        croped = cv2.cvtColor(croped, cv2.COLOR_BGR2GRAY)
        croped = cv2.GaussianBlur(croped, (3, 3), 0)
        croped = cv2.bilateralFilter(croped, 5, 25, 25)
        croped = cv2.dilate(croped, None, iterations=1)
        croped = cv2.threshold(croped, 0, 255,
                               cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
        #     croped = cv2.threshold(croped, 90, 255, cv2.THRESH_BINARY)[1]
        croped = cv2.cvtColor(croped, cv2.COLOR_BGR2RGB)
        custom_oem_psm_config = r'--oem 1 --psm 12'
        # print("--------")
        # print(pytesseract.image_to_string(croped, lang='eng'))
        plain_txt += "--------\n"
        plain_txt += pytesseract.image_to_string(croped,
                                                 lang='eng',
                                                 config=custom_oem_psm_config)

    copy_plain_txt = plain_txt
    # plain_txt = re.sub(r"b", "6", plain_txt)
    plain_txt = re.sub(r"\$", "5", plain_txt)
    plain_txt = re.sub(r"%", "7", plain_txt)
    plain_txt = re.sub(r"Y", "5", plain_txt)
    plain_txt = re.sub(r"W", "99", plain_txt)
    plain_txt = re.sub(r"£", "1", plain_txt)
    plain_txt = re.sub(r"\)", "1", plain_txt)
    plain_txt = re.sub(r"\}", "1", plain_txt)
    plain_txt = re.sub(r"\|", "1", plain_txt)

    # print(plain_txt)
    # return 0
    #Localization
    init_patterns_1 = re.compile(r'TOA\sDO', re.IGNORECASE)
    init_patterns_2 = re.compile(r'\w{0,2}\d{5,}', re.IGNORECASE)
    term_patterns = re.compile(r'\n[^\-\d]{10,}', re.IGNORECASE)
    coor_patterns = re.compile(r'\d+\s*[\d]*\s*[\d\.]*', re.IGNORECASE)
    coordinates = coor_patterns.findall(plain_txt)
    for i in range(len(coordinates)):
        coordinates[i] = re.sub('\n', '', coordinates[i])
        coordinates[i] = re.sub('\x0c', '', coordinates[i])
        coordinates[i] = re.sub(r'\s', '', coordinates[i])
    # print(coordinates)
    # return 0
    temp_arr = coordinates.copy()
    for i in range(len(temp_arr)):
        try:
            # print(float(temp_arr[i]))
            if len(temp_arr[i]) <= 7:
                coordinates.remove(temp_arr[i])
        except ValueError:
            coordinates.remove(temp_arr[i])
    print(coordinates)

    cluster_arr = [[coor] for coor in coordinates]
    for i in range(len(coordinates)):
        for coor in coordinates:
            if cluster_arr[i][0] != coor and cluster_arr[i][0][0] == coor[
                    0] and cluster_arr[i][0][1] == coor[1] and cluster_arr[i][
                        0][2] == coor[2]:
                cluster_arr[i].append(coor)
    # print(cluster_arr)

    cluster_lens = []
    for cluster in cluster_arr:
        cluster_lens.append(len(cluster))
    # print(cluster_lens)

    try:
        max_len = max(cluster_lens)
    except ValueError:
        max_len = 0
    coor_arr_1 = []
    for cluster in cluster_arr:
        if max_len == len(cluster):
            coor_arr_1 = cluster
            break
    # print(coor_arr_1)

    cluster_arr = []
    for coor in coordinates:
        if coor not in coor_arr_1:
            cluster_arr.append([coor])
    # print(cluster_arr)

    for i in range(len(cluster_arr)):
        for coor in coordinates:
            if coor not in coor_arr_1 and cluster_arr[i][
                    0] != coor and cluster_arr[i][0][0] == coor[
                        0] and cluster_arr[i][0][1] == coor[1] and cluster_arr[
                            i][0][2] == coor[2]:
                cluster_arr[i].append(coor)
    # print(cluster_arr)

    cluster_lens = []
    for cluster in cluster_arr:
        cluster_lens.append(len(cluster))
    # print(cluster_lens)

    try:
        max_len = max(cluster_lens)
    except ValueError:
        max_len = 0

    # print(cluster_arr)
    coor_arr_2 = []
    similar_cluster_arr = []
    temp = 0
    for cluster in cluster_arr:
        if max_len == len(cluster):
            temp += 1
            coor_arr_2 = cluster
            similar_cluster_arr.append(cluster)
    if temp > 1:
        similar_val_arr = []
        for cluster in similar_cluster_arr:
            similar_val_arr.append(similar_value(cluster, coor_arr_1))
        right_index = np.where(
            similar_val_arr == np.amin(similar_val_arr))[0][0]
        coor_arr_2 = similar_cluster_arr[right_index]
    # print(coor_arr_2)

    temp_lst = []

    if len(eliminate(coor_arr_1, temp_lst)) != 0:
        coor_arr_1 = eliminate(coor_arr_1, temp_lst)
    else:
        insert_point(coor_arr_1)
    # print('Arr 1 after remove:')
    # print(coor_arr_1)

    if len(eliminate(coor_arr_2, temp_lst)) != 0:
        coor_arr_2 = eliminate(coor_arr_2, temp_lst)
    else:
        insert_point(coor_arr_2)
    # print('Arr 2 after remove:')
    # print(coor_arr_2)

    X = []
    Y = []

    if findX(coor_arr_1, coordinates) > findX(coor_arr_2, coordinates):
        X = coor_arr_1
        Y = coor_arr_2
    else:
        X = coor_arr_2
        Y = coor_arr_1

    print('X: ' + str(X))
    print('Y: ' + str(Y))

    temp_arr = []
    for coor in X:
        try:
            float(coor)
            temp_arr.append(float(coor))
        except ValueError:
            pass
    X = temp_arr
    temp_arr = []
    for coor in Y:
        try:
            float(coor)
            temp_arr.append(float(coor))
        except ValueError:
            pass
    Y = temp_arr

    sim_arr = str_similarity(X, coordinates)
    sim_arr = np.array(sim_arr)
    try:
        optimal_index = np.where(sim_arr == np.amax(sim_arr))[0][0]
        x = X[optimal_index]
    except ValueError:
        x = 0

    sim_arr = str_similarity(Y, coordinates)
    sim_arr = np.array(sim_arr)
    try:
        optimal_index = np.where(sim_arr == np.amax(sim_arr))[0][0]
        y = Y[optimal_index]
    except ValueError:
        y = 0

    print('Most likely to be x: ' + str(x))
    print('Most likely to be y: ' + str(y))

    #################### VN2K TO WGS83 ####################

    y, x = vn2k_to_wgs84((x, y), crs)
    print((x, y))
    return (x, y)


# processing('test_16.jpg', 9210)
コード例 #4
0
class NpPointsCraft(object):
    """
    NpPointsCraft Class
    git clone https://github.com/clovaai/CRAFT-pytorch.git
    """
    def __init__(self, **args):
        pass

    @classmethod
    def get_classname(cls):
        return cls.__name__

    def load(self, mtl_model_path="latest", refiner_model_path="latest"):
        """
        TODO: describe method
        """
        if mtl_model_path == "latest":
            model_info = download_latest_model(self.get_classname(),
                                               "mtl",
                                               ext="pth",
                                               mode=get_mode_torch())
            mtl_model_path = model_info["path"]
        if refiner_model_path == "latest":
            model_info = download_latest_model(self.get_classname(),
                                               "refiner",
                                               ext="pth",
                                               mode=get_mode_torch())
            refiner_model_path = model_info["path"]
        device = "cpu"
        if get_mode_torch() == "gpu":
            device = "cuda"
        self.loadModel(device, True, mtl_model_path, refiner_model_path)

    def loadModel(self,
                  device="cuda",
                  is_refine=True,
                  trained_model=os.path.join(CRAFT_DIR,
                                             'weights/craft_mlt_25k.pth'),
                  refiner_model=os.path.join(
                      CRAFT_DIR, 'weights/craft_refiner_CTW1500.pth')):
        """
        TODO: describe method
        """
        is_cuda = device == "cuda"
        self.is_cuda = is_cuda

        # load net
        self.net = CRAFT()  # initialize

        print('Loading weights from checkpoint (' + trained_model + ')')
        if is_cuda:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(torch.load(trained_model, map_location='cpu')))

        if is_cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if is_refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            print('Loading weights of refiner from checkpoint (' +
                  refiner_model + ')')
            if is_cuda:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(refiner_model,
                                             map_location='cpu')))

            self.refine_net.eval()
            self.is_poly = True

    def detectByImagePath(self,
                          image_path,
                          targetBoxes,
                          qualityProfile=[1, 0, 0],
                          debug=False):
        """
        TODO: describe method
        """
        image = imgproc.loadImage(image_path)
        for targetBox in targetBoxes:
            x = min(targetBox['x1'], targetBox['x2'])
            w = abs(targetBox['x2'] - targetBox['x1'])
            y = min(targetBox['y1'], targetBox['y2'])
            h = abs(targetBox['y2'] - targetBox['y1'])
            #print('x: {} w: {} y: {} h: {}'.format(x,w,y,h))
            image_part = image[y:y + h, x:x + w]
            points = self.detectInBbox(image_part)
            propablyPoints = addCoordinatesOffset(points, x, y)
            targetBox['points'] = []
            targetBox['imgParts'] = []
            if (len(propablyPoints)):
                targetPointsVariants = makeRectVariants2(
                    propablyPoints, h, w, qualityProfile)
                # targetBox['points'] = addCoordinatesOffset(points, x, y)
                # targetPointsVariants = [targetPoints, fixSideFacets(targetPoints)]
                if len(targetPointsVariants) > 1:
                    imgParts = [
                        getCvZoneRGB(image, reshapePoints(rect, 1))
                        for rect in targetPointsVariants
                    ]
                    idx = detectBestPerspective(
                        normalizePerspectiveImages(imgParts))
                    print('--------------------------------------------------')
                    print('idx={}'.format(idx))
                    #targetBox['points'] = addoptRectToBbox2(targetPointsVariants[idx], image.shape,x,y)
                    targetBox['points'] = targetPointsVariants[idx]
                    targetBox['imgParts'] = imgParts
                else:
                    targetBox['points'] = targetPointsVariants[0]
        return targetBoxes, image

    def detect(self,
               image,
               targetBoxes,
               qualityProfile=[1, 0, 0],
               debug=False):
        """
        TODO: describe method
        """
        all_points = []
        for targetBox in targetBoxes:
            x = int(min(targetBox[0], targetBox[2]))
            w = int(abs(targetBox[2] - targetBox[0]))
            y = int(min(targetBox[1], targetBox[3]))
            h = int(abs(targetBox[3] - targetBox[1]))

            image_part = image[y:y + h, x:x + w]
            propablyPoints = addCoordinatesOffset(
                self.detectInBbox(image_part), x, y)
            points = []
            if (len(propablyPoints)):
                targetPointsVariants = makeRectVariants2(
                    propablyPoints, h, w, qualityProfile)
                if len(targetPointsVariants) > 1:
                    imgParts = [
                        getCvZoneRGB(image, reshapePoints(rect, 1))
                        for rect in targetPointsVariants
                    ]
                    idx = detectBestPerspective(
                        normalizePerspectiveImages(imgParts))
                    points = targetPointsVariants[idx]
                else:
                    points = targetPointsVariants[0]
                all_points.append(points)
            else:
                all_points.append([[x, y + h], [x, y], [x + w, y],
                                   [x + w, y + h]])
        return all_points

    def detectInBbox(self, image, debug=False):
        """
        TODO: describe method
        """
        low_text = 0.4
        link_threshold = 0.7  # 0.4
        text_threshold = 0.6
        canvas_size = 1280
        mag_ratio = 1.5

        t = time.time()
        bboxes, polys, score_text = test_net(self.net, image, text_threshold,
                                             link_threshold, low_text,
                                             self.is_cuda, self.is_poly,
                                             canvas_size, self.refine_net,
                                             mag_ratio)
        if debug:
            print("elapsed time : {}s".format(time.time() - t))
        dimensions = []
        for poly in bboxes:
            dimensions.append({
                'dx': distance(poly[0], poly[1]),
                'dy': distance(poly[1], poly[2])
            })

        if (debug):
            print(score_text.shape)
            # print(polys)
            print(dimensions)
            print(bboxes)

        np_bboxes_idx, garbage_bboxes_idx = split_boxes(bboxes, dimensions)

        targetPoints = []
        if (debug):
            print('np_bboxes_idx')
            print(np_bboxes_idx)
            print('garbage_bboxes_idx')
            print(garbage_bboxes_idx)
            print('raw_boxes')
            print(raw_boxes)
            print('raw_polys')
            print(raw_polys)

        if len(np_bboxes_idx) == 1:
            targetPoints = bboxes[np_bboxes_idx[0]]

        if len(np_bboxes_idx) > 1:
            targetPoints = minimum_bounding_rectangle(
                np.concatenate([bboxes[i] for i in np_bboxes_idx], axis=0))

        imgParts = []
        if len(np_bboxes_idx) > 0:
            targetPoints = normalizeRect(targetPoints)
            if (debug):
                print('###################################')
                print(targetPoints)

            if (debug):
                print('image.shape')
                print(image.shape)
            #targetPoints = fixSideFacets(targetPoints, image.shape)
            targetPoints = addoptRectToBbox(targetPoints, image.shape, 7, 12,
                                            0, 12)
        return targetPoints
コード例 #5
0
def analysis(image_path, result_path):
    """ For test images in a folder """
    net = CRAFT()     # initialize

    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()

    image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net,
                                         image,
                                         args.text_threshold,
                                         args.link_threshold,
                                         args.low_text,
                                         args.cuda,
                                         args.poly,
                                         refine_net)
                                         
    opencv_image = cv2.imread(image_path)
    
    for index, box in enumerate(polys):
        xmin, xmax, ymin, ymax = box[0][0], box[1][0], box[0][1], box[2][1]
        multiplier_area = image[int(ymin):int(ymax), int(xmin):int(xmax)]
        
        try:
            im_pil = Image.fromarray(multiplier_area)
            #if you want to detect the text on the image
            if args.ocr_on:
                if args.ocr_method == 'pytesseract':
                    configuration = ("-l eng --oem 1 --psm 7")
                    multiplier = (pytesseract.image_to_string(multiplier_area, config=configuration).lower())
                    multiplier = multiplier.split("\n")[0]
                    
                elif args.ocr_method == 'TPS-ResNet-BiLSTM':
                    multiplier = text_recognition.recognition(im_pil)
                    
                cv2.putText(opencv_image, multiplier, (int(xmin),int(ymin-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 1)
                
            cv2.rectangle(opencv_image,(int(xmin),int(ymin)), (int(xmax),int(ymax)),(0,0,255), 1)
            cv2.imwrite(result_path, opencv_image)
                
        except:
            print("====ERROR====", traceback.format_exc())
            pass
コード例 #6
0
class TextDetection:
    def __init__(self):

        self.trained_model = '../chinese-ocr/weights/craft_mlt_25k.pth'
        self.text_threshold = 0.75
        self.low_text = 0.6
        self.link_threshold = 0.9
        self.cuda = True
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.poly = False
        self.show_time = False

        self.net = CRAFT()

        self.net.load_state_dict(
            copy_state_dict(torch.load(self.trained_model)))
        self.net = self.net.cuda()

        cudnn.benchmark = False

        self.net.eval()

    def get_bounding_box(self, image_file, verbose=False):
        """
        Get the bounding boxes from image_file
        :param image_file
        :param verbose
        :return:
        """
        image = cv2.imread(image_file)
        img_dim = image.shape
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)

        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if self.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               self.text_threshold,
                                               self.link_threshold,
                                               self.low_text, self.poly)

        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)

        center_point = []
        for i, _b in enumerate(boxes):
            b = np.array(_b, dtype=np.int16)
            xmin = np.min(b[:, 0])
            ymin = np.min(b[:, 1])

            xmax = np.max(b[:, 0])
            ymax = np.max(b[:, 1])
            x_m = xmin + (xmax - xmin) / 2
            y_m = ymin + (ymax - ymin) / 2
            center_point.append([x_m, y_m])

        list_images = get_box_img(boxes, image)

        if verbose:
            for _b in boxes:
                b = np.array(_b, dtype=np.int16)
                xmin = np.min(b[:, 0])
                ymin = np.min(b[:, 1])

                xmax = np.max(b[:, 0])
                ymax = np.max(b[:, 1])

                r = image[ymin:ymax, xmin:xmax, :].copy()

        return boxes, list_images, center_point, img_dim