def test(modelpara):
    # load net
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint {}'.format(modelpara))
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(modelpara)))
    else:
        net.load_state_dict(copyStateDict(torch.load(modelpara, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        #cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    print("elapsed time : {}s".format(time.time() - t))
Пример #2
0
def main(trained_model='weights/craft_mlt_25k.pth', 
            text_threshold=0.7, low_text=0.4, link_threshold=0.4, cuda=True,
            canvas_size=1280, mag_ratio=1.5,
            poly=False, show_time=False, test_folder='/data/', 
            refine=True, refiner_model='weights/craft_refiner_CTW1500.pth'):
# if __name__ == '__main__':
    # load net
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + trained_model + ')')
    if cuda:
        net.load_state_dict(copyStateDict(torch.load(trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu')))

    if cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + refiner_model + ')')
        if cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model, map_location='cpu')))

        refine_net.eval()
        poly = True

    t = time.time()

    # load data
    image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net)

    # save score text
    filename, file_ext = os.path.splitext(os.path.basename(image_path))
    mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    cv2.imwrite(mask_file, score_text)

    final_img = file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
    
    print("elapsed time : {}s".format(time.time() - t))
Пример #3
0
def LoadDetectionModel(args):
    net = CRAFT()     # initialize
    print('Loading weights from checkpoint (' + args.trained_model + ')')
    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))#,map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()
    return net
Пример #4
0
def main():
        # load net
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True

    t = time.time()
    print(image_list)
    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)

        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)

        file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    # print("elapsed time : {}s".format(time.time() - t))
Пример #5
0
def get_detector(trained_model, device='cpu'):
    net = CRAFT()

    if device == 'cpu':
        net.load_state_dict(
            copyStateDict(torch.load(trained_model, map_location=device)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(trained_model, map_location=device)))
        net = torch.nn.DataParallel(net).to(device)
        cudnn.benchmark = False

    net.eval()
    return net
def createModel():
    net = CRAFT()

    weightPath = os.path.join(settings.BASE_DIR,
                              'CRAFT/weights/craft_mlt_25k.pth')

    print('Loading weights from checkpoint (' + weightPath + ')')

    net.load_state_dict(copyStateDict(torch.load(weightPath)))
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = False
    net.eval()
    return net
Пример #7
0
def load_detection_model():
  parser = argparse.ArgumentParser(description='CRAFT Text Detection')
  parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
  parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
  parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
  parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
  parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
  parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
  parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
  parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
  parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
  parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
  parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
  parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
  args = parser.parse_args(["--trained_model=./models/craft_mlt_25k.pth","--refine", "--refiner_model=./models/craft_refiner_CTW1500.pth"])
  net = CRAFT()     # initialize
  print('Loading weights from checkpoint (' + args.trained_model + ')')
  if args.cuda:
      net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
  else:
      net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

  if args.cuda:
      net = net.cuda()
      net = torch.nn.DataParallel(net)
      cudnn.benchmark = False

  net.eval()

  # LinkRefiner
  refine_net = None
  if args.refine:
      from refinenet import RefineNet
      refine_net = RefineNet()
      print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
      if args.cuda:
          refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
          refine_net = refine_net.cuda()
          refine_net = torch.nn.DataParallel(refine_net)
      else:
          refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

      refine_net.eval()
      # args.poly = True
  return net,refine_net,args
Пример #8
0
def runCraftNet(image_list):  # image list is the folder containing the images

    args = argparse.Namespace(
        canvas_size=1280,
        cuda=False,
        link_threshold=0.4,
        low_text=0.4,
        mag_ratio=1.5,
        poly=False,
        refine=False,
        refiner_model='weights/craft_refiner_CTW1500.pth',
        show_time=False,
        test_folder='images',
        text_threshold=0.7,
        trained_model='craft_mlt_25k.pth')
    net = CRAFT()  # initialize
    net.load_state_dict(
        copyStateDict(torch.load(args.trained_model, map_location='cpu')))
    net.eval()

    # image_list, _, _ = file_utils.get_files(args.test_folder)
    t = time.time()
    # result_folder = './result/'

    # load data
    refine_net = None

    for k, image_path in enumerate(image_list):
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = test_net(net, image, args.text_threshold,
                                             args.link_threshold,
                                             args.low_text, args.cuda,
                                             args.poly, refine_net)

    # print("elapsed time : {}s ".format(time.time() - t))
    img = np.array(image[:, :, ::-1])
    txt = []
    for i, box in enumerate(polys):
        poly = np.array(box).astype(np.int32).reshape((-1))
        strResult = ','.join([str(p) for p in poly])
        txt.append(strResult)

    return [img, txt]
Пример #9
0
def test(image, epoch, index, cvt=False):

    image = image

    print('input image shape {}'.format(image.shape))

    checkpoint = torch.load('/root/data/test_param/{}_{}.pth'.format(
        epoch, index))

    net = CRAFT().cuda()

    net.load_state_dict(copyStateDict(checkpoint['model_state_dict']))

    #이미지 리사이징 등등

    #했다고 치고 진행

    image = normalizeMeanVariance(image)

    image = cv2.resize(image, (768, 768), interpolation=cv2.INTER_LINEAR)

    x = torch.from_numpy(image).permute(2, 0, 1)
    x = Variable(x.unsqueeze(0).type(torch.FloatTensor))
    x = x.cuda()

    print(x.size())

    with torch.no_grad():
        y, _ = net(x)

    pred_region = y[0, :, :, 0].cpu().data.numpy()
    pred_affinity = y[0, :, :, 1].cpu().data.numpy()

    print(type(pred_region))
    print(pred_region.shape)

    # cvt == True -> Region, Affinity score H x W x C
    # cvt == False -> Region, Affinity score H x W
    if cvt:
        pred_region = Gray2RGB(pred_region)
        pred_affinity = Gray2RGB(pred_affinity)

    return pred_region, pred_affinity
Пример #10
0
def get_detector(trained_model, device='cpu', quantize=True):
    net = CRAFT()

    if device == 'cpu':
        net.load_state_dict(
            copyStateDict(torch.load(trained_model, map_location=device)))
        if quantize:
            try:
                torch.quantization.quantize_dynamic(net,
                                                    dtype=torch.qint8,
                                                    inplace=True)
            except:
                pass
    else:
        net.load_state_dict(
            copyStateDict(torch.load(trained_model, map_location=device)))
        net = torch.nn.DataParallel(net).to(device)
        cudnn.benchmark = False

    net.eval()
    return net
Пример #11
0
def main(pth_file_path):
    cuda = True
    net = CRAFT()     # initialize

    print('Loading weights from checkpoint (' + pth_file_path + ')')
    if cuda:
        net.load_state_dict(copyStateDict(torch.load(pth_file_path)))
    else:
        net.load_state_dict(copyStateDict(torch.load(pth_file_path, map_location='cpu')))

    if cuda:
        net = net.cuda()
        cudnn.benchmark = False

    net.eval()

    script_module = torch.jit.script(net)

    file_path_without_ext = os.path.splitext(pth_file_path)[0]
    output_file_path = file_path_without_ext + ".pt"
    script_module.save(output_file_path)
    print("TorchScript model created:", output_file_path)
Пример #12
0
        if polys[k] is None: polys[k] = boxes[k]

    t1 = time.time() - t1

    # render results (optional)
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    ret_score_text = cvt2HeatmapImg(render_img)

    if show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, ret_score_text


net = CRAFT()
net.load_state_dict(copyStateDict(torch.load(trained_model_path, map_location='cpu')))
net.eval()

# image_path = './doc/2.jpg'
# image = loadImage(image_path)
# bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net)
#
# poly_indexes = {}
# central_poly_indexes = []
# for i in range(len(polys)):
#     poly_indexes[i] = polys[i]
#     x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] + polys[i][3][0]) / 4
#     y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] + polys[i][3][1]) / 4
#     central_poly_indexes.append({i: [int(x_central), int(y_central)]})

import copy
Пример #13
0
def applyCraft(image_file):
    # Initialize CRAFT parameters
    text_threshold = 0.7
    low_text = 0.4
    link_threshold = 0.4
    cuda = False
    canvas_size = 1280
    mag_ratio = 1.5
    # if text image present curve --> poly=true
    poly = False
    refine = False
    show_time = False
    refine_net = None
    trained_model_path = './app/CRAFT/craft_mlt_25k.pth'

    net = CRAFT()
    net.load_state_dict(
        copyStateDict(torch.load(trained_model_path, map_location='cpu')))
    net.eval()

    image = imgproc.loadImage(image_file)

    poly = False
    refine = False
    show_time = False
    refine_net = None
    bboxes, polys, score_text = test_net(net, canvas_size, mag_ratio, image,
                                         text_threshold, link_threshold,
                                         low_text, cuda, poly, refine_net)

    # Compute coordinate of central point in each bounding box returned by CRAFT
    # Purpose: easier for us to make cluster in G-DBScan step
    poly_indexes = {}
    central_poly_indexes = []
    for i in range(len(polys)):
        poly_indexes[i] = polys[i]
        x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] +
                     polys[i][3][0]) / 4
        y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] +
                     polys[i][3][1]) / 4
        central_poly_indexes.append({i: [int(x_central), int(y_central)]})

    # for i in central_poly_indexes:
    #   print(i)

    # For each of these cordinates convert them to new Point instances
    X = []

    for idx, x in enumerate(central_poly_indexes):
        point = Point(x[idx][0], x[idx][1], idx)
        X.append(point)

    # Cluster these central points
    clustered = GDBSCAN(Points(X), n_pred, 1, w_card)

    cluster_values = []
    for cluster in clustered:
        sort_cluster = sorted(cluster, key=lambda elem: (elem.x, elem.y))
        max_point_id = sort_cluster[len(sort_cluster) - 1].id
        min_point_id = sort_cluster[0].id
        max_rectangle = sorted(poly_indexes[max_point_id],
                               key=lambda elem: (elem[0], elem[1]))
        min_rectangle = sorted(poly_indexes[min_point_id],
                               key=lambda elem: (elem[0], elem[1]))

        right_above_max_vertex = max_rectangle[len(max_rectangle) - 1]
        right_below_max_vertex = max_rectangle[len(max_rectangle) - 2]
        left_above_min_vertex = min_rectangle[0]
        left_below_min_vertex = min_rectangle[1]

        if (int(min_rectangle[0][1]) > int(min_rectangle[1][1])):
            left_above_min_vertex = min_rectangle[1]
            left_below_min_vertex = min_rectangle[0]
        if (int(max_rectangle[len(max_rectangle) - 1][1]) < int(
                max_rectangle[len(max_rectangle) - 2][1])):
            right_above_max_vertex = max_rectangle[len(max_rectangle) - 2]
            right_below_max_vertex = max_rectangle[len(max_rectangle) - 1]

        cluster_values.append([
            left_above_min_vertex, left_below_min_vertex,
            right_above_max_vertex, right_below_max_vertex
        ])

    image = imgproc.loadImage(image_file)
    img = np.array(image[:, :, ::-1])
    img = img.astype('uint8')
    ocr_res = []
    for i, box in enumerate(cluster_values):
        poly = np.array(box).astype(np.int32).reshape((-1))
        poly = poly.reshape(-1, 2)

        rect = cv2.boundingRect(poly)
        x, y, w, h = rect
        cropped = img[y:y + h, x:x + w].copy()

        # Preprocess cropped segment
        cropped = cv2.resize(cropped,
                             None,
                             fx=5,
                             fy=5,
                             interpolation=cv2.INTER_LINEAR)
        cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY)
        cropped = cv2.GaussianBlur(cropped, (3, 3), 0)
        cropped = cv2.bilateralFilter(cropped, 5, 25, 25)
        cropped = cv2.dilate(cropped, None, iterations=1)
        cropped = cv2.threshold(cropped, 0, 255,
                                cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
        #cropped = cv2.threshold(cropped, 90, 255, cv2.THRESH_BINARY)[1]
        #cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)

        ocr_res.append(pytesseract.image_to_string(cropped, lang='eng'))

    return ocr_res
Пример #14
0
def ground_truth(args):
    # initiate pretrained network
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    filelist, _, _ = file_utils.list_files('/home/ubuntu/Kyumin/Autotation/data/IC13/images')

    for img_name in filelist:
        # get datapath
        if 'train' in img_name:
            label_name = img_name.replace('images/train/', 'labels/train/gt_').replace('jpg', 'txt')
        else:
            label_name = img_name.replace('images/test/', 'labels/test/gt_').replace('jpg', 'txt')
        label_dir = img_name.replace('Autotation', 'craft').replace('images', 'labels').replace('.jpg', '/')

        os.makedirs(label_dir, exist_ok=True)

        image = imgproc.loadImage(img_name)

        gt_boxes = []
        gt_words = []
        with open(label_name, 'r', encoding='utf-8-sig') as f:
            lines = f.readlines()
        for line in lines:
            if 'IC13' in img_name:  # IC13
                gt_box, gt_word, _ = line.split('"')
                if 'train' in img_name:
                    x1, y1, x2, y2 = [int(a) for a in gt_box.strip().split(' ')]
                else:
                    x1, y1, x2, y2 = [int(a.strip()) for a in gt_box.split(',') if a.strip().isdigit()]
                gt_boxes.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
                gt_words.append(gt_word)
            elif 'IC15' in img_name:
                gt_data = line.strip().split(',')
                gt_box = gt_data[:8]
                if len(gt_data) > 9:
                    gt_word = ','.join(gt_data[8:])
                else:
                    gt_word = gt_data[-1]
                gt_box = [int(a) for a in gt_box]
                gt_box = np.reshape(np.array(gt_box), (4, 2))
                gt_boxes.append(gt_box)
                gt_words.append(gt_word)

        score_region, score_link, conf_map = generate_gt(net, image, gt_boxes, gt_words, args)

        torch.save(score_region, label_dir + 'region.pt')
        torch.save(score_link, label_dir + 'link.pt')
        torch.save(conf_map, label_dir + 'conf.pt')
Пример #15
0
        torch.save(score_region, label_dir + 'region.pt')
        torch.save(score_link, label_dir + 'link.pt')
        torch.save(conf_map, label_dir + 'conf.pt')


if __name__ == '__main__':
    import ocr
    score_region = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/region.pt')
    score_link = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/link.pt')
    conf_map = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/conf.pt')
    image = imgproc.loadImage('/home/ubuntu/Kyumin/Autotation/data/IC13/images/train/100.jpg')
    print(score_region.shape, score_link.shape, conf_map.shape)
    # cv2.imshow('original', image)
    cv2.imshow('region', imgproc.cvt2HeatmapImg(score_region))
    cv2.imshow('link', score_link)
    cv2.imshow('conf', conf_map)

    net = CRAFT().cuda()
    net.load_state_dict(test.copyStateDict(torch.load('weights/craft_mlt_25k.pth')))

    net.eval()
    _, _, ref_text, ref_link, _ = test.test_net(net, image, ocr.argument_parser().parse_args())
    cv2.imshow('ref text', imgproc.cvt2HeatmapImg(ref_text))
    cv2.imshow('ref link', ref_link)

    cv2.waitKey(0)
    cv2.destroyAllWindows()


Пример #16
0
    #dataloader = syndata(imgname, charbox, imgtxt)
    dataloader = Synth80k('./data/SynthText', target_size = args.target_size)
    train_loader = torch.utils.data.DataLoader(
        dataloader,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        drop_last=True,
        pin_memory=True)
    batch_syn = iter(train_loader)
    # prefetcher = data_prefetcher(dataloader)
    # input, target1, target2 = prefetcher.next()
    #print(input.size())
    net = CRAFT(freeze=True)
    net.load_state_dict(copyStateDict(torch.load(args.load_model)))
    #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/CRAFT_net_050000.pth')))
    #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth')))
    #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/craft_mlt_25k.pth')))
    #net.load_state_dict(copyStateDict(torch.load('vgg16_bn-6c64b313.pth')))
    #realdata = realdata(net)
    # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size = 768)
    # real_data_loader = torch.utils.data.DataLoader(
    #     realdata,
    #     batch_size=10,
    #     shuffle=True,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)
    net = net.cuda()
    #net = CRAFT_net
        print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, ret_score_text


if __name__ == '__main__':
    # load net
    net = CRAFT()  # initialize

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    net.load_state_dict(torch.load(args.trained_model))
    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path),
              end='\r')
        image = imgproc.loadImage(image_path)

        bboxes, score_text = test_net(net, image, args.text_threshold,
                                      args.link_threshold, args.low_text,
                                      args.cuda)
Пример #18
0
if __name__ == '__main__':

    dataloader = Synth80k(root_data + '/SynthText/SynthText', target_size=768)
    train_loader = torch.utils.data.DataLoader(dataloader,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=True,
                                               pin_memory=True)
    batch_syn = iter(train_loader)
    print("Loaded Synth data.")

    net = CRAFT()

    net.load_state_dict(copyStateDict(
        torch.load('pretrain/craft_mlt_25k.pth')))

    net = net.cuda()

    print("Loaded CRAFT net.")

    net = torch.nn.DataParallel(net, device_ids=[0]).cuda()
    cudnn.benchmark = True
    net.train()
    realdata = ICDAR2015(net, root_data + '/DDI', target_size=768)
    real_data_loader = torch.utils.data.DataLoader(realdata,
                                                   batch_size=1,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=True,
                                                   pin_memory=True)
Пример #19
0
def test(modelpara):
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint {}'.format(modelpara))
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(modelpara)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(modelpara, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path),
              end='\n')
        image = imgproc.loadImage(image_path)
        res = image.copy()

        # bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
        gh_pred, bboxes_pred, polys_pred, size_heatmap = test_net(
            net, image, args.text_threshold, args.link_threshold,
            args.low_text, args.cuda, args.poly)

        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        result_dir = os.path.join(result_folder, filename)
        os.makedirs(result_dir, exist_ok=True)
        for gh_img, field in zip(gh_pred, CLASSES):
            img = imgproc.cvt2HeatmapImg(gh_img)
            img_path = os.path.join(result_dir,
                                    'res_{}_{}.jpg'.format(filename, field))
            cv2.imwrite(img_path, img)
        h, w = image.shape[:2]
        img = cv2.resize(image, size_heatmap)[::, ::, ::-1]
        img_path = os.path.join(result_dir,
                                'res_{}.jpg'.format(filename, field))
        cv2.imwrite(img_path, img)

        # # save score text
        # filename, file_ext = os.path.splitext(os.path.basename(image_path))
        # mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        # cv2.imwrite(mask_file, score_text)

        res = cv2.resize(res, size_heatmap)
        for polys, field in zip(polys_pred, CLASSES):
            TEXT_WIDTH = 10 * len(field) + 10
            TEXT_HEIGHT = 15
            polys = np.int32([poly.reshape((-1, 1, 2)) for poly in polys])
            res = cv2.polylines(res, polys, True, (0, 0, 255), 2)
            for poly in polys:
                poly[1, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1]]
                poly[2, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1] + TEXT_HEIGHT]
                poly[3, 0] = [
                    poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1] + TEXT_HEIGHT
                ]
                poly[0, 0] = [poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1]]
            res = cv2.fillPoly(res, polys, (224, 224, 224))
            # print(poly)
            for poly in polys:
                res = cv2.putText(res,
                                  field,
                                  tuple(poly[3, 0] + [+5, -5]),
                                  cv2.FONT_HERSHEY_SIMPLEX,
                                  0.4, (0, 0, 0),
                                  thickness=1)
        res_file = os.path.join(result_dir,
                                'res_{}_bbox.jpg'.format(filename, field))
        cv2.imwrite(res_file, res[::, ::, ::-1])
        # break

        # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

    print("elapsed time : {}s".format(time.time() - t))
Пример #20
0
def main(args, logger=None):
    # load net
    net = CRAFT(pretrained=False)  # initialize

    print('Loading weights from checkpoint {}'.format(args.model_path))
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.model_path)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.model_path, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    t = time.time()

    # load data
    """ For test images in a folder """
    image_list, _, _ = file_utils.get_files(args.img_path)
    est_folder = os.path.join(args.rst_path, 'est')
    mask_folder = os.path.join(args.rst_path, 'mask')
    eval_folder = os.path.join(args.rst_path, 'eval')
    cg.folder_exists(est_folder, create_=True)
    cg.folder_exists(mask_folder, create_=True)
    cg.folder_exists(eval_folder, create_=True)

    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                  image_path))
        image = imgproc.loadImage(image_path)
        # image = cv2.resize(image, dsize=(768, 768), interpolation=cv2.INTER_CUBIC) ##
        bboxes, polys, score_text = test_net(
            net,
            image,
            text_threshold=args.text_threshold,
            link_threshold=args.link_threshold,
            low_text=args.low_text,
            cuda=args.cuda,
            canvas_size=args.canvas_size,
            mag_ratio=args.mag_ratio,
            poly=args.poly,
            show_time=args.show_time)
        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = mask_folder + "/res_" + filename + '_mask.jpg'
        if not (cg.file_exists(mask_file)):
            cv2.imwrite(mask_file, score_text)

        file_utils.saveResult15(image_path,
                                bboxes,
                                dirname=est_folder,
                                mode='test')

    eval_dataset(est_folder=est_folder,
                 gt_folder=args.gt_path,
                 eval_folder=eval_folder,
                 dataset_type=args.dataset_type)
    print("elapsed time : {}s".format(time.time() - t))
Пример #21
0
    use_cuda = torch.cuda.is_available()
    device = 'cuda:0' if use_cuda else 'cpu'

    print('Load the synthetic data ...')
    data_loader = Synth80k('D:/Datasets/SynthText')
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=True,
                                               pin_memory=True)
    batch_syn = iter(train_loader)

    print('Prepare the net ...')
    net = CRAFT()
    net.load_state_dict(copyStateDict(
        torch.load('./weigths/synweights/0.pth')))
    net.to(device)
    data_parallel = False
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
        data_parallel = True
    cudnn.benchmark = False

    print('Load the real data')
    real_data = ICDAR2013(net, 'D:/Datasets/ICDAR_2013')
    real_data_loader = torch.utils.data.DataLoader(real_data,
                                                   batch_size=5,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=True,
                                                   pin_memory=True)
Пример #22
0
                                            batch_size=8,
                                            shuffle=True,
                                            num_workers=0,
                                            drop_last=True,
                                            pin_memory=True)
 # print("train_loade1", train_loader)
 #batch_syn = iter(train_loader)
 # prefetcher = data_prefetcher(dataloader)
 # input, target1, target2 = prefetcher.next()
 #print(input.size())
 net = CRAFT()
 #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/CRAFT_net_050000.pth')))
 #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth')))
 #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/craft_mlt_25k.pth')))
 net.load_state_dict(
     copyStateDict(
         torch.load(
             './pretrain/data/CRAFT-pytorch/synweights/Syndata.pth')))
 # net.load_state_dict(copyStateDict(torch.load('./pretrain/data/CRAFT-pytorch/vgg16_bn-6c64b313.pth')))
 #realdata = realdata(net)
 # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size = 768)
 # real_data_loader = torch.utils.data.DataLoader(
 #     realdata,
 #     batch_size=10,
 #     shuffle=True,
 #     num_workers=0,
 #     drop_last=True,
 #     pin_memory=True)
 net = net.cuda()
 #net = CRAFT_net
 # if args.cdua:
 # print('__Number CUDA Devices:', torch.cuda.device_count())
class TextExtractor():
    def __init__(self, image_folder, extract_text_file, split):
        self.i_folder = image_folder
        #print(image_folder)
        #print("aaaaaaa test")
        self.extract_text_file = extract_text_file
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.show_time = False
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.cuda = torch.cuda.is_available()
        self.net = CRAFT()  #(1st model) model to detect words in images
        if self.cuda:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load('CRAFT-pytorch/craft_mlt_25k.pth')))
        else:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load('CRAFT-pytorch/craft_mlt_25k.pth',
                               map_location='cpu')))
        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False
        self.net.eval()
        self.refine_net = None

        self.text_threshold = 0.7
        self.link_threshold = 0.4
        self.low_text = 0.4
        self.poly = False

        self.result_folder = './' + split + '_' + 'intermediate_result/'

        if not os.path.isdir(self.result_folder):
            os.mkdir(self.result_folder)

        #Parameters for image to text model (2nd model)
        self.parser = argparse.ArgumentParser()
        #Data processing
        self.parser.add_argument('--batch_max_length',
                                 type=int,
                                 default=25,
                                 help='maximum-label-length')
        self.parser.add_argument('--imgH',
                                 type=int,
                                 default=32,
                                 help='the height of the input image')
        self.parser.add_argument('--imgW',
                                 type=int,
                                 default=100,
                                 help='the width of the input image')
        self.parser.add_argument('--rgb',
                                 default=False,
                                 action='store_true',
                                 help='use rgb input')
        self.parser.add_argument(
            '--character',
            type=str,
            default='0123456789abcdefghijklmnopqrstuvwxyz',
            help='character label')
        self.parser.add_argument('--sensitive',
                                 action='store_true',
                                 help='for sensitive character mode')
        self.parser.add_argument(
            '--PAD',
            action='store_true',
            help='whether to keep ratio then pad for image resize')
        #Model Architecture
        self.parser.add_argument('--Transformation',
                                 type=str,
                                 default='TPS',
                                 help='Transformation stage. None|TPS')
        self.parser.add_argument(
            '--FeatureExtraction',
            type=str,
            default='ResNet',
            help='FeatureExtraction stage. VGG|RCNN|ResNet')
        self.parser.add_argument('--SequenceModeling',
                                 type=str,
                                 default='BiLSTM',
                                 help='SequenceModeling stage. None|BiLSTM')
        self.parser.add_argument('--Prediction',
                                 type=str,
                                 default='Attn',
                                 help='Prediction stage. CTC|Attn')
        self.parser.add_argument('--num_fiducial',
                                 type=int,
                                 default=20,
                                 help='number of fiducial points of TPS-STN')
        self.parser.add_argument(
            '--input_channel',
            type=int,
            default=1,
            help='the number of input channel of Feature extractor')
        self.parser.add_argument(
            '--output_channel',
            type=int,
            default=512,
            help='the number of output channel of Feature extractor')
        self.parser.add_argument('--hidden_size',
                                 type=int,
                                 default=256,
                                 help='the size of the LSTM hidden state')
        #self.opt = self.parser.parse_args()
        self.opt, unknown = self.parser.parse_known_args()
        #self.opt, unknown = self.parser.parse_known_args()

        if 'CTC' in self.opt.Prediction:
            self.converter = CTCLabelConverter(self.opt.character)
        else:
            self.converter = AttnLabelConverter(self.opt.character)
        self.opt.num_class = len(self.converter.character)
        #print(opt.rgb)
        if self.opt.rgb:
            self.opt.input_channel = 3
        self.opt.num_gpu = torch.cuda.device_count()
        self.opt.batch_size = 192
        #self.opt.batch_size = 3
        self.opt.workers = 0
        self.model = Model(self.opt)  #image to text model (2nd model)
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.load_state_dict(
            torch.load(
                'deep-text-recognition-benchmark/TPS-ResNet-BiLSTM-Attn.pth',
                map_location=self.device))
        self.model.eval()

    def copyStateDict(self, state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    def test_net(self,
                 net,
                 image,
                 text_threshold,
                 link_threshold,
                 low_text,
                 cuda,
                 poly,
                 refine_net=None):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def extract_text(self):
        l = sorted(os.listdir(self.i_folder))
        img_to_index = {}
        count = 0
        for full_file in l:
            split_file = full_file.split(".")
            filename = split_file[0]
            img_to_index[count] = filename
            #print(count, filename)
            count += 1
            #print(filename)
            file_extension = "." + split_file[1]
            #print(filename, file_extension)
            image = imgproc.loadImage(self.i_folder + full_file)
            bboxes, polys, score_text = self.test_net(
                self.net, image, self.text_threshold, self.link_threshold,
                self.low_text, self.cuda, self.poly, self.refine_net)
            img = cv2.imread(self.i_folder + filename + file_extension)
            rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            points = []
            order = []
            for i in range(0, len(bboxes)):
                sample_bbox = bboxes[i]
                min_point = sample_bbox[0]
                max_point = sample_bbox[2]
                for j, p in enumerate(sample_bbox):
                    if (p[0] <= min_point[0]):
                        min_point = (p[0], min_point[1])
                    if (p[1] <= min_point[1]):
                        min_point = (min_point[0], p[1])
                    if (p[0] >= max_point[0]):
                        max_point = (p[0], max_point[1])
                    if (p[1] >= max_point[1]):
                        max_point = (max_point[0], p[1])
                min_point = (max(min(len(rgb_img[0]), min_point[0]),
                                 0), max(min(len(rgb_img), min_point[1]), 0))
                max_point = (max(min(len(rgb_img[0]), max_point[0]),
                                 0), max(min(len(rgb_img), max_point[1]), 0))
                points.append((min_point, max_point))
                order.append(0)
            num_ordered = 0
            rows_ordered = 0
            points_sorted = []
            ordered_points_index = 0
            order_sorted = []
            while (num_ordered < len(points)):
                #find lowest-y that is unordered
                min_y = len(rgb_img)
                min_y_index = -1
                for i in range(0, len(points)):
                    if (order[i] == 0):
                        if (points[i][0][1] <= min_y):
                            min_y = points[i][0][1]
                            min_y_index = i
                rows_ordered += 1
                order[min_y_index] = rows_ordered
                num_ordered += 1
                points_sorted.append(points[min_y_index])
                order_sorted.append(rows_ordered)
                ordered_points_index = len(points_sorted) - 1

                # Group bboxes that are on the same row
                max_y = points[min_y_index][1][1]
                range_y = max_y - min_y
                for i in range(0, len(points)):
                    if (order[i] == 0):
                        min_y_i = points[i][0][1]
                        max_y_i = points[i][1][1]
                        range_y_i = max_y_i - min_y_i
                        if (max_y_i >= min_y and min_y_i <= max_y):
                            overlap = (min(max_y_i, max_y) -
                                       max(min_y_i, min_y)) / (max(
                                           1, min(range_y, range_y_i)))
                            if (overlap >= 0.30):
                                order[i] = rows_ordered
                                num_ordered += 1
                                min_x_i = points[i][0][0]
                                for j in range(ordered_points_index,
                                               len(points_sorted) + 1):
                                    if (j < len(points_sorted)
                                        ):  #insert before
                                        min_x_j = points_sorted[j][0][0]
                                        if (min_x_i < min_x_j):
                                            points_sorted.insert(j, points[i])
                                            order_sorted.insert(
                                                j, rows_ordered)
                                            break
                                    else:  #insert at the end of array
                                        points_sorted.insert(j, points[i])
                                        order_sorted.insert(j, rows_ordered)
                                        break
            for i in range(0, len(points_sorted)):
                min_point = points_sorted[i][0]
                max_point = points_sorted[i][1]
                mask_file = self.result_folder + filename + "_" + str(
                    order_sorted[i]) + "_" + str(i) + file_extension
                crop_image = rgb_img[int(min_point[1]):int(max_point[1]),
                                     int(min_point[0]):int(max_point[0])]
                #print(filename, min_point, max_point, len(rgb_img), len(rgb_img[0]))
                cv2.imwrite(mask_file, crop_image)
        AlignCollate_demo = AlignCollate(imgH=self.opt.imgH,
                                         imgW=self.opt.imgW,
                                         keep_ratio_with_pad=self.opt.PAD)
        demo_data = RawDataset(root=self.result_folder,
                               opt=self.opt)  # use RawDataset
        demo_loader = torch.utils.data.DataLoader(
            demo_data,
            batch_size=self.opt.batch_size,
            shuffle=False,
            num_workers=int(self.opt.workers),
            collate_fn=AlignCollate_demo,
            pin_memory=True)
        f = open(self.extract_text_file, "w")
        count = -1
        curr_order = 1
        curr_filename = ""
        output_string = ""
        end_line = "[SEP] "
        with torch.no_grad():
            for image_tensors, image_path_list in demo_loader:
                batch_size = image_tensors.size(0)
                image = image_tensors.to(self.device)
                #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device)
                #print(image_path_list)
                #print(image.size())
                length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                                  batch_size).to(self.device)
                text_for_pred = torch.LongTensor(batch_size,
                                                 self.opt.batch_max_length +
                                                 1).fill_(0).to(self.device)
                preds = self.model(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                preds_str = self.converter.decode(preds_index, length_for_pred)
                for path, p in zip(image_path_list, preds_str):
                    #print(path)
                    if 'Attn' in self.opt.Prediction:
                        pred_EOS = p.find('[s]')
                        p = p[:
                              pred_EOS]  # prune after "end of sentence" token ([s])
                    path_info = path[len(self.result_folder):].split(
                        ".")[0].split(
                            "_"
                        )  #ASSUMES FILE EXTENSION OF SIZE 4 (.PNG, .JPG, ETC)
                    #print(curr_filename)
                    #print(path_info[0])
                    #print("PATHINFO: ",path_info[0])
                    if (not (curr_filename == path_info[0])):
                        if (not (curr_filename == "")):
                            f.write(str(count) + "\n")
                            f.write(curr_filename + "\n")
                            f.write(output_string + "\n\n")
                        count += 1
                        curr_filename = img_to_index[count]  #path_info[0]
                        #print("CURRFILE: ", curr_filename)
                        while (not (curr_filename == path_info[0])):
                            f.write(str(count) + "\n")
                            f.write(curr_filename + "\n")
                            f.write("\n\n")
                            count += 1
                            curr_filename = img_to_index[count]  #path_info[0]
                            #print("CURRFILE: ", curr_filename)
                        output_string = ""
                        curr_order = 1
                    if (int(path_info[1]) > curr_order):
                        curr_order += 1
                        output_string += end_line
                    output_string += p + " "
            f.write(str(count) + "\n")
            f.write(curr_filename + "\n")
            f.write(output_string + "\n\n")
        f.close()

        #Go through each image in the i_folder and crop out text

        #generate text and write to text file

    def get_item(self, index):
        f = open(self.extract_text_file, "r")
        Lines = f.readlines()
        return (Lines[4 * index + 2][:-1])
        # read text file


#TEST
#t_e = TextExtractor("data/mmimdb-256/dataset-resized-256max/dev_n/images/","text_extract_output.txt")
#t_e.extract_text()
#text = t_e.get_item(1)
#print(text)
class Character_detect(object):
    def __init__(self):
        self.net = CRAFT()
        self.net.load_state_dict(
            self.copyStateDict(
                torch.load("weight/craft_mlt_25k.pth", map_location='cpu')))
        self.net.eval()

    def test_net(self,
                 net,
                 image,
                 text_threshold,
                 link_threshold,
                 low_text,
                 poly,
                 refine_net=None):
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image, 1280, interpolation=cv.INTER_LINEAR, mag_ratio=1.5)
        ratio_h = ratio_w = 1 / target_ratio
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]

        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        return boxes, polys, ret_score_text

    def detect(self, path):
        image = imgproc.loadImage(path)
        refine_net = None
        bboxes, polys, score_text = self.test_net(self.net, image, 0.7, 999999,
                                                  0.5, False, refine_net)
        bbox = []
        for i, box in enumerate(polys):
            poly = np.array(box).astype(np.int32).reshape((-1))
            bbox.append([poly[0] - 3, poly[1] - 5, poly[2], poly[5] + 5])
        file_utils.saveResult(path,
                              image[:, :, ::-1],
                              polys,
                              dirname="Detect_result/")
        bbox.sort(key=sorting_key)
        return bbox

    def copyStateDict(self, state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict
Пример #25
0
if __name__ == '__main__':
    # synthtextloader = Synth80k('/home/jiachx/publicdatasets/SynthText/SynthText', target_size=768, viz=True, debug=True)
    # train_loader = torch.utils.data.DataLoader(
    #     synthtextloader,
    #     batch_size=1,
    #     shuffle=False,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)
    # train_batch = iter(train_loader)
    # image_origin, target_gaussian_heatmap, target_gaussian_affinity_heatmap, mask = next(train_batch)
    from craft import CRAFT
    from torchutil import copyStateDict

    net = CRAFT(freeze=True)
    net.load_state_dict(copyStateDict(torch.load('/ic15_iter_1300.pth')))
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    net.eval()
    dataloader = ICDAR2015(net,
                           '/icdar2015/icdar2015train',
                           target_size=640,
                           viz=True)
    train_loader = torch.utils.data.DataLoader(dataloader,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=0,
                                               drop_last=True,
                                               pin_memory=True)
    total = 0
    total_sum = 0
Пример #26
0
class CraftNet(object):
    def __init__(self, ocrObj):
        self.net = CRAFT()    
        print('Loading weights from checkpoint (' + trained_model + ')')
        if isCuda:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model)))
        else:
            self.net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu')))
        if isCuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False 
        self.net.eval()
        self.jsonFile = defaultdict(dict)
        self.ocrObj = ocrObj

    def test_net(self, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.net(x)

        # make score and link map
        score_text = y[0,:,:,0].cpu().data.numpy()
        score_link = y[0,:,:,1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0,:,:,0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        # if show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text


    def evaluateBB(self, image_path):
        print(image_path)
        print(os.getcwd())
        image = imgproc.loadImage(image_path)
        imageCpy = image
        t = time.time()
        tnew = t
        bboxes, polys, score_text = self.test_net(image, text_thresholdVal, link_thresholdVal, low_textVal, isCuda, polyVal)
        deltaTime = time.time() - tnew
        words = []
        # # save image with BB
        # filename, file_ext = os.path.splitext(os.path.basename(image_path))
        # real_folder = result_folder + '/' + image_path.replace('images', '').replace(filename + file_ext, '')
        # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=real_folder)
        if(isTest):
            curImg = {
                "BBs" : defaultdict(dict),
                "pretrained" : "MLT",
                "procTime" : deltaTime,
                "OCR" : "CRNN",
            }
        for i in range(len(polys)):
            if(saveResult):
                cv2.rectangle(imageCpy, (int(polys[i][0][0]), int(polys[i][0][1])), (int(polys[i][1][0]), int(polys[i][2][1])), (255,0,0), 2)
            tnew = time.time()
            # incorrect, correct = self.ocrObj.getString(image, polys[i])
            # distTime = time.time() - tnew
            # print(incorrect, correct) 
            # tnew = time.time()
            incorrect, correct = self.ocrObj.getStringnGram(image, polys[i])
            nTime = time.time() - tnew
            if(correct is not None and saveResult):
                cv2.putText(imageCpy, correct, (int(polys[i][0][0]), int(polys[i][0][1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),2) 
            words.append((incorrect,correct))
            if(includeTesseract):
                tnew = time.time()
                incTess, corrTess = tesseractOCR.getStringnGram(image, polys[i])
                tessTime = time.time() - tnew
            if(isTest):
                if(includeTesseract):
                    curImg["BBs"][i] = {
                        "BB" : polys[i].tolist(),
                        "strings" : incorrect,
                        "stringsCorrect" : correct,
                        "ocrTime" : nTime,
                        "stringsTess": incTess,
                        "stringsCorrectTess": corrTess,
                        "ocrTimeTess": tessTime
                    }
                else:
                    curImg["BBs"][i] = {
                        "BB" : polys[i].tolist(),
                        "strings" : incorrect,
                        "stringsCorrect" : correct,
                        "ocrTime" : nTime,
                    }
        if(isTest):
            name, folder = getNameAndFolder(image_path)
            self.jsonFile[folder][name] = curImg
            with open("./CRAFT-pytorch-master/stats.json", "w") as write_file:
                json.dump(self.jsonFile, write_file, sort_keys=True, indent=4)
        if(saveResult):
            name, folder = getNameAndFolder(image_path)
            imageCpy = cv2.cvtColor(imageCpy, cv2.COLOR_BGR2RGB)
            if(not os.path.exists("./result/edited/"+folder)):
                os.makedirs("./result/edited/"+folder)
            cv2.imwrite("./result/edited/"+folder+"/"+name, imageCpy)# + image_path
        # return polys
        corr = ""
        incorr = ""
        for w in words:
            if(w[1] != None):
                corr.join(w[1] + "  ")
            if(w[0] != None):
                incorr.join(w[0] + "  ")
        return self.evaluateResponse(curImg["BBs"],image)

    def getQuadrant(self, bb, image):
        shape = image.shape
        newRect = [bb[0][0], bb[0][1], bb[1][0], bb[2][1]]
        # cv2.rectangle(image, (int(newRect[0]), int(newRect[1])), (int(newRect[2]), int(newRect[3])), (255,0,0), 2)
        xPt = newRect[0] + (newRect[2]-newRect[0])/2
        yPt = newRect[1] + (newRect[3]-newRect[1])/2
        xQuad = 0
        yQuad = 0
        if(0 <= xPt < shape[1]/3):
            xQuad = 1
        elif(shape[1]/3 <= xPt < shape[1]*2/3):
            xQuad = 2
        else:
            xQuad = 3
        if(0 <= yPt < shape[0]/3):
            yQuad = 0
        elif(shape[0]/3 <= yPt < shape[0]*2/3):
            yQuad =1
        else:
            yQuad = 2
        # cv2.putText(image, str(xQuad + yQuad*3), (int(newRect[0]), int(newRect[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
        
        return xQuad + yQuad*3

    def evaluateResponse(self,bbValues,image):
        gridWords = {
            1 : [],
            2 : [],
            3 : [],
            4 : [],
            5 : [],
            6 : [],
            7 : [],
            8 : [],
            9 : []
        }
        words = 0
        threeWords = [("",0),("",0),("",0)]
        full = False
        for i in bbValues:
            if(bbValues[i]["stringsCorrect"] != None):
                words += 1
                gridWords[self.getQuadrant(bbValues[i]["BB"],image)].append(bbValues[i]["stringsCorrect"])
                found = False
                j = 0
                while not found and j < len(threeWords):
                    if((threeWords[j][1] < self.getArea(bbValues[i]["BB"]) and full) or threeWords[j][1] == 0):
                        found = True
                        threeWords[j] = (bbValues[i]["stringsCorrect"], self.getArea(bbValues[i]["BB"]))
                    if(j == len(threeWords)-1):
                        full = True

                    j += 1
        dictionary = {
            "grid":{
                "Top Left" : gridWords[1],
                "Top" : gridWords[2],
                "Top Right" : gridWords[3],
                "Center Left" : gridWords[4],
                "Center" : gridWords[5],
                "Center Right" : gridWords[6],
                "Bottom Left" : gridWords[7],
                "Bottom" : gridWords[8],
                "Bottom Right" : gridWords[9]
            },
            "threeWords":[threeWords[0][0],threeWords[1][0],threeWords[2][0]],
            "newWords":words
        }
        # cv2.imshow("gigi",image)
        # cv2.waitKey(0)
        
        return dictionary

    def getArea(self, bb):
        newRect = [bb[0][0], bb[0][1], bb[1][0], bb[2][1]]
        return (newRect[2]-newRect[0])*(newRect[3]-newRect[1])
Пример #27
0
if __name__ == '__main__':
    # synthtextloader = Synth80k('/home/jiachx/publicdatasets/SynthText/SynthText', target_size=768, viz=True, debug=True)
    # train_loader = torch.utils.data.DataLoader(
    #     synthtextloader,
    #     batch_size=1,
    #     shuffle=False,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)
    # train_batch = iter(train_loader)
    # image_origin, target_gaussian_heatmap, target_gaussian_affinity_heatmap, mask = next(train_batch)
    from craft import CRAFT
    from torchutil import copyStateDict

    net = CRAFT(freeze=True)
    net.load_state_dict(
        copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth')))
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    net.eval()
    dataloader = ICDAR2015(net,
                           '/data/CRAFT-pytorch/icdar2015',
                           target_size=768,
                           viz=True)
    train_loader = torch.utils.data.DataLoader(dataloader,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=0,
                                               drop_last=True,
                                               pin_memory=True)
    total = 0
    total_sum = 0
Пример #28
0
class Ocr:
    def __init__(self):
        super().__init__()
        manager = Manager()
        self.send = manager.list()
        self.date = manager.list()
        self.quote = manager.list()
        self.number = manager.list()
        self.header = manager.list()
        self.sign = manager.list()
        self.device = torch.device('cpu')
        state_dict = torch.load(
            '/home/dung/Project/Python/ocr/craft_mlt_25k.pth')
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v

        self.craft = CRAFT()
        self.craft.load_state_dict(new_state_dict)
        self.craft.to(self.device)
        self.craft.eval()
        self.craft.share_memory()
        self.config = Cfg.load_config_from_name('vgg_transformer')
        self.config[
            'weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
        self.config['device'] = 'cpu'
        self.config['predictor']['beamsearch'] = False
        self.weights = '/home/dung/Documents/transformerocr.pth'

        # self.model, self.vocab = build_model(self.config)

    def predict(self, model, vocab, seq, key, idx, img):

        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])
        img = img.to(self.config['device'])
        with torch.no_grad():
            src = model.cnn(img)
            memory = model.transformer.forward_encoder(src)
            translated_sentence = [[1] * len(img)]
            max_length = 0
            while max_length <= 128 and not all(
                    np.any(np.asarray(translated_sentence).T == 2, axis=1)):
                tgt_inp = torch.LongTensor(translated_sentence).to(self.device)
                output = model.transformer.forward_decoder(tgt_inp, memory)
                output = output.to('cpu')
                values, indices = torch.topk(output, 5)
                indices = indices[:, -1, 0]
                indices = indices.tolist()
                translated_sentence.append(indices)
                max_length += 1
                del output
            translated_sentence = np.asarray(translated_sentence).T
        s = translated_sentence[0].tolist()
        s = vocab.decode(s)
        seq[idx] = s
        # print(time.time() - time1)

    def process(self, craft, seq, key, sub_img):
        img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
            sub_img, 2560, interpolation=cv2.INTER_LINEAR, mag_ratio=1.)
        ratio_h = ratio_w = 1 / target_ratio

        x = normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = x.unsqueeze(0)  # [c, h, w] to [b, c, h, w]
        x = x.to(self.device)
        y, feature = craft(x)
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()
        boxes, polys = getDetBoxes(score_text,
                                   score_link,
                                   text_threshold=0.7,
                                   link_threshold=0.4,
                                   low_text=0.4,
                                   poly=False)
        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None:
                polys[k] = boxes[k]
        result = []
        for i, box in enumerate(polys):
            poly = np.array(box).astype(np.int32).reshape((-1))
            result.append(poly)
        horizontal_list, free_list = group_text_box(result,
                                                    slope_ths=0.8,
                                                    ycenter_ths=0.5,
                                                    height_ths=1,
                                                    width_ths=1,
                                                    add_margin=0.1)
        # horizontal_list = [i for i in horizontal_list if i[0] > 0 and i[1] > 0]
        min_size = 20
        if min_size:
            horizontal_list = [
                i for i in horizontal_list
                if max(i[1] - i[0], i[3] - i[2]) > 10
            ]
            free_list = [
                i for i in free_list
                if max(diff([c[0] for c in i]), diff([c[1]
                                                      for c in i])) > min_size
            ]
        seq[:] = [None] * len(horizontal_list)
        model, vocab = build_model(self.config)
        model.load_state_dict(
            torch.load(self.weights, map_location=torch.device('cpu')))

        for i, ele in enumerate(horizontal_list):
            ele = [0 if i < 0 else i for i in ele]
            img = sub_img[ele[2]:ele[3], ele[0]:ele[1], :]
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img.astype(np.uint8))
            p = threading.Thread(target=self.predict,
                                 args=(model, vocab, seq, key, i, img))
            p.start()
            p.join()
        # print(time.time() - time1)

    def forward(self, img, rs):
        # time1 = time.time()
        for key, v in rs.items():
            x0, y0, x1, y1 = v
            if key == 'send':
                p = mp.Process(target=self.process,
                               args=(
                                   self.craft,
                                   self.send,
                                   key,
                                   img[y0:y1, x0:x1, :],
                               ))
            elif key == 'date':
                p = mp.Process(target=self.process,
                               args=(
                                   self.craft,
                                   self.date,
                                   key,
                                   img[y0:y1, x0:x1, :],
                               ))
            elif key == 'quote':
                p = mp.Process(target=self.process,
                               args=(
                                   self.craft,
                                   self.date,
                                   key,
                                   img[y0:y1, x0:x1, :],
                               ))
            elif key == 'number':
                p = mp.Process(target=self.process,
                               args=(
                                   self.craft,
                                   self.date,
                                   key,
                                   img[y0:y1, x0:x1, :],
                               ))
            elif key == 'header':
                p = mp.Process(target=self.process,
                               args=(
                                   self.craft,
                                   self.date,
                                   key,
                                   img[y0:y1, x0:x1, :],
                               ))
            elif key == 'sign':
                p = mp.Process(target=self.process,
                               args=(
                                   self.craft,
                                   self.date,
                                   key,
                                   img[y0:y1, x0:x1, :],
                               ))
            p.start()
            p.join()
        return self.send[:], self.date[:], self.quote[:], self.number[:], self.header[:], self.sign[:]
Пример #29
0
    render_img = np.hstack((render_img, score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)

    if args.show_time:
        print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, ret_score_text


if __name__ == '__main__':
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
Пример #30
0
class TextDetector:
    def __init__(self):
        #Parameters
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.text_threshold = 0.7
        self.low_text = 0.4
        self.link_threshold = 0.4
        self.refine = False
        self.refiner_model = ''
        self.poly = False
        self.cuda = True

        self.net = CRAFT()
        if self.cuda:
            self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth')))
        else:
            self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth', map_location='cpu')))

        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
        self.net.eval()

        # LinkRefiner
        self.refine_net = None
        if self.refine:
            from refinenet import RefineNet
            self.refine_net = RefineNet()
            if self.cuda:
                self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model)))
                self.refine_net = self.refine_net.cuda()
                self.refine_net = torch.nn.DataParallel(self.refine_net)
            else:
                self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model, map_location='cpu')))
            self.refine_net.eval()
            self.poly = True

    def detect(self, image):
        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, self.canvas_size,
                                                                              interpolation=cv2.INTER_LINEAR,
                                                                              mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]

        if self.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if self.refine_net is not None:
            with torch.no_grad():
                y_refiner = self.refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()


        # Post-processing
        boxes, _ = craft_utils.getDetBoxes(score_text, score_link, self.text_threshold, self.link_threshold,
                                               self.low_text, self.poly)
        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        toRet = []
        for box in boxes:
            toRet.append(box2xyxy(box, image.shape[0: 2]))

        return toRet