def initialize(self): start = time.time() # self.saved_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_addKorean_synth/best_accuracy.pth' # self.craft_trained_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train/craft_mlt_25k.pth' # self.saved_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_v2/best_accuracy.pth' # self.craft_trained_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_v2/best_accuracy_craft.pth' # # official self.saved_model = './data_ocr/best_accuracy.pth' self.craft_trained_model = './data_ocr/best_accuracy_craft.pth' self.logfilepath = './data_ocr/log_ocr_result.txt' """ vocab / character number configuration """ # if self.sensitive: # self.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True self.num_gpu = torch.cuda.device_count() """ model configuration """ # detetion self.net = CRAFT(self) # initialize print('Loading detection weights from checkpoint ' + self.craft_trained_model) self.net.load_state_dict( copyStateDict( torch.load(self.craft_trained_model, map_location=self.device))) self.net = torch.nn.DataParallel(self.net).to(self.device) self.converter = AttnLabelConverter(self.character) self.num_class = len(self.converter.character) if self.rgb: self.input_channel = 3 self.model = Model(self, self.num_class) # print('model input parameters', self.imgH, self.imgW, self.num_fiducial, self.input_channel, self.output_channel, # self.hidden_size, self.num_class, self.batch_max_length) # load model self.model = torch.nn.DataParallel(self.model).to(self.device) print('Loading recognition weights from checkpoint %s' % self.saved_model) self.model.load_state_dict( torch.load(self.saved_model, map_location=self.device)) if torch.cuda.is_available(): self.model = self.model.cuda() self.net = self.net.cuda() cudnn.benchmark = False # print('Initialization Done! It tooks {:.2f} mins.\n'.format((time.time() - start) / 60)) print( 'Initialization Done! It tooks {:.2f} sec.\n'.format(time.time() - start)) return True
def text_detector(image, count): net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict( copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() # LinkRefiner refine_net = None if args.refine: from refinenet import RefineNet refine_net = RefineNet() print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') if args.cuda: refine_net.load_state_dict( copyStateDict(torch.load(args.refiner_model))) refine_net = refine_net.cuda() refine_net = torch.nn.DataParallel(refine_net) else: refine_net.load_state_dict( copyStateDict( torch.load(args.refiner_model, map_location='cpu'))) refine_net.eval() args.poly = True t = time.time() # load data # for k, image_path in enumerate(image_list): # print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') # image = imgproc.loadImage(image_path) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) # # save score text # filename, file_ext = os.path.splitext(os.path.basename(image_path)) # mask_file = result_folder + "/res_" + filename + '_mask.jpg' # cv2.imwrite(mask_file, score_text) # cropping to be used for OCR crop_object(image[:, :, ::-1], polys, result_folder, count) # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def init_craft_net(args): print("-" * 50) print("init_craft_net") net = CRAFT() # initialize print('CRAFT Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) else: net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() return net
def __init__(self, weights_path='./craft/weights/craft_mlt_25k.pth', canvas_size=1280, text_threshold=0.9, link_threshold=0.1, mag_ratio=1.5, low_text=0.5, gpu=True): self.canvas_size = canvas_size self.text_threshold = text_threshold self.link_threshold = link_threshold self.low_text = low_text self.mag_ratio = mag_ratio self.gpu = gpu self.model = CRAFT() self.model.load_state_dict( self.copy_state_dict(torch.load(weights_path, map_location='cpu'))) if self.gpu: self.model.cuda() self.model = torch.nn.DataParallel(self.model) self.model.eval()
def init_craft_net(self): """ 初始化 文本区域检测网络 """ net = CRAFT() # initialize print('CRAFT Loading weights from checkpoint (' + self.args.trained_model + ')') if self.args.cuda: net.load_state_dict(copyStateDict(torch.load(self.args.trained_model))) else: net.load_state_dict(copyStateDict(torch.load(self.args.trained_model, map_location='cpu'))) if self.args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() return net
def test_craft(objetos, dir_pth="craft_mlt_25k.pth"): net = CRAFT() # initialize print('Loading weights from checkpoint (' + dir_pth + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(dir_pth))) else: net.load_state_dict( copyStateDict(torch.load(dir_pth, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() imgs_text_obj = {} #Analizar objetos detectados for i in range(len(objetos)): #Recuperar id y imagenes del objeto imagenes = objetos[i]["objetos"] id_obj = objetos[i]["id_obj"] imgs_text_obj[id_obj] = [] #Obtener textos de cada una de las imagenes que componen el objeto crops_text = [] for img in imagenes: img_path = img["img_path"] bb = img["bb"] img_crop = recuperar_crop_imagen_bb(img_path, bb) img_crop = np.array(img_crop) image = imgproc.loadImage(img_crop) # detection bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, None) for i, bbs in enumerate(bboxes): crop = bounding_box(bbs) cropped = image[crop[0][1]:crop[1][1], crop[0][0]:crop[1][0]] crops_text.append(cropped) imgs_text_obj[id_obj] = crops_text print("elapsed time : {}s".format(time.time() - t)) return imgs_text_obj
def test_craft_prueba(ruta, DIR_PTH="craft_mlt_25k.pth"): net = CRAFT() # initialize print('Loading weights from checkpoint (' + DIR_PTH + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(DIR_PTH))) else: net.load_state_dict( copyStateDict(torch.load(DIR_PTH, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() imgs_text_obj = {} arr = os.listdir(ruta) for i, name in enumerate(arr): arr[i] = ruta + "/" + name crops_text = [] #Analizar objetos detectados for i in range(len(arr)): #Obtener textos de cada una de las imagenes que componen el objeto img_crop = recuperar_imagen_dir(arr[i]) img_crop = np.array(img_crop) image = imgproc.loadImage(img_crop) # detection bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, None) for i, bbs in enumerate(bboxes): crop = bounding_box(bbs) cropped = image[crop[0][1]:crop[1][1], crop[0][0]:crop[1][0]] crops_text.append(cropped) print("elapsed time : {}s".format(time.time() - t)) return crops_text
def init_model(weight_dir, weight_path=None, num_class=2, linear=True): # make weight_dir if it doesn't exist Path(weight_dir).mkdir(parents=True, exist_ok=True) # input: NCHW model = CRAFT(pretrained=True, num_class=num_class, linear=linear).cuda() # output: NHWC if weight_path: # pretrained_weight_path = os.path.join(weight_dir, weight_fname) model.load_state_dict(torch.load(weight_path)) model.eval() criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.001) # tweak parameters return model, criterion, optimizer
def __init__(self, img_dir, gt_dir, weights_path=None, color_flag=1, character_map=True, affinity_map=False, word_map=False, direction_map=True): # super(ICDAR2015Dataset).__init__() self.raw_dataset = ICDAR2015Dataset(img_dir, gt_dir, color_flag=color_flag) self.num_class = character_map + affinity_map + word_map + 2 * direction_map model = CRAFT(pretrained=False, num_class=self.num_class).cuda() if weights_path: model.load_state_dict(torch.load(weights_path)) model.eval() self.model = model
class OCRRecognizer: def __init__(self): self.net = None #detect self.model = None #recog self.converter = None self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.res_imagefileName = None self.opt_craft, self.opt_recog = self.setup_parser() self.args_craft = vars(self.opt_craft) self.args = vars(self.opt_recog) self.detect_time = 0.0 self.recog_time = 0.0 self.total_time = 0.0 # print("~~~~~~~~ Hyperparameters used: ~~~~~~~") # for x, y in self.args.items(): # print("{} : {}".format(x, y)) self.__dict__.update(self.args_craft) self.__dict__.update(self.args) def initialize(self): start = time.time() # self.saved_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_addKorean_synth/best_accuracy.pth' # self.craft_trained_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train/craft_mlt_25k.pth' # self.saved_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_v2/best_accuracy.pth' # self.craft_trained_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_v2/best_accuracy_craft.pth' # # official self.saved_model = './data_ocr/best_accuracy.pth' self.craft_trained_model = './data_ocr/best_accuracy_craft.pth' self.logfilepath = './data_ocr/log_ocr_result.txt' """ vocab / character number configuration """ # if self.sensitive: # self.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True self.num_gpu = torch.cuda.device_count() """ model configuration """ # detetion self.net = CRAFT(self) # initialize print('Loading detection weights from checkpoint ' + self.craft_trained_model) self.net.load_state_dict( copyStateDict( torch.load(self.craft_trained_model, map_location=self.device))) self.net = torch.nn.DataParallel(self.net).to(self.device) self.converter = AttnLabelConverter(self.character) self.num_class = len(self.converter.character) if self.rgb: self.input_channel = 3 self.model = Model(self, self.num_class) # print('model input parameters', self.imgH, self.imgW, self.num_fiducial, self.input_channel, self.output_channel, # self.hidden_size, self.num_class, self.batch_max_length) # load model self.model = torch.nn.DataParallel(self.model).to(self.device) print('Loading recognition weights from checkpoint %s' % self.saved_model) self.model.load_state_dict( torch.load(self.saved_model, map_location=self.device)) if torch.cuda.is_available(): self.model = self.model.cuda() self.net = self.net.cuda() cudnn.benchmark = False # print('Initialization Done! It tooks {:.2f} mins.\n'.format((time.time() - start) / 60)) print( 'Initialization Done! It tooks {:.2f} sec.\n'.format(time.time() - start)) return True def setup_parser(self): """ Sets up an argument parser """ parser_craft = argparse.ArgumentParser( description='CRAFT Text Detection') parser_craft.add_argument('--craft_trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') parser_craft.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') parser_craft.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') parser_craft.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') parser_craft.add_argument('--cuda', default=False, type=str2bool, help='Use cuda for inference') parser_craft.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') parser_craft.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') parser_craft.add_argument('--poly', default=False, action='store_true', help='enable polygon type') parser_craft.add_argument('--show_time', default=False, action='store_true', help='show processing time') parser_craft.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') parser_craft.add_argument('--result_folder', default='./data_ocr/', type=str, help='result folder path') parser_craft.add_argument('--refine', default=False, action='store_true', help='enable link refiner') parser_craft.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') args_craft = parser_craft.parse_args() parser_recog = argparse.ArgumentParser(description='ocr recognition') parser_recog.add_argument( '--image_path', help='path to image_folder or image_file which contains text images' ) parser_recog.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser_recog.add_argument('--batch_size', type=int, default=1, help='input batch size') parser_recog.add_argument('--saved_model', help="path to saved_model to evaluation") parser_recog.add_argument('--logfilepath', help="path to log to demo") """ Data processing """ parser_recog.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') parser_recog.add_argument('--imgH', type=int, default=32, help='the height of the input image') parser_recog.add_argument('--imgW', type=int, default=100, help='the width of the input image') parser_recog.add_argument('--rgb', action='store_true', help='use rgb input') # parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') parser_recog.add_argument( '--character', type=str, default= '0123456789abcdefghijklmnopqrstuvwxyz가각간갇갈감갑값갓강갖같갚갛개객걀걔거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀귓규균귤그극근글긁금급긋긍기긴길김깅깊까깍깎깐깔깜깝깡깥깨꺼꺾껌껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꾼꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냇냉냐냥너넉넌널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐댓더덕던덜덟덤덥덧덩덮데델도독돈돌돕돗동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿링마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몬몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭘뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벨벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브븐블비빌빔빗빚빛빠빡빨빵빼뺏뺨뻐뻔뻗뼈뼉뽑뿌뿐쁘쁨사삭산살삶삼삿상새색샌생샤서석섞선설섬섭섯성세섹센셈셋셔션소속손솔솜솟송솥쇄쇠쇼수숙순숟술숨숫숭숲쉬쉰쉽슈스슨슬슴습슷승시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액앨야약얀얄얇양얕얗얘어억언얹얻얼엄업없엇엉엊엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷옹와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡잣장잦재쟁쟤저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쩔쩜쪽쫓쭈쭉찌찍찢차착찬찮찰참찻창찾채책챔챙처척천철첩첫청체쳐초촉촌촛총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칫칭카칸칼캄캐캠커컨컬컴컵컷케켓켜코콘콜콤콩쾌쿄쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택탤터턱턴털텅테텍텔템토톤톨톱통퇴투툴툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔팝패팩팬퍼퍽페펜펴편펼평폐포폭폰표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홈홉홍화확환활황회획횟횡효후훈훌훔훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘', help='character label') parser_recog.add_argument('--sensitive', action='store_true', help='for sensitive character mode') parser_recog.add_argument( '--PAD', action='store_true', help='whether to keep ratio then pad for image resize') """ Model Architecture """ parser_recog.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') parser_recog.add_argument( '--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') parser_recog.add_argument( '--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') parser_recog.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') args_recog = parser_recog.parse_args() return args_craft, args_recog def apply(self, image, timestamp, save_img=False): #coordinate : list pred, timestamp = detect_ocr(self, image, timestamp, save_img) return pred, timestamp
class Localization: def __init__(self, weights_path='./craft/weights/craft_mlt_25k.pth', canvas_size=1280, text_threshold=0.9, link_threshold=0.1, mag_ratio=1.5, low_text=0.5, gpu=True): self.canvas_size = canvas_size self.text_threshold = text_threshold self.link_threshold = link_threshold self.low_text = low_text self.mag_ratio = mag_ratio self.gpu = gpu self.model = CRAFT() self.model.load_state_dict( self.copy_state_dict(torch.load(weights_path, map_location='cpu'))) if self.gpu: self.model.cuda() self.model = torch.nn.DataParallel(self.model) self.model.eval() @staticmethod def copy_state_dict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict def predict(self, image): img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, self.canvas_size, interpolation=cv2.INTER_AREA, mag_ratio=self.mag_ratio) ratio_h = ratio_w = 1 / target_ratio x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if self.gpu: x = x.cuda() y, _ = self.model(x) score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() boxes = craft_utils.getDetBoxes(score_text, score_link, self.text_threshold, self.link_threshold, self.low_text) boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) boxes = np.reshape(boxes, newshape=(-1, 8)).astype(np.int) print(boxes) return boxes
# print("\nrun time (detection) : {:.2f} {:.2f} s".format(detect_time, args.detect_time )) return detection_list, image[:, :, ::-1], polys if __name__ == '__main__': args = Config() if not os.path.isdir(args.result_folder): os.mkdir(args.result_folder) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # load net net = CRAFT() # initialize print('Loading detection weights from checkpoint ' + args.craft_trained_model) # if args.cuda: # net.load_state_dict(copyStateDict(torch.load(args.craft_trained_model))) # else: # # net.load_state_dict(copyStateDict(torch.load(args.craft_trained_model, map_location='cpu'))) net.load_state_dict( copyStateDict(torch.load(args.craft_trained_model, map_location=device))) if args.cuda: net = net.cuda()
def export_task3(img_dir, classifier_weight_path, detector_weight_path, alphabet=datasets.ALPHANUMERIC, submit_dir='submit', expand_factor=3, thresh=0.3, verbose=True, gt_path=None, min_side_length=16): # make submit dirs Path(submit_dir).mkdir(parents=True, exist_ok=True) # instantiate models detector = CRAFT(pretrained=True, num_class=2, linear=True).cuda() detector.load_state_dict(torch.load(detector_weight_path)) detector.eval() classifier = CharClassifier(num_classes=len(alphabet)).double().cuda() classifier.load_state_dict(torch.load(classifier_weight_path)) classifier.eval() # prepare results filename submit_name = f"task3.txt" submit_path = os.path.join(submit_dir, submit_name) with open(submit_path, 'w') as f: # clear contents first f.write("") # read gt file if gt_path: with open(gt_path, 'r') as f: gt = f.read() else: gt = "" # get predictions for all images img_names = get_filenames(img_dir, ['.png'], recursive=False) for img_name in img_names: res_name = img_name[5:] img_path = os.path.join(img_dir, img_name) # get image img = PIL.Image.open(img_path).convert('RGB') # resize if too small (ie when side < 16) w, h = img.width, img.height if min(h, w) < min_side_length: if h <= w: ratio = min_side_length / h elif w < h: ratio = min_side_length / w new_size = (int(ratio * w), int(ratio * h)) img = img.resize(new_size) # convert PIL.Image to torch.Tensor (1, C, H, W) img = ((np.array(img) / 255.0) - 0.5) * 2 img = torch.from_numpy(img).permute(2, 0, 1)[None, ...].cuda() # get heatmaps with torch.no_grad(): output, _ = detector(img.float()) char_heatmap = output[0, :, :, :1] # (H, W, 1) cropped_chars, _, _, _ = map_to_crop(img[0], char_heatmap, thresh=thresh, expand_factor=expand_factor) if cropped_chars is not None: # get character predicitons with torch.no_grad(): onehots = classifier(cropped_chars.double()) chars = onehot_to_chars(onehots, alphabet) # string has same order as charBBs string = "" for c in chars: if c is None: string += '?' elif c == '"': string += r'\"' elif c == '\\': string += r'\\' else: string += c else: string = "" # write to file contents = f'{res_name}, "{string}"\n' if verbose: print(contents[:-1], end='') if gt: pattern = re.compile(re.escape(img_name) + r',\ "(.*)"') matches = re.search(pattern, gt) if matches: print(f'\tgt: "{matches.group(1)}"', end='') print('') with open(submit_path, 'a') as f: f.write(contents) print("Done!")
def export_task1_4(test_img_dir, test_gt_dir, classifier_weight_path, detector_weight_path, alphabet=datasets.ALPHANUMERIC, submit_dir='submit', thresh=0.3): testset = datasets.ICDAR2013TestDataset(test_gt_dir, test_img_dir) max_size = (700, 700) # make submit dirs submit_task1_dirpath = os.path.join(submit_dir, 'text_localization_output') submit_task4_dirpath = os.path.join(submit_dir, 'end_to_end_output') Path(submit_task1_dirpath).mkdir(parents=True, exist_ok=True) Path(submit_task4_dirpath).mkdir(parents=True, exist_ok=True) # instantiate models detector = CRAFT(pretrained=True, num_class=2, linear=True).cuda() detector.load_state_dict(torch.load(detector_weight_path)) detector.eval() classifier = CharClassifier(num_classes=len(alphabet)).double().cuda() classifier.load_state_dict(torch.load(classifier_weight_path)) classifier.eval() # get predictions for all test images for img, _, _, pil_img in testset: filename = os.path.basename(pil_img.filename) orig_w, orig_h = pil_img.width, pil_img.height # resize when # of pixels exceeds size if (max_size is not None) and ((orig_w * orig_h) > (max_size[0] * max_size[1])): """ max_w, max_h = max_size if orig_w > orig_h: ratio = max_w / orig_w elif orig_h > orig_w: ratio = max_h / orig_h new_w, new_h = int(ratio*orig_w), int(ratio*orig_h) print(f"Resizing ({orig_w}, {orig_h}) to ({new_w}, {new_h})", end='') img = img.resize((new_w, new_h)) """ img = img.permute(1, 2, 0).cpu().numpy() img = cv2.resize(img, dsize=max_size) # HWC img = torch.from_numpy(img).permute(2, 0, 1) # CHW new_w, new_h = max_size else: new_w, new_h = orig_w, orig_h print(filename + '... ', end='') wordBBs, words = recognizer(img, detector, classifier, alphabet, thresh=thresh) if wordBBs is None: # write blank files print("No characters found. Making blank file... ", end='') submit_name = f"res_{filename[:-4]}.txt" submit_path = os.path.join(submit_task1_dirpath, submit_name) with open(submit_path, 'w') as f: f.write("") submit_path = os.path.join(submit_task4_dirpath, submit_name) with open(submit_path, 'w') as f: f.write("") continue unsquish_factor = np.array([orig_w / new_w, orig_h / new_h]).reshape(1, 1, 2) wordBBs = unsquish_factor * wordBBs print('Formatting submission... ', end='') # create submission contents for task 1 (text localization) xymin = np.min(wordBBs, axis=1).astype('int32') xymax = np.max(wordBBs, axis=1).astype('int32') xmin, ymin = xymin[:, 0], xymin[:, 1] xmax, ymax = xymax[:, 0], xymax[:, 1] contents = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax} submission = pd.DataFrame(contents) # write to file submit_name = f"res_{filename[:-4]}.txt" submit_path = os.path.join(submit_task1_dirpath, submit_name) submission.to_csv(submit_path, header=False, index=False) # create submission contents for task 4 (end to end recognition) x1, y1 = xmin, ymin x2, y2 = xmax, ymin x3, y3 = xmax, ymax x4, y4 = xmin, ymax contents = { 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'x3': x3, 'y3': y3, 'x4': x4, 'y4': y4, 'transcription': words } submission = pd.DataFrame(contents) submit_path = os.path.join(submit_task4_dirpath, submit_name) submission.to_csv(submit_path, header=False, index=False) print(f'Done.') print("Done!")
t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, ret_score_text if __name__ == '__main__': # load net net = CRAFT() # initialize print('Loading weights from checkpoint (' + args.trained_model + ')') net.load_state_dict(copyStateDict(torch.load(args.trained_model))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data for k, image_path in enumerate(image_list):