コード例 #1
0
def text_detector(image, count):
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' +
              args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(
                copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(
                    torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()

    # load data
    # for k, image_path in enumerate(image_list):
    #     print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
    # image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net, image, args.text_threshold,
                                         args.link_threshold, args.low_text,
                                         args.cuda, args.poly, refine_net)

    # # save score text
    # filename, file_ext = os.path.splitext(os.path.basename(image_path))
    # mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    # cv2.imwrite(mask_file, score_text)

    # cropping to be used for OCR
    crop_object(image[:, :, ::-1], polys, result_folder, count)

    # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    print("elapsed time : {}s".format(time.time() - t))
コード例 #2
0
ファイル: train_detector.py プロジェクト: aerjayc/asstr
def init_model(weight_dir, weight_path=None, num_class=2, linear=True):
    # make weight_dir if it doesn't exist
    Path(weight_dir).mkdir(parents=True, exist_ok=True)

    # input: NCHW
    model = CRAFT(pretrained=True, num_class=num_class, linear=linear).cuda()
    # output: NHWC

    if weight_path:
        # pretrained_weight_path = os.path.join(weight_dir, weight_fname)
        model.load_state_dict(torch.load(weight_path))
        model.eval()

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)  # tweak parameters

    return model, criterion, optimizer
コード例 #3
0
def init_craft_net(args):
    print("-" * 50)
    print("init_craft_net")            
    net = CRAFT()     # initialize
    print('CRAFT Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False
    net.eval()

    return net
コード例 #4
0
    def init_craft_net(self):
        """
        初始化 文本区域检测网络
        """        
        net = CRAFT()     # initialize
        print('CRAFT Loading weights from checkpoint (' + self.args.trained_model + ')')
        if self.args.cuda:
            net.load_state_dict(copyStateDict(torch.load(self.args.trained_model)))
        else:
            net.load_state_dict(copyStateDict(torch.load(self.args.trained_model, map_location='cpu')))

        if self.args.cuda:
            net = net.cuda()
            net = torch.nn.DataParallel(net)
            cudnn.benchmark = False
        net.eval()
        
        return net
def test_craft(objetos, dir_pth="craft_mlt_25k.pth"):
    net = CRAFT()  # initialize
    print('Loading weights from checkpoint (' + dir_pth + ')')

    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(dir_pth)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(dir_pth, map_location='cpu')))
    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False
    net.eval()
    t = time.time()
    imgs_text_obj = {}
    #Analizar objetos detectados
    for i in range(len(objetos)):
        #Recuperar id y imagenes del objeto
        imagenes = objetos[i]["objetos"]
        id_obj = objetos[i]["id_obj"]
        imgs_text_obj[id_obj] = []
        #Obtener textos de cada una de las imagenes que componen el objeto
        crops_text = []
        for img in imagenes:
            img_path = img["img_path"]
            bb = img["bb"]
            img_crop = recuperar_crop_imagen_bb(img_path, bb)
            img_crop = np.array(img_crop)
            image = imgproc.loadImage(img_crop)
            # detection
            bboxes, polys, score_text = test_net(net, image,
                                                 args.text_threshold,
                                                 args.link_threshold,
                                                 args.low_text, args.cuda,
                                                 args.poly, None)
            for i, bbs in enumerate(bboxes):
                crop = bounding_box(bbs)
                cropped = image[crop[0][1]:crop[1][1], crop[0][0]:crop[1][0]]
                crops_text.append(cropped)
        imgs_text_obj[id_obj] = crops_text
    print("elapsed time : {}s".format(time.time() - t))
    return imgs_text_obj
コード例 #6
0
    def __init__(self,
                 img_dir,
                 gt_dir,
                 weights_path=None,
                 color_flag=1,
                 character_map=True,
                 affinity_map=False,
                 word_map=False,
                 direction_map=True):
        # super(ICDAR2015Dataset).__init__()
        self.raw_dataset = ICDAR2015Dataset(img_dir,
                                            gt_dir,
                                            color_flag=color_flag)

        self.num_class = character_map + affinity_map + word_map + 2 * direction_map
        model = CRAFT(pretrained=False, num_class=self.num_class).cuda()
        if weights_path:
            model.load_state_dict(torch.load(weights_path))
            model.eval()
        self.model = model
def test_craft_prueba(ruta, DIR_PTH="craft_mlt_25k.pth"):
    net = CRAFT()  # initialize
    print('Loading weights from checkpoint (' + DIR_PTH + ')')

    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(DIR_PTH)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(DIR_PTH, map_location='cpu')))
    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False
    net.eval()
    t = time.time()
    imgs_text_obj = {}
    arr = os.listdir(ruta)
    for i, name in enumerate(arr):
        arr[i] = ruta + "/" + name
    crops_text = []
    #Analizar objetos detectados
    for i in range(len(arr)):
        #Obtener textos de cada una de las imagenes que componen el objeto

        img_crop = recuperar_imagen_dir(arr[i])
        img_crop = np.array(img_crop)
        image = imgproc.loadImage(img_crop)
        # detection
        bboxes, polys, score_text = test_net(net, image, args.text_threshold,
                                             args.link_threshold,
                                             args.low_text, args.cuda,
                                             args.poly, None)
        for i, bbs in enumerate(bboxes):
            crop = bounding_box(bbs)
            cropped = image[crop[0][1]:crop[1][1], crop[0][0]:crop[1][0]]
            crops_text.append(cropped)
    print("elapsed time : {}s".format(time.time() - t))
    return crops_text
コード例 #8
0
class Localization:
    def __init__(self,
                 weights_path='./craft/weights/craft_mlt_25k.pth',
                 canvas_size=1280,
                 text_threshold=0.9,
                 link_threshold=0.1,
                 mag_ratio=1.5,
                 low_text=0.5,
                 gpu=True):
        self.canvas_size = canvas_size
        self.text_threshold = text_threshold
        self.link_threshold = link_threshold
        self.low_text = low_text
        self.mag_ratio = mag_ratio
        self.gpu = gpu
        self.model = CRAFT()
        self.model.load_state_dict(
            self.copy_state_dict(torch.load(weights_path, map_location='cpu')))
        if self.gpu:
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
        self.model.eval()

    @staticmethod
    def copy_state_dict(state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    def predict(self, image):
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_AREA,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if self.gpu:
            x = x.cuda()

        y, _ = self.model(x)

        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        boxes = craft_utils.getDetBoxes(score_text, score_link,
                                        self.text_threshold,
                                        self.link_threshold, self.low_text)
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)

        boxes = np.reshape(boxes, newshape=(-1, 8)).astype(np.int)
        print(boxes)
        return boxes
コード例 #9
0
ファイル: recognizer.py プロジェクト: aerjayc/asstr
def export_task3(img_dir,
                 classifier_weight_path,
                 detector_weight_path,
                 alphabet=datasets.ALPHANUMERIC,
                 submit_dir='submit',
                 expand_factor=3,
                 thresh=0.3,
                 verbose=True,
                 gt_path=None,
                 min_side_length=16):
    # make submit dirs
    Path(submit_dir).mkdir(parents=True, exist_ok=True)

    # instantiate models
    detector = CRAFT(pretrained=True, num_class=2, linear=True).cuda()
    detector.load_state_dict(torch.load(detector_weight_path))
    detector.eval()

    classifier = CharClassifier(num_classes=len(alphabet)).double().cuda()
    classifier.load_state_dict(torch.load(classifier_weight_path))
    classifier.eval()

    # prepare results filename
    submit_name = f"task3.txt"
    submit_path = os.path.join(submit_dir, submit_name)
    with open(submit_path, 'w') as f:  # clear contents first
        f.write("")

    # read gt file
    if gt_path:
        with open(gt_path, 'r') as f:
            gt = f.read()
    else:
        gt = ""

    # get predictions for all images
    img_names = get_filenames(img_dir, ['.png'], recursive=False)
    for img_name in img_names:
        res_name = img_name[5:]
        img_path = os.path.join(img_dir, img_name)

        # get image
        img = PIL.Image.open(img_path).convert('RGB')

        # resize if too small (ie when side < 16)
        w, h = img.width, img.height
        if min(h, w) < min_side_length:
            if h <= w:
                ratio = min_side_length / h
            elif w < h:
                ratio = min_side_length / w

            new_size = (int(ratio * w), int(ratio * h))

            img = img.resize(new_size)

        # convert PIL.Image to torch.Tensor (1, C, H, W)
        img = ((np.array(img) / 255.0) - 0.5) * 2
        img = torch.from_numpy(img).permute(2, 0, 1)[None, ...].cuda()

        # get heatmaps
        with torch.no_grad():
            output, _ = detector(img.float())
            char_heatmap = output[0, :, :, :1]  # (H, W, 1)

        cropped_chars, _, _, _ = map_to_crop(img[0],
                                             char_heatmap,
                                             thresh=thresh,
                                             expand_factor=expand_factor)
        if cropped_chars is not None:
            # get character predicitons
            with torch.no_grad():
                onehots = classifier(cropped_chars.double())
                chars = onehot_to_chars(onehots, alphabet)

            # string has same order as charBBs
            string = ""
            for c in chars:
                if c is None:
                    string += '?'
                elif c == '"':
                    string += r'\"'
                elif c == '\\':
                    string += r'\\'
                else:
                    string += c
        else:
            string = ""

        # write to file
        contents = f'{res_name}, "{string}"\n'
        if verbose:
            print(contents[:-1], end='')
            if gt:
                pattern = re.compile(re.escape(img_name) + r',\ "(.*)"')
                matches = re.search(pattern, gt)
                if matches:
                    print(f'\tgt: "{matches.group(1)}"', end='')
            print('')

        with open(submit_path, 'a') as f:
            f.write(contents)

    print("Done!")
コード例 #10
0
ファイル: recognizer.py プロジェクト: aerjayc/asstr
def export_task1_4(test_img_dir,
                   test_gt_dir,
                   classifier_weight_path,
                   detector_weight_path,
                   alphabet=datasets.ALPHANUMERIC,
                   submit_dir='submit',
                   thresh=0.3):
    testset = datasets.ICDAR2013TestDataset(test_gt_dir, test_img_dir)
    max_size = (700, 700)

    # make submit dirs
    submit_task1_dirpath = os.path.join(submit_dir, 'text_localization_output')
    submit_task4_dirpath = os.path.join(submit_dir, 'end_to_end_output')
    Path(submit_task1_dirpath).mkdir(parents=True, exist_ok=True)
    Path(submit_task4_dirpath).mkdir(parents=True, exist_ok=True)

    # instantiate models
    detector = CRAFT(pretrained=True, num_class=2, linear=True).cuda()
    detector.load_state_dict(torch.load(detector_weight_path))
    detector.eval()

    classifier = CharClassifier(num_classes=len(alphabet)).double().cuda()
    classifier.load_state_dict(torch.load(classifier_weight_path))
    classifier.eval()

    # get predictions for all test images
    for img, _, _, pil_img in testset:
        filename = os.path.basename(pil_img.filename)
        orig_w, orig_h = pil_img.width, pil_img.height

        # resize when # of pixels exceeds size
        if (max_size is not None) and ((orig_w * orig_h) >
                                       (max_size[0] * max_size[1])):
            """
            max_w, max_h = max_size
            if orig_w > orig_h:
                ratio = max_w / orig_w
            elif orig_h > orig_w:
                ratio = max_h / orig_h

            new_w, new_h = int(ratio*orig_w), int(ratio*orig_h)
            print(f"Resizing ({orig_w}, {orig_h}) to ({new_w}, {new_h})",
                    end='')
            img = img.resize((new_w, new_h))
            """
            img = img.permute(1, 2, 0).cpu().numpy()
            img = cv2.resize(img, dsize=max_size)  # HWC
            img = torch.from_numpy(img).permute(2, 0, 1)  # CHW
            new_w, new_h = max_size
        else:
            new_w, new_h = orig_w, orig_h

        print(filename + '... ', end='')
        wordBBs, words = recognizer(img,
                                    detector,
                                    classifier,
                                    alphabet,
                                    thresh=thresh)
        if wordBBs is None:  # write blank files
            print("No characters found. Making blank file... ", end='')
            submit_name = f"res_{filename[:-4]}.txt"
            submit_path = os.path.join(submit_task1_dirpath, submit_name)
            with open(submit_path, 'w') as f:
                f.write("")

            submit_path = os.path.join(submit_task4_dirpath, submit_name)
            with open(submit_path, 'w') as f:
                f.write("")

            continue

        unsquish_factor = np.array([orig_w / new_w,
                                    orig_h / new_h]).reshape(1, 1, 2)
        wordBBs = unsquish_factor * wordBBs

        print('Formatting submission... ', end='')
        # create submission contents for task 1 (text localization)
        xymin = np.min(wordBBs, axis=1).astype('int32')
        xymax = np.max(wordBBs, axis=1).astype('int32')
        xmin, ymin = xymin[:, 0], xymin[:, 1]
        xmax, ymax = xymax[:, 0], xymax[:, 1]

        contents = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}
        submission = pd.DataFrame(contents)

        # write to file
        submit_name = f"res_{filename[:-4]}.txt"
        submit_path = os.path.join(submit_task1_dirpath, submit_name)
        submission.to_csv(submit_path, header=False, index=False)

        # create submission contents for task 4 (end to end recognition)
        x1, y1 = xmin, ymin
        x2, y2 = xmax, ymin
        x3, y3 = xmax, ymax
        x4, y4 = xmin, ymax
        contents = {
            'x1': x1,
            'y1': y1,
            'x2': x2,
            'y2': y2,
            'x3': x3,
            'y3': y3,
            'x4': x4,
            'y4': y4,
            'transcription': words
        }
        submission = pd.DataFrame(contents)
        submit_path = os.path.join(submit_task4_dirpath, submit_name)
        submission.to_csv(submit_path, header=False, index=False)

        print(f'Done.')

    print("Done!")
コード例 #11
0
ファイル: test.py プロジェクト: xianyuntang/snapnews-loc
    return boxes, ret_score_text


if __name__ == '__main__':
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda)

        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)