def train(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length+2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) log.write('-' * 80 + '\n') log.close() text_dataset = text_gen(opt) text_loader = torch.utils.data.DataLoader( text_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), pin_memory=True, drop_last=True) opt.num_class = len(converter.character) c_code_size = opt.latent cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size) ocrModel = ModelV1(opt) genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size) accumulate(g_ema, genModel, 0) # uCriterion = torch.nn.MSELoss() # sCriterion = torch.nn.MSELoss() # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': # ocrCriterion = torch.nn.L1Loss() # else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: print('Not implemented error') sys.exit() # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 cEncoder= torch.nn.DataParallel(cEncoder).to(device) cEncoder.train() genModel = torch.nn.DataParallel(genModel).to(device) g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() disEncModel = torch.nn.DataParallel(disEncModel).to(device) disEncModel.train() ocrModel = torch.nn.DataParallel(ocrModel).to(device) if opt.ocrFixed: if opt.Transformation == 'TPS': ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() else: ocrModel.train() g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( list(genModel.parameters())+list(cEncoder.parameters()), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disEncModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ocr_optimizer = optim.Adam( ocrModel.parameters(), lr=opt.lr, betas=(0.9, 0.99), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) # if opt.saved_gen_model !='' and opt.saved_gen_model !='None': # print(f'loading pretrained gen model from {opt.saved_gen_model}') # checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) # genModel.module.load_state_dict(checkpoint['g']) # g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disEncModel.load_state_dict(checkpoint['disEncModel']) ocrModel.load_state_dict(checkpoint['ocrModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_unsup = Averager() loss_avg_sup = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() loss_avg_ocr_sup = Averager() loss_avg_ocr_unsup = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 # loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 epsilon = 10e-50 # sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() ocr_scheduler.step() image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() image_input_tensors = image_input_tensors.to(device) gt_image_tensors = image_input_tensors[:opt.batch_size].detach() real_image_tensors = image_input_tensors[opt.batch_size:].detach() labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, False) requires_grad(genModel, False) requires_grad(disEncModel, True) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) # #unsupervised code prediction on generated image # u_pred_code = disEncModel(fake_img, mode='enc') # uCost = uCriterion(u_pred_code, z_code) # #supervised code prediction on gt image # s_pred_code = disEncModel(gt_image_tensors, mode='enc') # sCost = uCriterion(s_pred_code, gt_phoc_tensors) #Domain discriminator fake_pred = disEncModel(fake_img) real_pred = disEncModel(real_image_tensors) disCost = d_logistic_loss(real_pred, fake_pred) # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost loss_avg_dis.add(disCost) # loss_avg_sup.add(opt.beta*sCost) # loss_avg_unsup.add(opt.gamma_e * uCost) disEncModel.zero_grad() disCost.backward() dis_optimizer.step() d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: real_image_tensors.requires_grad = True real_pred = disEncModel(real_image_tensors) r1_loss = d_r1_loss(real_pred, real_image_tensors) disEncModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() log_r1_val.add(r1_loss) # Recognizer update if not opt.ocrFixed and not opt.zAlone: requires_grad(disEncModel, False) requires_grad(ocrModel, True) if 'CTC' in opt.Prediction: preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt) else: print("Not implemented error") sys.exit() ocrModel.zero_grad() ocrCost.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr_sup.add(ocrCost) else: loss_avg_ocr_sup.add(torch.tensor(0.0)) # [Word Generator] update # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:opt.batch_size] # real_image_tensors = image_input_tensors[opt.batch_size:] # labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, True) requires_grad(genModel, True) requires_grad(disEncModel, False) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) fake_pred = disEncModel(fake_img) disGenCost = g_nonsaturating_loss(fake_pred) if opt.zAlone: ocrCost = torch.tensor(0.0) else: #Compute OCR prediction (Reconstruction of content) # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device) if 'CTC' in opt.Prediction: preds_recon = ocrModel(fake_img, text_z_c, is_train=False) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c) else: print("Not implemented error") sys.exit() genModel.zero_grad() cEncoder.zero_grad() gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0] loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0] loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) if opt.grad_balance: gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0] a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR)) if a is None: print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) if a>1000 or a<0.0001: print(a) ocrCost = a.detach() * ocrCost gen_enc_cost = disGenCost + ocrCost gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0] loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) with torch.no_grad(): gen_enc_cost.backward() else: gen_enc_cost.backward() loss_avg_gen.add(disGenCost) loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost) optimizer.step() g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink) # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:path_batch_size] # labels_gt = labels[:path_batch_size] text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length) # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length) decay = 0.01 path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - mean_path_length_orig).pow(2).mean() mean_path_length = mean_path_length_orig.detach().item() genModel.zero_grad() cEncoder.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() # mean_path_length_avg = ( # reduce_sum(mean_path_length).item() / get_world_size() # ) #commented above for multi-gpu , non-distributed setting mean_path_length_avg = mean_path_length accumulate(g_ema, genModel, accum) log_avg_path_loss_val.add(path_loss) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #generate paired content with similar style labels_z_c_1 = iter(text_loader).next() labels_z_c_2 = iter(text_loader).next() text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length) text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length) z_c_code_1 = cEncoder(text_z_c_1) z_c_code_2 = cEncoder(text_z_c_2) style_c1_s1 = [] style_c2_s1 = [] style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s1.append(style_s1[0]*z_c_code_1) style_c2_s1.append(style_s1[0]*z_c_code_2) if len(style_s1)>1: style_c1_s1.append(style_s1[1]*z_c_code_1) style_c2_s1.append(style_s1[1]*z_c_code_2) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s2 = [] style_c1_s2.append(noise_style[0]*z_c_code_1) if len(noise_style)>1: style_c1_s2.append(noise_style[1]*z_c_code_1) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style_c1_s1[0][:,:opt.latent]) if len(style_c1_s1)>1: newstyle.append(style_c1_s1[1][:,:opt.latent]) style_c1_s1 = newstyle style_c2_s1 = newstyle style_c1_s2 = newstyle fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent) fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent) fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent) if not opt.zAlone: #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(gt_image_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt = converter.decode(preds_index.data, preds_size.data) else: print("Not implemented error") sys.exit() else: preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0] preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(opt.batch_size): try: save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png')) if not opt.zAlone: save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png')) save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png')) if trImgCntr<gt_image_tensors.shape[0]: save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) print(loss_log) loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_ocr_unsup.reset() loss_avg_ocr_sup.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ 'cEncoder':cEncoder.state_dict(), 'genModel':genModel.state_dict(), 'g_ema':g_ema.state_dict(), 'ocrModel':ocrModel.state_dict(), 'disEncModel':disEncModel.state_dict(), 'optimizer':optimizer.state_dict(), 'ocr_optimizer':ocr_optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def train(opt): lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(train_dataset_log) print('-' * 80) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) # styleModel = StyleTensorEncoder(input_dim=opt.input_channel) # genModel = AdaIN_Tensor_WordGenerator(opt) # disModel = MsImageDisV2(opt) # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none') # mixModel = Mixer(opt,nblk=3, dim=opt.latent) genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) ocrModel = ModelV1(opt).to(device) accumulate(g_ema, genModel, 0) # # weight initialization # for currModel in [styleModel, mixModel]: # for name, param in currModel.named_parameters(): # if 'localization_fc2' in name: # print(f'Skip {name} as it is already initialized') # continue # try: # if 'bias' in name: # init.constant_(param, 0.0) # elif 'weight' in name: # init.kaiming_normal_(param) # except Exception as e: # for batchnorm. # if 'weight' in name: # param.data.fill_(1) # continue if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': ocrCriterion = torch.nn.L1Loss() else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # vggRecCriterion = torch.nn.L1Loss() # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) if opt.distributed: genModel = torch.nn.parallel.DistributedDataParallel( genModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) disModel = torch.nn.parallel.DistributedDataParallel( disModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) ocrModel = torch.nn.parallel.DistributedDataParallel( ocrModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False ) # styleModel = torch.nn.DataParallel(styleModel).to(device) # styleModel.train() # mixModel = torch.nn.DataParallel(mixModel).to(device) # mixModel.train() # genModel = torch.nn.DataParallel(genModel).to(device) # g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() # disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() # vggModel = torch.nn.DataParallel(vggModel).to(device) # vggModel.eval() # ocrModel = torch.nn.DataParallel(ocrModel).to(device) # if opt.distributed: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() # # ocrModel.module.SequenceModeling.eval() # ocrModel.module.Prediction.eval() # else: # ocrModel.Transformation.eval() # ocrModel.FeatureExtraction.eval() # ocrModel.AdaptiveAvgPool.eval() # # ocrModel.SequenceModeling.eval() # ocrModel.Prediction.eval() ocrModel.eval() if opt.distributed: g_module = genModel.module d_module = disModel.module else: g_module = genModel d_module = disModel g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( genModel.parameters(), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': if not opt.distributed: ocrModel = torch.nn.DataParallel(ocrModel) print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) #temporary fix if not opt.distributed: ocrModel = ocrModel.module if opt.saved_gen_model !='' and opt.saved_gen_model !='None': print(f'loading pretrained gen model from {opt.saved_gen_model}') checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) genModel.module.load_state_dict(checkpoint['g']) g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disModel.load_state_dict(checkpoint['disModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg = Averager() loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_vgg_per = Averager() loss_avg_vgg_sty = Averager() loss_avg_ocr = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, False) # requires_grad(styleModel, False) # requires_grad(mixModel, False) requires_grad(disModel, True) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device) # scInput = mixModel(style,text_2) if 'CTC' in opt.Prediction: images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent) else: images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent) #Domain discriminator: Dis update if opt.augment: image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p) images_recon_2, _ = augment(images_recon_2, ada_aug_p) else: image_gt_tensors_aug = image_gt_tensors fake_pred = disModel(images_recon_2) real_pred = disModel(image_gt_tensors_aug) disCost = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = disCost*opt.disWeight loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_avg_dis.add(disCost) disModel.zero_grad() disCost.backward() dis_optimizer.step() if opt.augment and opt.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device ) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > opt.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: image_gt_tensors.requires_grad = True image_input_tensors.requires_grad = True cat_tensor = image_gt_tensors real_pred = disModel(cat_tensor) r1_loss = d_r1_loss(real_pred, cat_tensor) disModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() loss_dict["r1"] = r1_loss # #[Style Encoder] + [Word Generator] update image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, True) # requires_grad(styleModel, True) # requires_grad(mixModel, True) requires_grad(disModel, False) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent) style = mixing_noise(batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, _ = genModel(style, text_2) else: images_recon_2, _ = genModel(style, text_2[:,1:-1]) if opt.augment: images_recon_2, _ = augment(images_recon_2, ada_aug_p) fake_pred = disModel(images_recon_2) disGenCost = g_nonsaturating_loss(fake_pred) loss_dict["g"] = disGenCost # # #Adversarial loss # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1)) # #Input reconstruction loss # recCost = recCriterion(images_recon_2,image_gt_tensors) # #vgg loss # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2) #ocr loss text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss) preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss) ocrCost = ocrCriterion(preds_recon, preds_gt) else: if 'CTC' in opt.Prediction: preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False) # preds_o = preds_recon[:, :text_1.shape[1], :] preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2) #predict ocr recognition on generated images # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size) _, preds_recon_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) # preds_s = preds_s[:, :text_1.shape[1] - 1, :] preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :] preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data) else: preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1)) #predict ocr recognition on generated images _, preds_o_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_o_index, length_for_pred) for idx, pred in enumerate(labels_o_ocr): pred_EOS = pred.find('[s]') labels_o_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index, length_for_pred) for idx, pred in enumerate(labels_s_ocr): pred_EOS = pred.find('[s]') labels_s_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred) for idx, pred in enumerate(labels_sc_ocr): pred_EOS = pred.find('[s]') labels_sc_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # cost = opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost cost = opt.disWeight*disGenCost + opt.ocrWeight*ocrCost # styleModel.zero_grad() genModel.zero_grad() # mixModel.zero_grad() disModel.zero_grad() # vggModel.zero_grad() ocrModel.zero_grad() cost.backward() optimizer.step() loss_avg.add(cost) g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) path_batch_size = max(1, batch_size // opt.path_batch_shrink) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True) style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True) else: images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( images_recon_2, latents, mean_path_length ) genModel.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() #Individual losses loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(torch.tensor(0.0)) loss_avg_vgg_per.add(torch.tensor(0.0)) loss_avg_vgg_sty.add(torch.tensor(0.0)) loss_avg_ocr.add(opt.ocrWeight*ocrCost) log_r1_val.add(loss_reduced["path"]) log_avg_path_loss_val.add(loss_reduced["path"]) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: # pbar.set_description( # ( # f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " # f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " # f"augment: {ada_aug_p:.4f}" # ) # ) if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # if cntr % 100 == 0: # with torch.no_grad(): # g_ema.eval() # sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]]) # utils.save_image( # sample, # os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"), # nrow=int(opt.n_sample ** 0.5), # normalize=True, # range=(-1, 1), # ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images curr_batch_size = style[0].shape[0] images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent) os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png')) else: save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # styleModel.eval() genModel.eval() g_ema.eval() # mixModel.eval() disModel.eval() with torch.no_grad(): valid_loss, infer_time, length_of_data = validation_synth_v6( iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt) # styleModel.train() genModel.train() # mixModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train OCR loss: {loss_avg_ocr.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Valid Synth loss: {valid_loss[0]:0.5f}, \ Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \ Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item()) print(loss_log) loss_avg.reset() loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_vgg_per.reset() loss_avg_vgg_sty.reset() loss_avg_ocr.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ # 'styleModel':styleModel.state_dict(), # 'mixModel':mixModel.state_dict(), 'genModel':g_module.state_dict(), 'g_ema':g_ema.state_dict(), 'disModel':d_module.state_dict(), 'optimizer':optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def test(opt): lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier) g_ema = torch.nn.DataParallel(g_ema).to(device) g_ema.eval() print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) ## Loading pre-trained files if opt.modelFolderFlag: if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))) > 0: opt.saved_synth_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))[-1] if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) g_ema.load_state_dict(checkpoint['g_ema'], strict=False) # pdb.set_trace() if opt.truncation < 1: with torch.no_grad(): mean_latent = g_ema.module.mean_latent_content(opt.truncation_mean) else: mean_latent = None cntr = 0 for i, (image_input_tensors, image_gt_tensors, labels_1, labels_2) in enumerate(valid_loader): print(i, len(valid_loader)) image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode( labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator if opt.fixedStyleBatch: fixstyle = [] # pdb.set_trace() style = mixing_noise(1, opt.latent, opt.mixing, device) fixstyle.append(style[0].repeat(opt.batch_size, 1)) if len(style) > 1: fixstyle.append(style[1].repeat(opt.batch_size, 1)) style = fixstyle else: style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, _ = g_ema(style, text_2, input_is_latent=opt.input_latent, inject_index=5, truncation=opt.truncation, truncation_latent=mean_latent, randomize_noise=False) else: images_recon_2, _ = g_ema(style, text_2[:, 1:-1], input_is_latent=opt.input_latent, inject_index=5, truncation=opt.truncation, truncation_latent=mean_latent, randomize_noise=False) # os.makedirs(os.path.join(opt.valDir,str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image( tensor2im(image_input_tensors[trImgCntr].detach()), os.path.join( opt.valDir, str(cntr) + '_' + str(trImgCntr) + '_sInput_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(image_gt_tensors[trImgCntr].detach()), os.path.join( opt.valDir, str(cntr) + '_' + str(trImgCntr) + '_csGT_' + labels_2[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.valDir, str(cntr) + '_' + str(trImgCntr) + '_csRecon_' + labels_2[trImgCntr] + '.png')) except: print('Warning while saving training image') cntr += 1
def train(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignPHOCCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset = LmdbStylePHOCDataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size * 2, #*2 to sample different images from training encoder and discriminator real images shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) valid_dataset = LmdbStylePHOCDataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size * 2, #*2 to sample different images from training encoder and discriminator real images shuffle= False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) log.write('-' * 80 + '\n') log.close() phoc_dataset = phoc_gen(opt) phoc_loader = torch.utils.data.DataLoader(phoc_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), pin_memory=True, drop_last=True) opt.num_class = len(converter.character) if opt.zAlone: genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) else: genModel = styleGANGen(opt.size, opt.latent + phoc_dataset.phoc_size, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) g_ema = styleGANGen(opt.size, opt.latent + phoc_dataset.phoc_size, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=phoc_dataset.phoc_size) accumulate(g_ema, genModel, 0) uCriterion = torch.nn.MSELoss() sCriterion = torch.nn.MSELoss() genModel = torch.nn.DataParallel(genModel).to(device) g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() disEncModel = torch.nn.DataParallel(disEncModel).to(device) disEncModel.train() g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( genModel.parameters(), lr=opt.lr * g_reg_ratio, betas=(0**g_reg_ratio, 0.99**g_reg_ratio), ) dis_optimizer = optim.Adam( disEncModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0**d_reg_ratio, 0.99**d_reg_ratio), ) ## Loading pre-trained files if opt.modelFolderFlag: if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))) > 0: opt.saved_synth_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))[-1] # if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': # print(f'loading pretrained ocr model from {opt.saved_ocr_model}') # checkpoint = torch.load(opt.saved_ocr_model) # ocrModel.load_state_dict(checkpoint) # if opt.saved_gen_model !='' and opt.saved_gen_model !='None': # print(f'loading pretrained gen model from {opt.saved_gen_model}') # checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) # genModel.module.load_state_dict(checkpoint['g']) # g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disEncModel.load_state_dict(checkpoint['disEncModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_unsup = Averager() loss_avg_sup = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int( opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer, opt) dis_scheduler = get_scheduler(dis_optimizer, opt) start_time = time.time() iteration = start_iter cntr = 0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 # loss_dict = {} accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 # sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while (True): # print(cntr) # train part if opt.lr_policy != "None": scheduler.step() dis_scheduler.step() image_input_tensors, _, labels_1, _, phoc_1, _ = iter( train_loader).next() z_code, z_labels = iter(phoc_loader).next() image_input_tensors = image_input_tensors.to(device) gt_image_tensors = image_input_tensors[:opt.batch_size] real_image_tensors = image_input_tensors[opt.batch_size:] phoc_1 = phoc_1.to(device) gt_phoc_tensors = phoc_1[:opt.batch_size] labels_1 = labels_1[:opt.batch_size] z_code = z_code.to(device) requires_grad(genModel, False) # requires_grad(styleModel, False) # requires_grad(mixModel, False) requires_grad(disEncModel, True) text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) style = mixing_noise(z_code, opt.batch_size, opt.latent, opt.mixing, device) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:, :opt.latent]) if len(style) > 1: newstyle.append(style[1][:, :opt.latent]) style = newstyle fake_img, _ = genModel(style, input_is_latent=opt.input_latent) #unsupervised code prediction on generated image u_pred_code = disEncModel(fake_img, mode='enc') uCost = uCriterion(u_pred_code, z_code) #supervised code prediction on gt image s_pred_code = disEncModel(gt_image_tensors, mode='enc') sCost = uCriterion(s_pred_code, gt_phoc_tensors) #Domain discriminator fake_pred = disEncModel(fake_img) real_pred = disEncModel(real_image_tensors) disCost = d_logistic_loss(real_pred, fake_pred) dis_enc_cost = disCost + opt.gamma_e * uCost + opt.beta * sCost loss_avg_dis.add(disCost) loss_avg_sup.add(opt.beta * sCost) loss_avg_unsup.add(opt.gamma_e * uCost) disEncModel.zero_grad() dis_enc_cost.backward() dis_optimizer.step() d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: real_image_tensors.requires_grad = True real_pred = disEncModel(real_image_tensors) r1_loss = d_r1_loss(real_pred, real_image_tensors) disEncModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() # loss_dict["r1"] = r1_loss # [Word Generator] update image_input_tensors, _, labels_1, _, phoc_1, _ = iter( train_loader).next() z_code, z_labels = iter(phoc_loader).next() image_input_tensors = image_input_tensors.to(device) gt_image_tensors = image_input_tensors[:opt.batch_size] real_image_tensors = image_input_tensors[opt.batch_size:] phoc_1 = phoc_1.to(device) gt_phoc_tensors = phoc_1[:opt.batch_size] labels_1 = labels_1[:opt.batch_size] z_code = z_code.to(device) requires_grad(genModel, True) requires_grad(disEncModel, False) text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) style = mixing_noise(z_code, opt.batch_size, opt.latent, opt.mixing, device) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:, :opt.latent]) if len(style) > 1: newstyle.append(style[1][:, :opt.latent]) style = newstyle fake_img, _ = genModel(style, input_is_latent=opt.input_latent) #unsupervised code prediction on generated image u_pred_code = disEncModel(fake_img, mode='enc') uCost = uCriterion(u_pred_code, z_code) fake_pred = disEncModel(fake_img) disGenCost = g_nonsaturating_loss(fake_pred) gen_enc_cost = disGenCost + opt.gamma_g * uCost loss_avg_gen.add(disGenCost) loss_avg_unsup.add(opt.gamma_g * uCost) # loss_dict["g"] = disGenCost genModel.zero_grad() disEncModel.zero_grad() gen_enc_cost.backward() optimizer.step() g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: image_input_tensors, _, labels_1, _, phoc_1, _ = iter( train_loader).next() z_code, z_labels = iter(phoc_loader).next() image_input_tensors = image_input_tensors.to(device) path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink) gt_image_tensors = image_input_tensors[:path_batch_size] phoc_1 = phoc_1.to(device) gt_phoc_tensors = phoc_1[:path_batch_size] labels_1 = labels_1[:path_batch_size] z_code = z_code.to(device) z_code = z_code[:path_batch_size] z_labels = z_labels[:path_batch_size] text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) style = mixing_noise(z_code, path_batch_size, opt.latent, opt.mixing, device) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:, :opt.latent]) if len(style) > 1: newstyle.append(style[1][:, :opt.latent]) style = newstyle fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length) decay = 0.01 path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) mean_path_length_orig = mean_path_length + decay * ( path_lengths.mean() - mean_path_length) path_loss = (path_lengths - mean_path_length_orig).pow(2).mean() mean_path_length = mean_path_length_orig.detach().item() # path_loss, mean_path_length, path_lengths = g_path_regularize( # images_recon_2, latents, mean_path_length # ) genModel.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() # mean_path_length_avg = ( # reduce_sum(mean_path_length).item() / get_world_size() # ) #commented above for multi-gpu , non-distributed setting mean_path_length_avg = mean_path_length accumulate(g_ema, genModel, accum) log_r1_val.add(r1_loss) log_avg_path_loss_val.add(path_loss) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: if wandb and opt.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #generate paired content with similar style z_code_1, z_labels_1 = iter(phoc_loader).next() z_code_2, z_labels_2 = iter(phoc_loader).next() z_code_1 = z_code_1.to(device) z_code_2 = z_code_2.to(device) style_1 = mixing_noise(z_code_1, opt.batch_size, opt.latent, opt.mixing, device) style_2 = [] style_2.append( torch.cat((style_1[0][:, :opt.latent], z_code_2), dim=1)) if len(style_1) > 1: style_2.append( torch.cat((style_1[1][:, :opt.latent], z_code_2), dim=1)) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style_1[0][:, :opt.latent]) if len(style_1) > 1: newstyle.append(style_1[1][:, :opt.latent]) style_1 = newstyle style_2 = newstyle fake_img_1, _ = g_ema(style_1, input_is_latent=opt.input_latent) fake_img_2, _ = g_ema(style_2, input_is_latent=opt.input_latent) os.makedirs(os.path.join(opt.trainDir, str(iteration)), exist_ok=True) for trImgCntr in range(opt.batch_size): try: save_image( tensor2im(fake_img_1[trImgCntr].detach()), os.path.join( opt.trainDir, str(iteration), str(trImgCntr) + '_pair1_' + z_labels_1[trImgCntr] + '.png')) save_image( tensor2im(fake_img_2[trImgCntr].detach()), os.path.join( opt.trainDir, str(iteration), str(trImgCntr) + '_pair2_' + z_labels_2[trImgCntr] + '.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train UnSup loss: {loss_avg_unsup.val():0.5f}, Train Sup loss: {loss_avg_sup.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir, 'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-UnSup-Loss'), loss_avg_unsup.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-Sup-Loss'), loss_avg_sup.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot( os.path.join(opt.plotDir, 'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-ada_aug_p'), log_ada_aug_p.val().item()) print(loss_log) loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_unsup.reset() loss_avg_sup.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save( { 'genModel': genModel.state_dict(), 'g_ema': g_ema.state_dict(), 'disEncModel': disEncModel.state_dict(), 'optimizer': optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict() }, os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr += 1
def test(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length+2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 # AlignCollate_valid = AlignPairImgCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) AlignCollate_valid = AlignPairImgCollate_Test(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = LmdbTestStyleContentDataset(root=opt.test_data, opt=opt, dataMode=opt.realVaData) test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=False, # 'True' to check training progress with validation function. sampler=test_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False) print('-' * 80) AlignCollate_text = AlignSynthTextCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) text_dataset = text_gen_synth(opt) text_data_sampler = data_sampler(text_dataset, shuffle=True, distributed=False) text_loader = torch.utils.data.DataLoader( text_dataset, batch_size=opt.batch_size, shuffle=False, sampler=text_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_text, drop_last=True) opt.num_class = len(converter.character) text_loader = sample_data(text_loader, text_data_sampler, False) c_code_size = opt.latent if opt.cEncode == 'mlp': cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size).to(device) elif opt.cEncode == 'cnn': if opt.contentNorm == 'in': cEncoder = ResNet_StyleExtractor_WIN(1, opt.latent).to(device) else: cEncoder = ResNet_StyleExtractor(1, opt.latent).to(device) if opt.styleNorm == 'in': styleModel = ResNet_StyleExtractor_WIN(opt.input_channel, opt.latent).to(device) else: styleModel = ResNet_StyleExtractor(opt.input_channel, opt.latent).to(device) ocrModel = ModelV1(opt).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, content_dim=c_code_size, channel_multiplier=opt.channel_multiplier).to(device) g_ema.eval() bestModelError=1e5 ## Loading pre-trained files print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage) ocrModel.load_state_dict(checkpoint) print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model, map_location=lambda storage, loc: storage) cEncoder.load_state_dict(checkpoint['cEncoder']) styleModel.load_state_dict(checkpoint['styleModel']) g_ema.load_state_dict(checkpoint['g_ema']) iCntr=0 evalCntr=0 fCntr=0 valMSE=0.0 valSSIM=0.0 valPSNR=0.0 c1_s1_input_correct=0.0 c2_s1_gen_correct=0.0 c1_s1_input_ed_correct=0.0 c2_s1_gen_ed_correct=0.0 ims, txts = [], [] for vCntr, (image_input_tensors, image_output_tensors, labels_gt, labels_z_c, labelSynthImg, synth_z_c, input_1_shape, input_2_shape) in enumerate(valid_loader): print(vCntr) if opt.debugFlag and vCntr >10: break image_input_tensors = image_input_tensors.to(device) image_output_tensors = image_output_tensors.to(device) if opt.realVaData and opt.outPairFile=="": # pdb.set_trace() labels_z_c, synth_z_c = next(text_loader) labelSynthImg = labelSynthImg.to(device) synth_z_c = synth_z_c.to(device) synth_z_c = synth_z_c[:labelSynthImg.shape[0]] text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) # print(labels_z_c) cEncoder.eval() styleModel.eval() g_ema.eval() with torch.no_grad(): if opt.cEncode == 'mlp': z_c_code = cEncoder(text_z_c) z_gt_code = cEncoder(text_gt) elif opt.cEncode == 'cnn': z_c_code = cEncoder(synth_z_c) z_gt_code = cEncoder(labelSynthImg) style = styleModel(image_input_tensors) if opt.noiseConcat or opt.zAlone: style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device, style) else: style = [style] fake_img_c1_s1, _ = g_ema(style, z_gt_code, input_is_latent=opt.input_latent) fake_img_c2_s1, _ = g_ema(style, z_c_code, input_is_latent=opt.input_latent) currBatchSize = fake_img_c1_s1.shape[0] # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device) #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(fake_img_c1_s1, text_gt, is_train=False, inAct = opt.taskActivation) preds_size = torch.IntTensor([preds.size(1)] * currBatchSize) _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c2_s1, text_z_c, is_train=False, inAct = opt.taskActivation) preds_size = torch.IntTensor([preds.size(1)] * currBatchSize) _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(image_input_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(image_output_tensors, text_z_c, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_output_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_2 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(fake_img_c1_s1, text_gt[:, :-1], is_train=False, inAct = opt.taskActivation) # align with Attention.forward _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_fake_img_c1_s1): pred_EOS = pred.find('[s]') preds_str_fake_img_c1_s1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) preds = ocrModel(fake_img_c2_s1, text_z_c[:, :-1], is_train=False, inAct = opt.taskActivation) # align with Attention.forward _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_fake_img_c2_s1): pred_EOS = pred.find('[s]') preds_str_fake_img_c2_s1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) preds = ocrModel(image_input_tensors, text_gt[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_1): pred_EOS = pred.find('[s]') preds_str_gt_1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) preds = ocrModel(image_output_tensors, text_z_c[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_2 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_2): pred_EOS = pred.find('[s]') preds_str_gt_2[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pathPrefix = os.path.join(opt.valDir, opt.exp_iter) os.makedirs(os.path.join(pathPrefix), exist_ok=True) for trImgCntr in range(image_output_tensors.shape[0]): if opt.outPairFile!="": labelId = 'label-' + valid_loader.dataset.pairId[fCntr] + '-' + str(fCntr) else: labelId = 'label-%09d' % valid_loader.dataset.filtered_index_list[fCntr] #evaluations valRange = (-1,+1) # pdb.set_trace() # inpTensor = skimage.img_as_ubyte(resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))) # gtTensor = skimage.img_as_ubyte(resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))) # predTensor = skimage.img_as_ubyte(resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))) # inpTensor = resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True) # gtTensor = resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True) # predTensor = resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True) inpTensor = F.interpolate(image_input_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) gtTensor = F.interpolate(image_output_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) predTensor = F.interpolate(fake_img_c2_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) predGTTensor = F.interpolate(fake_img_c1_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])) # inpTensor = cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # gtTensor = cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # predTensor = cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # predGTTensor = cv2.resize(tensor2im(fake_img_c1_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])) # inpTensor = cv2.medianBlur(cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5) # gtTensor = cv2.medianBlur(cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5) # predTensor = cv2.medianBlur(cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5) # pdb.set_trace() if not(opt.realVaData): evalMSE = mean_squared_error(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255) # evalMSE = mean_squared_error(gtTensor/255, predTensor/255) evalSSIM = structural_similarity(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255, multichannel=True) # evalSSIM = structural_similarity(gtTensor/255, predTensor/255, multichannel=True) evalPSNR = peak_signal_noise_ratio(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255) # evalPSNR = peak_signal_noise_ratio(gtTensor/255, predTensor/255) # print(evalMSE,evalSSIM,evalPSNR) valMSE+=evalMSE valSSIM+=evalSSIM valPSNR+=evalPSNR #ocr accuracy # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): c1_s1_input_gt = labels_gt[trImgCntr] c1_s1_input_ocr = preds_str_gt_1[trImgCntr] c2_s1_gen_gt = labels_z_c[trImgCntr] c2_s1_gen_ocr = preds_str_fake_img_c2_s1[trImgCntr] if c1_s1_input_gt == c1_s1_input_ocr: c1_s1_input_correct += 1 if c2_s1_gen_gt == c2_s1_gen_ocr: c2_s1_gen_correct += 1 # ICDAR2019 Normalized Edit Distance if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0: c1_s1_input_ed_correct += 0 elif len(c1_s1_input_gt) > len(c1_s1_input_ocr): c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt) else: c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr) if len(c2_s1_gen_gt) == 0 or len(c2_s1_gen_ocr) == 0: c2_s1_gen_ed_correct += 0 elif len(c2_s1_gen_gt) > len(c2_s1_gen_ocr): c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_gt) else: c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_ocr) evalCntr+=1 #save generated images if opt.visFlag and iCntr>500: pass else: try: if iCntr == 0: # update website webpage = html.HTML(pathPrefix, 'Experiment name = %s' % 'Test') webpage.add_header('Testing iteration') iCntr += 1 img_path_c1_s1 = os.path.join(labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png') img_path_gt_1 = os.path.join(labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png') img_path_gt_2 = os.path.join(labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png') img_path_c2_s1 = os.path.join(labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png') ims.append([img_path_gt_1, img_path_c1_s1, img_path_gt_2, img_path_c2_s1]) content_c1_s1 = 'PSTYLE-1 '+'Text-1:' + labels_gt[trImgCntr]+' OCR:' + preds_str_fake_img_c1_s1[trImgCntr] content_gt_1 = 'OSTYLE-1 '+'GT:' + labels_gt[trImgCntr]+' OCR:' + preds_str_gt_1[trImgCntr] content_gt_2 = 'OSTYLE-1 '+'GT:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_gt_2[trImgCntr] content_c2_s1 = 'PSTYLE-1 '+'Text-2:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_fake_img_c2_s1[trImgCntr] txts.append([content_gt_1, content_c1_s1, content_gt_2, content_c2_s1]) utils.save_image(predGTTensor,os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'), predGTTensor) # pdb.set_trace() if not opt.zAlone: utils.save_image(inpTensor,os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) utils.save_image(gtTensor,os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) utils.save_image(predTensor,os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor) # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor) # imsave(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor) # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) except: print('Warning while saving validation image') fCntr += 1 webpage.add_images(ims, txts, width=256, realFlag=opt.realVaData) webpage.save() avg_valMSE = valMSE/float(evalCntr) avg_valSSIM = valSSIM/float(evalCntr) avg_valPSNR = valPSNR/float(evalCntr) avg_c1_s1_input_wer = c1_s1_input_correct/float(evalCntr) avg_c2_s1_gen_wer = c2_s1_gen_correct/float(evalCntr) avg_c1_s1_input_cer = c1_s1_input_ed_correct/float(evalCntr) avg_c2_s1_gen_cer = c2_s1_gen_ed_correct/float(evalCntr) # if not(opt.realVaData): with open(os.path.join(opt.exp_dir,opt.exp_name,'log_test.txt'), 'a') as log: # training loss and validation loss if opt.realVaData: loss_log = f'Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}' else: loss_log = f'Test MSE: {avg_valMSE:0.5f}, Test SSIM: {avg_valSSIM:0.5f}, Test PSNR: {avg_valPSNR:0.5f}, Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}' print(loss_log) log.write(loss_log+"\n")