def recognition(image): converter = AttnLabelConverter(args.character) args.num_class = len(converter.character) if args.rgb: args.input_channel = 3 model = Model(args) model = torch.nn.DataParallel(model).to(device) model.load_state_dict(torch.load(args.model_dir, map_location=device)) transformer = resizeNormalize((100, 32)) #Convert RGB images to Gray Scale which is neccesary for our convolution layers. if image.mode == 'RGB': image = image.convert('L') image = transformer(image) batch_size = image.size(0) if torch.cuda.is_available(): image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) model.eval() length_for_pred = torch.IntTensor([args.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, args.batch_max_length + 1).fill_(0).to(device) preds = model(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) text_prediction = preds_str[0].replace("[s]", "") return text_prediction
class PytorchNet: def __init__(self, model_path, gpu_id=None): """ 初始化模型 :param model_path: 模型地址 :param gpu_id: 在哪一块gpu上运行 """ checkpoint = torch.load(model_path) print(f"load {checkpoint['epoch']} epoch params") config = checkpoint['config'] alphabet = config['dataset']['alphabet'] if gpu_id is not None and isinstance( gpu_id, int) and torch.cuda.is_available(): self.device = torch.device("cuda:%s" % gpu_id) else: self.device = torch.device("cpu") print('device:', self.device) self.transform = [] for t in config['dataset']['train']['dataset']['args']['transforms']: if t['type'] in ['ToTensor', 'Normalize']: self.transform.append(t) self.transform = get_transforms(self.transform) self.gpu_id = gpu_id img_h, img_w = 32, 100 for process in config['dataset']['train']['dataset']['args'][ 'pre_processes']: if process['type'] == "Resize": img_h = process['args']['img_h'] img_w = process['args']['img_w'] break self.img_w = img_w self.img_h = img_h self.img_mode = config['dataset']['train']['dataset']['args'][ 'img_mode'] self.alphabet = alphabet img_channel = 3 if config['dataset']['train']['dataset']['args'][ 'img_mode'] != 'GRAY' else 1 if config['arch']['args']['prediction']['type'] == 'CTC': self.converter = CTCLabelConverter(config['dataset']['alphabet']) elif config['arch']['args']['prediction']['type'] == 'Attn': self.converter = AttnLabelConverter(config['dataset']['alphabet']) self.net = get_model(img_channel, len(self.converter.character), config['arch']['args']) self.net.load_state_dict(checkpoint['state_dict']) # self.net = torch.jit.load('crnn_lite_gpu.pt') self.net.to(self.device) self.net.eval() sample_input = torch.zeros( (2, img_channel, img_h, img_w)).to(self.device) self.net.get_batch_max_length(sample_input) def predict(self, img_path, model_save_path=None): """ 对传入的图像进行预测,支持图像地址和numpy数组 :param img_path: 图像地址 :return: """ assert os.path.exists(img_path), 'file is not exists' img = self.pre_processing(img_path) tensor = self.transform(img) tensor = tensor.unsqueeze(dim=0) tensor = tensor.to(self.device) preds, tensor_img = self.net(tensor) preds = preds.softmax(dim=2).detach().cpu().numpy() # result = decode(preds, self.alphabet, raw=True) # print(result) result = self.converter.decode(preds) if model_save_path is not None: # 输出用于部署的模型 save(self.net, tensor, model_save_path) return result, tensor_img def pre_processing(self, img_path): """ 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度 :param img_path: 图片地址 :return: """ img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0) if self.img_mode == 'RGB': img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w = img.shape[:2] ratio_h = float(self.img_h) / h new_w = int(w * ratio_h) if new_w < self.img_w: img = cv2.resize(img, (new_w, self.img_h)) step = np.zeros((self.img_h, self.img_w - new_w, img.shape[-1]), dtype=img.dtype) img = np.column_stack((img, step)) else: img = cv2.resize(img, (self.img_w, self.img_h)) return img
def demo(opt): """ model configuration """ lists = [] #목적지라고 생각하는 사진에서 인식한 text를 담을 배열 converter = AttnLabelConverter(opt.character) #ATTN opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) #model.py의 Model import print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) #파라미터 값 정보 출력 model = torch.nn.DataParallel(model).to(device) #GPU로 데이터 병렬 처리 진행 # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) #모델의 매개변수를 불러옴 AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data1 = RawDataset(root=opt.image_folder1, opt=opt) # use RawDataset 간판탐지결과 demo_data2 = RawDataset(root=opt.image_folder2, opt=opt) # use RawDataset 구글맵문자열탐지결과 demo_loader1 = torch.utils.data.DataLoader(demo_data1, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) demo_loader2 = torch.utils.data.DataLoader(demo_data2, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader1: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) #ATTn preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') #이어서 쓸수 있게 열고 dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') #테이블 양식 출력 log.write( f'{dashed_line}\n{head}\n{dashed_line}\n') #txt에 테이블 양식 저장 preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence score 값을 계산 confidence_score = pred_max_prob.cumprod(dim=0)[-1] lists.append(pred) print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}' ) #구한 값을 출력 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) #구한 값을 txt에 저장 log.close() #파일 닫기 with torch.no_grad(): for image_tensors, image_path_list in demo_loader2: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) #ATTn preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') #이어서 쓸수 있게 열고 dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') #테이블 양식 출력 log.write( f'{dashed_line}\n{head}\n{dashed_line}\n') #txt에 테이블 양식 저장 preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # confidence score 값을 계산 confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}' ) #구한 값을 출력 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) #구한 값을 txt에 저장 if pred in lists: print(pred + "은(는) 알맞은 목적지입니다.") else: print(pred + "은(는) 알맞은 목적지가 아닙니다.") log.close() #파일 닫기
def train(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length+2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) log.write('-' * 80 + '\n') log.close() text_dataset = text_gen(opt) text_loader = torch.utils.data.DataLoader( text_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), pin_memory=True, drop_last=True) opt.num_class = len(converter.character) c_code_size = opt.latent cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size) ocrModel = ModelV1(opt) genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size) accumulate(g_ema, genModel, 0) # uCriterion = torch.nn.MSELoss() # sCriterion = torch.nn.MSELoss() # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': # ocrCriterion = torch.nn.L1Loss() # else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: print('Not implemented error') sys.exit() # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 cEncoder= torch.nn.DataParallel(cEncoder).to(device) cEncoder.train() genModel = torch.nn.DataParallel(genModel).to(device) g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() disEncModel = torch.nn.DataParallel(disEncModel).to(device) disEncModel.train() ocrModel = torch.nn.DataParallel(ocrModel).to(device) if opt.ocrFixed: if opt.Transformation == 'TPS': ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() else: ocrModel.train() g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( list(genModel.parameters())+list(cEncoder.parameters()), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disEncModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ocr_optimizer = optim.Adam( ocrModel.parameters(), lr=opt.lr, betas=(0.9, 0.99), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) # if opt.saved_gen_model !='' and opt.saved_gen_model !='None': # print(f'loading pretrained gen model from {opt.saved_gen_model}') # checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) # genModel.module.load_state_dict(checkpoint['g']) # g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disEncModel.load_state_dict(checkpoint['disEncModel']) ocrModel.load_state_dict(checkpoint['ocrModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_unsup = Averager() loss_avg_sup = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() loss_avg_ocr_sup = Averager() loss_avg_ocr_unsup = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 # loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 epsilon = 10e-50 # sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() ocr_scheduler.step() image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() image_input_tensors = image_input_tensors.to(device) gt_image_tensors = image_input_tensors[:opt.batch_size].detach() real_image_tensors = image_input_tensors[opt.batch_size:].detach() labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, False) requires_grad(genModel, False) requires_grad(disEncModel, True) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) # #unsupervised code prediction on generated image # u_pred_code = disEncModel(fake_img, mode='enc') # uCost = uCriterion(u_pred_code, z_code) # #supervised code prediction on gt image # s_pred_code = disEncModel(gt_image_tensors, mode='enc') # sCost = uCriterion(s_pred_code, gt_phoc_tensors) #Domain discriminator fake_pred = disEncModel(fake_img) real_pred = disEncModel(real_image_tensors) disCost = d_logistic_loss(real_pred, fake_pred) # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost loss_avg_dis.add(disCost) # loss_avg_sup.add(opt.beta*sCost) # loss_avg_unsup.add(opt.gamma_e * uCost) disEncModel.zero_grad() disCost.backward() dis_optimizer.step() d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: real_image_tensors.requires_grad = True real_pred = disEncModel(real_image_tensors) r1_loss = d_r1_loss(real_pred, real_image_tensors) disEncModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() log_r1_val.add(r1_loss) # Recognizer update if not opt.ocrFixed and not opt.zAlone: requires_grad(disEncModel, False) requires_grad(ocrModel, True) if 'CTC' in opt.Prediction: preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt) else: print("Not implemented error") sys.exit() ocrModel.zero_grad() ocrCost.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr_sup.add(ocrCost) else: loss_avg_ocr_sup.add(torch.tensor(0.0)) # [Word Generator] update # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:opt.batch_size] # real_image_tensors = image_input_tensors[opt.batch_size:] # labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, True) requires_grad(genModel, True) requires_grad(disEncModel, False) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) fake_pred = disEncModel(fake_img) disGenCost = g_nonsaturating_loss(fake_pred) if opt.zAlone: ocrCost = torch.tensor(0.0) else: #Compute OCR prediction (Reconstruction of content) # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device) if 'CTC' in opt.Prediction: preds_recon = ocrModel(fake_img, text_z_c, is_train=False) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c) else: print("Not implemented error") sys.exit() genModel.zero_grad() cEncoder.zero_grad() gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0] loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0] loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) if opt.grad_balance: gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0] a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR)) if a is None: print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) if a>1000 or a<0.0001: print(a) ocrCost = a.detach() * ocrCost gen_enc_cost = disGenCost + ocrCost gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0] loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) with torch.no_grad(): gen_enc_cost.backward() else: gen_enc_cost.backward() loss_avg_gen.add(disGenCost) loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost) optimizer.step() g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink) # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:path_batch_size] # labels_gt = labels[:path_batch_size] text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length) # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length) decay = 0.01 path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - mean_path_length_orig).pow(2).mean() mean_path_length = mean_path_length_orig.detach().item() genModel.zero_grad() cEncoder.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() # mean_path_length_avg = ( # reduce_sum(mean_path_length).item() / get_world_size() # ) #commented above for multi-gpu , non-distributed setting mean_path_length_avg = mean_path_length accumulate(g_ema, genModel, accum) log_avg_path_loss_val.add(path_loss) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #generate paired content with similar style labels_z_c_1 = iter(text_loader).next() labels_z_c_2 = iter(text_loader).next() text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length) text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length) z_c_code_1 = cEncoder(text_z_c_1) z_c_code_2 = cEncoder(text_z_c_2) style_c1_s1 = [] style_c2_s1 = [] style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s1.append(style_s1[0]*z_c_code_1) style_c2_s1.append(style_s1[0]*z_c_code_2) if len(style_s1)>1: style_c1_s1.append(style_s1[1]*z_c_code_1) style_c2_s1.append(style_s1[1]*z_c_code_2) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s2 = [] style_c1_s2.append(noise_style[0]*z_c_code_1) if len(noise_style)>1: style_c1_s2.append(noise_style[1]*z_c_code_1) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style_c1_s1[0][:,:opt.latent]) if len(style_c1_s1)>1: newstyle.append(style_c1_s1[1][:,:opt.latent]) style_c1_s1 = newstyle style_c2_s1 = newstyle style_c1_s2 = newstyle fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent) fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent) fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent) if not opt.zAlone: #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(gt_image_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt = converter.decode(preds_index.data, preds_size.data) else: print("Not implemented error") sys.exit() else: preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0] preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(opt.batch_size): try: save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png')) if not opt.zAlone: save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png')) save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png')) if trImgCntr<gt_image_tensors.shape[0]: save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) print(loss_log) loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_ocr_unsup.reset() loss_avg_ocr_sup.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ 'cEncoder':cEncoder.state_dict(), 'genModel':genModel.state_dict(), 'g_ema':g_ema.state_dict(), 'ocrModel':ocrModel.state_dict(), 'disEncModel':disEncModel.state_dict(), 'optimizer':optimizer.state_dict(), 'ocr_optimizer':ocr_optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def train(opt): lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(train_dataset_log) print('-' * 80) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) # styleModel = StyleTensorEncoder(input_dim=opt.input_channel) # genModel = AdaIN_Tensor_WordGenerator(opt) # disModel = MsImageDisV2(opt) # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none') # mixModel = Mixer(opt,nblk=3, dim=opt.latent) genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) ocrModel = ModelV1(opt).to(device) accumulate(g_ema, genModel, 0) # # weight initialization # for currModel in [styleModel, mixModel]: # for name, param in currModel.named_parameters(): # if 'localization_fc2' in name: # print(f'Skip {name} as it is already initialized') # continue # try: # if 'bias' in name: # init.constant_(param, 0.0) # elif 'weight' in name: # init.kaiming_normal_(param) # except Exception as e: # for batchnorm. # if 'weight' in name: # param.data.fill_(1) # continue if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': ocrCriterion = torch.nn.L1Loss() else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # vggRecCriterion = torch.nn.L1Loss() # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) if opt.distributed: genModel = torch.nn.parallel.DistributedDataParallel( genModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) disModel = torch.nn.parallel.DistributedDataParallel( disModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) ocrModel = torch.nn.parallel.DistributedDataParallel( ocrModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False ) # styleModel = torch.nn.DataParallel(styleModel).to(device) # styleModel.train() # mixModel = torch.nn.DataParallel(mixModel).to(device) # mixModel.train() # genModel = torch.nn.DataParallel(genModel).to(device) # g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() # disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() # vggModel = torch.nn.DataParallel(vggModel).to(device) # vggModel.eval() # ocrModel = torch.nn.DataParallel(ocrModel).to(device) # if opt.distributed: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() # # ocrModel.module.SequenceModeling.eval() # ocrModel.module.Prediction.eval() # else: # ocrModel.Transformation.eval() # ocrModel.FeatureExtraction.eval() # ocrModel.AdaptiveAvgPool.eval() # # ocrModel.SequenceModeling.eval() # ocrModel.Prediction.eval() ocrModel.eval() if opt.distributed: g_module = genModel.module d_module = disModel.module else: g_module = genModel d_module = disModel g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( genModel.parameters(), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': if not opt.distributed: ocrModel = torch.nn.DataParallel(ocrModel) print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) #temporary fix if not opt.distributed: ocrModel = ocrModel.module if opt.saved_gen_model !='' and opt.saved_gen_model !='None': print(f'loading pretrained gen model from {opt.saved_gen_model}') checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) genModel.module.load_state_dict(checkpoint['g']) g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disModel.load_state_dict(checkpoint['disModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg = Averager() loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_vgg_per = Averager() loss_avg_vgg_sty = Averager() loss_avg_ocr = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, False) # requires_grad(styleModel, False) # requires_grad(mixModel, False) requires_grad(disModel, True) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device) # scInput = mixModel(style,text_2) if 'CTC' in opt.Prediction: images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent) else: images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent) #Domain discriminator: Dis update if opt.augment: image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p) images_recon_2, _ = augment(images_recon_2, ada_aug_p) else: image_gt_tensors_aug = image_gt_tensors fake_pred = disModel(images_recon_2) real_pred = disModel(image_gt_tensors_aug) disCost = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = disCost*opt.disWeight loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_avg_dis.add(disCost) disModel.zero_grad() disCost.backward() dis_optimizer.step() if opt.augment and opt.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device ) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > opt.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: image_gt_tensors.requires_grad = True image_input_tensors.requires_grad = True cat_tensor = image_gt_tensors real_pred = disModel(cat_tensor) r1_loss = d_r1_loss(real_pred, cat_tensor) disModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() loss_dict["r1"] = r1_loss # #[Style Encoder] + [Word Generator] update image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, True) # requires_grad(styleModel, True) # requires_grad(mixModel, True) requires_grad(disModel, False) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent) style = mixing_noise(batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, _ = genModel(style, text_2) else: images_recon_2, _ = genModel(style, text_2[:,1:-1]) if opt.augment: images_recon_2, _ = augment(images_recon_2, ada_aug_p) fake_pred = disModel(images_recon_2) disGenCost = g_nonsaturating_loss(fake_pred) loss_dict["g"] = disGenCost # # #Adversarial loss # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1)) # #Input reconstruction loss # recCost = recCriterion(images_recon_2,image_gt_tensors) # #vgg loss # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2) #ocr loss text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss) preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss) ocrCost = ocrCriterion(preds_recon, preds_gt) else: if 'CTC' in opt.Prediction: preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False) # preds_o = preds_recon[:, :text_1.shape[1], :] preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2) #predict ocr recognition on generated images # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size) _, preds_recon_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) # preds_s = preds_s[:, :text_1.shape[1] - 1, :] preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :] preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data) else: preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1)) #predict ocr recognition on generated images _, preds_o_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_o_index, length_for_pred) for idx, pred in enumerate(labels_o_ocr): pred_EOS = pred.find('[s]') labels_o_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index, length_for_pred) for idx, pred in enumerate(labels_s_ocr): pred_EOS = pred.find('[s]') labels_s_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred) for idx, pred in enumerate(labels_sc_ocr): pred_EOS = pred.find('[s]') labels_sc_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # cost = opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost cost = opt.disWeight*disGenCost + opt.ocrWeight*ocrCost # styleModel.zero_grad() genModel.zero_grad() # mixModel.zero_grad() disModel.zero_grad() # vggModel.zero_grad() ocrModel.zero_grad() cost.backward() optimizer.step() loss_avg.add(cost) g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) path_batch_size = max(1, batch_size // opt.path_batch_shrink) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True) style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True) else: images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( images_recon_2, latents, mean_path_length ) genModel.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() #Individual losses loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(torch.tensor(0.0)) loss_avg_vgg_per.add(torch.tensor(0.0)) loss_avg_vgg_sty.add(torch.tensor(0.0)) loss_avg_ocr.add(opt.ocrWeight*ocrCost) log_r1_val.add(loss_reduced["path"]) log_avg_path_loss_val.add(loss_reduced["path"]) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: # pbar.set_description( # ( # f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " # f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " # f"augment: {ada_aug_p:.4f}" # ) # ) if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # if cntr % 100 == 0: # with torch.no_grad(): # g_ema.eval() # sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]]) # utils.save_image( # sample, # os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"), # nrow=int(opt.n_sample ** 0.5), # normalize=True, # range=(-1, 1), # ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images curr_batch_size = style[0].shape[0] images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent) os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png')) else: save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # styleModel.eval() genModel.eval() g_ema.eval() # mixModel.eval() disModel.eval() with torch.no_grad(): valid_loss, infer_time, length_of_data = validation_synth_v6( iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt) # styleModel.train() genModel.train() # mixModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train OCR loss: {loss_avg_ocr.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Valid Synth loss: {valid_loss[0]:0.5f}, \ Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \ Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item()) print(loss_log) loss_avg.reset() loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_vgg_per.reset() loss_avg_vgg_sty.reset() loss_avg_ocr.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ # 'styleModel':styleModel.state_dict(), # 'mixModel':mixModel.state_dict(), 'genModel':g_module.state_dict(), 'g_ema':g_ema.state_dict(), 'disModel':d_module.state_dict(), 'optimizer':optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
class OCR: def __init__(self): # model settings # self.path_model = 'model/TPS-ResNet-BiLSTM-Attn.pth' self.batch_size = 1 self.batch_max_length = 25 self.imgH = 32 self.imgW = 100 self.character = '0123456789abcdefghijklmnopqrstuvwxyz' self.Transformation = 'TPS' self.FeatureExtraction = 'ResNet' self.SequenceModeling = 'BiLSTM' self.Prediction = 'Attn' self.num_fiducial = 20 self.input_channel = 1 self.output_channel = 512 self.hidden_size = 256 self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') parser = argparse.ArgumentParser() parser.add_argument('--rgb', action='store_true', help='use rgb input') self.opt = parser.parse_args() self.opt.num_gpu = torch.cuda.device_count() # load model self.converter = AttnLabelConverter(self.character) self.opt.num_class = len(self.converter.character) if self.opt.rgb: self.opt.input_channel = 3 self.model = Model(self.opt) self.model = torch.nn.DataParallel(self.model).to('cuda:0') # load model self.model.load_state_dict(torch.load(self.path_model)) def run(self, img): with torch.no_grad(): img = cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) img = cv2.resize(img, dsize=(100, 32), interpolation=cv2.INTER_CUBIC) image_tensor = img[np.newaxis, np.newaxis, ...] image = torch.from_numpy(image_tensor).float().to(self.device) # For max length prediction length_for_pred = torch.IntTensor([self.batch_max_length] * self.batch_size).to(self.device) text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) preds = self.model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) pred = '' for i in range(len(preds_str)): pred += preds_str[i][:preds_str[i].find( '[s]')] # prune after "end of sentence" token ([s]) print(pred) return pred
class Pytorch_model: def __init__(self, model_path: str, gpu_id=None): ''' 初始化pytorch模型 :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) :param alphabet: 字母表 :param img_shape: 图像的尺寸(w,h) :param net: 网络计算图,如果在model_path中指定的是参数的保存路径,则需要给出网络的计算图 :param img_channel: 图像的通道数: 1,3 :param gpu_id: 在哪一块gpu上运行 ''' self.gpu_id = gpu_id checkpoint = torch.load(model_path) if self.gpu_id is not None and isinstance( self.gpu_id, int) and torch.cuda.is_available(): self.device = torch.device("cuda:%s" % self.gpu_id) else: self.device = torch.device("cpu") print('device:', self.device) config = checkpoint['config'] self.prediction_type = config['arch']['args']['prediction']['type'] if self.prediction_type == 'CTC': self.converter = CTCLabelConverter( config['data_loader']['args']['alphabet']) else: self.converter = AttnLabelConverter( config['data_loader']['args']['alphabet']) num_class = len(self.converter.character) self.net = get_model(num_class, config) self.img_w = config['data_loader']['args']['dataset']['img_w'] self.img_h = config['data_loader']['args']['dataset']['img_h'] self.img_channel = config['data_loader']['args']['dataset'][ 'img_channel'] self.net.load_state_dict(checkpoint['state_dict']) self.net.to(self.device) self.net.eval() def predict(self, img): ''' 对传入的图像进行预测,支持图像地址和numpy数组 :param img: 像地址或numpy数组 :param is_numpy: :return: ''' assert self.img_channel in [1, 3], 'img_channel must in [1.3]' if isinstance(img, str): # read image assert os.path.exists(img), 'file is not exists' img = cv2.imread(img, 0 if self.img_channel == 1 else 1) img = self.pre_processing(img) # 将图片由(w,h)变为(1,img_channel,h,w) img = transforms.ToTensor()(img) img = img.unsqueeze_(0) img = img.to(self.device) with torch.no_grad(): text = torch.zeros(1, 80, dtype=torch.long, device=self.device) preds = self.net(img, text) preds = torch.softmax(preds, dim=2).detach().numpy() preds_str = self.converter.decode(preds) return preds_str def pre_processing(self, img): """ 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度 :param img_path: 图片 :return: """ img_h = self.img_h img_w = self.img_w h, w = img.shape[:2] ratio_h = float(img_h) / h new_w = int(w * ratio_h) img = cv2.resize(img, (new_w, img_h)) return img
class TextExtractor(): def __init__(self, image_folder, extract_text_file, split): self.i_folder = image_folder #print(image_folder) #print("aaaaaaa test") self.extract_text_file = extract_text_file self.canvas_size = 1280 self.mag_ratio = 1.5 self.show_time = False self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.cuda = torch.cuda.is_available() self.net = CRAFT() #(1st model) model to detect words in images if self.cuda: self.net.load_state_dict( self.copyStateDict( torch.load('CRAFT-pytorch/craft_mlt_25k.pth'))) else: self.net.load_state_dict( self.copyStateDict( torch.load('CRAFT-pytorch/craft_mlt_25k.pth', map_location='cpu'))) if self.cuda: self.net = self.net.cuda() self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = False self.net.eval() self.refine_net = None self.text_threshold = 0.7 self.link_threshold = 0.4 self.low_text = 0.4 self.poly = False self.result_folder = './' + split + '_' + 'intermediate_result/' if not os.path.isdir(self.result_folder): os.mkdir(self.result_folder) #Parameters for image to text model (2nd model) self.parser = argparse.ArgumentParser() #Data processing self.parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') self.parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') self.parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') self.parser.add_argument('--rgb', default=False, action='store_true', help='use rgb input') self.parser.add_argument( '--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') self.parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') self.parser.add_argument( '--PAD', action='store_true', help='whether to keep ratio then pad for image resize') #Model Architecture self.parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') self.parser.add_argument( '--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet') self.parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') self.parser.add_argument('--Prediction', type=str, default='Attn', help='Prediction stage. CTC|Attn') self.parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') self.parser.add_argument( '--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') self.parser.add_argument( '--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') self.parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') #self.opt = self.parser.parse_args() self.opt, unknown = self.parser.parse_known_args() #self.opt, unknown = self.parser.parse_known_args() if 'CTC' in self.opt.Prediction: self.converter = CTCLabelConverter(self.opt.character) else: self.converter = AttnLabelConverter(self.opt.character) self.opt.num_class = len(self.converter.character) #print(opt.rgb) if self.opt.rgb: self.opt.input_channel = 3 self.opt.num_gpu = torch.cuda.device_count() self.opt.batch_size = 192 #self.opt.batch_size = 3 self.opt.workers = 0 self.model = Model(self.opt) #image to text model (2nd model) self.model = torch.nn.DataParallel(self.model).to(self.device) self.model.load_state_dict( torch.load( 'deep-text-recognition-benchmark/TPS-ResNet-BiLSTM-Attn.pth', map_location=self.device)) self.model.eval() def copyStateDict(self, state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict def test_net(self, net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, self.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=self.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # refine link if refine_net is not None: with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = y_refiner[0, :, :, 0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if self.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text def extract_text(self): l = sorted(os.listdir(self.i_folder)) img_to_index = {} count = 0 for full_file in l: split_file = full_file.split(".") filename = split_file[0] img_to_index[count] = filename #print(count, filename) count += 1 #print(filename) file_extension = "." + split_file[1] #print(filename, file_extension) image = imgproc.loadImage(self.i_folder + full_file) bboxes, polys, score_text = self.test_net( self.net, image, self.text_threshold, self.link_threshold, self.low_text, self.cuda, self.poly, self.refine_net) img = cv2.imread(self.i_folder + filename + file_extension) rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) points = [] order = [] for i in range(0, len(bboxes)): sample_bbox = bboxes[i] min_point = sample_bbox[0] max_point = sample_bbox[2] for j, p in enumerate(sample_bbox): if (p[0] <= min_point[0]): min_point = (p[0], min_point[1]) if (p[1] <= min_point[1]): min_point = (min_point[0], p[1]) if (p[0] >= max_point[0]): max_point = (p[0], max_point[1]) if (p[1] >= max_point[1]): max_point = (max_point[0], p[1]) min_point = (max(min(len(rgb_img[0]), min_point[0]), 0), max(min(len(rgb_img), min_point[1]), 0)) max_point = (max(min(len(rgb_img[0]), max_point[0]), 0), max(min(len(rgb_img), max_point[1]), 0)) points.append((min_point, max_point)) order.append(0) num_ordered = 0 rows_ordered = 0 points_sorted = [] ordered_points_index = 0 order_sorted = [] while (num_ordered < len(points)): #find lowest-y that is unordered min_y = len(rgb_img) min_y_index = -1 for i in range(0, len(points)): if (order[i] == 0): if (points[i][0][1] <= min_y): min_y = points[i][0][1] min_y_index = i rows_ordered += 1 order[min_y_index] = rows_ordered num_ordered += 1 points_sorted.append(points[min_y_index]) order_sorted.append(rows_ordered) ordered_points_index = len(points_sorted) - 1 # Group bboxes that are on the same row max_y = points[min_y_index][1][1] range_y = max_y - min_y for i in range(0, len(points)): if (order[i] == 0): min_y_i = points[i][0][1] max_y_i = points[i][1][1] range_y_i = max_y_i - min_y_i if (max_y_i >= min_y and min_y_i <= max_y): overlap = (min(max_y_i, max_y) - max(min_y_i, min_y)) / (max( 1, min(range_y, range_y_i))) if (overlap >= 0.30): order[i] = rows_ordered num_ordered += 1 min_x_i = points[i][0][0] for j in range(ordered_points_index, len(points_sorted) + 1): if (j < len(points_sorted) ): #insert before min_x_j = points_sorted[j][0][0] if (min_x_i < min_x_j): points_sorted.insert(j, points[i]) order_sorted.insert( j, rows_ordered) break else: #insert at the end of array points_sorted.insert(j, points[i]) order_sorted.insert(j, rows_ordered) break for i in range(0, len(points_sorted)): min_point = points_sorted[i][0] max_point = points_sorted[i][1] mask_file = self.result_folder + filename + "_" + str( order_sorted[i]) + "_" + str(i) + file_extension crop_image = rgb_img[int(min_point[1]):int(max_point[1]), int(min_point[0]):int(max_point[0])] #print(filename, min_point, max_point, len(rgb_img), len(rgb_img[0])) cv2.imwrite(mask_file, crop_image) AlignCollate_demo = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD) demo_data = RawDataset(root=self.result_folder, opt=self.opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=self.opt.batch_size, shuffle=False, num_workers=int(self.opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) f = open(self.extract_text_file, "w") count = -1 curr_order = 1 curr_filename = "" output_string = "" end_line = "[SEP] " with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(self.device) #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device) #print(image_path_list) #print(image.size()) length_for_pred = torch.IntTensor([self.opt.batch_max_length] * batch_size).to(self.device) text_for_pred = torch.LongTensor(batch_size, self.opt.batch_max_length + 1).fill_(0).to(self.device) preds = self.model(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) for path, p in zip(image_path_list, preds_str): #print(path) if 'Attn' in self.opt.Prediction: pred_EOS = p.find('[s]') p = p[: pred_EOS] # prune after "end of sentence" token ([s]) path_info = path[len(self.result_folder):].split( ".")[0].split( "_" ) #ASSUMES FILE EXTENSION OF SIZE 4 (.PNG, .JPG, ETC) #print(curr_filename) #print(path_info[0]) #print("PATHINFO: ",path_info[0]) if (not (curr_filename == path_info[0])): if (not (curr_filename == "")): f.write(str(count) + "\n") f.write(curr_filename + "\n") f.write(output_string + "\n\n") count += 1 curr_filename = img_to_index[count] #path_info[0] #print("CURRFILE: ", curr_filename) while (not (curr_filename == path_info[0])): f.write(str(count) + "\n") f.write(curr_filename + "\n") f.write("\n\n") count += 1 curr_filename = img_to_index[count] #path_info[0] #print("CURRFILE: ", curr_filename) output_string = "" curr_order = 1 if (int(path_info[1]) > curr_order): curr_order += 1 output_string += end_line output_string += p + " " f.write(str(count) + "\n") f.write(curr_filename + "\n") f.write(output_string + "\n\n") f.close() #Go through each image in the i_folder and crop out text #generate text and write to text file def get_item(self, index): f = open(self.extract_text_file, "r") Lines = f.readlines() return (Lines[4 * index + 2][:-1]) # read text file #TEST #t_e = TextExtractor("data/mmimdb-256/dataset-resized-256max/dev_n/images/","text_extract_output.txt") #t_e.extract_text() #text = t_e.get_item(1) #print(text)
def demo(opt): """ model configuration """ converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) # print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, # opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, # opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: all_pred_strs = [] all_confidence_scores = [] batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) predss = model(image, text_for_pred, is_train=False)[0] for i, preds in enumerate(predss): confidence_score_list = [] pred_str_list = [] # select max probability (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for pred, pred_max_prob in zip(preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_str_list.append(pred) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) try: confidence_score = pred_max_prob.cumprod( dim=0)[-1].cpu().numpy() except: confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) confidence_score_list.append(confidence_score) all_pred_strs.append(pred_str_list) all_confidence_scores.append(confidence_score_list) all_confidence_scores = np.array(all_confidence_scores) all_pred_strs = np.array(all_pred_strs) best_pred_index = np.argmax(all_confidence_scores, axis=0) best_pred_index = np.expand_dims(best_pred_index, axis=0) # Get max predition per image through blocks all_pred_strs = np.take_along_axis(all_pred_strs, best_pred_index, axis=0)[0] all_confidence_scores = np.take_along_axis(all_confidence_scores, best_pred_index, axis=0)[0] log = open(f'./log_demo_result.txt', 'w') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') for img_name, pred, confidence_score in zip( image_path_list, all_pred_strs, all_confidence_scores): print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close()
class ModelWrapper(Model): def __init__(self, opt): if 'CTC' in opt.Prediction: self.converter = CTCLabelConverter(opt.character) else: self.converter = AttnLabelConverter(opt.character) opt.num_class = len(self.converter.character) if opt.rgb: opt.input_channel = 3 super().__init__(opt) self = torch.nn.DataParallel(self).to(device) print('loading pretrained model from %s' % opt.saved_model) self.load_state_dict(torch.load(opt.saved_model, map_location=device)) self.opt = opt def predict(self, img): self.eval() batch_size = 1 with torch.no_grad(): AlignCollate_demo = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD) transform_PIL = transforms.ToPILImage() image = [transform_PIL(img)] image = AlignCollate_demo((image, ""))[0] length_for_pred = torch.IntTensor([self.opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor( batch_size, self.opt.batch_max_length + 1).fill_(0).to(device) print(image.shape) cv2.imshow("", tensor2im(image[0])) if 'CTC' in self.opt.Prediction: preds = self(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = self.converter.decode(preds_index.data, preds_size.data) else: preds = self(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) pred_max_prob = preds_max_prob[0] pred = preds_str[0] if 'Attn' in self.opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] confidence_score = pred_max_prob.cumprod(dim=0)[-1] return pred, confidence_score
class TextReader(object): opts = SimpleNamespace() trmodel = None device = None convertor = None def __init__(self, args): path = os.path.abspath(__file__) dir_path = os.path.dirname(path) self.opts.workers = args.get("workers", 5) # number of data loading workers ## data processing args self.opts.batch_size = args.get('batch_size', 50) #maximum-label-length self.opts.batch_max_length = args.get('batch_max_length', 25) #maximum-label-length self.opts.imgH = args.get('imgH', 32) # the height of the input image self.opts.imgW = args.get('imgW', 100) # #the width of the input image self.opts.rgb = args.get('rgb', False) # use rgb input self.opts.character = args.get( 'character', '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ) #self.opts.character = args.get('character','0123456789abcdefghijklmnopqrstuvwxyz') #character label self.opts.sensitive = args.get('sensitive', True) #for sensitive character mode self.opts.PAD = args.get( 'PAD', True) #whether to keep ratio then pad for image resize ## Model architecture self.opts.Transformation = args.get( 'Transformation', 'TPS') #Transformation stage. None|TPS self.opts.FeatureExtraction = args.get( 'FeatureExtraction', 'ResNet') #FeatureExtraction stage. VGG|RCNN|ResNet self.opts.SequenceModeling = args.get( 'SequenceModeling', 'BiLSTM') #SequenceModeling stage. None|BiLSTM self.opts.Prediction = args.get('Prediction', 'Attn') #Prediction stage. CTC|Attn self.opts.num_fiducial = args.get( 'num_fiducial', 20) #number of fiducial points of TPS-STN self.opts.input_channel = args.get( 'input_channel', 1) #the number of input channel of Feature extractor self.opts.output_channel = args.get( 'output_channel', 512) #the number of output channel of Feature extractor self.opts.hidden_size = args.get( 'hidden_size', 256) #the size of the LSTM hidden state self.opts.num_gpu = 0 if torch.cuda.is_available( ) else torch.cuda.device_count() self.opts.saved_model = args.get( 'saved_model', dir_path + "/models/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth") if self.opts.sensitive: self.opts.character = string.printable[: -6] # same with ASTER setting (use 94 char). if 'CTC' in self.opts.Prediction: self.converter = CTCLabelConverter(self.opts.character) else: self.converter = AttnLabelConverter(self.opts.character) self.opts.num_class = len(self.converter.character) if self.opts.rgb: self.opts.input_channel = 3 self.pre_load_model() # # private function to preload the torch model # def pre_load_model(self): opt = self.opts print("preloading the model with opts " + str(opt)) self.trmodel = Model(opt) self.trmodel = torch.nn.DataParallel(self.trmodel) if torch.cuda.is_available(): self.trmodel = self.trmodel.cuda() self.device = torch.device('cuda:0') else: self.device = torch.device('cpu') # load model print('loading pretrained model from %s' % self.opts.saved_model) if torch.cuda.is_available(): self.trmodel.load_state_dict(torch.load(self.opts.saved_model)) else: self.trmodel.load_state_dict( torch.load(opt.saved_model, map_location='cpu')) self.trmodel.eval() # # # def predictAllImagesInFolder(self, src_path): opt = self.opts AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=src_path, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=torch.cuda.is_available()) results = [] for image_tensors, image_path_list in demo_loader: preds_str = self.predict(image_tensors) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) results.append(f'{os.path.basename(img_name)},{pred}') return results ## ## ## def predict(self, image_tensors): print("############# About to predict for next batch ****") batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.to(self.device) length_for_pred = torch.IntTensor([self.opts.batch_max_length] * batch_size) text_for_pred = torch.LongTensor( batch_size, self.opts.batch_max_length + 1).fill_(0) if 'CTC' in self.opts.Prediction: preds = self.trmodel(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = self.converter.decode(preds_index.data, preds_size.data) else: preds = self.trmodel(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) # print('-' * 80) # print('image_path\tpredicted_labels') # print('-' * 80) # pred = "" # for pred in preds_str: # if 'Attn' in self.opts.Prediction: # pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s]) # print("predictions = "+pred) return preds_str
class TextRecongtion(): def __init__(self): self.opt = Recognition_Option() # self.opt = parser.parse_args() if self.opt.sensitive: self.opt.character = string.printable[: -6] # same with ASTER setting (use 94 char). cudnn.benchmark = False cudnn.deterministic = True self.opt.num_gpu = torch.cuda.device_count() def load_net(self): """ model configuration """ if 'CTC' in self.opt.Prediction: self.converter = CTCLabelConverter(self.opt.character) else: self.converter = AttnLabelConverter(self.opt.character) self.opt.num_class = len(self.converter.character) if self.opt.rgb: self.opt.input_channel = 3 print("input channle :{}".format(self.opt.input_channel)) self.model = Model(self.opt) print('model input parameters', self.opt.imgH, self.opt.imgW, self.opt.num_fiducial, self.opt.input_channel, self.opt.output_channel, self.opt.hidden_size, self.opt.num_class, self.opt.batch_max_length, self.opt.Transformation, self.opt.FeatureExtraction, self.opt.SequenceModeling, self.opt.Prediction) self.model = torch.nn.DataParallel(self.model).to(device) # load model print('loading pretrained model from %s' % self.opt.saved_model) state_dict = torch.load(self.opt.saved_model, map_location=device) self.model.load_state_dict(state_dict) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo self.AlignCollate_demo = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD) def predict(self, image_list): if len(image_list) <= 0: return [""] demo_data = RawDataset(image_list, opt=self.opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=self.opt.batch_size, shuffle=False, num_workers=int(self.opt.workers), collate_fn=self.AlignCollate_demo, pin_memory=True) # predict ret = [] self.model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([self.opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, self.opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in self.opt.Prediction: preds = self.model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = self.converter.decode(preds_index.data, preds_size.data) else: preds = self.model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) # log = open(f'./log_demo_result.txt', 'a') # dashed_line = '-' * 80 # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' # print(f'{dashed_line}\n{head}\n{dashed_line}') # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip( image_path_list, preds_str, preds_max_prob): if 'Attn' in self.opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] if confidence_score <= 0.4: pred = "None" ret.append(pred) # ret.append(preds_str) print( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') return ret
def test(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length + 2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = LmdbDataset(root=opt.test_data, opt=opt) test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. sampler=test_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False) print('-' * 80) opt.num_class = len(converter.character) ocrModel = ModelV1(opt).to(device) ## Loading pre-trained files print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage) ocrModel.load_state_dict(checkpoint) evalCntr = 0 fCntr = 0 c1_s1_input_correct = 0.0 c1_s1_input_ed_correct = 0.0 # pdb.set_trace() for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader): print(vCntr) image_input_tensors = image_input_tensors.to(device) text_gt, length_gt = converter.encode( labels_gt, batch_max_length=opt.batch_max_length) with torch.no_grad(): currBatchSize = image_input_tensors.shape[0] # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device) #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(image_input_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel( image_input_tensors, text_gt[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_1): pred_EOS = pred.find('[s]') preds_str_gt_1[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) for trImgCntr in range(image_input_tensors.shape[0]): #ocr accuracy # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): c1_s1_input_gt = labels_gt[trImgCntr] c1_s1_input_ocr = preds_str_gt_1[trImgCntr] if c1_s1_input_gt == c1_s1_input_ocr: c1_s1_input_correct += 1 # ICDAR2019 Normalized Edit Distance if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0: c1_s1_input_ed_correct += 0 elif len(c1_s1_input_gt) > len(c1_s1_input_ocr): c1_s1_input_ed_correct += 1 - edit_distance( c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt) else: c1_s1_input_ed_correct += 1 - edit_distance( c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr) evalCntr += 1 fCntr += 1 avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr) avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr) # if not(opt.realVaData): with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'), 'a') as log: # training loss and validation loss loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}' print(loss_log) log.write(loss_log + "\n")
def test(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length+2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 # AlignCollate_valid = AlignPairImgCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) AlignCollate_valid = AlignPairImgCollate_Test(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = LmdbTestStyleContentDataset(root=opt.test_data, opt=opt, dataMode=opt.realVaData) test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=False, # 'True' to check training progress with validation function. sampler=test_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False) print('-' * 80) AlignCollate_text = AlignSynthTextCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) text_dataset = text_gen_synth(opt) text_data_sampler = data_sampler(text_dataset, shuffle=True, distributed=False) text_loader = torch.utils.data.DataLoader( text_dataset, batch_size=opt.batch_size, shuffle=False, sampler=text_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_text, drop_last=True) opt.num_class = len(converter.character) text_loader = sample_data(text_loader, text_data_sampler, False) c_code_size = opt.latent if opt.cEncode == 'mlp': cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size).to(device) elif opt.cEncode == 'cnn': if opt.contentNorm == 'in': cEncoder = ResNet_StyleExtractor_WIN(1, opt.latent).to(device) else: cEncoder = ResNet_StyleExtractor(1, opt.latent).to(device) if opt.styleNorm == 'in': styleModel = ResNet_StyleExtractor_WIN(opt.input_channel, opt.latent).to(device) else: styleModel = ResNet_StyleExtractor(opt.input_channel, opt.latent).to(device) ocrModel = ModelV1(opt).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, content_dim=c_code_size, channel_multiplier=opt.channel_multiplier).to(device) g_ema.eval() bestModelError=1e5 ## Loading pre-trained files print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage) ocrModel.load_state_dict(checkpoint) print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model, map_location=lambda storage, loc: storage) cEncoder.load_state_dict(checkpoint['cEncoder']) styleModel.load_state_dict(checkpoint['styleModel']) g_ema.load_state_dict(checkpoint['g_ema']) iCntr=0 evalCntr=0 fCntr=0 valMSE=0.0 valSSIM=0.0 valPSNR=0.0 c1_s1_input_correct=0.0 c2_s1_gen_correct=0.0 c1_s1_input_ed_correct=0.0 c2_s1_gen_ed_correct=0.0 ims, txts = [], [] for vCntr, (image_input_tensors, image_output_tensors, labels_gt, labels_z_c, labelSynthImg, synth_z_c, input_1_shape, input_2_shape) in enumerate(valid_loader): print(vCntr) if opt.debugFlag and vCntr >10: break image_input_tensors = image_input_tensors.to(device) image_output_tensors = image_output_tensors.to(device) if opt.realVaData and opt.outPairFile=="": # pdb.set_trace() labels_z_c, synth_z_c = next(text_loader) labelSynthImg = labelSynthImg.to(device) synth_z_c = synth_z_c.to(device) synth_z_c = synth_z_c[:labelSynthImg.shape[0]] text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) # print(labels_z_c) cEncoder.eval() styleModel.eval() g_ema.eval() with torch.no_grad(): if opt.cEncode == 'mlp': z_c_code = cEncoder(text_z_c) z_gt_code = cEncoder(text_gt) elif opt.cEncode == 'cnn': z_c_code = cEncoder(synth_z_c) z_gt_code = cEncoder(labelSynthImg) style = styleModel(image_input_tensors) if opt.noiseConcat or opt.zAlone: style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device, style) else: style = [style] fake_img_c1_s1, _ = g_ema(style, z_gt_code, input_is_latent=opt.input_latent) fake_img_c2_s1, _ = g_ema(style, z_c_code, input_is_latent=opt.input_latent) currBatchSize = fake_img_c1_s1.shape[0] # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device) #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(fake_img_c1_s1, text_gt, is_train=False, inAct = opt.taskActivation) preds_size = torch.IntTensor([preds.size(1)] * currBatchSize) _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c2_s1, text_z_c, is_train=False, inAct = opt.taskActivation) preds_size = torch.IntTensor([preds.size(1)] * currBatchSize) _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(image_input_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(image_output_tensors, text_z_c, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_output_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_2 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(fake_img_c1_s1, text_gt[:, :-1], is_train=False, inAct = opt.taskActivation) # align with Attention.forward _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_fake_img_c1_s1): pred_EOS = pred.find('[s]') preds_str_fake_img_c1_s1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) preds = ocrModel(fake_img_c2_s1, text_z_c[:, :-1], is_train=False, inAct = opt.taskActivation) # align with Attention.forward _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_fake_img_c2_s1): pred_EOS = pred.find('[s]') preds_str_fake_img_c2_s1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) preds = ocrModel(image_input_tensors, text_gt[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_1): pred_EOS = pred.find('[s]') preds_str_gt_1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) preds = ocrModel(image_output_tensors, text_z_c[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_2 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_2): pred_EOS = pred.find('[s]') preds_str_gt_2[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pathPrefix = os.path.join(opt.valDir, opt.exp_iter) os.makedirs(os.path.join(pathPrefix), exist_ok=True) for trImgCntr in range(image_output_tensors.shape[0]): if opt.outPairFile!="": labelId = 'label-' + valid_loader.dataset.pairId[fCntr] + '-' + str(fCntr) else: labelId = 'label-%09d' % valid_loader.dataset.filtered_index_list[fCntr] #evaluations valRange = (-1,+1) # pdb.set_trace() # inpTensor = skimage.img_as_ubyte(resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))) # gtTensor = skimage.img_as_ubyte(resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))) # predTensor = skimage.img_as_ubyte(resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))) # inpTensor = resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True) # gtTensor = resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True) # predTensor = resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True) inpTensor = F.interpolate(image_input_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) gtTensor = F.interpolate(image_output_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) predTensor = F.interpolate(fake_img_c2_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) predGTTensor = F.interpolate(fake_img_c1_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) # inpTensor = cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # gtTensor = cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # predTensor = cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # predGTTensor = cv2.resize(tensor2im(fake_img_c1_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # inpTensor = cv2.medianBlur(cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5) # gtTensor = cv2.medianBlur(cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5) # predTensor = cv2.medianBlur(cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5) # pdb.set_trace() if not(opt.realVaData): evalMSE = mean_squared_error(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255) # evalMSE = mean_squared_error(gtTensor/255, predTensor/255) evalSSIM = structural_similarity(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255, multichannel=True) # evalSSIM = structural_similarity(gtTensor/255, predTensor/255, multichannel=True) evalPSNR = peak_signal_noise_ratio(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255) # evalPSNR = peak_signal_noise_ratio(gtTensor/255, predTensor/255) # print(evalMSE,evalSSIM,evalPSNR) valMSE+=evalMSE valSSIM+=evalSSIM valPSNR+=evalPSNR #ocr accuracy # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): c1_s1_input_gt = labels_gt[trImgCntr] c1_s1_input_ocr = preds_str_gt_1[trImgCntr] c2_s1_gen_gt = labels_z_c[trImgCntr] c2_s1_gen_ocr = preds_str_fake_img_c2_s1[trImgCntr] if c1_s1_input_gt == c1_s1_input_ocr: c1_s1_input_correct += 1 if c2_s1_gen_gt == c2_s1_gen_ocr: c2_s1_gen_correct += 1 # ICDAR2019 Normalized Edit Distance if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0: c1_s1_input_ed_correct += 0 elif len(c1_s1_input_gt) > len(c1_s1_input_ocr): c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt) else: c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr) if len(c2_s1_gen_gt) == 0 or len(c2_s1_gen_ocr) == 0: c2_s1_gen_ed_correct += 0 elif len(c2_s1_gen_gt) > len(c2_s1_gen_ocr): c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_gt) else: c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_ocr) evalCntr+=1 #save generated images if opt.visFlag and iCntr>500: pass else: try: if iCntr == 0: # update website webpage = html.HTML(pathPrefix, 'Experiment name = %s' % 'Test') webpage.add_header('Testing iteration') iCntr += 1 img_path_c1_s1 = os.path.join(labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png') img_path_gt_1 = os.path.join(labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png') img_path_gt_2 = os.path.join(labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png') img_path_c2_s1 = os.path.join(labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png') ims.append([img_path_gt_1, img_path_c1_s1, img_path_gt_2, img_path_c2_s1]) content_c1_s1 = 'PSTYLE-1 '+'Text-1:' + labels_gt[trImgCntr]+' OCR:' + preds_str_fake_img_c1_s1[trImgCntr] content_gt_1 = 'OSTYLE-1 '+'GT:' + labels_gt[trImgCntr]+' OCR:' + preds_str_gt_1[trImgCntr] content_gt_2 = 'OSTYLE-1 '+'GT:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_gt_2[trImgCntr] content_c2_s1 = 'PSTYLE-1 '+'Text-2:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_fake_img_c2_s1[trImgCntr] txts.append([content_gt_1, content_c1_s1, content_gt_2, content_c2_s1]) utils.save_image(predGTTensor,os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'), predGTTensor) # pdb.set_trace() if not opt.zAlone: utils.save_image(inpTensor,os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) utils.save_image(gtTensor,os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) utils.save_image(predTensor,os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor) # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor) # imsave(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) except: print('Warning while saving validation image') fCntr += 1 webpage.add_images(ims, txts, width=256, realFlag=opt.realVaData) webpage.save() avg_valMSE = valMSE/float(evalCntr) avg_valSSIM = valSSIM/float(evalCntr) avg_valPSNR = valPSNR/float(evalCntr) avg_c1_s1_input_wer = c1_s1_input_correct/float(evalCntr) avg_c2_s1_gen_wer = c2_s1_gen_correct/float(evalCntr) avg_c1_s1_input_cer = c1_s1_input_ed_correct/float(evalCntr) avg_c2_s1_gen_cer = c2_s1_gen_ed_correct/float(evalCntr) # if not(opt.realVaData): with open(os.path.join(opt.exp_dir,opt.exp_name,'log_test.txt'), 'a') as log: # training loss and validation loss if opt.realVaData: loss_log = f'Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}' else: loss_log = f'Test MSE: {avg_valMSE:0.5f}, Test SSIM: {avg_valSSIM:0.5f}, Test PSNR: {avg_valPSNR:0.5f}, Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}' print(loss_log) log.write(loss_log+"\n")
def runDeepTextNet(segmentedImagesList): opt = argparse.Namespace(FeatureExtraction='ResNet', PAD=False, Prediction='Attn', SequenceModeling='BiLSTM', Transformation='TPS', batch_max_length=25, batch_size=192, character='0123456789abcdefghijklmnopqrstuvwxyz', hidden_size=256, image_folder='demo_image/', imgH=32, imgW=100, input_channel=1, num_class=38, num_fiducial=20, num_gpu=0, output_channel=512, rgb=False, saved_model='TPS-ResNet-BiLSTM-Attn.pth', sensitive=False, workers=4) model = Model(opt) model = torch.nn.DataParallel(model).to('cpu') directory = "TPS-ResNet-BiLSTM-Attn.pth" model.load_state_dict(torch.load(directory, map_location='cpu')) converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=segmentedImagesList, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() out_preds_texts = [] for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] # print(pred) out_preds_texts.append(pred) # print(out_preds_texts) sentence_out = [' '.join(out_preds_texts)] return (sentence_out)
def train(opt): plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots') if not os.path.exists(plotDir): os.makedirs(plotDir) lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset, train_dataset_log = hierarchical_dataset( root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(train_dataset_log) print('-' * 80) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 ocrModel = ModelV1(opt) styleModel = StyleTensorEncoder(input_dim=opt.input_channel) genModel = AdaIN_Tensor_WordGenerator(opt) disModel = MsImageDisV2(opt) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': ocrCriterion = torch.nn.L1Loss() else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 vggRecCriterion = torch.nn.L1Loss() vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for currModel in [styleModel, genModel, disModel]: for name, param in currModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue styleModel = torch.nn.DataParallel(styleModel).to(device) styleModel.train() genModel = torch.nn.DataParallel(genModel).to(device) genModel.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() vggModel = torch.nn.DataParallel(vggModel).to(device) vggModel.eval() ocrModel = torch.nn.DataParallel(ocrModel).to(device) ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() if opt.modelFolderFlag: if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))) > 0: opt.saved_synth_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))[-1] if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) styleModel.load_state_dict(checkpoint['styleModel']) genModel.load_state_dict(checkpoint['genModel']) disModel.load_state_dict(checkpoint['disModel']) if opt.imgReconLoss == 'l1': recCriterion = torch.nn.L1Loss() elif opt.imgReconLoss == 'ssim': recCriterion = ssim elif opt.imgReconLoss == 'ms-ssim': recCriterion = msssim if opt.styleLoss == 'l1': styleRecCriterion = torch.nn.L1Loss() elif opt.styleLoss == 'triplet': styleRecCriterion = torch.nn.TripletMarginLoss( margin=opt.tripletMargin, p=1) #for validation; check only positive pairs styleTestRecCriterion = torch.nn.L1Loss() # loss averager loss_avg = Averager() loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_vgg_per = Averager() loss_avg_vgg_sty = Averager() loss_avg_ocr = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, styleModel.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) for p in filter(lambda p: p.requires_grad, genModel.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable style and generator params num : ', sum(params_num)) # setup optimizer if opt.optim == 'adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("SynthOptimizer:") print(optimizer) #filter parameters for Dis training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.optim == 'adam': dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int( opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer, opt) dis_scheduler = get_scheduler(dis_optimizer, opt) start_time = time.time() iteration = start_iter cntr = 0 while (True): # train part if opt.lr_policy != "None": scheduler.step() dis_scheduler.step() image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter( train_loader).next() cntr += 1 image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode( labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator style = styleModel(image_input_tensors) images_recon_2 = genModel(style, text_2) #Domain discriminator: Dis update disModel.zero_grad() disCost = opt.disWeight * (disModel.module.calc_dis_loss( torch.cat((images_recon_2.detach(), image_input_tensors), dim=1), torch.cat((image_gt_tensors, image_input_tensors), dim=1))) disCost.backward() dis_optimizer.step() loss_avg_dis.add(disCost) # #[Style Encoder] + [Word Generator] update #Adversarial loss disGenCost = disModel.module.calc_gen_loss( torch.cat((images_recon_2, image_input_tensors), dim=1)) #Input reconstruction loss recCost = recCriterion(images_recon_2, image_gt_tensors) #vgg loss vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2) #ocr loss text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss) preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss) ocrCost = ocrCriterion(preds_recon, preds_gt) else: if 'CTC' in opt.Prediction: preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False) preds_o = preds_recon.deepcopy()[:, :text_1.shape[1] - 1, :] preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size) preds_recon = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon, text_2, preds_size, length_2) #predict ocr recognition on generated images preds_o_size = torch.IntTensor([preds_o.size(1)] * batch_size) _, preds_o_index = preds_o.max(2) labels_o_ocr = converter.decode(preds_o_index.data, preds_o_size.data) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) preds_s = preds_s[:, :text_1.shape[1] - 1, :] preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_input_tensors, text_for_pred, is_train=False) preds_sc = preds_sc[:, :text_2.shape[1] - 1, :] preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data) else: preds_recon = ocrModel( images_recon_2, text_for_pred[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion( preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1)) #predict ocr recognition on generated images _, preds_o_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_o_index, length_for_pred) for idx, pred in enumerate(labels_o_ocr): pred_EOS = pred.find('[s]') labels_o_ocr[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index, length_for_pred) for idx, pred in enumerate(labels_s_ocr): pred_EOS = pred.find('[s]') labels_s_ocr[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred) for idx, pred in enumerate(labels_sc_ocr): pred_EOS = pred.find('[s]') labels_sc_ocr[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) cost = opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.vggPerWeight * vggPerCost + opt.vggStyWeight * vggStyleCost + opt.ocrWeight * ocrCost styleModel.zero_grad() genModel.zero_grad() disModel.zero_grad() vggModel.zero_grad() ocrModel.zero_grad() cost.backward() optimizer.step() loss_avg.add(cost) #Individual losses loss_avg_gen.add(opt.disWeight * disGenCost) loss_avg_imgRecon.add(opt.reconWeight * recCost) loss_avg_vgg_per.add(opt.vggPerWeight * vggPerCost) loss_avg_vgg_sty.add(opt.vggStyWeight * vggStyleCost) loss_avg_ocr.add(opt.ocrWeight * ocrCost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages', str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': save_image( tensor2im(image_input_tensors[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_sInput_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(image_gt_tensors[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_csGT_' + labels_2[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_csRecon_' + labels_2[trImgCntr] + '.png')) else: save_image( tensor2im(image_input_tensors[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_sInput_' + labels_1[trImgCntr] + '_' + labels_s_ocr[trImgCntr] + '.png')) save_image( tensor2im(image_gt_tensors[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_csGT_' + labels_2[trImgCntr] + '_' + labels_sc_ocr[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_csRecon_' + labels_2[trImgCntr] + '_' + labels_o_ocr[trImgCntr] + '.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: styleModel.eval() genModel.eval() disModel.eval() with torch.no_grad(): valid_loss, infer_time, length_of_data = validation_synth_v4( iteration, styleModel, genModel, vggModel, ocrModel, disModel, recCriterion, ocrCriterion, valid_loader, converter, opt) styleModel.train() genModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train ImgRecon loss: {loss_avg_imgRecon.val():0.5f}, Train VGG-Per loss: {loss_avg_vgg_per.val():0.5f},\ Train VGG-Sty loss: {loss_avg_vgg_sty.val():0.5f}, Train OCR loss: {loss_avg_ocr.val():0.5f}, Valid Synth loss: {valid_loss[0]:0.5f}, \ Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \ Valid ImgRecon loss: {valid_loss[3]:0.5f}, Valid VGG-Per loss: {valid_loss[4]:0.5f}, \ Valid VGG-Sty loss: {valid_loss[5]:0.5f}, Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'), valid_loss[2].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'), valid_loss[3].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Per-Loss'), valid_loss[4].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Sty-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'), valid_loss[6].item()) print(loss_log) loss_avg.reset() loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_vgg_per.reset() loss_avg_vgg_sty.reset() loss_avg_ocr.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save( { 'styleModel': styleModel.state_dict(), 'genModel': genModel.state_dict(), 'disModel': disModel.state_dict() }, os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1