def test(modelpara): # load net net = CRAFT() # initialize print('Loading weights from checkpoint {}'.format(modelpara)) if args.cuda: net.load_state_dict(copyStateDict(torch.load(modelpara))) else: net.load_state_dict(copyStateDict(torch.load(modelpara, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' #cv2.imwrite(mask_file, score_text) file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def main(trained_model='weights/craft_mlt_25k.pth', text_threshold=0.7, low_text=0.4, link_threshold=0.4, cuda=True, canvas_size=1280, mag_ratio=1.5, poly=False, show_time=False, test_folder='/data/', refine=True, refiner_model='weights/craft_refiner_CTW1500.pth'): # if __name__ == '__main__': # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + trained_model + ')') if cuda: net.load_state_dict(copyStateDict(torch.load(trained_model))) else: net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu'))) if cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + refiner_model + ')') if cuda: refine_net.load_state_dict(copyStateDict(torch.load(refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict(copyStateDict(torch.load(refiner_model, map_location='cpu'))) refine_net.eval() poly = True t = time.time() # load data image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' cv2.imwrite(mask_file, score_text) final_img = file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def LoadDetectionModel(args): net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') net.load_state_dict(copyStateDict(torch.load(args.trained_model)))#,map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() return net
def main(): # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if args.refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') if args.cuda: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) refine_net.eval() args.poly = True t = time.time() print(image_list) # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' cv2.imwrite(mask_file, score_text) file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) # print("elapsed time : {}s".format(time.time() - t))
def get_detector(trained_model, device='cpu'): net = CRAFT() if device == 'cpu': net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) else: net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) net = torch.nn.DataParallel(net).to(device) cudnn.benchmark = False net.eval() return net
def createModel(): net = CRAFT() weightPath = os.path.join(settings.BASE_DIR, 'CRAFT/weights/craft_mlt_25k.pth') print('Loading weights from checkpoint (' + weightPath + ')') net.load_state_dict(copyStateDict(torch.load(weightPath))) net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() return net
def load_detection_model(): parser = argparse.ArgumentParser(description='CRAFT Text Detection') parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') args = parser.parse_args(["--trained_model=./models/craft_mlt_25k.pth","--refine", "--refiner_model=./models/craft_refiner_CTW1500.pth"]) net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if args.refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') if args.cuda: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) refine_net.eval() # args.poly = True return net,refine_net,args
def runCraftNet(image_list): # image list is the folder containing the images args = argparse.Namespace( canvas_size=1280, cuda=False, link_threshold=0.4, low_text=0.4, mag_ratio=1.5, poly=False, refine=False, refiner_model='weights/craft_refiner_CTW1500.pth', show_time=False, test_folder='images', text_threshold=0.7, trained_model='craft_mlt_25k.pth') net = CRAFT() # initialize net.load_state_dict( copyStateDict(torch.load(args.trained_model, map_location='cpu'))) net.eval() # image_list, _, _ = file_utils.get_files(args.test_folder) t = time.time() # result_folder = './result/' # load data refine_net = None for k, image_path in enumerate(image_list): image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) # print("elapsed time : {}s ".format(time.time() - t)) img = np.array(image[:, :, ::-1]) txt = [] for i, box in enumerate(polys): poly = np.array(box).astype(np.int32).reshape((-1)) strResult = ','.join([str(p) for p in poly]) txt.append(strResult) return [img, txt]
def test(image, epoch, index, cvt=False): image = image print('input image shape {}'.format(image.shape)) checkpoint = torch.load('/root/data/test_param/{}_{}.pth'.format( epoch, index)) net = CRAFT().cuda() net.load_state_dict(copyStateDict(checkpoint['model_state_dict'])) #이미지 리사이징 등등 #했다고 치고 진행 image = normalizeMeanVariance(image) image = cv2.resize(image, (768, 768), interpolation=cv2.INTER_LINEAR) x = torch.from_numpy(image).permute(2, 0, 1) x = Variable(x.unsqueeze(0).type(torch.FloatTensor)) x = x.cuda() print(x.size()) with torch.no_grad(): y, _ = net(x) pred_region = y[0, :, :, 0].cpu().data.numpy() pred_affinity = y[0, :, :, 1].cpu().data.numpy() print(type(pred_region)) print(pred_region.shape) # cvt == True -> Region, Affinity score H x W x C # cvt == False -> Region, Affinity score H x W if cvt: pred_region = Gray2RGB(pred_region) pred_affinity = Gray2RGB(pred_affinity) return pred_region, pred_affinity
def get_detector(trained_model, device='cpu', quantize=True): net = CRAFT() if device == 'cpu': net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) if quantize: try: torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) except: pass else: net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) net = torch.nn.DataParallel(net).to(device) cudnn.benchmark = False net.eval() return net
def main(pth_file_path): cuda = True net = CRAFT() # initialize print('Loading weights from checkpoint (' + pth_file_path + ')') if cuda: net.load_state_dict(copyStateDict(torch.load(pth_file_path))) else: net.load_state_dict(copyStateDict(torch.load(pth_file_path, map_location='cpu'))) if cuda: net = net.cuda() cudnn.benchmark = False net.eval() script_module = torch.jit.script(net) file_path_without_ext = os.path.splitext(pth_file_path)[0] output_file_path = file_path_without_ext + ".pt" script_module.save(output_file_path) print("TorchScript model created:", output_file_path)
if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = cvt2HeatmapImg(render_img) if show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text net = CRAFT() net.load_state_dict(copyStateDict(torch.load(trained_model_path, map_location='cpu'))) net.eval() # image_path = './doc/2.jpg' # image = loadImage(image_path) # bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net) # # poly_indexes = {} # central_poly_indexes = [] # for i in range(len(polys)): # poly_indexes[i] = polys[i] # x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] + polys[i][3][0]) / 4 # y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] + polys[i][3][1]) / 4 # central_poly_indexes.append({i: [int(x_central), int(y_central)]}) import copy
def applyCraft(image_file): # Initialize CRAFT parameters text_threshold = 0.7 low_text = 0.4 link_threshold = 0.4 cuda = False canvas_size = 1280 mag_ratio = 1.5 # if text image present curve --> poly=true poly = False refine = False show_time = False refine_net = None trained_model_path = './app/CRAFT/craft_mlt_25k.pth' net = CRAFT() net.load_state_dict( copyStateDict(torch.load(trained_model_path, map_location='cpu'))) net.eval() image = imgproc.loadImage(image_file) poly = False refine = False show_time = False refine_net = None bboxes, polys, score_text = test_net(net, canvas_size, mag_ratio, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net) # Compute coordinate of central point in each bounding box returned by CRAFT # Purpose: easier for us to make cluster in G-DBScan step poly_indexes = {} central_poly_indexes = [] for i in range(len(polys)): poly_indexes[i] = polys[i] x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] + polys[i][3][0]) / 4 y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] + polys[i][3][1]) / 4 central_poly_indexes.append({i: [int(x_central), int(y_central)]}) # for i in central_poly_indexes: # print(i) # For each of these cordinates convert them to new Point instances X = [] for idx, x in enumerate(central_poly_indexes): point = Point(x[idx][0], x[idx][1], idx) X.append(point) # Cluster these central points clustered = GDBSCAN(Points(X), n_pred, 1, w_card) cluster_values = [] for cluster in clustered: sort_cluster = sorted(cluster, key=lambda elem: (elem.x, elem.y)) max_point_id = sort_cluster[len(sort_cluster) - 1].id min_point_id = sort_cluster[0].id max_rectangle = sorted(poly_indexes[max_point_id], key=lambda elem: (elem[0], elem[1])) min_rectangle = sorted(poly_indexes[min_point_id], key=lambda elem: (elem[0], elem[1])) right_above_max_vertex = max_rectangle[len(max_rectangle) - 1] right_below_max_vertex = max_rectangle[len(max_rectangle) - 2] left_above_min_vertex = min_rectangle[0] left_below_min_vertex = min_rectangle[1] if (int(min_rectangle[0][1]) > int(min_rectangle[1][1])): left_above_min_vertex = min_rectangle[1] left_below_min_vertex = min_rectangle[0] if (int(max_rectangle[len(max_rectangle) - 1][1]) < int( max_rectangle[len(max_rectangle) - 2][1])): right_above_max_vertex = max_rectangle[len(max_rectangle) - 2] right_below_max_vertex = max_rectangle[len(max_rectangle) - 1] cluster_values.append([ left_above_min_vertex, left_below_min_vertex, right_above_max_vertex, right_below_max_vertex ]) image = imgproc.loadImage(image_file) img = np.array(image[:, :, ::-1]) img = img.astype('uint8') ocr_res = [] for i, box in enumerate(cluster_values): poly = np.array(box).astype(np.int32).reshape((-1)) poly = poly.reshape(-1, 2) rect = cv2.boundingRect(poly) x, y, w, h = rect cropped = img[y:y + h, x:x + w].copy() # Preprocess cropped segment cropped = cv2.resize(cropped, None, fx=5, fy=5, interpolation=cv2.INTER_LINEAR) cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY) cropped = cv2.GaussianBlur(cropped, (3, 3), 0) cropped = cv2.bilateralFilter(cropped, 5, 25, 25) cropped = cv2.dilate(cropped, None, iterations=1) cropped = cv2.threshold(cropped, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] #cropped = cv2.threshold(cropped, 90, 255, cv2.THRESH_BINARY)[1] #cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB) ocr_res.append(pytesseract.image_to_string(cropped, lang='eng')) return ocr_res
def ground_truth(args): # initiate pretrained network net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(test.copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() filelist, _, _ = file_utils.list_files('/home/ubuntu/Kyumin/Autotation/data/IC13/images') for img_name in filelist: # get datapath if 'train' in img_name: label_name = img_name.replace('images/train/', 'labels/train/gt_').replace('jpg', 'txt') else: label_name = img_name.replace('images/test/', 'labels/test/gt_').replace('jpg', 'txt') label_dir = img_name.replace('Autotation', 'craft').replace('images', 'labels').replace('.jpg', '/') os.makedirs(label_dir, exist_ok=True) image = imgproc.loadImage(img_name) gt_boxes = [] gt_words = [] with open(label_name, 'r', encoding='utf-8-sig') as f: lines = f.readlines() for line in lines: if 'IC13' in img_name: # IC13 gt_box, gt_word, _ = line.split('"') if 'train' in img_name: x1, y1, x2, y2 = [int(a) for a in gt_box.strip().split(' ')] else: x1, y1, x2, y2 = [int(a.strip()) for a in gt_box.split(',') if a.strip().isdigit()] gt_boxes.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])) gt_words.append(gt_word) elif 'IC15' in img_name: gt_data = line.strip().split(',') gt_box = gt_data[:8] if len(gt_data) > 9: gt_word = ','.join(gt_data[8:]) else: gt_word = gt_data[-1] gt_box = [int(a) for a in gt_box] gt_box = np.reshape(np.array(gt_box), (4, 2)) gt_boxes.append(gt_box) gt_words.append(gt_word) score_region, score_link, conf_map = generate_gt(net, image, gt_boxes, gt_words, args) torch.save(score_region, label_dir + 'region.pt') torch.save(score_link, label_dir + 'link.pt') torch.save(conf_map, label_dir + 'conf.pt')
torch.save(score_region, label_dir + 'region.pt') torch.save(score_link, label_dir + 'link.pt') torch.save(conf_map, label_dir + 'conf.pt') if __name__ == '__main__': import ocr score_region = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/region.pt') score_link = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/link.pt') conf_map = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/conf.pt') image = imgproc.loadImage('/home/ubuntu/Kyumin/Autotation/data/IC13/images/train/100.jpg') print(score_region.shape, score_link.shape, conf_map.shape) # cv2.imshow('original', image) cv2.imshow('region', imgproc.cvt2HeatmapImg(score_region)) cv2.imshow('link', score_link) cv2.imshow('conf', conf_map) net = CRAFT().cuda() net.load_state_dict(test.copyStateDict(torch.load('weights/craft_mlt_25k.pth'))) net.eval() _, _, ref_text, ref_link, _ = test.test_net(net, image, ocr.argument_parser().parse_args()) cv2.imshow('ref text', imgproc.cvt2HeatmapImg(ref_text)) cv2.imshow('ref link', ref_link) cv2.waitKey(0) cv2.destroyAllWindows()
#dataloader = syndata(imgname, charbox, imgtxt) dataloader = Synth80k('./data/SynthText', target_size = args.target_size) train_loader = torch.utils.data.DataLoader( dataloader, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True) batch_syn = iter(train_loader) # prefetcher = data_prefetcher(dataloader) # input, target1, target2 = prefetcher.next() #print(input.size()) net = CRAFT(freeze=True) net.load_state_dict(copyStateDict(torch.load(args.load_model))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/CRAFT_net_050000.pth'))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/craft_mlt_25k.pth'))) #net.load_state_dict(copyStateDict(torch.load('vgg16_bn-6c64b313.pth'))) #realdata = realdata(net) # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size = 768) # real_data_loader = torch.utils.data.DataLoader( # realdata, # batch_size=10, # shuffle=True, # num_workers=0, # drop_last=True, # pin_memory=True) net = net.cuda() #net = CRAFT_net
print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, ret_score_text if __name__ == '__main__': # load net net = CRAFT() # initialize if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False print('Loading weights from checkpoint (' + args.trained_model + ')') net.load_state_dict(torch.load(args.trained_model)) net.eval() t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) bboxes, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda)
if __name__ == '__main__': dataloader = Synth80k(root_data + '/SynthText/SynthText', target_size=768) train_loader = torch.utils.data.DataLoader(dataloader, batch_size=1, shuffle=True, num_workers=0, drop_last=True, pin_memory=True) batch_syn = iter(train_loader) print("Loaded Synth data.") net = CRAFT() net.load_state_dict(copyStateDict( torch.load('pretrain/craft_mlt_25k.pth'))) net = net.cuda() print("Loaded CRAFT net.") net = torch.nn.DataParallel(net, device_ids=[0]).cuda() cudnn.benchmark = True net.train() realdata = ICDAR2015(net, root_data + '/DDI', target_size=768) real_data_loader = torch.utils.data.DataLoader(realdata, batch_size=1, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
def test(modelpara): # load net net = CRAFT() # initialize print('Loading weights from checkpoint {}'.format(modelpara)) if args.cuda: net.load_state_dict(copyStateDict(torch.load(modelpara))) else: net.load_state_dict( copyStateDict(torch.load(modelpara, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\n') image = imgproc.loadImage(image_path) res = image.copy() # bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) gh_pred, bboxes_pred, polys_pred, size_heatmap = test_net( net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) filename, file_ext = os.path.splitext(os.path.basename(image_path)) result_dir = os.path.join(result_folder, filename) os.makedirs(result_dir, exist_ok=True) for gh_img, field in zip(gh_pred, CLASSES): img = imgproc.cvt2HeatmapImg(gh_img) img_path = os.path.join(result_dir, 'res_{}_{}.jpg'.format(filename, field)) cv2.imwrite(img_path, img) h, w = image.shape[:2] img = cv2.resize(image, size_heatmap)[::, ::, ::-1] img_path = os.path.join(result_dir, 'res_{}.jpg'.format(filename, field)) cv2.imwrite(img_path, img) # # save score text # filename, file_ext = os.path.splitext(os.path.basename(image_path)) # mask_file = result_folder + "/res_" + filename + '_mask.jpg' # cv2.imwrite(mask_file, score_text) res = cv2.resize(res, size_heatmap) for polys, field in zip(polys_pred, CLASSES): TEXT_WIDTH = 10 * len(field) + 10 TEXT_HEIGHT = 15 polys = np.int32([poly.reshape((-1, 1, 2)) for poly in polys]) res = cv2.polylines(res, polys, True, (0, 0, 255), 2) for poly in polys: poly[1, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1]] poly[2, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1] + TEXT_HEIGHT] poly[3, 0] = [ poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1] + TEXT_HEIGHT ] poly[0, 0] = [poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1]] res = cv2.fillPoly(res, polys, (224, 224, 224)) # print(poly) for poly in polys: res = cv2.putText(res, field, tuple(poly[3, 0] + [+5, -5]), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), thickness=1) res_file = os.path.join(result_dir, 'res_{}_bbox.jpg'.format(filename, field)) cv2.imwrite(res_file, res[::, ::, ::-1]) # break # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def main(args, logger=None): # load net net = CRAFT(pretrained=False) # initialize print('Loading weights from checkpoint {}'.format(args.model_path)) if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.model_path))) else: net.load_state_dict( copyStateDict(torch.load(args.model_path, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data """ For test images in a folder """ image_list, _, _ = file_utils.get_files(args.img_path) est_folder = os.path.join(args.rst_path, 'est') mask_folder = os.path.join(args.rst_path, 'mask') eval_folder = os.path.join(args.rst_path, 'eval') cg.folder_exists(est_folder, create_=True) cg.folder_exists(mask_folder, create_=True) cg.folder_exists(eval_folder, create_=True) for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path)) image = imgproc.loadImage(image_path) # image = cv2.resize(image, dsize=(768, 768), interpolation=cv2.INTER_CUBIC) ## bboxes, polys, score_text = test_net( net, image, text_threshold=args.text_threshold, link_threshold=args.link_threshold, low_text=args.low_text, cuda=args.cuda, canvas_size=args.canvas_size, mag_ratio=args.mag_ratio, poly=args.poly, show_time=args.show_time) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = mask_folder + "/res_" + filename + '_mask.jpg' if not (cg.file_exists(mask_file)): cv2.imwrite(mask_file, score_text) file_utils.saveResult15(image_path, bboxes, dirname=est_folder, mode='test') eval_dataset(est_folder=est_folder, gt_folder=args.gt_path, eval_folder=eval_folder, dataset_type=args.dataset_type) print("elapsed time : {}s".format(time.time() - t))
use_cuda = torch.cuda.is_available() device = 'cuda:0' if use_cuda else 'cpu' print('Load the synthetic data ...') data_loader = Synth80k('D:/Datasets/SynthText') train_loader = torch.utils.data.DataLoader(data_loader, batch_size=1, shuffle=True, num_workers=0, drop_last=True, pin_memory=True) batch_syn = iter(train_loader) print('Prepare the net ...') net = CRAFT() net.load_state_dict(copyStateDict( torch.load('./weigths/synweights/0.pth'))) net.to(device) data_parallel = False if torch.cuda.device_count() > 1: net = nn.DataParallel(net) data_parallel = True cudnn.benchmark = False print('Load the real data') real_data = ICDAR2013(net, 'D:/Datasets/ICDAR_2013') real_data_loader = torch.utils.data.DataLoader(real_data, batch_size=5, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
batch_size=8, shuffle=True, num_workers=0, drop_last=True, pin_memory=True) # print("train_loade1", train_loader) #batch_syn = iter(train_loader) # prefetcher = data_prefetcher(dataloader) # input, target1, target2 = prefetcher.next() #print(input.size()) net = CRAFT() #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/CRAFT_net_050000.pth'))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/craft_mlt_25k.pth'))) net.load_state_dict( copyStateDict( torch.load( './pretrain/data/CRAFT-pytorch/synweights/Syndata.pth'))) # net.load_state_dict(copyStateDict(torch.load('./pretrain/data/CRAFT-pytorch/vgg16_bn-6c64b313.pth'))) #realdata = realdata(net) # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size = 768) # real_data_loader = torch.utils.data.DataLoader( # realdata, # batch_size=10, # shuffle=True, # num_workers=0, # drop_last=True, # pin_memory=True) net = net.cuda() #net = CRAFT_net # if args.cdua: # print('__Number CUDA Devices:', torch.cuda.device_count())
class TextExtractor(): def __init__(self, image_folder, extract_text_file, split): self.i_folder = image_folder #print(image_folder) #print("aaaaaaa test") self.extract_text_file = extract_text_file self.canvas_size = 1280 self.mag_ratio = 1.5 self.show_time = False self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.cuda = torch.cuda.is_available() self.net = CRAFT() #(1st model) model to detect words in images if self.cuda: self.net.load_state_dict( self.copyStateDict( torch.load('CRAFT-pytorch/craft_mlt_25k.pth'))) else: self.net.load_state_dict( self.copyStateDict( torch.load('CRAFT-pytorch/craft_mlt_25k.pth', map_location='cpu'))) if self.cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() self.refine_net = None self.text_threshold = 0.7 self.link_threshold = 0.4 self.low_text = 0.4 self.poly = False self.result_folder = './' + split + '_' + 'intermediate_result/' if not os.path.isdir(self.result_folder): os.mkdir(self.result_folder) #Parameters for image to text model (2nd model) self.parser = argparse.ArgumentParser() #Data processing self.parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') self.parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') self.parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') self.parser.add_argument('--rgb', default=False, action='store_true', help='use rgb input') self.parser.add_argument( '--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') self.parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') self.parser.add_argument( '--PAD', action='store_true', help='whether to keep ratio then pad for image resize') #Model Architecture self.parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') self.parser.add_argument( '--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet') self.parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') self.parser.add_argument('--Prediction', type=str, default='Attn', help='Prediction stage. CTC|Attn') self.parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') self.parser.add_argument( '--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') self.parser.add_argument( '--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') self.parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') #self.opt = self.parser.parse_args() self.opt, unknown = self.parser.parse_known_args() #self.opt, unknown = self.parser.parse_known_args() if 'CTC' in self.opt.Prediction: self.converter = CTCLabelConverter(self.opt.character) else: self.converter = AttnLabelConverter(self.opt.character) self.opt.num_class = len(self.converter.character) #print(opt.rgb) if self.opt.rgb: self.opt.input_channel = 3 self.opt.num_gpu = torch.cuda.device_count() self.opt.batch_size = 192 #self.opt.batch_size = 3 self.opt.workers = 0 self.model = Model(self.opt) #image to text model (2nd model) self.model = torch.nn.DataParallel(self.model).to(self.device) self.model.load_state_dict( torch.load( 'deep-text-recognition-benchmark/TPS-ResNet-BiLSTM-Attn.pth', map_location=self.device)) self.model.eval() def copyStateDict(self, state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict def test_net(self, net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, self.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=self.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # refine link if refine_net is not None: with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = y_refiner[0, :, :, 0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if self.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text def extract_text(self): l = sorted(os.listdir(self.i_folder)) img_to_index = {} count = 0 for full_file in l: split_file = full_file.split(".") filename = split_file[0] img_to_index[count] = filename #print(count, filename) count += 1 #print(filename) file_extension = "." + split_file[1] #print(filename, file_extension) image = imgproc.loadImage(self.i_folder + full_file) bboxes, polys, score_text = self.test_net( self.net, image, self.text_threshold, self.link_threshold, self.low_text, self.cuda, self.poly, self.refine_net) img = cv2.imread(self.i_folder + filename + file_extension) rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) points = [] order = [] for i in range(0, len(bboxes)): sample_bbox = bboxes[i] min_point = sample_bbox[0] max_point = sample_bbox[2] for j, p in enumerate(sample_bbox): if (p[0] <= min_point[0]): min_point = (p[0], min_point[1]) if (p[1] <= min_point[1]): min_point = (min_point[0], p[1]) if (p[0] >= max_point[0]): max_point = (p[0], max_point[1]) if (p[1] >= max_point[1]): max_point = (max_point[0], p[1]) min_point = (max(min(len(rgb_img[0]), min_point[0]), 0), max(min(len(rgb_img), min_point[1]), 0)) max_point = (max(min(len(rgb_img[0]), max_point[0]), 0), max(min(len(rgb_img), max_point[1]), 0)) points.append((min_point, max_point)) order.append(0) num_ordered = 0 rows_ordered = 0 points_sorted = [] ordered_points_index = 0 order_sorted = [] while (num_ordered < len(points)): #find lowest-y that is unordered min_y = len(rgb_img) min_y_index = -1 for i in range(0, len(points)): if (order[i] == 0): if (points[i][0][1] <= min_y): min_y = points[i][0][1] min_y_index = i rows_ordered += 1 order[min_y_index] = rows_ordered num_ordered += 1 points_sorted.append(points[min_y_index]) order_sorted.append(rows_ordered) ordered_points_index = len(points_sorted) - 1 # Group bboxes that are on the same row max_y = points[min_y_index][1][1] range_y = max_y - min_y for i in range(0, len(points)): if (order[i] == 0): min_y_i = points[i][0][1] max_y_i = points[i][1][1] range_y_i = max_y_i - min_y_i if (max_y_i >= min_y and min_y_i <= max_y): overlap = (min(max_y_i, max_y) - max(min_y_i, min_y)) / (max( 1, min(range_y, range_y_i))) if (overlap >= 0.30): order[i] = rows_ordered num_ordered += 1 min_x_i = points[i][0][0] for j in range(ordered_points_index, len(points_sorted) + 1): if (j < len(points_sorted) ): #insert before min_x_j = points_sorted[j][0][0] if (min_x_i < min_x_j): points_sorted.insert(j, points[i]) order_sorted.insert( j, rows_ordered) break else: #insert at the end of array points_sorted.insert(j, points[i]) order_sorted.insert(j, rows_ordered) break for i in range(0, len(points_sorted)): min_point = points_sorted[i][0] max_point = points_sorted[i][1] mask_file = self.result_folder + filename + "_" + str( order_sorted[i]) + "_" + str(i) + file_extension crop_image = rgb_img[int(min_point[1]):int(max_point[1]), int(min_point[0]):int(max_point[0])] #print(filename, min_point, max_point, len(rgb_img), len(rgb_img[0])) cv2.imwrite(mask_file, crop_image) AlignCollate_demo = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD) demo_data = RawDataset(root=self.result_folder, opt=self.opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=self.opt.batch_size, shuffle=False, num_workers=int(self.opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) f = open(self.extract_text_file, "w") count = -1 curr_order = 1 curr_filename = "" output_string = "" end_line = "[SEP] " with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(self.device) #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device) #print(image_path_list) #print(image.size()) length_for_pred = torch.IntTensor([self.opt.batch_max_length] * batch_size).to(self.device) text_for_pred = torch.LongTensor(batch_size, self.opt.batch_max_length + 1).fill_(0).to(self.device) preds = self.model(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) for path, p in zip(image_path_list, preds_str): #print(path) if 'Attn' in self.opt.Prediction: pred_EOS = p.find('[s]') p = p[: pred_EOS] # prune after "end of sentence" token ([s]) path_info = path[len(self.result_folder):].split( ".")[0].split( "_" ) #ASSUMES FILE EXTENSION OF SIZE 4 (.PNG, .JPG, ETC) #print(curr_filename) #print(path_info[0]) #print("PATHINFO: ",path_info[0]) if (not (curr_filename == path_info[0])): if (not (curr_filename == "")): f.write(str(count) + "\n") f.write(curr_filename + "\n") f.write(output_string + "\n\n") count += 1 curr_filename = img_to_index[count] #path_info[0] #print("CURRFILE: ", curr_filename) while (not (curr_filename == path_info[0])): f.write(str(count) + "\n") f.write(curr_filename + "\n") f.write("\n\n") count += 1 curr_filename = img_to_index[count] #path_info[0] #print("CURRFILE: ", curr_filename) output_string = "" curr_order = 1 if (int(path_info[1]) > curr_order): curr_order += 1 output_string += end_line output_string += p + " " f.write(str(count) + "\n") f.write(curr_filename + "\n") f.write(output_string + "\n\n") f.close() #Go through each image in the i_folder and crop out text #generate text and write to text file def get_item(self, index): f = open(self.extract_text_file, "r") Lines = f.readlines() return (Lines[4 * index + 2][:-1]) # read text file #TEST #t_e = TextExtractor("data/mmimdb-256/dataset-resized-256max/dev_n/images/","text_extract_output.txt") #t_e.extract_text() #text = t_e.get_item(1) #print(text)
class Character_detect(object): def __init__(self): self.net = CRAFT() self.net.load_state_dict( self.copyStateDict( torch.load("weight/craft_mlt_25k.pth", map_location='cpu'))) self.net.eval() def test_net(self, net, image, text_threshold, link_threshold, low_text, poly, refine_net=None): img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, 1280, interpolation=cv.INTER_LINEAR, mag_ratio=1.5) ratio_h = ratio_w = 1 / target_ratio x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) return boxes, polys, ret_score_text def detect(self, path): image = imgproc.loadImage(path) refine_net = None bboxes, polys, score_text = self.test_net(self.net, image, 0.7, 999999, 0.5, False, refine_net) bbox = [] for i, box in enumerate(polys): poly = np.array(box).astype(np.int32).reshape((-1)) bbox.append([poly[0] - 3, poly[1] - 5, poly[2], poly[5] + 5]) file_utils.saveResult(path, image[:, :, ::-1], polys, dirname="Detect_result/") bbox.sort(key=sorting_key) return bbox def copyStateDict(self, state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict
if __name__ == '__main__': # synthtextloader = Synth80k('/home/jiachx/publicdatasets/SynthText/SynthText', target_size=768, viz=True, debug=True) # train_loader = torch.utils.data.DataLoader( # synthtextloader, # batch_size=1, # shuffle=False, # num_workers=0, # drop_last=True, # pin_memory=True) # train_batch = iter(train_loader) # image_origin, target_gaussian_heatmap, target_gaussian_affinity_heatmap, mask = next(train_batch) from craft import CRAFT from torchutil import copyStateDict net = CRAFT(freeze=True) net.load_state_dict(copyStateDict(torch.load('/ic15_iter_1300.pth'))) net = net.cuda() net = torch.nn.DataParallel(net) net.eval() dataloader = ICDAR2015(net, '/icdar2015/icdar2015train', target_size=640, viz=True) train_loader = torch.utils.data.DataLoader(dataloader, batch_size=1, shuffle=False, num_workers=0, drop_last=True, pin_memory=True) total = 0 total_sum = 0
class CraftNet(object): def __init__(self, ocrObj): self.net = CRAFT() print('Loading weights from checkpoint (' + trained_model + ')') if isCuda: self.net.load_state_dict(copyStateDict(torch.load(trained_model))) else: self.net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu'))) if isCuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() self.jsonFile = defaultdict(dict) self.ocrObj = ocrObj def test_net(self, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = self.net(x) # make score and link map score_text = y[0,:,:,0].cpu().data.numpy() score_link = y[0,:,:,1].cpu().data.numpy() # refine link if refine_net is not None: with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = y_refiner[0,:,:,0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) # if show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text def evaluateBB(self, image_path): print(image_path) print(os.getcwd()) image = imgproc.loadImage(image_path) imageCpy = image t = time.time() tnew = t bboxes, polys, score_text = self.test_net(image, text_thresholdVal, link_thresholdVal, low_textVal, isCuda, polyVal) deltaTime = time.time() - tnew words = [] # # save image with BB # filename, file_ext = os.path.splitext(os.path.basename(image_path)) # real_folder = result_folder + '/' + image_path.replace('images', '').replace(filename + file_ext, '') # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=real_folder) if(isTest): curImg = { "BBs" : defaultdict(dict), "pretrained" : "MLT", "procTime" : deltaTime, "OCR" : "CRNN", } for i in range(len(polys)): if(saveResult): cv2.rectangle(imageCpy, (int(polys[i][0][0]), int(polys[i][0][1])), (int(polys[i][1][0]), int(polys[i][2][1])), (255,0,0), 2) tnew = time.time() # incorrect, correct = self.ocrObj.getString(image, polys[i]) # distTime = time.time() - tnew # print(incorrect, correct) # tnew = time.time() incorrect, correct = self.ocrObj.getStringnGram(image, polys[i]) nTime = time.time() - tnew if(correct is not None and saveResult): cv2.putText(imageCpy, correct, (int(polys[i][0][0]), int(polys[i][0][1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),2) words.append((incorrect,correct)) if(includeTesseract): tnew = time.time() incTess, corrTess = tesseractOCR.getStringnGram(image, polys[i]) tessTime = time.time() - tnew if(isTest): if(includeTesseract): curImg["BBs"][i] = { "BB" : polys[i].tolist(), "strings" : incorrect, "stringsCorrect" : correct, "ocrTime" : nTime, "stringsTess": incTess, "stringsCorrectTess": corrTess, "ocrTimeTess": tessTime } else: curImg["BBs"][i] = { "BB" : polys[i].tolist(), "strings" : incorrect, "stringsCorrect" : correct, "ocrTime" : nTime, } if(isTest): name, folder = getNameAndFolder(image_path) self.jsonFile[folder][name] = curImg with open("./CRAFT-pytorch-master/stats.json", "w") as write_file: json.dump(self.jsonFile, write_file, sort_keys=True, indent=4) if(saveResult): name, folder = getNameAndFolder(image_path) imageCpy = cv2.cvtColor(imageCpy, cv2.COLOR_BGR2RGB) if(not os.path.exists("./result/edited/"+folder)): os.makedirs("./result/edited/"+folder) cv2.imwrite("./result/edited/"+folder+"/"+name, imageCpy)# + image_path # return polys corr = "" incorr = "" for w in words: if(w[1] != None): corr.join(w[1] + " ") if(w[0] != None): incorr.join(w[0] + " ") return self.evaluateResponse(curImg["BBs"],image) def getQuadrant(self, bb, image): shape = image.shape newRect = [bb[0][0], bb[0][1], bb[1][0], bb[2][1]] # cv2.rectangle(image, (int(newRect[0]), int(newRect[1])), (int(newRect[2]), int(newRect[3])), (255,0,0), 2) xPt = newRect[0] + (newRect[2]-newRect[0])/2 yPt = newRect[1] + (newRect[3]-newRect[1])/2 xQuad = 0 yQuad = 0 if(0 <= xPt < shape[1]/3): xQuad = 1 elif(shape[1]/3 <= xPt < shape[1]*2/3): xQuad = 2 else: xQuad = 3 if(0 <= yPt < shape[0]/3): yQuad = 0 elif(shape[0]/3 <= yPt < shape[0]*2/3): yQuad =1 else: yQuad = 2 # cv2.putText(image, str(xQuad + yQuad*3), (int(newRect[0]), int(newRect[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2) return xQuad + yQuad*3 def evaluateResponse(self,bbValues,image): gridWords = { 1 : [], 2 : [], 3 : [], 4 : [], 5 : [], 6 : [], 7 : [], 8 : [], 9 : [] } words = 0 threeWords = [("",0),("",0),("",0)] full = False for i in bbValues: if(bbValues[i]["stringsCorrect"] != None): words += 1 gridWords[self.getQuadrant(bbValues[i]["BB"],image)].append(bbValues[i]["stringsCorrect"]) found = False j = 0 while not found and j < len(threeWords): if((threeWords[j][1] < self.getArea(bbValues[i]["BB"]) and full) or threeWords[j][1] == 0): found = True threeWords[j] = (bbValues[i]["stringsCorrect"], self.getArea(bbValues[i]["BB"])) if(j == len(threeWords)-1): full = True j += 1 dictionary = { "grid":{ "Top Left" : gridWords[1], "Top" : gridWords[2], "Top Right" : gridWords[3], "Center Left" : gridWords[4], "Center" : gridWords[5], "Center Right" : gridWords[6], "Bottom Left" : gridWords[7], "Bottom" : gridWords[8], "Bottom Right" : gridWords[9] }, "threeWords":[threeWords[0][0],threeWords[1][0],threeWords[2][0]], "newWords":words } # cv2.imshow("gigi",image) # cv2.waitKey(0) return dictionary def getArea(self, bb): newRect = [bb[0][0], bb[0][1], bb[1][0], bb[2][1]] return (newRect[2]-newRect[0])*(newRect[3]-newRect[1])
if __name__ == '__main__': # synthtextloader = Synth80k('/home/jiachx/publicdatasets/SynthText/SynthText', target_size=768, viz=True, debug=True) # train_loader = torch.utils.data.DataLoader( # synthtextloader, # batch_size=1, # shuffle=False, # num_workers=0, # drop_last=True, # pin_memory=True) # train_batch = iter(train_loader) # image_origin, target_gaussian_heatmap, target_gaussian_affinity_heatmap, mask = next(train_batch) from craft import CRAFT from torchutil import copyStateDict net = CRAFT(freeze=True) net.load_state_dict( copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) net = net.cuda() net = torch.nn.DataParallel(net) net.eval() dataloader = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size=768, viz=True) train_loader = torch.utils.data.DataLoader(dataloader, batch_size=1, shuffle=False, num_workers=0, drop_last=True, pin_memory=True) total = 0 total_sum = 0
class Ocr: def __init__(self): super().__init__() manager = Manager() self.send = manager.list() self.date = manager.list() self.quote = manager.list() self.number = manager.list() self.header = manager.list() self.sign = manager.list() self.device = torch.device('cpu') state_dict = torch.load( '/home/dung/Project/Python/ocr/craft_mlt_25k.pth') if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v self.craft = CRAFT() self.craft.load_state_dict(new_state_dict) self.craft.to(self.device) self.craft.eval() self.craft.share_memory() self.config = Cfg.load_config_from_name('vgg_transformer') self.config[ 'weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA' self.config['device'] = 'cpu' self.config['predictor']['beamsearch'] = False self.weights = '/home/dung/Documents/transformerocr.pth' # self.model, self.vocab = build_model(self.config) def predict(self, model, vocab, seq, key, idx, img): img = process_input(img, self.config['dataset']['image_height'], self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width']) img = img.to(self.config['device']) with torch.no_grad(): src = model.cnn(img) memory = model.transformer.forward_encoder(src) translated_sentence = [[1] * len(img)] max_length = 0 while max_length <= 128 and not all( np.any(np.asarray(translated_sentence).T == 2, axis=1)): tgt_inp = torch.LongTensor(translated_sentence).to(self.device) output = model.transformer.forward_decoder(tgt_inp, memory) output = output.to('cpu') values, indices = torch.topk(output, 5) indices = indices[:, -1, 0] indices = indices.tolist() translated_sentence.append(indices) max_length += 1 del output translated_sentence = np.asarray(translated_sentence).T s = translated_sentence[0].tolist() s = vocab.decode(s) seq[idx] = s # print(time.time() - time1) def process(self, craft, seq, key, sub_img): img_resized, target_ratio, size_heatmap = resize_aspect_ratio( sub_img, 2560, interpolation=cv2.INTER_LINEAR, mag_ratio=1.) ratio_h = ratio_w = 1 / target_ratio x = normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = x.unsqueeze(0) # [c, h, w] to [b, c, h, w] x = x.to(self.device) y, feature = craft(x) score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() boxes, polys = getDetBoxes(score_text, score_link, text_threshold=0.7, link_threshold=0.4, low_text=0.4, poly=False) boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] result = [] for i, box in enumerate(polys): poly = np.array(box).astype(np.int32).reshape((-1)) result.append(poly) horizontal_list, free_list = group_text_box(result, slope_ths=0.8, ycenter_ths=0.5, height_ths=1, width_ths=1, add_margin=0.1) # horizontal_list = [i for i in horizontal_list if i[0] > 0 and i[1] > 0] min_size = 20 if min_size: horizontal_list = [ i for i in horizontal_list if max(i[1] - i[0], i[3] - i[2]) > 10 ] free_list = [ i for i in free_list if max(diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size ] seq[:] = [None] * len(horizontal_list) model, vocab = build_model(self.config) model.load_state_dict( torch.load(self.weights, map_location=torch.device('cpu'))) for i, ele in enumerate(horizontal_list): ele = [0 if i < 0 else i for i in ele] img = sub_img[ele[2]:ele[3], ele[0]:ele[1], :] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img.astype(np.uint8)) p = threading.Thread(target=self.predict, args=(model, vocab, seq, key, i, img)) p.start() p.join() # print(time.time() - time1) def forward(self, img, rs): # time1 = time.time() for key, v in rs.items(): x0, y0, x1, y1 = v if key == 'send': p = mp.Process(target=self.process, args=( self.craft, self.send, key, img[y0:y1, x0:x1, :], )) elif key == 'date': p = mp.Process(target=self.process, args=( self.craft, self.date, key, img[y0:y1, x0:x1, :], )) elif key == 'quote': p = mp.Process(target=self.process, args=( self.craft, self.date, key, img[y0:y1, x0:x1, :], )) elif key == 'number': p = mp.Process(target=self.process, args=( self.craft, self.date, key, img[y0:y1, x0:x1, :], )) elif key == 'header': p = mp.Process(target=self.process, args=( self.craft, self.date, key, img[y0:y1, x0:x1, :], )) elif key == 'sign': p = mp.Process(target=self.process, args=( self.craft, self.date, key, img[y0:y1, x0:x1, :], )) p.start() p.join() return self.send[:], self.date[:], self.quote[:], self.number[:], self.header[:], self.sign[:]
render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text if __name__ == '__main__': # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict( copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if args.refine: from refinenet import RefineNet
class TextDetector: def __init__(self): #Parameters self.canvas_size = 1280 self.mag_ratio = 1.5 self.text_threshold = 0.7 self.low_text = 0.4 self.link_threshold = 0.4 self.refine = False self.refiner_model = '' self.poly = False self.cuda = True self.net = CRAFT() if self.cuda: self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth'))) else: self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth', map_location='cpu'))) if self.cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) self.net.eval() # LinkRefiner self.refine_net = None if self.refine: from refinenet import RefineNet self.refine_net = RefineNet() if self.cuda: self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model))) self.refine_net = self.refine_net.cuda() self.refine_net = torch.nn.DataParallel(self.refine_net) else: self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model, map_location='cpu'))) self.refine_net.eval() self.poly = True def detect(self, image): # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, self.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=self.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if self.cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = self.net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # refine link if self.refine_net is not None: with torch.no_grad(): y_refiner = self.refine_net(y, feature) score_link = y_refiner[0, :, :, 0].cpu().data.numpy() # Post-processing boxes, _ = craft_utils.getDetBoxes(score_text, score_link, self.text_threshold, self.link_threshold, self.low_text, self.poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) toRet = [] for box in boxes: toRet.append(box2xyxy(box, image.shape[0: 2])) return toRet