def __init__(self): super().__init__() manager = Manager() self.send = manager.list() self.date = manager.list() self.quote = manager.list() self.number = manager.list() self.header = manager.list() self.sign = manager.list() self.device = torch.device('cpu') state_dict = torch.load( '/home/dung/Project/Python/ocr/craft_mlt_25k.pth') if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v self.craft = CRAFT() self.craft.load_state_dict(new_state_dict) self.craft.to(self.device) self.craft.eval() self.craft.share_memory() self.config = Cfg.load_config_from_name('vgg_transformer') self.config[ 'weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA' self.config['device'] = 'cpu' self.config['predictor']['beamsearch'] = False self.weights = '/home/dung/Documents/transformerocr.pth'
def test(modelpara): # load net net = CRAFT() # initialize print('Loading weights from checkpoint {}'.format(modelpara)) if args.cuda: net.load_state_dict(copyStateDict(torch.load(modelpara))) else: net.load_state_dict(copyStateDict(torch.load(modelpara, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' #cv2.imwrite(mask_file, score_text) file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def __init__(self, ocrObj): self.net = CRAFT() print('Loading weights from checkpoint (' + trained_model + ')') if isCuda: self.net.load_state_dict(copyStateDict(torch.load(trained_model))) else: self.net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu'))) if isCuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() self.jsonFile = defaultdict(dict) self.ocrObj = ocrObj
def loadModel(self, device="cuda", is_refine=True, trained_model=os.path.join(CRAFT_DIR, 'weights/craft_mlt_25k.pth'), refiner_model=os.path.join( CRAFT_DIR, 'weights/craft_refiner_CTW1500.pth')): """ TODO: describe method """ is_cuda = device == "cuda" self.is_cuda = is_cuda # load net self.net = CRAFT() # initialize print('Loading weights from checkpoint (' + trained_model + ')') if is_cuda: self.net.load_state_dict(copyStateDict(torch.load(trained_model))) else: self.net.load_state_dict( copyStateDict(torch.load(trained_model, map_location='cpu'))) if is_cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() # LinkRefiner self.refine_net = None if is_refine: from refinenet import RefineNet self.refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + refiner_model + ')') if is_cuda: self.refine_net.load_state_dict( copyStateDict(torch.load(refiner_model))) self.refine_net = self.refine_net.cuda() self.refine_net = torch.nn.DataParallel(self.refine_net) else: self.refine_net.load_state_dict( copyStateDict(torch.load(refiner_model, map_location='cpu'))) self.refine_net.eval() self.is_poly = True
def main(trained_model='weights/craft_mlt_25k.pth', text_threshold=0.7, low_text=0.4, link_threshold=0.4, cuda=True, canvas_size=1280, mag_ratio=1.5, poly=False, show_time=False, test_folder='/data/', refine=True, refiner_model='weights/craft_refiner_CTW1500.pth'): # if __name__ == '__main__': # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + trained_model + ')') if cuda: net.load_state_dict(copyStateDict(torch.load(trained_model))) else: net.load_state_dict(copyStateDict(torch.load(trained_model, map_location='cpu'))) if cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + refiner_model + ')') if cuda: refine_net.load_state_dict(copyStateDict(torch.load(refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict(copyStateDict(torch.load(refiner_model, map_location='cpu'))) refine_net.eval() poly = True t = time.time() # load data image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' cv2.imwrite(mask_file, score_text) final_img = file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def main(): # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if args.refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') if args.cuda: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) refine_net.eval() args.poly = True t = time.time() print(image_list) # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' cv2.imwrite(mask_file, score_text) file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) # print("elapsed time : {}s".format(time.time() - t))
def __init__(self): # model settings # self.trained_model = 'model/craft_mlt_25k.pth' self.text_threshold = 0.7 self.low_text = 0.4 self.link_threshold = 0.4 self.cuda = True self.canvas_size = 1280 self.mag_ratio = 1.5 self.poly = True self.show_time = False self.video_folder = 'input/' self.refine = False self.refiner_model = 'model/craft_refiner_CTW1500.pth' self.interpolation = cv2.INTER_LINEAR #import model self.net = CRAFT() # initialize
def test(image, epoch, index, cvt=False): image = image print('input image shape {}'.format(image.shape)) checkpoint = torch.load('/root/data/test_param/{}_{}.pth'.format( epoch, index)) net = CRAFT().cuda() net.load_state_dict(copyStateDict(checkpoint['model_state_dict'])) #이미지 리사이징 등등 #했다고 치고 진행 image = normalizeMeanVariance(image) image = cv2.resize(image, (768, 768), interpolation=cv2.INTER_LINEAR) x = torch.from_numpy(image).permute(2, 0, 1) x = Variable(x.unsqueeze(0).type(torch.FloatTensor)) x = x.cuda() print(x.size()) with torch.no_grad(): y, _ = net(x) pred_region = y[0, :, :, 0].cpu().data.numpy() pred_affinity = y[0, :, :, 1].cpu().data.numpy() print(type(pred_region)) print(pred_region.shape) # cvt == True -> Region, Affinity score H x W x C # cvt == False -> Region, Affinity score H x W if cvt: pred_region = Gray2RGB(pred_region) pred_affinity = Gray2RGB(pred_affinity) return pred_region, pred_affinity
def __init__(self, cuda=True): self.cuda = cuda for k, v in config_craft.items(): setattr(self, k, v) self.net = CRAFT() print(f'Loading weights from checkpoint ({self.trained_model})') if self.cuda: self.net.load_state_dict( copyStateDict(torch.load(self.trained_model))) else: self.net.load_state_dict( copyStateDict( torch.load(self.trained_model, map_location='cpu'))) if self.cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() # LinkRefiner self.refine_net = None if self.refine: from refinenet import RefineNet self.refine_net = RefineNet() print( f'Loading weights of refiner from checkpoint ({self.refiner_model})' ) if self.cuda: self.refine_net.load_state_dict( copyStateDict(torch.load(self.refiner_model))) self.refine_net = self.refine_net.cuda() self.refine_net = torch.nn.DataParallel(self.refine_net) else: self.refine_net.load_state_dict( copyStateDict( torch.load(self.refiner_model, map_location='cpu'))) self.refine_net.eval() self.poly = True
def __init__(self): self.trained_model = '../chinese-ocr/weights/craft_mlt_25k.pth' self.text_threshold = 0.75 self.low_text = 0.6 self.link_threshold = 0.9 self.cuda = True self.canvas_size = 1280 self.mag_ratio = 1.5 self.poly = False self.show_time = False self.net = CRAFT() self.net.load_state_dict( copy_state_dict(torch.load(self.trained_model))) self.net = self.net.cuda() cudnn.benchmark = False self.net.eval()
def get_detector(trained_model, device='cpu'): net = CRAFT() if device == 'cpu': net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) else: net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) net = torch.nn.DataParallel(net).to(device) cudnn.benchmark = False net.eval() return net
def load_detection_model(): parser = argparse.ArgumentParser(description='CRAFT Text Detection') parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') args = parser.parse_args(["--trained_model=./models/craft_mlt_25k.pth","--refine", "--refiner_model=./models/craft_refiner_CTW1500.pth"]) net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if args.refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') if args.cuda: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) refine_net.eval() # args.poly = True return net,refine_net,args
def test(modelpara): # load net net = CRAFT() # initialize print('Loading weights from checkpoint {}'.format(modelpara)) #### # if args.cuda: # net.load_state_dict(copyStateDict(torch.load(modelpara))) # else: # net.load_state_dict(copyStateDict(torch.load(modelpara, map_location='cpu'))) # # if args.cuda: # net = net.cuda() # net = torch.nn.DataParallel(net) # cudnn.benchmark = False ### device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = net.to(device) net.eval() #stop update the weight of the neuron t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) print("\n bboxes = ", bboxes, "\n poly = ", polys, "\n text = ", score_text, "\n text.shape = ", score_text.shape) # save score text filename, file_ext = os.path.splitext(os.path.basename(image_path)) mask_file = result_folder + "/res_" + filename + '_mask.jpg' #cv2.imwrite(mask_file, score_text) print("save in" + result_folder) file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def main(): dataset = ImageLoader(args) data_loader = data.DataLoader(dataset, args.batch_size, num_workers=1, shuffle=True, collate_fn=collate) model = CRAFT(pretrained=True).cuda() for i, batch_samples in enumerate(data_loader): batch_img, batch_char_label, batch_interval_label = batch_samples batch_img, _ = model(batch_img.cuda()) print(i, batch_img.shape, batch_char_label.shape, batch_interval_label.shape)
def let_load(self): self.net = CRAFT() # initialize print('Loading weights from checkpoint (' + self.trained_model + ')') if self.cuda: self.net.load_state_dict( copyStateDict(torch.load(self.trained_model))) else: self.net.load_state_dict( copyStateDict( torch.load(self.trained_model, map_location='cpu'))) if self.cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() self.refine_net = None if self.refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + self.refiner_model + ')') if self.cuda: refine_net.load_state_dict( copyStateDict(torch.load(self.refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict( copyStateDict( torch.load(self.refiner_model, map_location='cpu'))) refine_net.eval() self.poly = True t = time.time()
def LoadDetectionModel(args): net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') net.load_state_dict(copyStateDict(torch.load(args.trained_model)))#,map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() return net
def __init__(self): #Parameters self.canvas_size = 1280 self.mag_ratio = 1.5 self.text_threshold = 0.7 self.low_text = 0.4 self.link_threshold = 0.4 self.refine = False self.refiner_model = '' self.poly = False self.cuda = True self.net = CRAFT() if self.cuda: self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth'))) else: self.net.load_state_dict(copyStateDict(torch.load('CRAFT/weights/craft_mlt_25k.pth', map_location='cpu'))) if self.cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) self.net.eval() # LinkRefiner self.refine_net = None if self.refine: from refinenet import RefineNet self.refine_net = RefineNet() if self.cuda: self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model))) self.refine_net = self.refine_net.cuda() self.refine_net = torch.nn.DataParallel(self.refine_net) else: self.refine_net.load_state_dict(copyStateDict(torch.load(self.refiner_model, map_location='cpu'))) self.refine_net.eval() self.poly = True
def createModel(): net = CRAFT() weightPath = os.path.join(settings.BASE_DIR, 'CRAFT/weights/craft_mlt_25k.pth') print('Loading weights from checkpoint (' + weightPath + ')') net.load_state_dict(copyStateDict(torch.load(weightPath))) net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() return net
def main(pth_file_path): cuda = True net = CRAFT() # initialize print('Loading weights from checkpoint (' + pth_file_path + ')') if cuda: net.load_state_dict(copyStateDict(torch.load(pth_file_path))) else: net.load_state_dict(copyStateDict(torch.load(pth_file_path, map_location='cpu'))) if cuda: net = net.cuda() cudnn.benchmark = False net.eval() script_module = torch.jit.script(net) file_path_without_ext = os.path.splitext(pth_file_path)[0] output_file_path = file_path_without_ext + ".pt" script_module.save(output_file_path) print("TorchScript model created:", output_file_path)
def get_detector(trained_model, device='cpu', quantize=True): net = CRAFT() if device == 'cpu': net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) if quantize: try: torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) except: pass else: net.load_state_dict( copyStateDict(torch.load(trained_model, map_location=device))) net = torch.nn.DataParallel(net).to(device) cudnn.benchmark = False net.eval() return net
def runCraftNet(image_list): # image list is the folder containing the images args = argparse.Namespace( canvas_size=1280, cuda=False, link_threshold=0.4, low_text=0.4, mag_ratio=1.5, poly=False, refine=False, refiner_model='weights/craft_refiner_CTW1500.pth', show_time=False, test_folder='images', text_threshold=0.7, trained_model='craft_mlt_25k.pth') net = CRAFT() # initialize net.load_state_dict( copyStateDict(torch.load(args.trained_model, map_location='cpu'))) net.eval() # image_list, _, _ = file_utils.get_files(args.test_folder) t = time.time() # result_folder = './result/' # load data refine_net = None for k, image_path in enumerate(image_list): image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) # print("elapsed time : {}s ".format(time.time() - t)) img = np.array(image[:, :, ::-1]) txt = [] for i, box in enumerate(polys): poly = np.array(box).astype(np.int32).reshape((-1)) strResult = ','.join([str(p) for p in poly]) txt.append(strResult) return [img, txt]
# imgtxt = box['txt'][0] #dataloader = syndata(imgname, charbox, imgtxt) dataloader = Synth80k('/data/CRAFT-pytorch/syntext/SynthText/SynthText', target_size=768) train_loader = torch.utils.data.DataLoader(dataloader, batch_size=16, shuffle=True, num_workers=0, drop_last=True, pin_memory=True) #batch_syn = iter(train_loader) # prefetcher = data_prefetcher(dataloader) # input, target1, target2 = prefetcher.next() #print(input.size()) net = CRAFT() #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/CRAFT_net_050000.pth'))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) #net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/craft_mlt_25k.pth'))) #net.load_state_dict(copyStateDict(torch.load('vgg16_bn-6c64b313.pth'))) #realdata = realdata(net) # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size = 768) # real_data_loader = torch.utils.data.DataLoader( # realdata, # batch_size=10, # shuffle=True, # num_workers=0, # drop_last=True, # pin_memory=True) net = net.cuda() #net = CRAFT_net
t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text if __name__ == '__main__': # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict( copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval()
shuffle=False, num_workers=0, drop_last=True, pin_memory=True, ) # sample_train_loader = torch.utils.data.DataLoader( # sample_dataset, # batch_size = 32, # shuffle = True, # num_workers = 0, # drop_last = True, # pin_memory = True) scaler = torch.cuda.amp.GradScaler() net = CRAFT() optimizer = optim.Adam(net.parameters(), lr=1e-4) net = net.cuda() #DataParallel net = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda() #Distributed Parallel #net = DistributedDataParallel(net, device_ids= [0,1]).cuda() cudnn.benchmark = True scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.9)
if __name__ == '__main__': # synthtextloader = Synth80k('/home/jiachx/publicdatasets/SynthText/SynthText', target_size=768, viz=True, debug=True) # train_loader = torch.utils.data.DataLoader( # synthtextloader, # batch_size=1, # shuffle=False, # num_workers=0, # drop_last=True, # pin_memory=True) # train_batch = iter(train_loader) # image_origin, target_gaussian_heatmap, target_gaussian_affinity_heatmap, mask = next(train_batch) from craft import CRAFT from torchutil import copyStateDict net = CRAFT(freeze=True) net.load_state_dict( copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) net = net.cuda() net = torch.nn.DataParallel(net) net.eval() dataloader = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size=768, viz=True) train_loader = torch.utils.data.DataLoader(dataloader, batch_size=1, shuffle=False, num_workers=0, drop_last=True, pin_memory=True)
def test(modelpara): # load net net = CRAFT() # initialize print('Loading weights from checkpoint {}'.format(modelpara)) if args.cuda: net.load_state_dict(copyStateDict(torch.load(modelpara))) else: net.load_state_dict( copyStateDict(torch.load(modelpara, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\n') image = imgproc.loadImage(image_path) res = image.copy() # bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) gh_pred, bboxes_pred, polys_pred, size_heatmap = test_net( net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) filename, file_ext = os.path.splitext(os.path.basename(image_path)) result_dir = os.path.join(result_folder, filename) os.makedirs(result_dir, exist_ok=True) for gh_img, field in zip(gh_pred, CLASSES): img = imgproc.cvt2HeatmapImg(gh_img) img_path = os.path.join(result_dir, 'res_{}_{}.jpg'.format(filename, field)) cv2.imwrite(img_path, img) h, w = image.shape[:2] img = cv2.resize(image, size_heatmap)[::, ::, ::-1] img_path = os.path.join(result_dir, 'res_{}.jpg'.format(filename, field)) cv2.imwrite(img_path, img) # # save score text # filename, file_ext = os.path.splitext(os.path.basename(image_path)) # mask_file = result_folder + "/res_" + filename + '_mask.jpg' # cv2.imwrite(mask_file, score_text) res = cv2.resize(res, size_heatmap) for polys, field in zip(polys_pred, CLASSES): TEXT_WIDTH = 10 * len(field) + 10 TEXT_HEIGHT = 15 polys = np.int32([poly.reshape((-1, 1, 2)) for poly in polys]) res = cv2.polylines(res, polys, True, (0, 0, 255), 2) for poly in polys: poly[1, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1]] poly[2, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1] + TEXT_HEIGHT] poly[3, 0] = [ poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1] + TEXT_HEIGHT ] poly[0, 0] = [poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1]] res = cv2.fillPoly(res, polys, (224, 224, 224)) # print(poly) for poly in polys: res = cv2.putText(res, field, tuple(poly[3, 0] + [+5, -5]), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), thickness=1) res_file = os.path.join(result_dir, 'res_{}_bbox.jpg'.format(filename, field)) cv2.imwrite(res_file, res[::, ::, ::-1]) # break # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = cvt2HeatmapImg(render_img) if show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text net = CRAFT() net.load_state_dict(copyStateDict(torch.load(trained_model_path, map_location='cpu'))) net.eval() # image_path = './doc/2.jpg' # image = loadImage(image_path) # bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net) # # poly_indexes = {} # central_poly_indexes = [] # for i in range(len(polys)): # poly_indexes[i] = polys[i] # x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] + polys[i][3][0]) / 4 # y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] + polys[i][3][1]) / 4 # central_poly_indexes.append({i: [int(x_central), int(y_central)]})
def train(): # declare model net = CRAFT(input_shape=(args.canvas_size, args.canvas_size, 3)) loss_function = craft_loss() # lr decay depend on https://github.com/clovaai/CRAFT-pytorch/issues/18 lr_fn = tf.optimizers.schedules.ExponentialDecay(args.learning_rate, decay_steps=10000, decay_rate=0.8) optimizer = tf.keras.optimizers.Adam(lr_fn) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=net) manager = tf.train.CheckpointManager(checkpoint, directory=args.weight_dir, max_to_keep=10) # Create a checkpoint directory to store the checkpoints. if not os.path.exists(args.weight_dir): os.makedirs(args.weight_dir) checkpoint_dir = os.path.join(args.weight_dir, "ckpt") checkpoint_prefix = os.path.abspath(checkpoint_dir) # load dataset print("Data Set Loading ..") # train_real_data_list, test_data_list = TTLoader(args.real_data_path).get_dataset() train_real_data_list, test_data_list = CTWLoader( args.real_data_path).get_dataset() np.random.shuffle(train_real_data_list) np.random.shuffle(test_data_list) if args.use_fake: train_fake_data_list = [] # TODO np.random.shuffle(train_fake_data_list) train_generator = DataGenerator(net, { "real": train_real_data_list, "fake": train_fake_data_list }, [5, 1], args.canvas_size, args.batch_size) else: train_generator = DataGenerator(net, {"real": train_real_data_list}, [1], args.canvas_size, args.batch_size) print("Training Start ..") for idx in range(args.iterations): batch = train_generator.get_batch(args.batch_size) with tf.GradientTape() as tape: y, feature = net(batch["image"]) region = y[:, :, :, 0] affinity = y[:, :, :, 1] """ kind = "region" temp = batch[kind][0] img_temp = np.transpose([temp, temp, temp], (1, 2, 0)) * 255 cv2.imwrite("./logs/temp_%s.jpg" % kind, img_temp) """ try: loss, l_region, l_affinity, hard_bg_mask = loss_function([ batch["region"], batch["affinity"], region, affinity, batch["confidence"], batch["fg_mask"], batch["bg_mask"], args.alpha ]) except Exception as e: print(e) save_batch_images(idx, batch["image"], batch["word_box"], prefix="error_") loss, l_region, l_affinity, hard_bg_mask = loss_function([ batch["region"], batch["affinity"], region, affinity, batch["confidence"], batch["fg_mask"], batch["bg_mask"], args.alpha ]) exit() if idx % 50 == 0: save_batch_images(idx, batch["image"], batch["word_box"]) save_log(region, l_region, batch["region"], batch["fg_mask"], hard_bg_mask, "region", prefix="iter%d" % (idx + 1)) save_log(affinity, l_affinity, batch["affinity"], batch["fg_mask"], hard_bg_mask, "affinity", prefix="iter%d" % (idx + 1)) gradients = tape.gradient(loss, net.trainable_variables) optimizer.apply_gradients(zip(gradients, net.trainable_variables)) print("iteration %d, batch loss: " % (idx + 1), loss) # if (idx+1) % 100 == 0: # checkpoint.save(checkpoint_prefix) manager.save()
param_group['lr'] = lr if __name__ == '__main__': dataloader = Synth80k('/data/CRAFT-pytorch/syntext/SynthText/SynthText', target_size=768) train_loader = torch.utils.data.DataLoader(dataloader, batch_size=2, shuffle=True, num_workers=0, drop_last=True, pin_memory=True) batch_syn = iter(train_loader) net = CRAFT(freeze=True) net.load_state_dict( copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) net = net.cuda() net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3]).cuda() cudnn.benchmark = True realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size=768) real_data_loader = torch.utils.data.DataLoader(realdata, batch_size=10, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
def applyCraft(image_file): # Initialize CRAFT parameters text_threshold = 0.7 low_text = 0.4 link_threshold = 0.4 cuda = False canvas_size = 1280 mag_ratio = 1.5 # if text image present curve --> poly=true poly = False refine = False show_time = False refine_net = None trained_model_path = './app/CRAFT/craft_mlt_25k.pth' net = CRAFT() net.load_state_dict( copyStateDict(torch.load(trained_model_path, map_location='cpu'))) net.eval() image = imgproc.loadImage(image_file) poly = False refine = False show_time = False refine_net = None bboxes, polys, score_text = test_net(net, canvas_size, mag_ratio, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net) # Compute coordinate of central point in each bounding box returned by CRAFT # Purpose: easier for us to make cluster in G-DBScan step poly_indexes = {} central_poly_indexes = [] for i in range(len(polys)): poly_indexes[i] = polys[i] x_central = (polys[i][0][0] + polys[i][1][0] + polys[i][2][0] + polys[i][3][0]) / 4 y_central = (polys[i][0][1] + polys[i][1][1] + polys[i][2][1] + polys[i][3][1]) / 4 central_poly_indexes.append({i: [int(x_central), int(y_central)]}) # for i in central_poly_indexes: # print(i) # For each of these cordinates convert them to new Point instances X = [] for idx, x in enumerate(central_poly_indexes): point = Point(x[idx][0], x[idx][1], idx) X.append(point) # Cluster these central points clustered = GDBSCAN(Points(X), n_pred, 1, w_card) cluster_values = [] for cluster in clustered: sort_cluster = sorted(cluster, key=lambda elem: (elem.x, elem.y)) max_point_id = sort_cluster[len(sort_cluster) - 1].id min_point_id = sort_cluster[0].id max_rectangle = sorted(poly_indexes[max_point_id], key=lambda elem: (elem[0], elem[1])) min_rectangle = sorted(poly_indexes[min_point_id], key=lambda elem: (elem[0], elem[1])) right_above_max_vertex = max_rectangle[len(max_rectangle) - 1] right_below_max_vertex = max_rectangle[len(max_rectangle) - 2] left_above_min_vertex = min_rectangle[0] left_below_min_vertex = min_rectangle[1] if (int(min_rectangle[0][1]) > int(min_rectangle[1][1])): left_above_min_vertex = min_rectangle[1] left_below_min_vertex = min_rectangle[0] if (int(max_rectangle[len(max_rectangle) - 1][1]) < int( max_rectangle[len(max_rectangle) - 2][1])): right_above_max_vertex = max_rectangle[len(max_rectangle) - 2] right_below_max_vertex = max_rectangle[len(max_rectangle) - 1] cluster_values.append([ left_above_min_vertex, left_below_min_vertex, right_above_max_vertex, right_below_max_vertex ]) image = imgproc.loadImage(image_file) img = np.array(image[:, :, ::-1]) img = img.astype('uint8') ocr_res = [] for i, box in enumerate(cluster_values): poly = np.array(box).astype(np.int32).reshape((-1)) poly = poly.reshape(-1, 2) rect = cv2.boundingRect(poly) x, y, w, h = rect cropped = img[y:y + h, x:x + w].copy() # Preprocess cropped segment cropped = cv2.resize(cropped, None, fx=5, fy=5, interpolation=cv2.INTER_LINEAR) cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY) cropped = cv2.GaussianBlur(cropped, (3, 3), 0) cropped = cv2.bilateralFilter(cropped, 5, 25, 25) cropped = cv2.dilate(cropped, None, iterations=1) cropped = cv2.threshold(cropped, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] #cropped = cv2.threshold(cropped, 90, 255, cv2.THRESH_BINARY)[1] #cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB) ocr_res.append(pytesseract.image_to_string(cropped, lang='eng')) return ocr_res