def train(opt): if opt.use_tb: tb_dir = f'/home_hongdo/{getpass.getuser()}/tb/{opt.experiment_name}' print('tensorboard : ', tb_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) writer = SummaryWriter(log_dir=tb_dir) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) # log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') log = open(f'{save_dir}/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 # sekim for transfer learning model = Model(opt, 38) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) # sekim change last layer in_feature = model.module.Prediction.generator.in_features model.module.Prediction.attention_cell.rnn = nn.LSTMCell( 256 + opt.num_class, 256).to(device) model.module.Prediction.generator = nn.Linear(in_feature, opt.num_class).to(device) print(model.module.Prediction.generator) print("Model:") print(model) model.train() """ setup loss """ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: with open(f'{save_dir}/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print("-------------------------------------------------") print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'{save_dir}/{opt.experiment_name}/log_train.txt', 'a') as log: # with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() if opt.use_tb: writer.add_scalar('OCR_loss/train_loss', loss_avg.val(), i) writer.add_scalar('OCR_loss/validation_loss', valid_loss, i) current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') torch.save( model.state_dict(), f'{save_dir}/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') torch.save( model.state_dict(), f'{save_dir}/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') torch.save(model.state_dict(), f'{save_dir}/{opt.experiment_name}/iter_{i + 1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(self, opt): # src, tar dataloaders src_dataset, tar_dataset, valid_loader = self.dataloader(opt) src_dataset_size = src_dataset.total_data_size tar_dataset_size = tar_dataset.total_data_size train_size = max([src_dataset_size, tar_dataset_size]) iters_per_epoch = int(train_size / opt.batch_size) # Modify train size. Make sure both are of same size. # Modify training loop to continue giving src loss after tar is done. self.model.train() self.global_discriminator.train() self.local_discriminator.train() start_iter = 0 if opt.continue_model != '': self.load(opt.continue_model) print(" [*] Load SUCCESS") # loss averager cls_loss_avg = Averager() sim_loss_avg = Averager() loss_avg = Averager() # training loop print('training start !') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 # i = start_iter gamma = 0 omega = 1 epoch = 0 for step in range(start_iter, opt.num_iter + 1): epoch = step // iters_per_epoch if opt.decay_flag and step > (opt.num_iter // 2): self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) src_image, src_labels = src_dataset.get_batch() src_image = src_image.to(device) src_text, src_length = self.converter.encode( src_labels, batch_max_length=opt.batch_max_length) tar_image, tar_labels = tar_dataset.get_batch() tar_image = tar_image.to(device) tar_text, tar_length = self.converter.encode( tar_labels, batch_max_length=opt.batch_max_length) # Set gradient to zero... self.model.zero_grad() # Domain classifiers self.global_discriminator.zero_grad() self.local_discriminator.zero_grad() # Attention # align with Attention.forward src_preds, src_global_feature, src_local_feature = self.model( src_image, src_text[:, :-1]) # src_global_feature = self.model.visual_feature # src_local_feature = self.model.Prediction.context_history target = src_text[:, 1:] # without [GO] Symbol src_cls_loss = self.criterion( src_preds.view(-1, src_preds.shape[-1]), target.contiguous().view(-1)) src_global_feature = src_global_feature.view( src_global_feature.shape[0], -1) src_local_feature = src_local_feature.view( -1, src_local_feature.shape[-1]) tar_preds, tar_global_feature, tar_local_feature = self.model( tar_image, tar_text[:, :-1], is_train=False) # tar_global_feature = self.model.visual_feature # tar_local_feature = self.model.Prediction.context_history tar_global_feature = tar_global_feature.view( tar_global_feature.shape[0], -1) tar_local_feature = tar_local_feature.view( -1, tar_local_feature.shape[-1]) src_local_feature, tar_local_feature = filter_local_features( opt, src_local_feature, src_preds, tar_local_feature, tar_preds) # Add domain adaption elements # setup hyperparameter if step % 2000 == 0: p = float(step + start_iter) / opt.num_iter gamma = 2. / (1. + np.exp(-10 * p)) - 1 omega = 1 - 1. / (1. + np.exp(-10 * p)) self.global_discriminator.module.set_beta(gamma) self.local_discriminator.module.set_beta(gamma) src_d_img_score = self.global_discriminator(src_global_feature) src_d_inst_score = self.local_discriminator(src_local_feature) tar_d_img_score = self.global_discriminator(tar_global_feature) tar_d_inst_score = self.local_discriminator(tar_local_feature) src_d_img_loss = self.D_criterion( src_d_img_score, torch.zeros_like(src_d_img_score).to(device)) src_d_inst_loss = self.D_criterion( src_d_inst_score, torch.zeros_like(src_d_inst_score).to(device)) tar_d_img_loss = self.D_criterion( tar_d_img_score, torch.ones_like(tar_d_img_score).to(device)) tar_d_inst_loss = self.D_criterion( tar_d_inst_score, torch.ones_like(tar_d_inst_score).to(device)) d_img_loss = src_d_img_loss + tar_d_img_loss d_inst_loss = src_d_inst_loss + tar_d_inst_loss # Add domain loss loss = src_cls_loss.mean() + omega * (d_img_loss.mean() + d_inst_loss.mean()) loss_avg.add(loss) cls_loss_avg.add(src_cls_loss) sim_loss_avg.add(d_img_loss + d_inst_loss) # frcnn backward loss.backward() # clip_gradient(self.model, 10.) torch.nn.utils.clip_grad_norm_( self.model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) # frcnn optimizer update self.optimizer.step() # domain optimizer update self.d_inst_opt.step() self.d_image_opt.step() # validation part if step % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open( f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() cls_loss_avg.reset() sim_loss_avg.reset() self.model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( self.model, self.criterion, valid_loader, self.converter, opt) self.print_prediction_result(preds, labels, log) valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') self.model.train() self.global_discriminator.train() self.local_discriminator.train() # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth' self.save(opt, save_name) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' self.save(opt, save_name) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (step + 1) % 1e+5 == 0: save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth' self.save(opt, save_name)
def train(opt): print(opt.local_rank) opt.device = torch.device('cuda:{}'.format(opt.local_rank)) device = opt.device """ dataset preparation """ train_dataset = Batch_Balanced_Dataset(opt) valid_loader = train_dataset.getValDataloader() print('-' * 80) """ model configuration """ if 'CTC' == opt.Prediction: converter = CTCLabelConverter(opt.character, opt) elif 'Attn' == opt.Prediction: converter = AttnLabelConverter(opt.character, opt) elif 'CTC_Attn' == opt.Prediction: converter = CTCLabelConverter(opt.character, opt), AttnLabelConverter( opt.character, opt) opt.num_class = len(opt.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) model.to(opt.device) print(model) print('model input parameters', opt.rgb, opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue """ setup loss """ if 'CTC' == opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) elif 'Attn' == opt.Prediction: criterion = torch.nn.CrossEntropyLoss( ignore_index=0).to(device), torch.nn.MSELoss( reduction="sum").to(device) # ignore [GO] token = ignore index 0 elif 'CTC_Attn' == opt.Prediction: criterion = torch.nn.CTCLoss( zero_infinity=True).to(device), torch.nn.CrossEntropyLoss( ignore_index=0).to(device), torch.nn.MSELoss( reduction='sum').to(device) # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) if opt.local_rank == 0: print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.sgd: optimizer = optim.SGD(filtered_parameters, lr=opt.lr, momentum=0.9, weight_decay=opt.weight_decay) elif opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) if opt.local_rank == 0: print("Optimizer:") print(optimizer) if opt.sync_bn: model = apex.parallel.convert_syncbn_model(model) if opt.amp > 1: model, optimizer = amp.initialize(model, optimizer, opt_level="O" + str(opt.amp), keep_batchnorm_fp32=True, loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O" + str(opt.amp)) # data parallel for multi-GPU model = DDP(model) if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') try: model.load_state_dict( torch.load(opt.continue_model, map_location=torch.device( 'cuda', torch.cuda.current_device()))) except: traceback.print_exc() print(f'COPYING pretrained model from {opt.continue_model}') pretrained_dict = torch.load(opt.continue_model, map_location=torch.device( 'cuda', torch.cuda.current_device())) model_dict = model.state_dict() pretrained_dict2 = dict() for k, v in pretrained_dict.items(): if opt.Prediction == 'Attn': if 'module.Prediction_attn.' in k: k = k.replace('module.Prediction_attn.', 'module.Prediction.') if k in model_dict and model_dict[k].shape == v.shape: pretrained_dict2[k] = v model_dict.update(pretrained_dict2) model.load_state_dict(model_dict) model.train() """ final options """ with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' opt_log += str(model) print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter ct = opt.batch_mul model.zero_grad() dist.barrier() while (True): # train part start = time.time() image, labels, pos = train_dataset.sync_get_batch() end = time.time() data_t = end - start start = time.time() batch_size = image.size(0) if 'CTC' == opt.Prediction: text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True elif 'Attn' == opt.Prediction: text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) preds = model(image, text[:, :-1]) # align with Attention.forward preds_attn = preds[0] preds_alpha = preds[1] target = text[:, 1:] # without [GO] Symbol cost = criterion[0](preds_attn.view(-1, preds_attn.shape[-1]), target.contiguous().view(-1)) if opt.posreg_w > 0.001: cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[1]) print('attn_cost = ', cost, 'pos_cost = ', cost_pos * opt.posreg_w) cost += opt.posreg_w * cost_pos else: print('attn_cost = ', cost_attn) elif 'CTC_Attn' == opt.Prediction: text_ctc, length_ctc = converter[0].encode( labels, batch_max_length=opt.batch_max_length) text_attn, length_attn = converter[1].encode( labels, batch_max_length=opt.batch_max_length) """ ctc prediction and loss """ #should input text_attn here preds = model(image, text_attn[:, :-1]) preds_ctc = preds[0].log_softmax(2) preds_ctc_size = torch.IntTensor([preds_ctc.size(1)] * batch_size).to(device) preds_ctc = preds_ctc.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost_ctc = criterion[0](preds_ctc, text_ctc, preds_ctc_size, length_ctc) torch.backends.cudnn.enabled = True """ attention prediction and loss """ preds_attn = preds[1][0] # align with Attention.forward preds_alpha = preds[1][1] target = text_attn[:, 1:] # without [GO] Symbol cost_attn = criterion[1](preds_attn.view(-1, preds_attn.shape[-1]), target.contiguous().view(-1)) cost = opt.ctc_attn_loss_ratio * cost_ctc + ( 1 - opt.ctc_attn_loss_ratio) * cost_attn if opt.posreg_w > 0.001: cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[2]) cost += opt.posreg_w * cost_pos cost_ctc = reduce_tensor(cost_ctc) cost_attn = reduce_tensor(cost_attn) cost_pos = reduce_tensor(cost_pos) if opt.local_rank == 0: print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn, 'pos_cost = ', cost_pos * opt.posreg_w) else: cost_ctc = reduce_tensor(cost_ctc) cost_attn = reduce_tensor(cost_attn) if opt.local_rank == 0: print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn) cost /= opt.batch_mul if opt.amp > 0: with amp.scale_loss(cost, optimizer) as scaled_loss: scaled_loss.backward() else: cost.backward() """ https://github.com/davidlmorton/learning-rate-schedules/blob/master/increasing_batch_size_without_increasing_memory.ipynb """ ct -= 1 if ct == 0: if opt.amp > 0: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), opt.grad_clip) else: torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() model.zero_grad() ct = opt.batch_mul else: continue train_t = time.time() - start cost = reduce_tensor(cost) loss_avg.add(cost) if opt.local_rank == 0: print('iter', i, 'loss =', cost, ', data_t=', data_t, ',train_t=', train_t, ', batchsz=', opt.batch_mul * opt.batch_size) sys.stdout.flush() # validation part if (i > 0 and i % opt.valInterval == 0) or (i == 0 and opt.continue_model != ''): elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): if 'CTC_Attn' in opt.Prediction: # we only count for attention accuracy, because ctc is used to help attention valid_loss, current_accuracy_ctc, current_accuracy, current_norm_ED_ctc, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion[1], valid_loader, converter[1], opt, converter[0]) elif 'Attn' in opt.Prediction: valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion[0], valid_loader, converter, opt) else: valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:10], labels[:10]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' if 'CTC_Attn' in opt.Prediction: valid_log += f' ctc_accuracy: {current_accuracy_ctc:0.3f}, ctc_norm_ED: {current_norm_ED_ctc:0.2f}' current_accuracy = max(current_accuracy, current_accuracy_ctc) current_norm_ED = min(current_norm_ED, current_norm_ED_ctc) if opt.local_rank == 0: print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) torch.save( model, f'./saved_models/{opt.experiment_name}/best_accuracy.model' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) torch.save( model, f'./saved_models/{opt.experiment_name}/best_norm_ED.model' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per iter. if (i + 1) % opt.save_interval == 0 and opt.local_rank == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() if opt.prof_iter > 0 and i > opt.prof_iter: sys.exit() i += 1
def train(opt, log): """dataset preparation""" # train dataset. for convenience if opt.select_data == "label": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", ] elif opt.select_data == "synth": select_data = ["MJ", "ST"] elif opt.select_data == "synth_SA": select_data = ["MJ", "ST", "SA"] opt.batch_ratio = "0.4-0.4-0.2" # same ratio with SCATTER paper. elif opt.select_data == "mix": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", "MJ", "ST", ] elif opt.select_data == "mix_SA": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", "MJ", "ST", "SA", ] else: select_data = opt.select_data.split("-") # set batch_ratio for each data. if opt.batch_ratio: batch_ratio = opt.batch_ratio.split("-") else: batch_ratio = [round(1 / len(select_data), 3)] * len(select_data) train_loader = Batch_Balanced_Dataset(opt, opt.train_data, select_data, batch_ratio, log) if opt.semi != "None": select_data_unlabel = ["U1.Book32", "U2.TextVQA", "U3.STVQA"] batch_ratio_unlabel = [round(1 / len(select_data_unlabel), 3) ] * len(select_data_unlabel) dataset_root_unlabel = "data_CVPR2021/training/unlabel/" train_loader_unlabel_semi = Batch_Balanced_Dataset( opt, dataset_root_unlabel, select_data_unlabel, batch_ratio_unlabel, log, learn_type="semi", ) AlignCollate_valid = AlignCollate(opt, mode="test") valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt, mode="test") valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=False, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") """ model configuration """ if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.sos_token_index = converter.dict["[SOS]"] opt.eos_token_index = converter.dict["[EOS]"] opt.num_class = len(converter.character) model = Model(opt) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != "": fine_tuning_log = f"### loading pretrained model from {opt.saved_model}\n" if "MoCo" in opt.saved_model or "MoCo" in opt.self_pre: pretrained_state_dict_qk = torch.load(opt.saved_model) pretrained_state_dict = {} for name in pretrained_state_dict_qk: if "encoder_q" in name: rename = name.replace("encoder_q.", "") pretrained_state_dict[rename] = pretrained_state_dict_qk[ name] else: pretrained_state_dict = torch.load(opt.saved_model) for name, param in model.named_parameters(): try: param.data.copy_(pretrained_state_dict[name].data ) # load from pretrained model if opt.FT == "freeze": param.requires_grad = False # Freeze fine_tuning_log += f"pretrained layer (freezed): {name}\n" else: fine_tuning_log += f"pretrained layer: {name}\n" except: fine_tuning_log += f"non-pretrained layer: {name}\n" print(fine_tuning_log) log.write(fine_tuning_log + "\n") # print("Model:") # print(model) log.write(repr(model) + "\n") """ setup loss """ if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: # ignore [PAD] token criterion = torch.nn.CrossEntropyLoss( ignore_index=converter.dict["[PAD]"]).to(device) if "Pseudo" in opt.semi: criterion_SemiSL = PseudoLabelLoss(opt, converter, criterion) elif "MeanT" in opt.semi: criterion_SemiSL = MeanTeacherLoss(opt, student_for_init_teacher=model) # loss averager train_loss_avg = Averager() semi_loss_avg = Averager() # semi supervised loss avg # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print(f"Trainable params num: {sum(params_num)}") log.write(f"Trainable params num: {sum(params_num)}\n") # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optimizer == "sgd": optimizer = torch.optim.SGD( filtered_parameters, lr=opt.lr, momentum=opt.sgd_momentum, weight_decay=opt.sgd_weight_decay, ) elif opt.optimizer == "adadelta": optimizer = torch.optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) elif opt.optimizer == "adam": optimizer = torch.optim.Adam(filtered_parameters, lr=opt.lr) print("Optimizer:") print(optimizer) log.write(repr(optimizer) + "\n") if "super" in opt.schedule: if opt.optimizer == "sgd": cycle_momentum = True else: cycle_momentum = False scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=opt.lr, cycle_momentum=cycle_momentum, div_factor=20, final_div_factor=1000, total_steps=opt.num_iter, ) print("Scheduler:") print(scheduler) log.write(repr(scheduler) + "\n") """ final options """ # print(opt) opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): if str(k) == "character" and len(str(v)) > 500: opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n" else: opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) log.write(opt_log) log.close() """ start training """ start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass start_time = time.time() best_score = -1 # training loop for iteration in tqdm( range(start_iter + 1, opt.num_iter + 1), total=opt.num_iter, position=0, leave=True, ): if "MeanT" in opt.semi: image_tensors, image_tensors_ema, labels = train_loader.get_batch_ema( ) else: image_tensors, labels = train_loader.get_batch() image = image_tensors.to(device) labels_index, labels_length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) # default recognition loss part if "CTC" in opt.Prediction: preds = model(image) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) loss = criterion(preds_log_softmax, labels_index, preds_size, labels_length) else: preds = model(image, labels_index[:, :-1]) # align with Attention.forward target = labels_index[:, 1:] # without [SOS] Symbol loss = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) # semi supervised part (SemiSL) if "Pseudo" in opt.semi: image_unlabel, _ = train_loader_unlabel_semi.get_batch_two_images() image_unlabel = image_unlabel.to(device) loss_SemiSL = criterion_SemiSL(image_unlabel, model) loss = loss + loss_SemiSL semi_loss_avg.add(loss_SemiSL) elif "MeanT" in opt.semi: ( image_tensors_unlabel, image_tensors_unlabel_ema, ) = train_loader_unlabel_semi.get_batch_two_images() image_unlabel = image_tensors_unlabel.to(device) student_input = torch.cat([image, image_unlabel], dim=0) image_ema = image_tensors_ema.to(device) image_unlabel_ema = image_tensors_unlabel_ema.to(device) teacher_input = torch.cat([image_ema, image_unlabel_ema], dim=0) loss_SemiSL = criterion_SemiSL( student_input=student_input, student_logit=preds, student=model, teacher_input=teacher_input, iteration=iteration, ) loss = loss + loss_SemiSL semi_loss_avg.add(loss_SemiSL) model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() train_loss_avg.add(loss) if "super" in opt.schedule: scheduler.step() else: adjust_learning_rate(optimizer, iteration, opt) # validation part. # To see training progress, we also conduct validation when 'iteration == 1' if iteration % opt.val_interval == 0 or iteration == 1: # for validation log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_score, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # keep best score (accuracy or norm ED) model on valid dataset # Do not use this on test datasets. It would be an unfair comparison # (training should be done without referring test set). if current_score > best_score: best_score = current_score torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_score.pth", ) # validation log: loss, lr, score (accuracy or norm ED), time. lr = optimizer.param_groups[0]["lr"] elapsed_time = time.time() - start_time valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f}" valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n" valid_log += f'{"Current_score":17s}: {current_score:0.2f}, Current_lr: {lr:0.7f}\n' valid_log += f'{"Best_score":17s}: {best_score:0.2f}, Infer_time: {infer_time:0.1f}, Elapsed_time: {elapsed_time:0.1f}' # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[EOS]")] pred = pred[:pred.find("[EOS]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" valid_log = f"{valid_log}\n{predicted_result_log}" print(valid_log) log.write(valid_log + "\n") opt.writer.add_scalar("train/train_loss", float(f"{train_loss_avg.val():0.5f}"), iteration) opt.writer.add_scalar("train/semi_loss", float(f"{semi_loss_avg.val():0.5f}"), iteration) opt.writer.add_scalar("train/lr", float(f"{lr:0.7f}"), iteration) opt.writer.add_scalar("train/elapsed_time", float(f"{elapsed_time:0.1f}"), iteration) opt.writer.add_scalar("valid/valid_loss", float(f"{valid_loss:0.5f}"), iteration) opt.writer.add_scalar("valid/current_score", float(f"{current_score:0.2f}"), iteration) opt.writer.add_scalar("valid/best_score", float(f"{best_score:0.2f}"), iteration) train_loss_avg.reset() semi_loss_avg.reset() """ Evaluation at the end of training """ print("Start evaluation on benchmark testset") """ keep evaluation model and result logs """ os.makedirs(f"./result/{opt.exp_name}", exist_ok=True) os.makedirs(f"./evaluation_log", exist_ok=True) saved_best_model = f"./saved_models/{opt.exp_name}/best_score.pth" # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/') model.load_state_dict(torch.load(f"{saved_best_model}")) opt.eval_type = "benchmark" model.eval() with torch.no_grad(): total_accuracy, eval_data_list, accuracy_list = benchmark_all_eval( model, criterion, converter, opt) opt.writer.add_scalar("test/total_accuracy", float(f"{total_accuracy:0.2f}"), iteration) for eval_data, accuracy in zip(eval_data_list, accuracy_list): accuracy = float(accuracy) opt.writer.add_scalar(f"test/{eval_data}", float(f"{accuracy:0.2f}"), iteration) print( f'finished the experiment: {opt.exp_name}, "CUDA_VISIBLE_DEVICES" was {opt.CUDA_VISIBLE_DEVICES}' )
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') # train_dataset (image, label) train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) print("Model:") print(model) total_num, true_grad_num, false_grad_num = calculate_model_params(model) print("Total parameters: ", total_num) print("Number of parameters requires grad: ", true_grad_num) print("Number of parameters do not require grad: ", false_grad_num) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if isinstance(model, torch.nn.DataParallel): model = model.module # load pretrained model if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_pretrained_networks() elif opt.continue_train: model.load_checkpoint(opt.model_name) else: raise Exception('Something went wrong!') """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() log_dir = f'./saved_models/{opt.exp_name}' writer = SummaryWriter(log_dir) # """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) model.optimize_parameters() writer.add_scalar('train_loss', cost, iteration + 1) loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt, iteration) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('val_loss', valid_loss, iteration + 1) writer.add_scalar('accuracy', current_accuracy, iteration + 1) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy model.save_checkpoints(iteration, 'best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED model.save_checkpoints(iteration, 'best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: model.save_checkpoints(iteration + 1, opt.model_name) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, # 'True' to check training progress with validation function. shuffle=True, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) elif 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue """ setup loss """ if 'Transformer' in opt.SequenceModeling: criterion = transformer_loss elif 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: # ignore [GO] token = ignore index 0 criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) elif 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer = optim.Adam(filtered_parameters, betas=(0.9, 0.98), eps=1e-09) optimizer_schedule = ScheduledOptim(optimizer, opt.d_model, opt.n_warmup_steps) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") if opt.load_weights != '' and check_isfile(opt.load_weights): # load pretrained weights but ignore layers that don't match in size checkpoint = torch.load(opt.load_weights, pickle_module=pickle) if type(checkpoint) == dict: pretrain_dict = checkpoint['state_dict'] else: pretrain_dict = checkpoint model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) print("Loaded pretrained weights from '{}'".format(opt.load_weights)) del checkpoint torch.cuda.empty_cache() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') checkpoint = torch.load(opt.continue_model) print(checkpoint.keys()) model.load_state_dict(checkpoint['state_dict']) start_iter = checkpoint['step'] + 1 print('continue to train start_iter: ', start_iter) if 'optimizer' in checkpoint.keys(): optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() if 'best_accuracy' in checkpoint.keys(): best_accuracy = checkpoint['best_accuracy'] if 'best_norm_ED' in checkpoint.keys(): best_norm_ED = checkpoint['best_norm_ED'] del checkpoint torch.cuda.empty_cache() # data parallel for multi-GPU model = torch.nn.DataParallel(model).cuda() model.train() print("Model size:", count_num_param(model), 'M') if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer_schedule.n_current_steps = start_iter for i in tqdm(range(start_iter, opt.num_iter)): for p in model.parameters(): p.requires_grad = True cpu_images, cpu_texts = train_dataset.get_batch() image = cpu_images.cuda() if 'Transformer' in opt.SequenceModeling: text, length, text_pos = converter.encode(cpu_texts, opt.batch_max_length) elif 'CTC' in opt.Prediction: text, length = converter.encode(cpu_texts) else: text, length = converter.encode(cpu_texts, opt.batch_max_length) batch_size = image.size(0) if 'Transformer' in opt.SequenceModeling: preds = model(image, text, tgt_pos=text_pos) target = text[:, 1:] # without <s> Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) elif 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) else: preds = model(image, text) target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer_schedule.step_and_update_lr() elif 'Transformer' in opt.SequenceModeling: optimizer.step() else: # gradient clipping with 5 (Default) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() loss_avg.add(cost) # validation part if i > 0 and (i + 1) % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], gts[:5]): if 'Transformer' in opt.SequenceModeling: pred = pred[:pred.find('</s>')] gt = gt[:gt.find('</s>')] elif 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i+1}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy state_dict = model.module.state_dict() save_checkpoint( { 'best_accuracy': best_accuracy, 'state_dict': state_dict, }, False, f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED state_dict = model.module.state_dict() save_checkpoint( { 'best_norm_ED': best_norm_ED, 'state_dict': state_dict, }, False, f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) # torch.save( # model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1000 iter. if (i + 1) % 1000 == 0: state_dict = model.module.state_dict() optimizer_state_dict = optimizer.state_dict() save_checkpoint( { 'state_dict': state_dict, 'optimizer': optimizer_state_dict, 'step': i, 'best_accuracy': best_accuracy, 'best_norm_ED': best_norm_ED, }, False, f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
valid = matches > -1 mkpts0 = kpts0[valid] mkpts1 = kpts1[matches[valid]] mconf = conf[valid] viz_path = eval_output_dir / '{}_matches.{}'.format( str(i), opt.viz_extension) color = cm.jet(mconf) stem = pred['file_name'] text = [] out = make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, viz_path, stem, opt.show_keypoints, opt.fast_viz, opt.opencv_display, 'Matches') ### eval ### superglue.eval() homogrpahy_auc = validation(superglue, 'datasets/val.txt') epoch_loss /= len(train_loader) if not os.path.isdir("exp"): os.makedirs("exp") model_out_path = "exp/model_epoch_{}_{}.pth".format( epoch, homogrpahy_auc) torch.save(superglue, model_out_path) print( "Epoch [{}/{}] done. Epoch Loss {}. Checkpoint saved to {}".format( epoch, opt.epoch, epoch_loss, model_out_path))
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(self, opt): # Add custom dataset add cfgs from da-faster-rcnn # Make sure you change the imdb_name in factory.py """ Dummy format: args.src_dataset == '$YOUR_DATASET_NAME' args.src_imdb_name = '$YOUR_DATASET_NAME_2007_trainval' args.src_imdbval_name = '$YOUR_DATASET_NAME_2007_test' args.set_cfgs = [...] """ # src, tar dataloaders src_dataset, tar_dataset, valid_loader = self.dataloader(opt) src_dataset_size = src_dataset.total_data_size tar_dataset_size = tar_dataset.total_data_size train_size = max([src_dataset_size, tar_dataset_size]) self.model.train() start_iter = 0 if opt.continue_model != '': self.load(opt.continue_model) print(" [*] Load SUCCESS") # if opt.decay_flag and start_iter > (opt.num_iter // 2): # self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * ( # start_iter - opt.num_iter // 2) # self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * ( # start_iter - opt.num_iter // 2) # loss averager cls_loss_avg = Averager() sim_loss_avg = Averager() loss_avg = Averager() # training loop print('training start !') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 for step in range(start_iter, opt.num_iter + 1): src_image, src_labels = src_dataset.get_batch() src_image = src_image.to(device) src_text, src_length = self.converter.encode( src_labels, batch_max_length=opt.batch_max_length) tar_image, tar_labels = tar_dataset.get_batch() tar_image = tar_image.to(device) tar_text, tar_length = self.converter.encode( tar_labels, batch_max_length=opt.batch_max_length) # Set gradient to zero... self.model.zero_grad() # Attention # align with Attention.forward src_preds, src_global_feature, src_local_feature = self.model( src_image, src_text[:, :-1]) target = src_text[:, 1:] # without [GO] Symbol src_cls_loss = self.criterion( src_preds.view(-1, src_preds.shape[-1]), target.contiguous().view(-1)) src_local_feature = src_local_feature.view( -1, src_local_feature.shape[-1]) # TODO tar_preds, tar_global_feature, tar_local_feature = self.model( tar_image, tar_text[:, :-1], is_train=False) tar_local_feature = tar_local_feature.view( -1, tar_local_feature.shape[-1]) d_inst_loss = coral_loss(src_local_feature, src_preds, tar_local_feature, tar_preds) # Add domain loss loss = src_cls_loss.mean() + 0.1 * d_inst_loss.mean() loss_avg.add(loss) cls_loss_avg.add(src_cls_loss) sim_loss_avg.add(d_inst_loss) # frcnn backward loss.backward() torch.nn.utils.clip_grad_norm_( self.model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) # frcnn optimizer update self.optimizer.step() # validation part if step % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open( f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() cls_loss_avg.reset() sim_loss_avg.reset() self.model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( self.model, self.criterion, valid_loader, self.converter, opt) self.print_prediction_result(preds, labels, log) valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') self.model.train() # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth' self.save(opt, save_name) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' self.save(opt, save_name) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (step + 1) % 1e+5 == 0: save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth' self.save(opt, save_name)
def train(opt, AMP, WdB, ralph_path, train_data_path, train_data_list, test_data_path, test_data_list, experiment_name, train_batch_size, val_batch_size, workers, lr, valInterval, num_iter, wdbprj, continue_model='', finetune=''): HVD3P = pO.HVD or pO.DDP os.makedirs(f'./saved_models/{experiment_name}', exist_ok=True) # if OnceExecWorker and WdB: # wandb.init(project=wdbprj, name=experiment_name) # wandb.config.update(opt) # load supplied ralph with open(ralph_path, 'r') as f: ralph_train = json.load(f) print('[4] IN TRAIN; BEFORE MAKING DATASET') train_dataset = ds_load.myLoadDS(train_data_list, train_data_path, ralph=ralph_train) valid_dataset = ds_load.myLoadDS(test_data_list, test_data_path, ralph=ralph_train) # SAVE RALPH FOR LATER USE # with open(f'./saved_models/{experiment_name}/ralph.json', 'w+') as f: # json.dump(train_dataset.ralph, f) print('[5] DATASET DONE LOADING') if OnceExecWorker: print(pO) print('Alphabet :', len(train_dataset.alph), train_dataset.alph) for d in [train_dataset, valid_dataset]: print('Dataset Size :', len(d.fns)) # print('Max LbW : ',max(list(map(len,d.tlbls))) ) # print('#Chars : ',sum([len(x) for x in d.tlbls])) # print('Sample label :',d.tlbls[-1]) # print("Dataset :", sorted(list(map(len,d.tlbls))) ) print('-' * 80) if opt.num_gpu > 1: workers = workers * (1 if HVD3P else opt.num_gpu) if HVD3P: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=opt.world_size, rank=opt.rank) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=opt.world_size, rank=opt.rank) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True if not HVD3P else False, pin_memory=True, num_workers=int(workers), sampler=train_sampler if HVD3P else None, worker_init_fn=WrkSeeder, collate_fn=ds_load.SameTrCollate) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=val_batch_size, pin_memory=True, num_workers=int(workers), sampler=valid_sampler if HVD3P else None) model = OrigamiNet() model.apply(init_bn) # load finetune ckpt if finetune != '': model = load_finetune(model, finetune) model.train() if OnceExecWorker: import pprint [print(k, model.lreszs[k]) for k in sorted(model.lreszs.keys())] biparams = list( dict(filter(lambda kv: 'bias' in kv[0], model.named_parameters())).values()) nonbiparams = list( dict(filter(lambda kv: 'bias' not in kv[0], model.named_parameters())).values()) if not pO.DDP: model = model.to(device) else: model.cuda(opt.rank) optimizer = optim.Adam(model.parameters(), lr=lr) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=10**(-1 / 90000)) # if OnceExecWorker and WdB: # wandb.watch(model, log="all") if pO.HVD: hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) # optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.fp16) if pO.DDP and opt.rank != 0: random.seed() np.random.seed() # if AMP: # model, optimizer = amp.initialize(model, optimizer, opt_level = "O1") if pO.DP: model = torch.nn.DataParallel(model) elif pO.DDP: model = pDDP(model, device_ids=[opt.rank], output_device=opt.rank, find_unused_parameters=False) model_ema = ModelEma(model) if continue_model != '': if OnceExecWorker: print(f'loading pretrained model from {continue_model}') checkpoint = torch.load( continue_model, map_location=f'cuda:{opt.rank}' if HVD3P else None) model.load_state_dict(checkpoint['model'], strict=True) optimizer.load_state_dict(checkpoint['optimizer']) model_ema._load_checkpoint(continue_model, f'cuda:{opt.rank}' if HVD3P else None) criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True).to(device) converter = CTCLabelConverter(train_dataset.ralph.values()) if OnceExecWorker: with open(f'./saved_models/{experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' opt_log += gin.operative_config_str() opt_file.write(opt_log) # if WdB: # wandb.config.gin_str = gin.operative_config_str().splitlines() print(optimizer) print(opt_log) start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 best_CER = 1e+6 i = 0 gAcc = 1 epoch = 1 btReplay = False and AMP max_batch_replays = 3 if HVD3P: train_sampler.set_epoch(epoch) titer = iter(train_loader) while (True): start_time = time.time() model.zero_grad() train_loss = Metric(pO, 'train_loss') for j in trange(valInterval, leave=False, desc='Training'): # Load a batch try: image_tensors, labels, fnames = next(titer) except StopIteration: epoch += 1 if HVD3P: train_sampler.set_epoch(epoch) titer = iter(train_loader) image_tensors, labels, fnames = next(titer) # log filenames # fnames = [f'{i}___{fname}' for fname in fnames] # with open(f'./saved_models/{experiment_name}/filelog.txt', 'a+') as f: # f.write('\n'.join(fnames) + '\n') # Move to device image = image_tensors.to(device) text, length = converter.encode(labels) batch_size = image.size(0) replay_batch = True maxR = 3 while replay_batch and maxR > 0: maxR -= 1 # Forward pass preds = model(image, text).float() preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2).log_softmax(2) if i == 0 and OnceExecWorker: print('Model inp : ', image.dtype, image.size()) print('CTC inp : ', preds.dtype, preds.size(), preds_size[0]) # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size, length.to(device)).mean() / gAcc torch.backends.cudnn.enabled = True train_loss.update(cost) # cost tracking? # with open(f'./saved_models/{experiment_name}/steplog.txt', 'a+') as f: # f.write(f'Step {i} cost: {cost}\n') optimizer.zero_grad() default_optimizer_step = optimizer.step # added for batch replay # Backward and step if not AMP: cost.backward() replay_batch = False else: # with amp.scale_loss(cost, optimizer) as scaled_loss: # scaled_loss.backward() # if pO.HVD: optimizer.synchronize() # if optimizer.step is default_optimizer_step or not btReplay: # replay_batch = False # elif maxR>0: # optimizer.step() pass if btReplay: pass #amp._amp_state.loss_scalers[0]._loss_scale = mx_sc if (i + 1) % gAcc == 0: if pO.HVD and AMP: with optimizer.skip_synchronize(): optimizer.step() else: optimizer.step() model.zero_grad() model_ema.update(model, num_updates=i / 2) if (i + 1) % (gAcc * 2) == 0: lr_scheduler.step() i += 1 # validation part if True: elapsed_time = time.time() - start_time start_time = time.time() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, ted, bleu, preds, labels, infer_time = validation( model_ema.ema, criterion, valid_loader, converter, opt, pO) model.train() v_time = time.time() - start_time if OnceExecWorker: if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED checkpoint = { 'model': model.state_dict(), 'state_dict_ema': model_ema.ema.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save( checkpoint, f'./saved_models/{experiment_name}/best_norm_ED.pth') if ted < best_CER: best_CER = ted if current_accuracy > best_accuracy: best_accuracy = current_accuracy out = f'[{i}] Loss: {train_loss.avg:0.5f} time: ({elapsed_time:0.1f},{v_time:0.1f})' out += f' vloss: {valid_loss:0.3f}' out += f' CER: {ted:0.4f} NER: {current_norm_ED:0.4f} lr: {lr_scheduler.get_lr()[0]:0.5f}' out += f' bAcc: {best_accuracy:0.1f}, bNER: {best_norm_ED:0.4f}, bCER: {best_CER:0.4f}, B: {bleu*100:0.2f}' print(out) with open(f'./saved_models/{experiment_name}/log_train.txt', 'a') as log: log.write(out + '\n') # if WdB: # wandb.log({'lr': lr_scheduler.get_lr()[0], 'It':i, 'nED': current_norm_ED, 'B':bleu*100, # 'tloss':train_loss.avg, 'AnED': best_norm_ED, 'CER':ted, 'bestCER':best_CER, 'vloss':valid_loss}) if DEBUG: print( f'[!!!] Iteration check. Value of i: {i} | Value of num_iter: {num_iter}' ) # Change i == num_iter to i >= num_iter # Add num_iter > 0 condition if num_iter > 0 and i >= num_iter: print('end the training') #sys.exit() break
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_loss_fn = args.loss_fn config_file = 'config.yaml' config = load_config(config_file) data_root = config['PATH']['data_root'] labels = config['PARAMETERS']['labels'] root_path = config['PATH']['root'] model_dir = config['PATH']['model_path'] best_dir = config['PATH']['best_model_path'] data_class = config['PATH']['data_class'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) base_channel = int(config['PARAMETERS']['base_channels']) crop_size = int(config['PARAMETERS']['crop_size']) batch_size = int(config['PARAMETERS']['batch_size']) epochs = int(config['PARAMETERS']['epoch']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) ignore_idx = int(config['PARAMETERS']['ignore_index']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_data', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') # split dataset dir_ = os.path.join(data_root, data_class) data_content = train_split(dir_) # load training set and validation set train_set = data_loader(data_content=data_content, key='train', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(data_content=data_content, key='val', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) logger.info('Dataset loading finished!') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) net = init_U_Net(input_modalites, output_channels, base_channel) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if torch.cuda.device_count() > 1: logger.info('{} GPUs available.'.format(torch.cuda.device_count())) net = nn.DataParallel(net) net.to(device) if model_loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'CrossEntropy': criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'FocalLoss': criterion = FocalLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_CE': criterion = Dice_CE(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_FL': criterion = Dice_FL(labels=labels, ignore_index=ignore_idx) else: raise NotImplementedError() optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) net, optimizer = amp.initialize(net, optimizer, opt_level='O1') min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'BG_acc': [], 'NET_acc': [], 'ED_acc': [], 'ET_acc': [] } if is_resume: try: ckp_path = os.path.join(model_dir, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # setup to train mode net.train() running_loss = 0 dice_coeff_bg = 0 dice_coeff_net = 0 dice_coeff_ed = 0 dice_coeff_et = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) # loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # save the output at the begining of each epoch to visulize it if i == 0: in_images = images.detach().cpu().numpy()[:, 0, ...] in_segs = segs.detach().cpu().numpy() in_pred = outputs.detach().cpu().numpy() heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch + 1) running_loss += loss.detach().item() dice_score = dice_coe(outputs.detach().cpu(), segs.detach().cpu()) dice_coeff_bg += dice_score['BG'] dice_coeff_ed += dice_score['ED'] dice_coeff_et += dice_score['ET'] dice_coeff_net += dice_score['NET'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training (avg) accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: # validate net.eval() val_loss, val_acc = validation(net, val_set, criterion, device, batch_size) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['BG_acc'].append(dice_coeff_bg / nb_batches) train_info['NET_acc'].append(dice_coeff_net / nb_batches) train_info['ED_acc'].append(dice_coeff_ed / nb_batches) train_info['ET_acc'].append(dice_coeff_et / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) if min_loss > val_loss: min_loss = val_loss is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) logger.info('The average training loss for this epoch is {}'.format( running_loss / (np.ceil(n_train / batch_size)))) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy: {}'.format( val_loss, val_acc)) logger.info('The best validation loss till now is {}'.format(min_loss)) # save the training info every epoch logger.info('Writing the training info into file ...') with open(train_info_file, 'w') as fp: json.dump(train_info, fp) loss_plot(train_info_file, name=model_name) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break logger.info('finish training!')
def validation_part(best_accuracy, best_norm_ED, converter, criterion, i, loss_avg, model, opt, start_time, valid_loader): if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n')
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 # opt.select_data = opt.select_data.split('-')#[MJ,ST] # opt.batch_ratio = opt.batch_ratio.split('-')#[0.5,0.5] # train_dataset = Batch_Balanced_Dataset(opt) # log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) # train_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/train", # "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/traindata.mat",opt) # train_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr", # "/home/ldl/桌面/论文/文本识别/data/paddleocr/label/train.txt",6625*100) # train_dataset_chinese = TextRecognition(4068*75,opt.charalength,opt.chinesefile) # train_dataset_english = TextRecognition(4068*25,opt.charalength,opt.englishfile) # train_dataset = ConcatDataset([train_dataset_chinese,train_dataset_english]) train_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/gt') train_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/gt') train_dataset = ConcatDataset([train_dataset_xunfeichn,train_dataset_xunfeieng]) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid) # valid_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/test", # "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/testdata.mat",opt) # valid_dataset_chinese = TextRecognition(1001,opt.charalength,opt.chinesefile) # valid_dataset_english = TextRecognition(1001,opt.charalength,opt.englishfile) # valid_dataset = ConcatDataset([valid_dataset_chinese,valid_dataset_english]) valid_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/gt') valid_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/gt') valid_dataset = ConcatDataset([valid_dataset_xunfeichn,valid_dataset_xunfeieng]) # valid_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/images", # "/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/test.txt",6625, # split='jpg') valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid) # log.write(valid_dataset_log) print('-' * 80) # log.write('-' * 80 + '\n') # log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU if opt.num_gpu > 1: model = torch.nn.DataParallel(model).to(device) else: model.to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") # print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) print(f"Train dataset length {len(train_dataset)}") print(f"Train dataset length {len(valid_dataset)}") # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer # if opt.adam: # optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) # else: # optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) optimizer = optim.Adam(filtered_parameters,lr=0.0001) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 # if opt.saved_model != '': # try: # start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) # print(f'continue to train, start_iter: {start_iter}') # except: # pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter epoch = 0 train_iter_loader = iter(train_loader) while(True): # train part try: image_tensors, labels = train_iter_loader.next() # if len(labels)>80: # print(labels) print("{:4}".format(iteration),end='\r') except StopIteration: epoch += 1 print(f"epoch:{epoch}") # if epoch >= 1: # break # train_loader = torch.utils.data.DataLoader( # train_dataset, batch_size=opt.batch_size, # shuffle=True, # 'True' to check training progress with validation function. # num_workers=int(opt.workers), # collate_fn=AlignCollate_valid) train_iter_loader = iter(train_loader) continue image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * preds.size(0)) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) try: cost = criterion(preds, text, preds_size, length) except Exception: print(preds.shape,preds_size.shape) raise '' else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() print("validation") with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy >= best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED >= best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+4 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): os.makedirs(opt.log, exist_ok=True) writer = SummaryWriter(opt.log) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ ctc_converter = CTCLabelConverter(opt.character) attn_converter = AttnLabelConverter(opt.character) opt.num_class = len(attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) """ setup loss """ loss_avg = Averager() ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device) attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter pbar = tqdm(range(opt.num_iter)) for iteration in pbar: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) ctc_text, ctc_length = ctc_converter.encode( labels, batch_max_length=opt.batch_max_length) attn_text, attn_length = attn_converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds, refiner = model(image, attn_text[:, :-1]) refiner_size = torch.IntTensor([refiner.size(1)] * batch_size) refiner = refiner.log_softmax(2).permute(1, 0, 2) refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length) total_loss = opt.lambda_ctc * refiner_loss target = attn_text[:, 1:] # without [GO] Symbol for pred in preds: total_loss += opt.lambda_attn * attn_loss( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) model.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(total_loss) if loss_avg.val() <= 0.6: opt.grad_clip = 2 if loss_avg.val() <= 0.3: opt.grad_clip = 1 preds = (p.cpu() for p in preds) refiner = refiner.cpu() image = image.cpu() torch.cuda.empty_cache() writer.add_scalar('train_loss', loss_avg.val(), iteration) pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format( iteration, opt.num_iter, loss_avg.val())) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, attn_loss, valid_loader, attn_converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('Val_loss', valid_loss) pbar.set_description(loss_log) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' # print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' or 'Transformer' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' log.write(predicted_result_log + '\n') # save model per 1e+3 iter. if (iteration + 1) % 1e+3 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/SCATTER_STR.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def train(opt): """ 准备训练和验证的数据集 """ transform = transforms.Compose([ ToTensor(), ]) train_dataset = LmdbDataset(opt.train_data, opt=opt, transform=transform) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), ) valid_dataset = LmdbDataset(root=opt.valid_data, opt=opt, transform=transform) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), ) print('-' * 80) """ 模型的配置 """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # 权重初始化 for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue model = model.to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part for image_tensors, labels in train_loader: image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length ) # text: [index, index, ..., index], length: [10, 8] batch_size = image.size(0) if 'CTC' in opt.Prediction: # set xx = model(image, text) torch.Size([100, 63, 7]), xx.log_softmax(2)[0][0] = xx[0][0].log_softmax(-1) preds = model(image, text).log_softmax(2) # torch.Size([100, 63, 12]) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute( 1, 0, 2 ) # to use CTCLoss format # 100 * 63 * 7 -> 63 * 100 * 7 # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion( preds, text, preds_size, length ) # preds.shape: torch.Size([63, 100, 7]), 其中63是序列特征,100是batch_size, 7是输出类别数量; text.shape: torch.Size([1000]), 表示1000个字符 # preds_size:[63, 63, ..., 63] 100,数组中的63表示序列的长度 length: [10, 10, ..., 10] 100,数组中的每个10表示每个标签的长度,意思就是每一张图片有10个字符 torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open( f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #import ipdb;ipdb.set_trace() train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") #print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf-8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) #import ipdb;ipdb.set_trace() if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf-8") as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') #pred = pred.encode('utf-8') #gt = gt.encode('utf-8') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt, show_number=2, amp=False): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8") AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=min(32, opt.batch_size), shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), prefetch_factor=512, collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) if opt.saved_model != '': pretrained_dict = torch.load(opt.saved_model) if opt.new_prediction: model.Prediction = nn.Linear( model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight'])) model = torch.nn.DataParallel(model).to(device) print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(pretrained_dict, strict=False) else: model.load_state_dict(pretrained_dict) if opt.new_prediction: model.module.Prediction = nn.Linear( model.module.SequenceModeling_output, opt.num_class) for name, param in model.module.Prediction.named_parameters(): if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) model = model.to(device) else: # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue model = torch.nn.DataParallel(model).to(device) model.train() print("Model:") print(model) count_parameters(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # freeze some layers try: if opt.freeze_FeatureFxtraction: for param in model.module.FeatureExtraction.parameters(): param.requires_grad = False if opt.freeze_SequenceModeling: for param in model.module.SequenceModeling.parameters(): param.requires_grad = False except: pass # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim == 'adam': #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer = optim.Adam(filtered_parameters) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter scaler = GradScaler() t1 = time.time() while (True): # train part optimizer.zero_grad(set_to_none=True) if amp: with autocast(): image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) scaler.scale(cost).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) scaler.step(optimizer) scaler.update() else: image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() loss_avg.add(cost) # validation part if (i % opt.valInterval == 0) and (i != 0): print('training time: ', time.time() - t1) t1 = time.time() elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\ infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' #show_number = min(show_number, len(labels)) start = random.randint(0, len(labels) - show_number) for gt, pred, confidence in zip( labels[start:start + show_number], preds[start:start + show_number], confidence_score[start:start + show_number]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') print('validation time: ', time.time() - t1) t1 = time.time() # save model per 1e+4 iter. if (i + 1) % 1e+4 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt, tb): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] print("~~~~~~~~~~~~Gradient Descent~~~~~~~~~~~~~") #print(model.parameters()) #print(model.) for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Filtered parameters for gradient descent: \n', len(filtered_parameters)) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward print(preds[0][0]) target = text[:, 1:] # without [GO] Symbol print(target[0]) cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' tb.add_scalar('Training Loss vs Iteration', loss_avg.val(), i) # Record to Tensorboard tb.add_scalar('Validation Loss vs Iteration', valid_loss, i) # Record to Tensorboard loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' tb.add_scalar('Current Accuracy vs Iteration', current_accuracy, i) # Record to Tensorboard tb.add_scalar('Current Norm ED vs Iteration', current_norm_ED, i) # Record to Tensorboard # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ training pipeline for our character recognition model """ if not opt.data_filtering_off: print( "Filtering the images containing characters which are not in opt.character" ) print( "Filtering the images whose label is longer than opt.batch_max_length" ) opt.select_data = opt.select_data.split("-") opt.batch_ratio = opt.batch_ratio.split("-") train_dataset = Batch_Balanced_Dataset(opt) # Logging the experiment, so that we can refer to the performance of previous runs log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a") # Using params from user input to collation function for dataloader AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # Defining our validation dataloader valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") log.close() # Using either CTC or Attention for char predictions if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) # Runnning our OCR model in grayscale or RGB if opt.rgb: opt.input_channel = 3 # Defining our model using user inputs model = Model(opt) print( "model input parameters", opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction, ) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # Putting model in training mode model.train() # Using finetuning saved model from previous runs if opt.saved_model != "": print(f"loading pretrained model from {opt.saved_model}") if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") # print(model) # Sending model to cpu or gpu, depending upon the avialbility model.to(device) # Setting up loss functions in the case of either CTC or Attention if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print("Trainable params num : ", sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # Setup of optimizer to be used if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # print(opt) with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file: opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) opt_file.write(opt_log) # Training iteration starts here start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass # Setting up initial metrics results and initializing the timer start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if "CTC" in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}" loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_accuracy.pth", ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_norm_ED.pth", ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}" print(loss_model_log) log.write(loss_model_log + "\n") # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[s]")] pred = pred[:pred.find("[s]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" print(predicted_result_log) log.write(predicted_result_log + "\n") # save model per 1e+5 iter. if (iteration + 1) % 1e5 == 0: torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth", ) if (iteration + 1) == opt.num_iter: print("end the training") sys.exit() iteration += 1