def test(modelpara):
    # load net
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint {}'.format(modelpara))
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(modelpara)))
    else:
        net.load_state_dict(copyStateDict(torch.load(modelpara, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        #cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    print("elapsed time : {}s".format(time.time() - t))
Exemple #2
0
def LoadDetectionModel(args):
    net = CRAFT()     # initialize
    print('Loading weights from checkpoint (' + args.trained_model + ')')
    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))#,map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()
    return net
Exemple #3
0
def main(trained_model='weights/craft_mlt_25k.pth', 
            text_threshold=0.7, low_text=0.4, link_threshold=0.4, cuda=True,
            canvas_size=1280, mag_ratio=1.5,
            poly=False, show_time=False, test_folder='/data/', 
            refine=True, refiner_model='weights/craft_refiner_CTW1500.pth'):
# if __name__ == '__main__':
    # load net
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + trained_model + ')')
    if cuda:
        net.load_state_dict(copyStateDict(torch.load(trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu')))

    if cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + refiner_model + ')')
        if cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model, map_location='cpu')))

        refine_net.eval()
        poly = True

    t = time.time()

    # load data
    image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net)

    # save score text
    filename, file_ext = os.path.splitext(os.path.basename(image_path))
    mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    cv2.imwrite(mask_file, score_text)

    final_img = file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
    
    print("elapsed time : {}s".format(time.time() - t))
Exemple #4
0
def main():
        # load net
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()
    print(image_list)
    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)

        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    # print("elapsed time : {}s".format(time.time() - t))
def createModel():
    net = CRAFT()

    weightPath = os.path.join(settings.BASE_DIR,
                              'CRAFT/weights/craft_mlt_25k.pth')

    print('Loading weights from checkpoint (' + weightPath + ')')

    net.load_state_dict(copyStateDict(torch.load(weightPath)))
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = False
    net.eval()
    return net
Exemple #6
0
def load_detection_model():
  parser = argparse.ArgumentParser(description='CRAFT Text Detection')
  parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
  parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
  parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
  parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
  parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
  parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
  parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
  parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
  parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
  parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
  parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
  parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
  args = parser.parse_args(["--trained_model=./models/craft_mlt_25k.pth","--refine", "--refiner_model=./models/craft_refiner_CTW1500.pth"])
  net = CRAFT()     # initialize
  print('Loading weights from checkpoint (' + args.trained_model + ')')
  if args.cuda:
      net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
  else:
      net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

  if args.cuda:
      net = net.cuda()
      net = torch.nn.DataParallel(net)
      cudnn.benchmark = False

  net.eval()

  # LinkRefiner
  refine_net = None
  if args.refine:
      from refinenet import RefineNet
      refine_net = RefineNet()
      print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
      if args.cuda:
          refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
          refine_net = refine_net.cuda()
          refine_net = torch.nn.DataParallel(refine_net)
      else:
          refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

      refine_net.eval()
      # args.poly = True
  return net,refine_net,args
Exemple #7
0
def main(pth_file_path):
    cuda = True
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + pth_file_path + ')')
    if cuda:
        net.load_state_dict(copyStateDict(torch.load(pth_file_path)))
    else:
        net.load_state_dict(copyStateDict(torch.load(pth_file_path, map_location='cpu')))

    if cuda:
        net = net.cuda()
        cudnn.benchmark = False

    net.eval()

    script_module = torch.jit.script(net)

    file_path_without_ext = os.path.splitext(pth_file_path)[0]
    output_file_path = file_path_without_ext + ".pt"
    script_module.save(output_file_path)
    print("TorchScript model created:", output_file_path)
Exemple #8
0
    return boxes, polys, ret_score_text



if __name__ == '__main__':
    # load net
    mine_net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        mine_net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        mine_net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        mine_net = mine_net.cuda()
        mine_net = torch.nn.DataParallel(mine_net)
        cudnn.benchmark = False

    mine_net.eval()

    # LinkRefiner
    refine_net = None
    # if args.refine:
    #     from refinenet import RefineNet
    #     refine_net = RefineNet()
    #     print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
    #     if args.cuda:
    #         refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
    #         refine_net = refine_net.cuda()
    #         refine_net = torch.nn.DataParallel(refine_net)
Exemple #9
0
def ground_truth(args):
    # initiate pretrained network
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    filelist, _, _ = file_utils.list_files('/home/ubuntu/Kyumin/Autotation/data/IC13/images')

    for img_name in filelist:
        # get datapath
        if 'train' in img_name:
            label_name = img_name.replace('images/train/', 'labels/train/gt_').replace('jpg', 'txt')
        else:
            label_name = img_name.replace('images/test/', 'labels/test/gt_').replace('jpg', 'txt')
        label_dir = img_name.replace('Autotation', 'craft').replace('images', 'labels').replace('.jpg', '/')

        os.makedirs(label_dir, exist_ok=True)

        image = imgproc.loadImage(img_name)

        gt_boxes = []
        gt_words = []
        with open(label_name, 'r', encoding='utf-8-sig') as f:
            lines = f.readlines()
        for line in lines:
            if 'IC13' in img_name:  # IC13
                gt_box, gt_word, _ = line.split('"')
                if 'train' in img_name:
                    x1, y1, x2, y2 = [int(a) for a in gt_box.strip().split(' ')]
                else:
                    x1, y1, x2, y2 = [int(a.strip()) for a in gt_box.split(',') if a.strip().isdigit()]
                gt_boxes.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
                gt_words.append(gt_word)
            elif 'IC15' in img_name:
                gt_data = line.strip().split(',')
                gt_box = gt_data[:8]
                if len(gt_data) > 9:
                    gt_word = ','.join(gt_data[8:])
                else:
                    gt_word = gt_data[-1]
                gt_box = [int(a) for a in gt_box]
                gt_box = np.reshape(np.array(gt_box), (4, 2))
                gt_boxes.append(gt_box)
                gt_words.append(gt_word)

        score_region, score_link, conf_map = generate_gt(net, image, gt_boxes, gt_words, args)

        torch.save(score_region, label_dir + 'region.pt')
        torch.save(score_link, label_dir + 'link.pt')
        torch.save(conf_map, label_dir + 'conf.pt')
Exemple #10
0
class TextDetector:
    def __init__(self):
        #Parameters
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.text_threshold = 0.7
        self.low_text = 0.4
        self.link_threshold = 0.4
        self.refine = False
        self.refiner_model = ''
        self.poly = False
        self.cuda = True

        self.net = CRAFT()
        if self.cuda:
            self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth')))
        else:
            self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth', map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if self.refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            if self.cuda:
                self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model, map_location='cpu')))
            self.refine_net.eval()
            self.poly = True

    def detect(self, image):
        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, self.canvas_size,
                                                                              interpolation=cv2.INTER_LINEAR,
                                                                              mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]

        if self.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if self.refine_net is not None:
            with torch.no_grad():
                y_refiner = self.refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()


        # Post-processing
        boxes, _ = craft_utils.getDetBoxes(score_text, score_link, self.text_threshold, self.link_threshold,
                                               self.low_text, self.poly)
        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        toRet = []
        for box in boxes:
            toRet.append(box2xyxy(box, image.shape[0: 2]))

        return toRet
class NpPointsCraft(object):
    """
    NpPointsCraft Class
    git clone https://github.com/clovaai/CRAFT-pytorch.git
    """
    def __init__(self, **args):
        pass

    @classmethod
    def get_classname(cls):
        return cls.__name__

    def load(self, mtl_model_path="latest", refiner_model_path="latest"):
        """
        TODO: describe method
        """
        if mtl_model_path == "latest":
            model_info = download_latest_model(self.get_classname(),
                                               "mtl",
                                               ext="pth",
                                               mode=get_mode_torch())
            mtl_model_path = model_info["path"]
        if refiner_model_path == "latest":
            model_info = download_latest_model(self.get_classname(),
                                               "refiner",
                                               ext="pth",
                                               mode=get_mode_torch())
            refiner_model_path = model_info["path"]
        device = "cpu"
        if get_mode_torch() == "gpu":
            device = "cuda"
        self.loadModel(device, True, mtl_model_path, refiner_model_path)

    def loadModel(self,
                  device="cuda",
                  is_refine=True,
                  trained_model=os.path.join(CRAFT_DIR,
                                             'weights/craft_mlt_25k.pth'),
                  refiner_model=os.path.join(
                      CRAFT_DIR, 'weights/craft_refiner_CTW1500.pth')):
        """
        TODO: describe method
        """
        is_cuda = device == "cuda"
        self.is_cuda = is_cuda

        # load net
        self.net = CRAFT()  # initialize

        print('Loading weights from checkpoint (' + trained_model + ')')
        if is_cuda:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(torch.load(trained_model, map_location='cpu')))

        if is_cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if is_refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            print('Loading weights of refiner from checkpoint (' +
                  refiner_model + ')')
            if is_cuda:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(refiner_model,
                                             map_location='cpu')))

            self.refine_net.eval()
            self.is_poly = True

    def detectByImagePath(self,
                          image_path,
                          targetBoxes,
                          qualityProfile=[1, 0, 0],
                          debug=False):
        """
        TODO: describe method
        """
        image = imgproc.loadImage(image_path)
        for targetBox in targetBoxes:
            x = min(targetBox['x1'], targetBox['x2'])
            w = abs(targetBox['x2'] - targetBox['x1'])
            y = min(targetBox['y1'], targetBox['y2'])
            h = abs(targetBox['y2'] - targetBox['y1'])
            #print('x: {} w: {} y: {} h: {}'.format(x,w,y,h))
            image_part = image[y:y + h, x:x + w]
            points = self.detectInBbox(image_part)
            propablyPoints = addCoordinatesOffset(points, x, y)
            targetBox['points'] = []
            targetBox['imgParts'] = []
            if (len(propablyPoints)):
                targetPointsVariants = makeRectVariants2(
                    propablyPoints, h, w, qualityProfile)
                # targetBox['points'] = addCoordinatesOffset(points, x, y)
                # targetPointsVariants = [targetPoints, fixSideFacets(targetPoints)]
                if len(targetPointsVariants) > 1:
                    imgParts = [
                        getCvZoneRGB(image, reshapePoints(rect, 1))
                        for rect in targetPointsVariants
                    ]
                    idx = detectBestPerspective(
                        normalizePerspectiveImages(imgParts))
                    print('--------------------------------------------------')
                    print('idx={}'.format(idx))
                    #targetBox['points'] = addoptRectToBbox2(targetPointsVariants[idx], image.shape,x,y)
                    targetBox['points'] = targetPointsVariants[idx]
                    targetBox['imgParts'] = imgParts
                else:
                    targetBox['points'] = targetPointsVariants[0]
        return targetBoxes, image

    def detect(self,
               image,
               targetBoxes,
               qualityProfile=[1, 0, 0],
               debug=False):
        """
        TODO: describe method
        """
        all_points = []
        for targetBox in targetBoxes:
            x = int(min(targetBox[0], targetBox[2]))
            w = int(abs(targetBox[2] - targetBox[0]))
            y = int(min(targetBox[1], targetBox[3]))
            h = int(abs(targetBox[3] - targetBox[1]))

            image_part = image[y:y + h, x:x + w]
            propablyPoints = addCoordinatesOffset(
                self.detectInBbox(image_part), x, y)
            points = []
            if (len(propablyPoints)):
                targetPointsVariants = makeRectVariants2(
                    propablyPoints, h, w, qualityProfile)
                if len(targetPointsVariants) > 1:
                    imgParts = [
                        getCvZoneRGB(image, reshapePoints(rect, 1))
                        for rect in targetPointsVariants
                    ]
                    idx = detectBestPerspective(
                        normalizePerspectiveImages(imgParts))
                    points = targetPointsVariants[idx]
                else:
                    points = targetPointsVariants[0]
                all_points.append(points)
            else:
                all_points.append([[x, y + h], [x, y], [x + w, y],
                                   [x + w, y + h]])
        return all_points

    def detectInBbox(self, image, debug=False):
        """
        TODO: describe method
        """
        low_text = 0.4
        link_threshold = 0.7  # 0.4
        text_threshold = 0.6
        canvas_size = 1280
        mag_ratio = 1.5

        t = time.time()
        bboxes, polys, score_text = test_net(self.net, image, text_threshold,
                                             link_threshold, low_text,
                                             self.is_cuda, self.is_poly,
                                             canvas_size, self.refine_net,
                                             mag_ratio)
        if debug:
            print("elapsed time : {}s".format(time.time() - t))
        dimensions = []
        for poly in bboxes:
            dimensions.append({
                'dx': distance(poly[0], poly[1]),
                'dy': distance(poly[1], poly[2])
            })

        if (debug):
            print(score_text.shape)
            # print(polys)
            print(dimensions)
            print(bboxes)

        np_bboxes_idx, garbage_bboxes_idx = split_boxes(bboxes, dimensions)

        targetPoints = []
        if (debug):
            print('np_bboxes_idx')
            print(np_bboxes_idx)
            print('garbage_bboxes_idx')
            print(garbage_bboxes_idx)
            print('raw_boxes')
            print(raw_boxes)
            print('raw_polys')
            print(raw_polys)

        if len(np_bboxes_idx) == 1:
            targetPoints = bboxes[np_bboxes_idx[0]]

        if len(np_bboxes_idx) > 1:
            targetPoints = minimum_bounding_rectangle(
                np.concatenate([bboxes[i] for i in np_bboxes_idx], axis=0))

        imgParts = []
        if len(np_bboxes_idx) > 0:
            targetPoints = normalizeRect(targetPoints)
            if (debug):
                print('###################################')
                print(targetPoints)

            if (debug):
                print('image.shape')
                print(image.shape)
            #targetPoints = fixSideFacets(targetPoints, image.shape)
            targetPoints = addoptRectToBbox(targetPoints, image.shape, 7, 12,
                                            0, 12)
        return targetPoints
def main(args, logger=None):
    # load net
    net = CRAFT(pretrained=False)  # initialize

    print('Loading weights from checkpoint {}'.format(args.model_path))
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.model_path)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.model_path, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    """ For test images in a folder """
    image_list, _, _ = file_utils.get_files(args.img_path)
    est_folder = os.path.join(args.rst_path, 'est')
    mask_folder = os.path.join(args.rst_path, 'mask')
    eval_folder = os.path.join(args.rst_path, 'eval')
    cg.folder_exists(est_folder, create_=True)
    cg.folder_exists(mask_folder, create_=True)
    cg.folder_exists(eval_folder, create_=True)

    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path))
        image = imgproc.loadImage(image_path)
        # image = cv2.resize(image, dsize=(768, 768), interpolation=cv2.INTER_CUBIC) ##
        bboxes, polys, score_text = test_net(
            net,
            image,
            text_threshold=args.text_threshold,
            link_threshold=args.link_threshold,
            low_text=args.low_text,
            cuda=args.cuda,
            canvas_size=args.canvas_size,
            mag_ratio=args.mag_ratio,
            poly=args.poly,
            show_time=args.show_time)
        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = mask_folder + "/res_" + filename + '_mask.jpg'
        if not (cg.file_exists(mask_file)):
            cv2.imwrite(mask_file, score_text)

        file_utils.saveResult15(image_path,
                                bboxes,
                                dirname=est_folder,
                                mode='test')

    eval_dataset(est_folder=est_folder,
                 gt_folder=args.gt_path,
                 eval_folder=eval_folder,
                 dataset_type=args.dataset_type)
    print("elapsed time : {}s".format(time.time() - t))
class TextExtractor():
    def __init__(self, image_folder, extract_text_file, split):
        self.i_folder = image_folder
        #print(image_folder)
        #print("aaaaaaa test")
        self.extract_text_file = extract_text_file
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.show_time = False
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.cuda = torch.cuda.is_available()
        self.net = CRAFT()  #(1st model) model to detect words in images
        if self.cuda:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load('CRAFT-pytorch/craft_mlt_25k.pth')))
        else:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load('CRAFT-pytorch/craft_mlt_25k.pth',
                               map_location='cpu')))
        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False
        self.net.eval()
        self.refine_net = None

        self.text_threshold = 0.7
        self.link_threshold = 0.4
        self.low_text = 0.4
        self.poly = False

        self.result_folder = './' + split + '_' + 'intermediate_result/'

        if not os.path.isdir(self.result_folder):
            os.mkdir(self.result_folder)

        #Parameters for image to text model (2nd model)
        self.parser = argparse.ArgumentParser()
        #Data processing
        self.parser.add_argument('--batch_max_length',
                                 type=int,
                                 default=25,
                                 help='maximum-label-length')
        self.parser.add_argument('--imgH',
                                 type=int,
                                 default=32,
                                 help='the height of the input image')
        self.parser.add_argument('--imgW',
                                 type=int,
                                 default=100,
                                 help='the width of the input image')
        self.parser.add_argument('--rgb',
                                 default=False,
                                 action='store_true',
                                 help='use rgb input')
        self.parser.add_argument(
            '--character',
            type=str,
            default='0123456789abcdefghijklmnopqrstuvwxyz',
            help='character label')
        self.parser.add_argument('--sensitive',
                                 action='store_true',
                                 help='for sensitive character mode')
        self.parser.add_argument(
            '--PAD',
            action='store_true',
            help='whether to keep ratio then pad for image resize')
        #Model Architecture
        self.parser.add_argument('--Transformation',
                                 type=str,
                                 default='TPS',
                                 help='Transformation stage. None|TPS')
        self.parser.add_argument(
            '--FeatureExtraction',
            type=str,
            default='ResNet',
            help='FeatureExtraction stage. VGG|RCNN|ResNet')
        self.parser.add_argument('--SequenceModeling',
                                 type=str,
                                 default='BiLSTM',
                                 help='SequenceModeling stage. None|BiLSTM')
        self.parser.add_argument('--Prediction',
                                 type=str,
                                 default='Attn',
                                 help='Prediction stage. CTC|Attn')
        self.parser.add_argument('--num_fiducial',
                                 type=int,
                                 default=20,
                                 help='number of fiducial points of TPS-STN')
        self.parser.add_argument(
            '--input_channel',
            type=int,
            default=1,
            help='the number of input channel of Feature extractor')
        self.parser.add_argument(
            '--output_channel',
            type=int,
            default=512,
            help='the number of output channel of Feature extractor')
        self.parser.add_argument('--hidden_size',
                                 type=int,
                                 default=256,
                                 help='the size of the LSTM hidden state')
        #self.opt = self.parser.parse_args()
        self.opt, unknown = self.parser.parse_known_args()
        #self.opt, unknown = self.parser.parse_known_args()

        if 'CTC' in self.opt.Prediction:
            self.converter = CTCLabelConverter(self.opt.character)
        else:
            self.converter = AttnLabelConverter(self.opt.character)
        self.opt.num_class = len(self.converter.character)
        #print(opt.rgb)
        if self.opt.rgb:
            self.opt.input_channel = 3
        self.opt.num_gpu = torch.cuda.device_count()
        self.opt.batch_size = 192
        #self.opt.batch_size = 3
        self.opt.workers = 0
        self.model = Model(self.opt)  #image to text model (2nd model)
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.load_state_dict(
            torch.load(
                'deep-text-recognition-benchmark/TPS-ResNet-BiLSTM-Attn.pth',
                map_location=self.device))
        self.model.eval()

    def copyStateDict(self, state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    def test_net(self,
                 net,
                 image,
                 text_threshold,
                 link_threshold,
                 low_text,
                 cuda,
                 poly,
                 refine_net=None):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def extract_text(self):
        l = sorted(os.listdir(self.i_folder))
        img_to_index = {}
        count = 0
        for full_file in l:
            split_file = full_file.split(".")
            filename = split_file[0]
            img_to_index[count] = filename
            #print(count, filename)
            count += 1
            #print(filename)
            file_extension = "." + split_file[1]
            #print(filename, file_extension)
            image = imgproc.loadImage(self.i_folder + full_file)
            bboxes, polys, score_text = self.test_net(
                self.net, image, self.text_threshold, self.link_threshold,
                self.low_text, self.cuda, self.poly, self.refine_net)
            img = cv2.imread(self.i_folder + filename + file_extension)
            rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            points = []
            order = []
            for i in range(0, len(bboxes)):
                sample_bbox = bboxes[i]
                min_point = sample_bbox[0]
                max_point = sample_bbox[2]
                for j, p in enumerate(sample_bbox):
                    if (p[0] <= min_point[0]):
                        min_point = (p[0], min_point[1])
                    if (p[1] <= min_point[1]):
                        min_point = (min_point[0], p[1])
                    if (p[0] >= max_point[0]):
                        max_point = (p[0], max_point[1])
                    if (p[1] >= max_point[1]):
                        max_point = (max_point[0], p[1])
                min_point = (max(min(len(rgb_img[0]), min_point[0]),
                                 0), max(min(len(rgb_img), min_point[1]), 0))
                max_point = (max(min(len(rgb_img[0]), max_point[0]),
                                 0), max(min(len(rgb_img), max_point[1]), 0))
                points.append((min_point, max_point))
                order.append(0)
            num_ordered = 0
            rows_ordered = 0
            points_sorted = []
            ordered_points_index = 0
            order_sorted = []
            while (num_ordered < len(points)):
                #find lowest-y that is unordered
                min_y = len(rgb_img)
                min_y_index = -1
                for i in range(0, len(points)):
                    if (order[i] == 0):
                        if (points[i][0][1] <= min_y):
                            min_y = points[i][0][1]
                            min_y_index = i
                rows_ordered += 1
                order[min_y_index] = rows_ordered
                num_ordered += 1
                points_sorted.append(points[min_y_index])
                order_sorted.append(rows_ordered)
                ordered_points_index = len(points_sorted) - 1

                # Group bboxes that are on the same row
                max_y = points[min_y_index][1][1]
                range_y = max_y - min_y
                for i in range(0, len(points)):
                    if (order[i] == 0):
                        min_y_i = points[i][0][1]
                        max_y_i = points[i][1][1]
                        range_y_i = max_y_i - min_y_i
                        if (max_y_i >= min_y and min_y_i <= max_y):
                            overlap = (min(max_y_i, max_y) -
                                       max(min_y_i, min_y)) / (max(
                                           1, min(range_y, range_y_i)))
                            if (overlap >= 0.30):
                                order[i] = rows_ordered
                                num_ordered += 1
                                min_x_i = points[i][0][0]
                                for j in range(ordered_points_index,
                                               len(points_sorted) + 1):
                                    if (j < len(points_sorted)
                                        ):  #insert before
                                        min_x_j = points_sorted[j][0][0]
                                        if (min_x_i < min_x_j):
                                            points_sorted.insert(j, points[i])
                                            order_sorted.insert(
                                                j, rows_ordered)
                                            break
                                    else:  #insert at the end of array
                                        points_sorted.insert(j, points[i])
                                        order_sorted.insert(j, rows_ordered)
                                        break
            for i in range(0, len(points_sorted)):
                min_point = points_sorted[i][0]
                max_point = points_sorted[i][1]
                mask_file = self.result_folder + filename + "_" + str(
                    order_sorted[i]) + "_" + str(i) + file_extension
                crop_image = rgb_img[int(min_point[1]):int(max_point[1]),
                                     int(min_point[0]):int(max_point[0])]
                #print(filename, min_point, max_point, len(rgb_img), len(rgb_img[0]))
                cv2.imwrite(mask_file, crop_image)
        AlignCollate_demo = AlignCollate(imgH=self.opt.imgH,
                                         imgW=self.opt.imgW,
                                         keep_ratio_with_pad=self.opt.PAD)
        demo_data = RawDataset(root=self.result_folder,
                               opt=self.opt)  # use RawDataset
        demo_loader = torch.utils.data.DataLoader(
            demo_data,
            batch_size=self.opt.batch_size,
            shuffle=False,
            num_workers=int(self.opt.workers),
            collate_fn=AlignCollate_demo,
            pin_memory=True)
        f = open(self.extract_text_file, "w")
        count = -1
        curr_order = 1
        curr_filename = ""
        output_string = ""
        end_line = "[SEP] "
        with torch.no_grad():
            for image_tensors, image_path_list in demo_loader:
                batch_size = image_tensors.size(0)
                image = image_tensors.to(self.device)
                #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device)
                #print(image_path_list)
                #print(image.size())
                length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                                  batch_size).to(self.device)
                text_for_pred = torch.LongTensor(batch_size,
                                                 self.opt.batch_max_length +
                                                 1).fill_(0).to(self.device)
                preds = self.model(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                preds_str = self.converter.decode(preds_index, length_for_pred)
                for path, p in zip(image_path_list, preds_str):
                    #print(path)
                    if 'Attn' in self.opt.Prediction:
                        pred_EOS = p.find('[s]')
                        p = p[:
                              pred_EOS]  # prune after "end of sentence" token ([s])
                    path_info = path[len(self.result_folder):].split(
                        ".")[0].split(
                            "_"
                        )  #ASSUMES FILE EXTENSION OF SIZE 4 (.PNG, .JPG, ETC)
                    #print(curr_filename)
                    #print(path_info[0])
                    #print("PATHINFO: ",path_info[0])
                    if (not (curr_filename == path_info[0])):
                        if (not (curr_filename == "")):
                            f.write(str(count) + "\n")
                            f.write(curr_filename + "\n")
                            f.write(output_string + "\n\n")
                        count += 1
                        curr_filename = img_to_index[count]  #path_info[0]
                        #print("CURRFILE: ", curr_filename)
                        while (not (curr_filename == path_info[0])):
                            f.write(str(count) + "\n")
                            f.write(curr_filename + "\n")
                            f.write("\n\n")
                            count += 1
                            curr_filename = img_to_index[count]  #path_info[0]
                            #print("CURRFILE: ", curr_filename)
                        output_string = ""
                        curr_order = 1
                    if (int(path_info[1]) > curr_order):
                        curr_order += 1
                        output_string += end_line
                    output_string += p + " "
            f.write(str(count) + "\n")
            f.write(curr_filename + "\n")
            f.write(output_string + "\n\n")
        f.close()

        #Go through each image in the i_folder and crop out text

        #generate text and write to text file

    def get_item(self, index):
        f = open(self.extract_text_file, "r")
        Lines = f.readlines()
        return (Lines[4 * index + 2][:-1])
        # read text file


#TEST
#t_e = TextExtractor("data/mmimdb-256/dataset-resized-256max/dev_n/images/","text_extract_output.txt")
#t_e.extract_text()
#text = t_e.get_item(1)
#print(text)
Exemple #14
0
class CraftNet(object):
    def __init__(self, ocrObj):
        self.net = CRAFT()    
        print('Loading weights from checkpoint (' + trained_model + ')')
        if isCuda:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model)))
        else:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu')))
        if isCuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False 
        self.net.eval()
        self.jsonFile = defaultdict(dict)
        self.ocrObj = ocrObj

    def test_net(self, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.net(x)

        # make score and link map
        score_text = y[0,:,:,0].cpu().data.numpy()
        score_link = y[0,:,:,1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0,:,:,0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        # if show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text


    def evaluateBB(self, image_path):
        print(image_path)
        print(os.getcwd())
        image = imgproc.loadImage(image_path)
        imageCpy = image
        t = time.time()
        tnew = t
        bboxes, polys, score_text = self.test_net(image, text_thresholdVal, link_thresholdVal, low_textVal, isCuda, polyVal)
        deltaTime = time.time() - tnew
        words = []
        # # save image with BB
        # filename, file_ext = os.path.splitext(os.path.basename(image_path))
        # real_folder = result_folder + '/' + image_path.replace('images', '').replace(filename + file_ext, '')
        # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=real_folder)
        if(isTest):
            curImg = {
                "BBs" : defaultdict(dict),
                "pretrained" : "MLT",
                "procTime" : deltaTime,
                "OCR" : "CRNN",
            }
        for i in range(len(polys)):
            if(saveResult):
                cv2.rectangle(imageCpy, (int(polys[i][0][0]), int(polys[i][0][1])), (int(polys[i][1][0]), int(polys[i][2][1])), (255,0,0), 2)
            tnew = time.time()
            # incorrect, correct = self.ocrObj.getString(image, polys[i])
            # distTime = time.time() - tnew
            # print(incorrect, correct) 
            # tnew = time.time()
            incorrect, correct = self.ocrObj.getStringnGram(image, polys[i])
            nTime = time.time() - tnew
            if(correct is not None and saveResult):
                cv2.putText(imageCpy, correct, (int(polys[i][0][0]), int(polys[i][0][1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),2) 
            words.append((incorrect,correct))
            if(includeTesseract):
                tnew = time.time()
                incTess, corrTess = tesseractOCR.getStringnGram(image, polys[i])
                tessTime = time.time() - tnew
            if(isTest):
                if(includeTesseract):
                    curImg["BBs"][i] = {
                        "BB" : polys[i].tolist(),
                        "strings" : incorrect,
                        "stringsCorrect" : correct,
                        "ocrTime" : nTime,
                        "stringsTess": incTess,
                        "stringsCorrectTess": corrTess,
                        "ocrTimeTess": tessTime
                    }
                else:
                    curImg["BBs"][i] = {
                        "BB" : polys[i].tolist(),
                        "strings" : incorrect,
                        "stringsCorrect" : correct,
                        "ocrTime" : nTime,
                    }
        if(isTest):
            name, folder = getNameAndFolder(image_path)
            self.jsonFile[folder][name] = curImg
            with open("./CRAFT-pytorch-master/stats.json", "w") as write_file:
                json.dump(self.jsonFile, write_file, sort_keys=True, indent=4)
        if(saveResult):
            name, folder = getNameAndFolder(image_path)
            imageCpy = cv2.cvtColor(imageCpy, cv2.COLOR_BGR2RGB)
            if(not os.path.exists("./result/edited/"+folder)):
                os.makedirs("./result/edited/"+folder)
            cv2.imwrite("./result/edited/"+folder+"/"+name, imageCpy)# + image_path
        # return polys
        corr = ""
        incorr = ""
        for w in words:
            if(w[1] != None):
                corr.join(w[1] + "  ")
            if(w[0] != None):
                incorr.join(w[0] + "  ")
        return self.evaluateResponse(curImg["BBs"],image)

    def getQuadrant(self, bb, image):
        shape = image.shape
        newRect = [bb[0][0], bb[0][1], bb[1][0], bb[2][1]]
        # cv2.rectangle(image, (int(newRect[0]), int(newRect[1])), (int(newRect[2]), int(newRect[3])), (255,0,0), 2)
        xPt = newRect[0] + (newRect[2]-newRect[0])/2
        yPt = newRect[1] + (newRect[3]-newRect[1])/2
        xQuad = 0
        yQuad = 0
        if(0 <= xPt < shape[1]/3):
            xQuad = 1
        elif(shape[1]/3 <= xPt < shape[1]*2/3):
            xQuad = 2
        else:
            xQuad = 3
        if(0 <= yPt < shape[0]/3):
            yQuad = 0
        elif(shape[0]/3 <= yPt < shape[0]*2/3):
            yQuad =1
        else:
            yQuad = 2
        # cv2.putText(image, str(xQuad + yQuad*3), (int(newRect[0]), int(newRect[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
        
        return xQuad + yQuad*3

    def evaluateResponse(self,bbValues,image):
        gridWords = {
            1 : [],
            2 : [],
            3 : [],
            4 : [],
            5 : [],
            6 : [],
            7 : [],
            8 : [],
            9 : []
        }
        words = 0
        threeWords = [("",0),("",0),("",0)]
        full = False
        for i in bbValues:
            if(bbValues[i]["stringsCorrect"] != None):
                words += 1
                gridWords[self.getQuadrant(bbValues[i]["BB"],image)].append(bbValues[i]["stringsCorrect"])
                found = False
                j = 0
                while not found and j < len(threeWords):
                    if((threeWords[j][1] < self.getArea(bbValues[i]["BB"]) and full) or threeWords[j][1] == 0):
                        found = True
                        threeWords[j] = (bbValues[i]["stringsCorrect"], self.getArea(bbValues[i]["BB"]))
                    if(j == len(threeWords)-1):
                        full = True

                    j += 1
        dictionary = {
            "grid":{
                "Top Left" : gridWords[1],
                "Top" : gridWords[2],
                "Top Right" : gridWords[3],
                "Center Left" : gridWords[4],
                "Center" : gridWords[5],
                "Center Right" : gridWords[6],
                "Bottom Left" : gridWords[7],
                "Bottom" : gridWords[8],
                "Bottom Right" : gridWords[9]
            },
            "threeWords":[threeWords[0][0],threeWords[1][0],threeWords[2][0]],
            "newWords":words
        }
        # cv2.imshow("gigi",image)
        # cv2.waitKey(0)
        
        return dictionary

    def getArea(self, bb):
        newRect = [bb[0][0], bb[0][1], bb[1][0], bb[2][1]]
        return (newRect[2]-newRect[0])*(newRect[3]-newRect[1])
Exemple #15
0
def analysis(image_path, result_path):
    """ For test images in a folder """
    net = CRAFT()     # initialize

    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()

    image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net,
                                         image,
                                         args.text_threshold,
                                         args.link_threshold,
                                         args.low_text,
                                         args.cuda,
                                         args.poly,
                                         refine_net)
                                         
    opencv_image = cv2.imread(image_path)
    
    for index, box in enumerate(polys):
        xmin, xmax, ymin, ymax = box[0][0], box[1][0], box[0][1], box[2][1]
        multiplier_area = image[int(ymin):int(ymax), int(xmin):int(xmax)]
        
        try:
            im_pil = Image.fromarray(multiplier_area)
            #if you want to detect the text on the image
            if args.ocr_on:
                if args.ocr_method == 'pytesseract':
                    configuration = ("-l eng --oem 1 --psm 7")
                    multiplier = (pytesseract.image_to_string(multiplier_area, config=configuration).lower())
                    multiplier = multiplier.split("\n")[0]
                    
                elif args.ocr_method == 'TPS-ResNet-BiLSTM':
                    multiplier = text_recognition.recognition(im_pil)
                    
                cv2.putText(opencv_image, multiplier, (int(xmin),int(ymin-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 1)
                
            cv2.rectangle(opencv_image,(int(xmin),int(ymin)), (int(xmax),int(ymax)),(0,0,255), 1)
            cv2.imwrite(result_path, opencv_image)
                
        except:
            print("====ERROR====", traceback.format_exc())
            pass
Exemple #16
0
    transformer = dataset.ResizeNormalize(img_width=args.img_width,
                                          img_height=args.img_height)

    # load detect net
    detect_net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        detect_net.load_state_dict(
            copyStateDict(torch.load(args.trained_model)))
    else:
        detect_net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        detect_net = detect_net.cuda()
        detect_net = torch.nn.DataParallel(detect_net)
        cudnn.benchmark = False

    detect_net.eval()

    # load rec_net
    encoder = crnn.Encoder(3, args.hidden_size)
    # no dropout during inference
    decoder = crnn.Decoder(args.hidden_size,
                           num_classes,
                           dropout_p=0.0,
                           max_length=args.max_width)
    print(encoder)
    print(decoder)
    if torch.cuda.is_available() and args.use_gpu:
def test(modelpara):
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint {}'.format(modelpara))
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(modelpara)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(modelpara, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path),
              end='\n')
        image = imgproc.loadImage(image_path)
        res = image.copy()

        # bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
        gh_pred, bboxes_pred, polys_pred, size_heatmap = test_net(
            net, image, args.text_threshold, args.link_threshold,
            args.low_text, args.cuda, args.poly)

        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        result_dir = os.path.join(result_folder, filename)
        os.makedirs(result_dir, exist_ok=True)
        for gh_img, field in zip(gh_pred, CLASSES):
            img = imgproc.cvt2HeatmapImg(gh_img)
            img_path = os.path.join(result_dir,
                                    'res_{}_{}.jpg'.format(filename, field))
            cv2.imwrite(img_path, img)
        h, w = image.shape[:2]
        img = cv2.resize(image, size_heatmap)[::, ::, ::-1]
        img_path = os.path.join(result_dir,
                                'res_{}.jpg'.format(filename, field))
        cv2.imwrite(img_path, img)

        # # save score text
        # filename, file_ext = os.path.splitext(os.path.basename(image_path))
        # mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        # cv2.imwrite(mask_file, score_text)

        res = cv2.resize(res, size_heatmap)
        for polys, field in zip(polys_pred, CLASSES):
            TEXT_WIDTH = 10 * len(field) + 10
            TEXT_HEIGHT = 15
            polys = np.int32([poly.reshape((-1, 1, 2)) for poly in polys])
            res = cv2.polylines(res, polys, True, (0, 0, 255), 2)
            for poly in polys:
                poly[1, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1]]
                poly[2, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1] + TEXT_HEIGHT]
                poly[3, 0] = [
                    poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1] + TEXT_HEIGHT
                ]
                poly[0, 0] = [poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1]]
            res = cv2.fillPoly(res, polys, (224, 224, 224))
            # print(poly)
            for poly in polys:
                res = cv2.putText(res,
                                  field,
                                  tuple(poly[3, 0] + [+5, -5]),
                                  cv2.FONT_HERSHEY_SIMPLEX,
                                  0.4, (0, 0, 0),
                                  thickness=1)
        res_file = os.path.join(result_dir,
                                'res_{}_bbox.jpg'.format(filename, field))
        cv2.imwrite(res_file, res[::, ::, ::-1])
        # break

        # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    print("elapsed time : {}s".format(time.time() - t))
Exemple #18
0
class TextDetector():
    def __init__(self):
        self.trained_model = "text-detect/craft_mlt_25k.pth"
        self.text_threshold = 0.7
        self.low_text = 0.4
        self.link_threshold = 0.4
        self.cuda = True
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.poly = False
        self.show_time = False
        self.refine = False
        self.refine_model = "text-detect/craft_refiner_CTW1500.pth"
        # parser =argparse.ArgumentParser(description='CRAFT Text Detection')
        # self.trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
        # self.text_threshold', default=0.7, type=float, help='text confidence threshold')
        # self.low_text', default=0.4, type=float, help='text low-bound score')
        # self.link_threshold', default=0.4, type=float, help='link confidence threshold')
        # self.cuda', default=True, type=str2bool, help='Use cuda for inference')
        # self.canvas_size', default=1280, type=int, help='image size for inference')
        # self.mag_ratio', default=1.5, type=float, help='image magnification ratio')
        # self.poly', default=False, action='store_true', help='enable polygon type')
        # self.show_time', default=False, action='store_true', help='show processing time')
        # self.test_folder', default='/data/', type=str, help='folder path to input images')
        # self.refine', default=False, action='store_true', help='enable link refiner')
        # self.refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')

    def test_net(self,
                 net,
                 image,
                 text_threshold,
                 link_threshold,
                 low_text,
                 cuda,
                 poly,
                 refine_net=None):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def test_net(self,
                 net,
                 image,
                 text_threshold,
                 link_threshold,
                 low_text,
                 cuda,
                 poly,
                 refine_net=None):
        t0 = time.time()
        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def let_load(self):
        self.net = CRAFT()  # initialize
        print('Loading weights from checkpoint (' + self.trained_model + ')')
        if self.cuda:
            self.net.load_state_dict(
                copyStateDict(torch.load(self.trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(
                    torch.load(self.trained_model, map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        self.refine_net = None
        if self.refine:
            from refinenet import RefineNet
            refine_net = RefineNet()
            print('Loading weights of refiner from checkpoint (' +
                  self.refiner_model + ')')
            if self.cuda:
                refine_net.load_state_dict(
                    copyStateDict(torch.load(self.refiner_model)))
                refine_net = refine_net.cuda()
                refine_net = torch.nn.DataParallel(refine_net)
            else:
                refine_net.load_state_dict(
                    copyStateDict(
                        torch.load(self.refiner_model, map_location='cpu')))

            refine_net.eval()
            self.poly = True

        t = time.time()

    def text_detect(self, image):
        bboxes, polys, score_text = self.test_net(self.net, image,
                                                  self.text_threshold,
                                                  self.link_threshold,
                                                  self.low_text, self.cuda,
                                                  self.poly, self.refine_net)
        # cv2.imwrite(mask_file, score_text)

        # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
        return bboxes, polys, score_text
Exemple #19
0
class DETECTION:
    def __init__(self):
        # model settings #
        self.trained_model = 'model/craft_mlt_25k.pth'
        self.text_threshold = 0.7
        self.low_text = 0.4
        self.link_threshold = 0.4
        self.cuda = True
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.poly = True
        self.show_time = False
        self.video_folder = 'input/'
        self.refine = False
        self.refiner_model = 'model/craft_refiner_CTW1500.pth'
        self.interpolation = cv2.INTER_LINEAR

        #import model
        self.net = CRAFT()  # initialize

    def copyStateDict(self, state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    def load_model(self):
        print('Loading weights from checkpoint (' + self.trained_model + ')')
        if self.cuda:
            self.net.load_state_dict(
                self.copyStateDict(torch.load(self.trained_model)))
        else:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load(self.trained_model, map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        # # LinkRefiner
        self.refine_net = None
        if self.refine:
            self.refine_net = RefineNet()
            print('Loading weights of refiner from checkpoint (' +
                  self.refiner_model + ')')
            if self.cuda:
                self.refine_net.load_state_dict(
                    self.copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    self.copyStateDict(
                        torch.load(self.refiner_model, map_location='cpu')))

            self.refine_net.eval()
            self.poly = True
            t = time.time()

    def test_net(self, image_opencv):

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image_opencv,
            self.canvas_size,
            interpolation=self.interpolation,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]

        if self.cuda:
            x = x.cuda()

        # forward pass
        y, feature = self.net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        t0 = time.time()
        if self.refine_net is not None:
            y_refiner = self.refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               self.text_threshold,
                                               self.link_threshold,
                                               self.low_text, self.poly)
        #print(boxes)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]
        t1 = time.time() - t1

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
        return boxes, polys
Exemple #20
0
def craftnet():
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + CONFIG['trained_model'] + ')')
    if CONFIG['cuda']:
        net.load_state_dict(copyStateDict(torch.load(CONFIG['trained_model'])))
    else:
        net.load_state_dict(
            copyStateDict(
                torch.load(CONFIG['trained_model'], map_location='cpu')))

    if CONFIG['cuda']:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if CONFIG['refine']:
        from refinenet import RefineNet
        refine_net = RefineNet()
        #print('Loading weights of refiner from checkpoint (' + CONFIG['refiner_model'] + ')')
        if CONFIG['cuda']:
            refine_net.load_state_dict(
                copyStateDict(torch.load(CONFIG['refiner_model'])))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(
                    torch.load(CONFIG['refiner_model'], map_location='cpu')))

        refine_net.eval()
        CONFIG['poly'] = True

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        #print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        orig, image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(
            net, image, CONFIG['text_threshold'], CONFIG['link_threshold'],
            CONFIG['low_text'], CONFIG['cuda'], CONFIG['poly'], refine_net)

        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path,
                              image[:, :, ::-1],
                              polys,
                              dirname=result_folder)

    information = []
    for file in os.listdir('result/temp_result'):
        filename = os.path.splitext(file)[0]
        extension = os.path.splitext(file)[1]
        if extension == '.tif':
            #!tesseract oem 13 --tessdata-dir ./result/ ./result/temp_result{filename}.png ./test/{filename+'-eng'} -l eng+vie
            image = Image.open('result/temp_result/' + file)

            config = '--psm 10 --oem 3 -l vie+eng'
            raw_text = pytesseract.image_to_string(image,
                                                   lang='eng+vie',
                                                   config=config)
            information.append(raw_text)

    X = {
        "name": [],
        "phone": [],
        "email": [],
        "company": [],
        "website": [],
        "address": [],
        "extra_information": []
    }
    for i in range(len(information)):
        info = information[i]
        if parse_info(info):

            email_parse = parse_email(info)
            if email_parse != None:
                X["email"].append(email_parse)
                continue

            phone_parse = parse_phone(info)
            if phone_parse != None:
                X["phone"].append(phone_parse)
                continue

            website_parse = parse_website(info)
            if website_parse != None:
                X["website"].append(website_parse)
                continue

            company_parse = parse_company(info)
            if company_parse != None:
                X["company"].append(company_parse)
                continue

            address_parse = parse_address(info)
            if address_parse != None:
                X["address"].append(address_parse)
                continue

            name_parse = parse_name(info)
            if name_parse != None:
                X["name"].append(info)
                continue

            X["extra_information"].append(info)
    return X
Exemple #21
0
def train(args):
    # load net
    net = CRAFT()  # initialize

    if not os.path.exists(args.trained_model):
        args.trained_model = None

    if args.trained_model is not None:
        print('Loading weights from checkpoint (' + args.trained_model + ')')
        if args.cuda:
            net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
        else:
            net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    # # LinkRefiner
    # refine_net = None
    # if args.refine:
    #     from refinenet import RefineNet
    #
    #     refine_net = RefineNet()
    #     print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
    #     if args.cuda:
    #         refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model)))
    #         refine_net = refine_net.cuda()
    #         refine_net = torch.nn.DataParallel(refine_net)
    #     else:
    #         refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
    #
    #     args.poly = True

    criterion = craft_utils.CRAFTLoss()
    optimizer = optim.Adam(net.parameters(), args.learning_rate)
    train_data = CRAFTDataset(args)
    dataloader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
    t0 = time.time()

    for epoch in range(args.max_epoch):
        pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch}')
        running_loss = 0.0
        for i, data in pbar:
            x, y_region, y_link, y_conf = data
            x = x.cuda()
            y_region = y_region.cuda()
            y_link = y_link.cuda()
            y_conf = y_conf.cuda()
            optimizer.zero_grad()

            y, feature = net(x)

            score_text = y[:, :, :, 0]
            score_link = y[:, :, :, 1]

            L = criterion(score_text, score_link, y_region, y_link, y_conf)

            L.backward()
            optimizer.step()

            running_loss += L.data.item()
            if i % 2000 == 1999 or i == len(dataloader) - 1:
                pbar.set_postfix_str('[%d, %5d] loss: %.3f' %
                                     (epoch + 1, i + 1, running_loss / min(i + 1, 2000)))
                running_loss = 0.0

    # Save trained model
    torch.save(net.state_dict(), args.weight)

    print(f'training finished\n {time.time() - t0} spent for {args.max_epoch} epochs')
Exemple #22
0
def test(text_detection_modelpara, ocr_modelpara, dictionary_target):
    # load net
    net = CRAFT()  # initialize

    print('Loading text detection model from checkpoint {}'.format(
        text_detection_modelpara))
    if args.cuda:
        net.load_state_dict(copyStateDict(
            torch.load(text_detection_modelpara)))
    else:
        net.load_state_dict(
            copyStateDict(
                torch.load(text_detection_modelpara, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    params = {}
    params['n'] = 256
    params['m'] = 256
    params['dim_attention'] = 512
    params['D'] = 684
    params['K'] = 5748
    params['growthRate'] = 24
    params['reduction'] = 0.5
    params['bottleneck'] = True
    params['use_dropout'] = True
    params['input_channels'] = 3
    params['cuda'] = args.cuda

    # load model
    OCR = Encoder_Decoder(params)
    if args.cuda:
        OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara)))
    else:
        OCR.load_state_dict(
            copyStateDict(torch.load(ocr_modelpara, map_location='cpu')))
    if args.cuda:
        #OCR = OCR.cuda()
        OCR = torch.nn.DataParallel(OCR)
        cudnn.benchmark = False

    OCR.eval()
    net.eval()

    # load dictionary
    worddicts = load_dict(dictionary_target)
    worddicts_r = [None] * len(worddicts)
    for kk, vv in worddicts.items():
        worddicts_r[vv] = kk
    t = time.time()

    fontPIL = '/usr/share/fonts/truetype/fonts-japanese-gothic.ttf'  # japanese font
    size = 40
    colorBGR = (0, 0, 255)

    paper = ET.Element('paper')
    paper.set('xmlns', "http://codh.rois.ac.jp/modern-magazine/")
    # load data
    for k, image_path in enumerate(image_list[:]):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path),
              end='\r')
        res_img_file = result_folder + "res_" + os.path.basename(image_path)

        #print (res_img_file, os.path.basename(image_path), os.path.exists(res_img_file))
        #if os.path.exists(res_img_file): continue
        #image = imgproc.loadImage(image_path)
        '''image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        ret2,image = cv2.threshold(image,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        height = image.shape[0]
        width = image.shape[1]
        scale = 1000.0/height
        H = int(image.shape[0] * scale)
        W = int(image.shape[1] * scale)
        image = cv2.resize(image , (W, H))
        print(image.shape, image_path)
        cv2.imwrite(image_path, image) 
        continue'''
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[0], image.shape[1]
        print(image_path)
        page = ET.SubElement(paper, "page")
        page.set('file', os.path.basename(image_path).replace('.jpg', ''))
        page.set('height', str(h))
        page.set('width', str(w))
        page.set('dpi', str(100))
        page.set('number', str(1))

        bboxes, polys, score_text = test_net(net, image, args.text_threshold,
                                             args.link_threshold,
                                             args.low_text, args.cuda,
                                             args.poly)
        text = []
        localtions = []
        for i, box in enumerate(bboxes):
            poly = np.array(box).astype(np.int32)
            min_x = np.min(poly[:, 0])
            max_x = np.max(poly[:, 0])
            min_y = np.min(poly[:, 1])
            max_y = np.max(poly[:, 1])
            if min_x < 0:
                min_x = 0
            if min_y < 0:
                min_y = 0

            #image = cv2.rectangle(image,(min_x,min_y),(max_x,max_y),(0,255,0),3)
            input_img = image[min_y:max_y, min_x:max_x]

            w = max_x - min_x + 1
            h = max_y - min_y + 1
            line = ET.SubElement(page, "line")
            line.set("x", str(min_x))
            line.set("y", str(min_y))
            line.set("height", str(h))
            line.set("width", str(w))
            if w < h:
                rate = 20.0 / w
                w = int(round(w * rate))
                h = int(round(h * rate / 20.0) * 20)
            else:
                rate = 20.0 / h
                w = int(round(w * rate / 20.0) * 20)
                h = int(round(h * rate))
            #print (w, h, rate)
            input_img = cv2.resize(input_img, (w, h))

            mat = np.zeros([1, h, w], dtype='uint8')
            mat[0, :, :] = 0.299 * input_img[:, :,
                                             0] + 0.587 * input_img[:, :,
                                                                    1] + 0.114 * input_img[:, :,
                                                                                           2]

            xx_pad = mat.astype(np.float32) / 255.
            xx_pad = torch.from_numpy(xx_pad[None, :, :, :])  # (1,1,H,W)
            if args.cuda:
                xx_pad.cuda()
            with torch.no_grad():
                sample, score, alpha_past_list = gen_sample(OCR,
                                                            xx_pad,
                                                            params,
                                                            args.cuda,
                                                            k=10,
                                                            maxlen=600)
            score = score / np.array([len(s) for s in sample])
            ss = sample[score.argmin()]
            alpha_past = alpha_past_list[score.argmin()]
            result = ''
            i = 0
            location = []
            for vv in ss:

                if vv == 0:  # <eol>
                    break
                alpha = alpha_past[i]
                if i != 0: alpha = alpha_past[i] - alpha_past[i - 1]
                (y, x) = np.unravel_index(np.argmax(alpha, axis=None),
                                          alpha.shape)
                #print (int(16* x /rate), int(16* y/rate) , chr(int(worddicts_r[vv],16)))
                location.append(
                    [int(16 * x / rate) + min_x,
                     int(16 * y / rate) + min_y])
                #image = cv2.circle(image,(int(16* x/rate) -  8 + min_x, int(16* y/rate) + 8 + min_y),25, (0,0,255), -1)

                result += chr(int(worddicts_r[vv], 16))
                '''char = ET.SubElement(line, "char") 
                char.set('num_cand', '1') 
                char.set('x', str(int(16* x/rate) -  8 + min_x)) 
                char.set('y', str(int(16* y/rate) + 8 + min_y)) 
                res = ET.SubElement(char, "result") 
                res.set('CC', str(100))
                res.text = chr(int(worddicts_r[vv],16))
                cand = ET.SubElement(char, "cand") 
                cand.set('CC', str(100))
                cand.text = chr(int(worddicts_r[vv],16))'''
                i += 1
            line.text = result
            text.append(result)
            localtions.append(location)
            image = cv2_putText_1(img=image,
                                  text=result,
                                  org=(min_x, max_x, min_y, max_y),
                                  fontFace=fontPIL,
                                  fontScale=size,
                                  color=colorBGR)

        print('save image')
        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        #cv2.imwrite(mask_file, score_text)
        file_utils.saveResult(image_path, image, polys, dirname=result_folder)

    xml_string = ET.tostring(paper, 'Shift_JIS')

    fout = codecs.open('./data/result.xml', 'w', 'shift_jis')
    fout.write(xml_string.decode('shift_jis'))
    fout.close()

    print("elapsed time : {}s".format(time.time() - t))
Exemple #23
0
class CraftOne:
    def __init__(self, cuda=True):
        self.cuda = cuda
        for k, v in config_craft.items():
            setattr(self, k, v)

        self.net = CRAFT()
        print(f'Loading weights from checkpoint ({self.trained_model})')
        if self.cuda:
            self.net.load_state_dict(
                copyStateDict(torch.load(self.trained_model)))
        else:
            self.net.load_state_dict(
                copyStateDict(
                    torch.load(self.trained_model, map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False

        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if self.refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            print(
                f'Loading weights of refiner from checkpoint ({self.refiner_model})'
            )
            if self.cuda:
                self.refine_net.load_state_dict(
                    copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(
                    copyStateDict(
                        torch.load(self.refiner_model, map_location='cpu')))

            self.refine_net.eval()
            self.poly = True

    def detect_text(self,
                    net,
                    image,
                    text_threshold,
                    link_threshold,
                    low_text,
                    poly,
                    refine_net=None):
        t0 = time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if self.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time() - t0
        t1 = time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def main(self, image):
        bboxes, polys, score_text = self.detect_text(self.net, image,
                                                     self.text_threshold,
                                                     self.link_threshold,
                                                     self.low_text, self.poly,
                                                     self.refine_net)

        return bboxes, polys, score_text
Exemple #24
0
    #print(input.size())
    net = CRAFT()
    #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/CRAFT_net_050000.pth')))
    #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth')))
    #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/craft_mlt_25k.pth')))
    #net.load_state_dict(copyStateDict(torch.load('vgg16_bn-6c64b313.pth')))
    #realdata = realdata(net)
    # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size = 768)
    # real_data_loader = torch.utils.data.DataLoader(
    #     realdata,
    #     batch_size=10,
    #     shuffle=True,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)
    net = net.cuda()
    #net = CRAFT_net

    # if args.cdua:
    net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3]).cuda()
    cudnn.benchmark = True
    # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size=768)
    # real_data_loader = torch.utils.data.DataLoader(
    #     realdata,
    #     batch_size=10,
    #     shuffle=True,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)

    optimizer = optim.Adam(net.parameters(),
class TextDetection:
    def __init__(self):

        self.trained_model = '../chinese-ocr/weights/craft_mlt_25k.pth'
        self.text_threshold = 0.75
        self.low_text = 0.6
        self.link_threshold = 0.9
        self.cuda = True
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.poly = False
        self.show_time = False

        self.net = CRAFT()

        self.net.load_state_dict(
            copy_state_dict(torch.load(self.trained_model)))
        self.net = self.net.cuda()

        cudnn.benchmark = False

        self.net.eval()

    def get_bounding_box(self, image_file, verbose=False):
        """
        Get the bounding boxes from image_file
        :param image_file
        :param verbose
        :return:
        """
        image = cv2.imread(image_file)
        img_dim = image.shape
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)

        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if self.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               self.text_threshold,
                                               self.link_threshold,
                                               self.low_text, self.poly)

        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)

        center_point = []
        for i, _b in enumerate(boxes):
            b = np.array(_b, dtype=np.int16)
            xmin = np.min(b[:, 0])
            ymin = np.min(b[:, 1])

            xmax = np.max(b[:, 0])
            ymax = np.max(b[:, 1])
            x_m = xmin + (xmax - xmin) / 2
            y_m = ymin + (ymax - ymin) / 2
            center_point.append([x_m, y_m])

        list_images = get_box_img(boxes, image)

        if verbose:
            for _b in boxes:
                b = np.array(_b, dtype=np.int16)
                xmin = np.min(b[:, 0])
                ymin = np.min(b[:, 1])

                xmax = np.max(b[:, 0])
                ymax = np.max(b[:, 1])

                r = image[ymin:ymax, xmin:xmax, :].copy()

        return boxes, list_images, center_point, img_dim