def train(opt, tb): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] print("~~~~~~~~~~~~Gradient Descent~~~~~~~~~~~~~") #print(model.parameters()) #print(model.) for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Filtered parameters for gradient descent: \n', len(filtered_parameters)) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward print(preds[0][0]) target = text[:, 1:] # without [GO] Symbol print(target[0]) cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' tb.add_scalar('Training Loss vs Iteration', loss_avg.val(), i) # Record to Tensorboard tb.add_scalar('Validation Loss vs Iteration', valid_loss, i) # Record to Tensorboard loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' tb.add_scalar('Current Accuracy vs Iteration', current_accuracy, i) # Record to Tensorboard tb.add_scalar('Current Norm ED vs Iteration', current_norm_ED, i) # Record to Tensorboard # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #import ipdb;ipdb.set_trace() train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") #print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf-8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) #import ipdb;ipdb.set_trace() if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf-8") as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') #pred = pred.encode('utf-8') #gt = gt.encode('utf-8') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): os.makedirs(opt.log, exist_ok=True) writer = SummaryWriter(opt.log) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ ctc_converter = CTCLabelConverter(opt.character) attn_converter = AttnLabelConverter(opt.character) opt.num_class = len(attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) """ setup loss """ loss_avg = Averager() ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device) attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter pbar = tqdm(range(opt.num_iter)) for iteration in pbar: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) ctc_text, ctc_length = ctc_converter.encode( labels, batch_max_length=opt.batch_max_length) attn_text, attn_length = attn_converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds, refiner = model(image, attn_text[:, :-1]) refiner_size = torch.IntTensor([refiner.size(1)] * batch_size) refiner = refiner.log_softmax(2).permute(1, 0, 2) refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length) total_loss = opt.lambda_ctc * refiner_loss target = attn_text[:, 1:] # without [GO] Symbol for pred in preds: total_loss += opt.lambda_attn * attn_loss( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) model.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(total_loss) if loss_avg.val() <= 0.6: opt.grad_clip = 2 if loss_avg.val() <= 0.3: opt.grad_clip = 1 preds = (p.cpu() for p in preds) refiner = refiner.cpu() image = image.cpu() torch.cuda.empty_cache() writer.add_scalar('train_loss', loss_avg.val(), iteration) pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format( iteration, opt.num_iter, loss_avg.val())) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, attn_loss, valid_loader, attn_converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('Val_loss', valid_loss) pbar.set_description(loss_log) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' # print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' or 'Transformer' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' log.write(predicted_result_log + '\n') # save model per 1e+3 iter. if (iteration + 1) % 1e+3 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/SCATTER_STR.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def train(opt): plotDir = os.path.join(opt.exp_dir,opt.exp_name,'plots') if not os.path.exists(plotDir): os.makedirs(plotDir) lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size*2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDisV1(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for currModel in [model, ocrModel, disModel]: for name, param in currModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) if not opt.ocrFixed: ocrModel.train() else: ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth")))>0: opt.saved_dis_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth"))[-1] #loading pre-trained model if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '' and opt.saved_dis_model != 'None': print(f'loading pretrained discriminator model from {opt.saved_dis_model}') if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() styleRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() loss_avg_ocrRecon_1 = Averager() loss_avg_ocrRecon_2 = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_styRecon = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim=='adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.optim=='adam': ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.optim=='adam': dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter cntr=0 while(True): # train part if opt.lr_policy !="None": scheduler.step() ocr_scheduler.step() dis_scheduler.step() image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch() # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) cntr+=1 disCnt = int(image_tensors_all.size(0)/2) image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) batch_size = image.size(0) ##-----------------------------------## #generate text(labels) from ocr.forward if opt.ocrFixed: # ocrModel.eval() length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = ocrModel(image, text_for_pred) preds = preds[:, :text_for_loss.shape[1] - 1, :] preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(labels_1): pred_EOS = pred.find('[s]') labels_1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # ocrModel.train() else: labels_1 = labels_gt ##-----------------------------------## text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator images_recon_1, images_recon_2, style = model(image, text_1, text_2) if 'CTC' in opt.Prediction: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1) ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2) # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 ) else: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1[:, :-1]) # align with Attention.forward target_ocr = text_1[:, 1:] # without [GO] Symbol ocrCost_train = ocrCriterion(preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1)) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False) # align with Attention.forward target_1 = text_1[:, 1:] # without [GO] Symbol preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1)) ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1)) # ocrCost = 0.5*(ocrCost_1+ocrCost_2) if not opt.ocrFixed: #training OCR ocrModel.zero_grad() ocrCost_train.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() #if ocr is fixed; ignore this loss loss_avg_ocr.add(ocrCost_train) else: loss_avg_ocr.add(torch.tensor(0.0)) #Domain discriminator: Dis update disModel.zero_grad() disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) disCost.backward() # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) # #[Style Encoder] + [Word Generator] update #Adversarial loss disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2)) #Input reconstruction loss recCost = recCriterion(images_recon_1,image) #Pair style reconstruction loss if opt.styleReconWeight == 0.0: styleRecCost = torch.tensor(0.0) else: if opt.styleDetach: styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach()) else: styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style) #OCR Content cost ocrCost = 0.5*(ocrCost_1+ocrCost_2) cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.styleReconWeight*styleRecCost model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #Individual losses loss_avg_ocrRecon_1.add(opt.ocrWeight*0.5*ocrCost_1) loss_avg_ocrRecon_2.add(opt.ocrWeight*0.5*ocrCost_2) loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(opt.reconWeight*recCost) loss_avg_styRecon.add(opt.styleReconWeight*styleRecCost) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_gt[trImgCntr]+'.png')) save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: model.eval() ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() disModel.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res( iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() if not opt.ocrFixed: ocrModel.train() else: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.train() # ocrModel.module.Prediction.eval() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' #plotting lib.plot.plot(os.path.join(plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Synth-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Dis-Loss'), valid_loss[2].item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon1-Loss'), valid_loss[3].item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon2-Loss'), valid_loss[4].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Gen-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(plotDir,'Valid-ImgRecon1-Loss'), valid_loss[6].item()) lib.plot.plot(os.path.join(plotDir,'Valid-StyRecon2-Loss'), valid_loss[7].item()) lib.plot.plot(os.path.join(plotDir,'Orig-OCR-WordAccuracy'), current_accuracy[0]) lib.plot.plot(os.path.join(plotDir,'Recon-OCR-WordAccuracy'), current_accuracy[1]) lib.plot.plot(os.path.join(plotDir,'Pair-OCR-WordAccuracy'), current_accuracy[2]) lib.plot.plot(os.path.join(plotDir,'Orig-OCR-CharAccuracy'), current_norm_ED[0]) lib.plot.plot(os.path.join(plotDir,'Recon-OCR-CharAccuracy'), current_norm_ED[1]) lib.plot.plot(os.path.join(plotDir,'Pair-OCR-CharAccuracy'), current_norm_ED[2]) # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] if not opt.ocrFixed: torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] if not opt.ocrFixed: torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: # gt_ocr = gt_ocr[:gt_ocr.find('[s]')] pred_ocr = pred_ocr[:pred_ocr.find('[s]')] # gt_1 = gt_1[:gt_1.find('[s]')] pred_1 = pred_1[:pred_1.find('[s]')] # gt_2 = gt_2[:gt_2.find('[s]')] pred_2 = pred_2[:pred_2.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() loss_avg_ocrRecon_1.reset() loss_avg_ocrRecon_2.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_styRecon.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): if opt.use_tb: tb_dir = f'/home_hongdo/{getpass.getuser()}/tb/{opt.experiment_name}' print('tensorboard : ', tb_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) writer = SummaryWriter(log_dir=tb_dir) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) # log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') log = open(f'{save_dir}/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 # sekim for transfer learning model = Model(opt, 38) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) # sekim change last layer in_feature = model.module.Prediction.generator.in_features model.module.Prediction.attention_cell.rnn = nn.LSTMCell( 256 + opt.num_class, 256).to(device) model.module.Prediction.generator = nn.Linear(in_feature, opt.num_class).to(device) print(model.module.Prediction.generator) print("Model:") print(model) model.train() """ setup loss """ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: with open(f'{save_dir}/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print("-------------------------------------------------") print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'{save_dir}/{opt.experiment_name}/log_train.txt', 'a') as log: # with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() if opt.use_tb: writer.add_scalar('OCR_loss/train_loss', loss_avg.val(), i) writer.add_scalar('OCR_loss/validation_loss', valid_loss, i) current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') torch.save( model.state_dict(), f'{save_dir}/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') torch.save( model.state_dict(), f'{save_dir}/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') torch.save(model.state_dict(), f'{save_dir}/{opt.experiment_name}/iter_{i + 1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size*2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDis() print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # Synthesizer weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Recognizer weight initialization for name, param in ocrModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Discriminator weight initialization for name, param in disModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() ocrModel = torch.nn.DataParallel(ocrModel).to(device) ocrModel.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.saved_synth_model != '': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("Model:") print(model) if opt.saved_ocr_model != '': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) # ocrModel.eval() #as we can't call RNN.backward in eval mode print("OCRModel:") print(ocrModel) if opt.saved_dis_model != '': print(f'loading pretrained discriminator model from {opt.saved_dis_model}') if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) # ocrModel.eval() #as we can't call RNN.backward in eval mode print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() # loss averager loss_avg = Averager() loss_avg_ocr = Averager() ##---------- loss_avg_dis = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training # filter that only require gradient decent ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.adam: ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training # filter that only require gradient decent dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.adam: dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("DisOptimizer:") print(dis_optimizer) """ final options """ # print(opt) with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '': try: start_iter = int(opt.saved_synth_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter while(True): # train part image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch() # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### disCnt = int(image_tensors_all.size(0)/2) image_tensors, image_tensors_real, labels_1, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) batch_size = image.size(0) images_recon_1, images_recon_2, _ = model(image, text_1, text_2) if 'CTC' in opt.Prediction: #ocr training preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #dis training #Check: Using alternate real images disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) #synth training preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost = 0.5*(ocrCriterion(preds_1, text_1, preds_size_1, length_1) + ocrCriterion(preds_2, text_2, preds_size_2, length_2)) #gen training disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2)) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) recCost = recCriterion(images_recon_1,image) cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost disModel.zero_grad() disCost.backward() torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #training OCR ocrModel.zero_grad() ocrCost_train.backward() torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr.add(ocrCost_train) #START HERE # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: model.eval() ocrModel.eval() disModel.eval() with torch.no_grad(): # valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( # model, criterion, valid_loader, converter, opt) valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_adv( iteration, model, ocrModel, disModel, recCriterion, ocrCriterion, valid_loader, converter, opt) model.train() ocrModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_{iteration+1}.pth')) torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_{iteration+1}_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_{iteration+1}_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt, log): """dataset preparation""" # train dataset. for convenience if opt.select_data == "label": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", ] elif opt.select_data == "synth": select_data = ["MJ", "ST"] elif opt.select_data == "synth_SA": select_data = ["MJ", "ST", "SA"] opt.batch_ratio = "0.4-0.4-0.2" # same ratio with SCATTER paper. elif opt.select_data == "mix": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", "MJ", "ST", ] elif opt.select_data == "mix_SA": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", "MJ", "ST", "SA", ] else: select_data = opt.select_data.split("-") # set batch_ratio for each data. if opt.batch_ratio: batch_ratio = opt.batch_ratio.split("-") else: batch_ratio = [round(1 / len(select_data), 3)] * len(select_data) train_loader = Batch_Balanced_Dataset(opt, opt.train_data, select_data, batch_ratio, log) if opt.semi != "None": select_data_unlabel = ["U1.Book32", "U2.TextVQA", "U3.STVQA"] batch_ratio_unlabel = [round(1 / len(select_data_unlabel), 3) ] * len(select_data_unlabel) dataset_root_unlabel = "data_CVPR2021/training/unlabel/" train_loader_unlabel_semi = Batch_Balanced_Dataset( opt, dataset_root_unlabel, select_data_unlabel, batch_ratio_unlabel, log, learn_type="semi", ) AlignCollate_valid = AlignCollate(opt, mode="test") valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt, mode="test") valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=False, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") """ model configuration """ if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.sos_token_index = converter.dict["[SOS]"] opt.eos_token_index = converter.dict["[EOS]"] opt.num_class = len(converter.character) model = Model(opt) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != "": fine_tuning_log = f"### loading pretrained model from {opt.saved_model}\n" if "MoCo" in opt.saved_model or "MoCo" in opt.self_pre: pretrained_state_dict_qk = torch.load(opt.saved_model) pretrained_state_dict = {} for name in pretrained_state_dict_qk: if "encoder_q" in name: rename = name.replace("encoder_q.", "") pretrained_state_dict[rename] = pretrained_state_dict_qk[ name] else: pretrained_state_dict = torch.load(opt.saved_model) for name, param in model.named_parameters(): try: param.data.copy_(pretrained_state_dict[name].data ) # load from pretrained model if opt.FT == "freeze": param.requires_grad = False # Freeze fine_tuning_log += f"pretrained layer (freezed): {name}\n" else: fine_tuning_log += f"pretrained layer: {name}\n" except: fine_tuning_log += f"non-pretrained layer: {name}\n" print(fine_tuning_log) log.write(fine_tuning_log + "\n") # print("Model:") # print(model) log.write(repr(model) + "\n") """ setup loss """ if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: # ignore [PAD] token criterion = torch.nn.CrossEntropyLoss( ignore_index=converter.dict["[PAD]"]).to(device) if "Pseudo" in opt.semi: criterion_SemiSL = PseudoLabelLoss(opt, converter, criterion) elif "MeanT" in opt.semi: criterion_SemiSL = MeanTeacherLoss(opt, student_for_init_teacher=model) # loss averager train_loss_avg = Averager() semi_loss_avg = Averager() # semi supervised loss avg # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print(f"Trainable params num: {sum(params_num)}") log.write(f"Trainable params num: {sum(params_num)}\n") # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optimizer == "sgd": optimizer = torch.optim.SGD( filtered_parameters, lr=opt.lr, momentum=opt.sgd_momentum, weight_decay=opt.sgd_weight_decay, ) elif opt.optimizer == "adadelta": optimizer = torch.optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) elif opt.optimizer == "adam": optimizer = torch.optim.Adam(filtered_parameters, lr=opt.lr) print("Optimizer:") print(optimizer) log.write(repr(optimizer) + "\n") if "super" in opt.schedule: if opt.optimizer == "sgd": cycle_momentum = True else: cycle_momentum = False scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=opt.lr, cycle_momentum=cycle_momentum, div_factor=20, final_div_factor=1000, total_steps=opt.num_iter, ) print("Scheduler:") print(scheduler) log.write(repr(scheduler) + "\n") """ final options """ # print(opt) opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): if str(k) == "character" and len(str(v)) > 500: opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n" else: opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) log.write(opt_log) log.close() """ start training """ start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass start_time = time.time() best_score = -1 # training loop for iteration in tqdm( range(start_iter + 1, opt.num_iter + 1), total=opt.num_iter, position=0, leave=True, ): if "MeanT" in opt.semi: image_tensors, image_tensors_ema, labels = train_loader.get_batch_ema( ) else: image_tensors, labels = train_loader.get_batch() image = image_tensors.to(device) labels_index, labels_length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) # default recognition loss part if "CTC" in opt.Prediction: preds = model(image) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) loss = criterion(preds_log_softmax, labels_index, preds_size, labels_length) else: preds = model(image, labels_index[:, :-1]) # align with Attention.forward target = labels_index[:, 1:] # without [SOS] Symbol loss = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) # semi supervised part (SemiSL) if "Pseudo" in opt.semi: image_unlabel, _ = train_loader_unlabel_semi.get_batch_two_images() image_unlabel = image_unlabel.to(device) loss_SemiSL = criterion_SemiSL(image_unlabel, model) loss = loss + loss_SemiSL semi_loss_avg.add(loss_SemiSL) elif "MeanT" in opt.semi: ( image_tensors_unlabel, image_tensors_unlabel_ema, ) = train_loader_unlabel_semi.get_batch_two_images() image_unlabel = image_tensors_unlabel.to(device) student_input = torch.cat([image, image_unlabel], dim=0) image_ema = image_tensors_ema.to(device) image_unlabel_ema = image_tensors_unlabel_ema.to(device) teacher_input = torch.cat([image_ema, image_unlabel_ema], dim=0) loss_SemiSL = criterion_SemiSL( student_input=student_input, student_logit=preds, student=model, teacher_input=teacher_input, iteration=iteration, ) loss = loss + loss_SemiSL semi_loss_avg.add(loss_SemiSL) model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() train_loss_avg.add(loss) if "super" in opt.schedule: scheduler.step() else: adjust_learning_rate(optimizer, iteration, opt) # validation part. # To see training progress, we also conduct validation when 'iteration == 1' if iteration % opt.val_interval == 0 or iteration == 1: # for validation log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_score, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # keep best score (accuracy or norm ED) model on valid dataset # Do not use this on test datasets. It would be an unfair comparison # (training should be done without referring test set). if current_score > best_score: best_score = current_score torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_score.pth", ) # validation log: loss, lr, score (accuracy or norm ED), time. lr = optimizer.param_groups[0]["lr"] elapsed_time = time.time() - start_time valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f}" valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n" valid_log += f'{"Current_score":17s}: {current_score:0.2f}, Current_lr: {lr:0.7f}\n' valid_log += f'{"Best_score":17s}: {best_score:0.2f}, Infer_time: {infer_time:0.1f}, Elapsed_time: {elapsed_time:0.1f}' # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[EOS]")] pred = pred[:pred.find("[EOS]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" valid_log = f"{valid_log}\n{predicted_result_log}" print(valid_log) log.write(valid_log + "\n") opt.writer.add_scalar("train/train_loss", float(f"{train_loss_avg.val():0.5f}"), iteration) opt.writer.add_scalar("train/semi_loss", float(f"{semi_loss_avg.val():0.5f}"), iteration) opt.writer.add_scalar("train/lr", float(f"{lr:0.7f}"), iteration) opt.writer.add_scalar("train/elapsed_time", float(f"{elapsed_time:0.1f}"), iteration) opt.writer.add_scalar("valid/valid_loss", float(f"{valid_loss:0.5f}"), iteration) opt.writer.add_scalar("valid/current_score", float(f"{current_score:0.2f}"), iteration) opt.writer.add_scalar("valid/best_score", float(f"{best_score:0.2f}"), iteration) train_loss_avg.reset() semi_loss_avg.reset() """ Evaluation at the end of training """ print("Start evaluation on benchmark testset") """ keep evaluation model and result logs """ os.makedirs(f"./result/{opt.exp_name}", exist_ok=True) os.makedirs(f"./evaluation_log", exist_ok=True) saved_best_model = f"./saved_models/{opt.exp_name}/best_score.pth" # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/') model.load_state_dict(torch.load(f"{saved_best_model}")) opt.eval_type = "benchmark" model.eval() with torch.no_grad(): total_accuracy, eval_data_list, accuracy_list = benchmark_all_eval( model, criterion, converter, opt) opt.writer.add_scalar("test/total_accuracy", float(f"{total_accuracy:0.2f}"), iteration) for eval_data, accuracy in zip(eval_data_list, accuracy_list): accuracy = float(accuracy) opt.writer.add_scalar(f"test/{eval_data}", float(f"{accuracy:0.2f}"), iteration) print( f'finished the experiment: {opt.exp_name}, "CUDA_VISIBLE_DEVICES" was {opt.CUDA_VISIBLE_DEVICES}' )
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + ' ') log.close() """ model configuration """ # if 'CTC' in opt.Prediction: if opt.baiduCTC: CTC_converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: CTC_converter = CTCLabelConverter(opt.character) # else: Attn_converter = AttnLabelConverter(opt.character) opt.num_class_ctc = len(CTC_converter.character) opt.num_class_attn = len(Attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class_ctc, opt.num_class_attn, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): # print(name) if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() print("Model:") print(model) # print(summary(model, (1, opt.imgH, opt.imgW,1))) """ setup loss """ if opt.baiduCTC: # need to install warpctc. see our guideline. if opt.label_smooth: criterion_major_path = SmoothCTCLoss(num_classes=opt.num_class_ctc, weight=0.05) else: criterion_major_path = CTCLoss() #criterion_major_path = CTCLoss(average_frames=False, reduction="mean", blank=0) else: criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device) # else: # criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager #criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device) criterion_guide_path = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) loss_avg_major_path = Averager() loss_avg_guide_path = Averager() # filter that only require gradient decent guide_parameters = [] major_parameters = [] guide_model_part_names = [ "Transformation", "FeatureExtraction", "SequenceModeling_Attn", "Attention" ] major_model_part_names = ["SequenceModeling_CTC", "CTC"] for name, param in model.named_parameters(): if param.requires_grad: if name.split(".")[1] in guide_model_part_names: guide_parameters.append(param) elif name.split(".")[1] in major_model_part_names: major_parameters.append(param) # print(name) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] if opt.continue_training: guide_parameters = [] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer_ctc = AdamW(major_parameters, lr=opt.lr) if not opt.continue_training: optimizer_attn = AdamW(guide_parameters, lr=opt.lr) scheduler_ctc = get_linear_schedule_with_warmup( optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter) scheduler_attn = get_linear_schedule_with_warmup( optimizer_attn, num_warmup_steps=10000, num_training_steps=opt.num_iter) start_iter = 0 if opt.saved_model != '' and (not opt.continue_training): print(f'loading pretrained model from {opt.saved_model}') checkpoint = torch.load(opt.saved_model) start_iter = checkpoint['start_iter'] + 1 if not opt.adam: optimizer_ctc.load_state_dict( checkpoint['optimizer_ctc_state_dict']) if not opt.continue_training: optimizer_attn.load_state_dict( checkpoint['optimizer_attn_state_dict']) scheduler_ctc.load_state_dict( checkpoint['scheduler_ctc_state_dict']) scheduler_attn.load_state_dict( checkpoint['scheduler_attn_state_dict']) print(scheduler_ctc.get_lr()) print(scheduler_attn.get_lr()) if opt.FT: model.load_state_dict(checkpoint['model_state_dict'], strict=False) else: model.load_state_dict(checkpoint['model_state_dict']) if opt.continue_training: model.load_state_dict(torch.load(opt.saved_model)) # print("Optimizer:") # print(optimizer) # scheduler_ctc = get_linear_schedule_with_warmup( optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter, last_epoch=start_iter - 1) scheduler_attn = get_linear_schedule_with_warmup( optimizer_attn, num_warmup_steps=10000, num_training_steps=opt.num_iter, last_epoch=start_iter - 1) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options ------------- ' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)} ' opt_log += '--------------------------------------- ' print(opt_log) opt_file.write(opt_log) """ start training """ start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter - 1 if opt.continue_training: start_iter = 0 while (True): # train part image_tensors, labels = train_dataset.get_batch() iteration += 1 if iteration < start_iter: continue image = image_tensors.to(device) # print(image.size()) text_attn, length_attn = Attn_converter.encode( labels, batch_max_length=opt.batch_max_length) #print("1") text_ctc, length_ctc = CTC_converter.encode( labels, batch_max_length=opt.batch_max_length) #print("2") #if iteration == start_iter : # writer.add_graph(model, (image, text_attn)) batch_size = image.size(0) preds_major, preds_guide = model(image, text_attn[:, :-1]) #print("10") preds_size = torch.IntTensor([preds_major.size(1)] * batch_size) if opt.baiduCTC: preds_major = preds_major.permute(1, 0, 2) # to use CTCLoss format if opt.label_smooth: cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size, length_ctc, batch_size) else: cost_ctc = criterion_major_path( preds_major, text_ctc, preds_size, length_ctc) / batch_size else: preds_major = preds_major.log_softmax(2).permute(1, 0, 2) cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size, length_ctc) #print("3") # preds = model(image, text[:, :-1]) # align with Attention.forward target = text_attn[:, 1:] # without [GO] Symbol if not opt.continue_training: cost_attn = criterion_guide_path( preds_guide.view(-1, preds_guide.shape[-1]), target.contiguous().view(-1)) optimizer_attn.zero_grad() cost_attn.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_( guide_parameters, opt.grad_clip) # gradient clipping with 5 (Default) optimizer_attn.step() optimizer_ctc.zero_grad() cost_ctc.backward() torch.nn.utils.clip_grad_norm_( major_parameters, opt.grad_clip) # gradient clipping with 5 (Default) optimizer_ctc.step() scheduler_ctc.step() scheduler_attn.step() #print("4") loss_avg_major_path.add(cost_ctc) if not opt.continue_training: loss_avg_guide_path.add(cost_attn) if (iteration + 1) % 100 == 0: writer.add_scalar("Loss/train_ctc", loss_avg_major_path.val(), (iteration + 1) // 100) loss_avg_major_path.reset() if not opt.continue_training: writer.add_scalar("Loss/train_attn", loss_avg_guide_path.val(), (iteration + 1) // 100) loss_avg_guide_path.reset() # validation part if ( iteration + 1 ) % opt.valInterval == 0: #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion_major_path, valid_loader, CTC_converter, opt) model.train() writer.add_scalar("Loss/valid", valid_loss, (iteration + 1) // opt.valInterval) writer.add_scalar("Metrics/accuracy", current_accuracy, (iteration + 1) // opt.valInterval) writer.add_scalar("Metrics/norm_ED", current_norm_ED, (iteration + 1) // opt.valInterval) # loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' # loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # training loss and validation loss if not opt.continue_training: loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Train loss attn: {loss_avg_guide_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' else: loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_major_path.reset() if not opt.continue_training: loss_avg_guide_path.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'{fol_ckpt}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'{fol_ckpt}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log} {current_model_log} {best_model_log}' print(loss_model_log) log.write(loss_model_log + ' ') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line} {head} {dashed_line} ' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): # if 'Attn' in opt.Prediction: # gt = gt[:gt.find('[s]')] # pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f} {str(pred == gt)} ' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + ' ') # save model per 1e+5 iter. if (iteration + 1) % 1e+3 == 0 and (not opt.continue_training): # print(scheduler_ctc.get_lr()) # print(scheduler_attn.get_lr()) torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_attn_state_dict': optimizer_attn.state_dict(), 'optimizer_ctc_state_dict': optimizer_ctc.state_dict(), 'start_iter': iteration, 'scheduler_ctc_state_dict': scheduler_ctc.state_dict(), 'scheduler_attn_state_dict': scheduler_attn.state_dict(), }, f'{fol_ckpt}/current_model.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def train(opt, show_number=2, amp=False): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8") AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=min(32, opt.batch_size), shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), prefetch_factor=512, collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) if opt.saved_model != '': pretrained_dict = torch.load(opt.saved_model) if opt.new_prediction: model.Prediction = nn.Linear( model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight'])) model = torch.nn.DataParallel(model).to(device) print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(pretrained_dict, strict=False) else: model.load_state_dict(pretrained_dict) if opt.new_prediction: model.module.Prediction = nn.Linear( model.module.SequenceModeling_output, opt.num_class) for name, param in model.module.Prediction.named_parameters(): if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) model = model.to(device) else: # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue model = torch.nn.DataParallel(model).to(device) model.train() print("Model:") print(model) count_parameters(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # freeze some layers try: if opt.freeze_FeatureFxtraction: for param in model.module.FeatureExtraction.parameters(): param.requires_grad = False if opt.freeze_SequenceModeling: for param in model.module.SequenceModeling.parameters(): param.requires_grad = False except: pass # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim == 'adam': #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer = optim.Adam(filtered_parameters) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter scaler = GradScaler() t1 = time.time() while (True): # train part optimizer.zero_grad(set_to_none=True) if amp: with autocast(): image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) scaler.scale(cost).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) scaler.step(optimizer) scaler.update() else: image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() loss_avg.add(cost) # validation part if (i % opt.valInterval == 0) and (i != 0): print('training time: ', time.time() - t1) t1 = time.time() elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\ infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' #show_number = min(show_number, len(labels)) start = random.randint(0, len(labels) - show_number) for gt, pred, confidence in zip( labels[start:start + show_number], preds[start:start + show_number], confidence_score[start:start + show_number]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') print('validation time: ', time.time() - t1) t1 = time.time() # save model per 1e+4 iter. if (i + 1) % 1e+4 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, # 'True' to check training progress with validation function. shuffle=True, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) elif 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue """ setup loss """ if 'Transformer' in opt.SequenceModeling: criterion = transformer_loss elif 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: # ignore [GO] token = ignore index 0 criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) elif 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer = optim.Adam(filtered_parameters, betas=(0.9, 0.98), eps=1e-09) optimizer_schedule = ScheduledOptim(optimizer, opt.d_model, opt.n_warmup_steps) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") if opt.load_weights != '' and check_isfile(opt.load_weights): # load pretrained weights but ignore layers that don't match in size checkpoint = torch.load(opt.load_weights, pickle_module=pickle) if type(checkpoint) == dict: pretrain_dict = checkpoint['state_dict'] else: pretrain_dict = checkpoint model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) print("Loaded pretrained weights from '{}'".format(opt.load_weights)) del checkpoint torch.cuda.empty_cache() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') checkpoint = torch.load(opt.continue_model) print(checkpoint.keys()) model.load_state_dict(checkpoint['state_dict']) start_iter = checkpoint['step'] + 1 print('continue to train start_iter: ', start_iter) if 'optimizer' in checkpoint.keys(): optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() if 'best_accuracy' in checkpoint.keys(): best_accuracy = checkpoint['best_accuracy'] if 'best_norm_ED' in checkpoint.keys(): best_norm_ED = checkpoint['best_norm_ED'] del checkpoint torch.cuda.empty_cache() # data parallel for multi-GPU model = torch.nn.DataParallel(model).cuda() model.train() print("Model size:", count_num_param(model), 'M') if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer_schedule.n_current_steps = start_iter for i in tqdm(range(start_iter, opt.num_iter)): for p in model.parameters(): p.requires_grad = True cpu_images, cpu_texts = train_dataset.get_batch() image = cpu_images.cuda() if 'Transformer' in opt.SequenceModeling: text, length, text_pos = converter.encode(cpu_texts, opt.batch_max_length) elif 'CTC' in opt.Prediction: text, length = converter.encode(cpu_texts) else: text, length = converter.encode(cpu_texts, opt.batch_max_length) batch_size = image.size(0) if 'Transformer' in opt.SequenceModeling: preds = model(image, text, tgt_pos=text_pos) target = text[:, 1:] # without <s> Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) elif 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) else: preds = model(image, text) target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer_schedule.step_and_update_lr() elif 'Transformer' in opt.SequenceModeling: optimizer.step() else: # gradient clipping with 5 (Default) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() loss_avg.add(cost) # validation part if i > 0 and (i + 1) % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], gts[:5]): if 'Transformer' in opt.SequenceModeling: pred = pred[:pred.find('</s>')] gt = gt[:gt.find('</s>')] elif 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i+1}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy state_dict = model.module.state_dict() save_checkpoint( { 'best_accuracy': best_accuracy, 'state_dict': state_dict, }, False, f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED state_dict = model.module.state_dict() save_checkpoint( { 'best_norm_ED': best_norm_ED, 'state_dict': state_dict, }, False, f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) # torch.save( # model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1000 iter. if (i + 1) % 1000 == 0: state_dict = model.module.state_dict() optimizer_state_dict = optimizer.state_dict() save_checkpoint( { 'state_dict': state_dict, 'optimizer': optimizer_state_dict, 'step': i, 'best_accuracy': best_accuracy, 'best_norm_ED': best_norm_ED, }, False, f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
def train(opt, log): if opt.self == "MoCo": opt.batch_size = 256 """ dataset preparation """ if opt.select_data == "unlabel": select_data = ["U1.Book32", "U2.TextVQA", "U3.STVQA"] batch_ratio = [round(1 / len(select_data), 3)] * len(select_data) else: select_data = opt.select_data.split("-") batch_ratio = opt.batch_ratio.split("-") train_loader = Batch_Balanced_Dataset( opt, opt.train_data, select_data, batch_ratio, log, learn_type="self" ) AlignCollate_valid = AlignCollate_SelfSL(opt) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt, data_type="unlabel" ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=False, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") """ model configuration """ if opt.self == "RotNet": model = Model(opt, SelfSL_layer=opt.SelfSL_layer) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue elif opt.self == "MoCo": model = MoCoLoss( opt, dim=opt.moco_dim, K=opt.moco_k, m=opt.moco_m, T=opt.moco_t ) # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != "": print(f"loading pretrained model from {opt.saved_model}") if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) log.write(repr(model) + "\n") """ setup loss """ criterion = torch.nn.CrossEntropyLoss(ignore_index=-1).to(device) # loss averager train_loss_avg = Averager() valid_loss_avg = Averager() # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print(f"Trainable params num: {sum(params_num)}") log.write(f"Trainable params num: {sum(params_num)}\n") # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optimizer == "adam": optimizer = torch.optim.Adam(filtered_parameters, lr=opt.lr) elif opt.self == "MoCo": optimizer = torch.optim.SGD( filtered_parameters, lr=opt.moco_lr, momentum=opt.moco_SGD_m, weight_decay=opt.moco_wd, ) opt.schedule = opt.moco_schedule opt.lr = opt.moco_lr opt.lr_drop_rate = opt.moco_lr_drop_rate else: optimizer = torch.optim.SGD( filtered_parameters, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay, ) print("Optimizer:") print(optimizer) log.write(repr(optimizer) + "\n") if "super" in opt.schedule: if opt.optimizer == "sgd": cycle_momentum = True else: cycle_momentum = False scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=opt.lr, cycle_momentum=cycle_momentum, div_factor=20, final_div_factor=1000, total_steps=opt.num_iter, ) print("Scheduler:") print(scheduler) log.write(repr(scheduler) + "\n") """ final options """ # print(opt) opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) log.write(opt_log) log.close() """ start training """ start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass start_time = time.time() iteration = start_iter best_score = -1 # training loop for iteration in tqdm( range(start_iter + 1, opt.num_iter + 1), total=opt.num_iter, position=0, leave=True, ): # train part if opt.self == "RotNet": image, Self_label = train_loader.get_batch() image = image.to(device) preds = model(image, SelfSL_layer=opt.SelfSL_layer) target = torch.LongTensor(Self_label).to(device) elif opt.self == "MoCo": q, k = train_loader.get_batch_two_images() q = q.to(device) k = k.to(device) preds, target = model(im_q=q, im_k=k) loss = criterion(preds, target) train_loss_avg.add(loss) model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip ) # gradient clipping with 5 (Default) optimizer.step() if "super" in opt.schedule: scheduler.step() else: adjust_learning_rate(optimizer, iteration, opt) # validation part. # To see training progress, we also conduct validation when 'iteration == 1' if iteration % opt.val_interval == 0 or iteration == 1: # for validation log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): length_of_data = 0 infer_time = 0 n_correct = 0 for i, (image_valid, Self_label_valid) in tqdm( enumerate(valid_loader), total=len(valid_loader), position=1, leave=False, ): if opt.self == "RotNet": batch_size = image_valid.size(0) start_infer_time = time.time() preds = model( image_valid.to(device), SelfSL_layer=opt.SelfSL_layer ) forward_time = time.time() - start_infer_time target = torch.LongTensor(Self_label_valid).to(device) elif opt.self == "MoCo": batch_size = image_valid.size(0) q_valid = image_valid.to(device) k_valid = Self_label_valid.to(device) start_infer_time = time.time() preds, target = model(im_q=q_valid, im_k=k_valid) forward_time = time.time() - start_infer_time loss = criterion(preds, target) valid_loss_avg.add(loss) infer_time += forward_time _, preds_index = preds.max(1) n_correct += (preds_index == target).sum().item() length_of_data = length_of_data + batch_size current_score = n_correct / length_of_data * 100 model.train() # keep best score (accuracy) model on valid dataset if current_score > best_score: best_score = current_score torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_score.pth", ) # validation log: loss, lr, score, time. lr = optimizer.param_groups[0]["lr"] elapsed_time = time.time() - start_time valid_log = f"\n[{iteration}/{opt.num_iter}] Train loss: {train_loss_avg.val():0.5f}, Valid loss: {valid_loss_avg.val():0.5f}, lr: {lr:0.7f}\n" valid_log += f"Best_score: {best_score:0.2f}, Current_score: {current_score:0.2f}, " valid_log += ( f"Infer_time: {infer_time:0.1f}, Elapsed_time: {elapsed_time:0.1f}" ) train_loss_avg.reset() valid_loss_avg.reset() # show some predicted results dashed_line = "-" * 80 if opt.self == "RotNet": head = f"GT:0 vs Pred | GT:90 vs Pred | GT:180 vs Pred | GT:270 vs Pred" preds_index = preds_index[:20] gts = Self_label_valid[:20] elif opt.self == "MoCo": head = f"GT:0 vs Pred | GT:0 vs Pred | GT:0 vs Pred | GT:0 vs Pred" preds_index = preds_index[:8] gts = torch.zeros(preds_index.shape[0], dtype=torch.long) predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for i, (gt, pred) in enumerate(zip(gts, preds_index)): if opt.self == "RotNet": gt, pred = gt * 90, pred * 90 if i % 4 != 3: predicted_result_log += f"{gt} vs {pred} | " else: predicted_result_log += f"{gt} vs {pred} \n" predicted_result_log += f"{dashed_line}" valid_log = f"{valid_log}\n{predicted_result_log}" print(valid_log) log.write(valid_log + "\n") print( f'finished the experiment: {opt.exp_name}, "CUDA_VISIBLE_DEVICES" was {opt.CUDA_VISIBLE_DEVICES}' )
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ # CTCLoss converter_ctc = CTCLabelConverter(opt.character) # Attention converter_atten = AttnLabelConverter(opt.character) opt.num_class_ctc = len(converter_ctc.character) opt.num_class_atten = len(converter_atten.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p_: p_.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # use fp16 to train model = model.to(device) if opt.fp16: with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: log.write('==> Enable fp16 training' + '\n') print('==> Enable fp16 training') model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # data parallel for multi-GPU if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model).to(device) model.train() # for i in model.module.Prediction_atten: # i.to(device) # for i in model.module.Feat_Extraction.scr: # i.to(device) if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device) criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() """ final options """ writer = SummaryWriter(f'./saved_models/{opt.exp_name}') # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter # image_tensors, labels = train_dataset.get_batch() while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) batch_size = image.size(0) text_ctc, length_ctc = converter_ctc.encode( labels, batch_max_length=opt.batch_max_length) text_atten, length_atten = converter_atten.encode( labels, batch_max_length=opt.batch_max_length) # type tuple; (tensor, list); text_atten[:, :-1]:align with Attention.forward preds_ctc, preds_atten = model(image, text_atten[:, :-1]) # CTC Loss preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size) # _, preds_index = preds_ctc.max(2) # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data) preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2) cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size, length_ctc) # Attention Loss # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten] # # select max probabilty (greedy decoding) then decode index to character # preds_index_atten = [i.max(2)[1] for i in preds_atten] # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten] # preds_str_atten2 = preds_str_atten # preds_str_atten = [] # for i in preds_str_atten2: # prune after "end of sentence" token ([s]) # temp = [] # for j in i: # j = j[:j.find('[s]')] # temp.append(j) # preds_str_atten.append(temp) # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i] target = text_atten[:, 1:] # without [GO] Symbol # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1)) for index, pred in enumerate(preds_atten): if index == 0: cost_atten = 1.0 * criterion_atten( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) else: cost_atten += 1.0 * criterion_atten( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in # preds_atten] # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1)) cost = cost_ctc + cost_atten writer.add_scalar('loss', cost.item(), global_step=iteration + 1) # cost = cost_ctc # cost = cost_atten if (iteration + 1) % 100 == 0: print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format( iteration + 1, cost.item(), loss_avg.val()), end='\n') else: print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format( iteration + 1, cost.item(), loss_avg.val()), end='') sys.stdout.flush() if cost < 0.001: print(f'iter: {iteration + 1}\tloss: {cost}') # aaaaaa = 0 # model.zero_grad() optimizer.zero_grad() if torch.isnan(cost): print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN') sys.exit() elif torch.isinf(cost): print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF') sys.exit() else: if opt.fp16: with amp.scale_loss(cost, optimizer) as scaled_loss: scaled_loss.backward() else: cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) writer.add_scalar('loss_avg', loss_avg.val(), global_step=iteration + 1) # if loss_avg.val() <= 0.6: # opt.grad_clip = 2 # if loss_avg.val() <= 0.3: # opt.grad_clip = 1 # validation part if iteration == 0 or ( iteration + 1 ) % opt.valInterval == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion_atten, valid_loader, converter_atten, opt) model.train() writer.add_scalar('accuracy', current_accuracy, global_step=iteration + 1) # training loss and validation loss loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() # if (iteration + 1) % opt.valInterval == 0: # print(f'iter: {iteration + 1}\tloss: {cost}') iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') # train_dataset (image, label) train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) print("Model:") print(model) total_num, true_grad_num, false_grad_num = calculate_model_params(model) print("Total parameters: ", total_num) print("Number of parameters requires grad: ", true_grad_num) print("Number of parameters do not require grad: ", false_grad_num) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if isinstance(model, torch.nn.DataParallel): model = model.module # load pretrained model if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_pretrained_networks() elif opt.continue_train: model.load_checkpoint(opt.model_name) else: raise Exception('Something went wrong!') """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() log_dir = f'./saved_models/{opt.exp_name}' writer = SummaryWriter(log_dir) # """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) model.optimize_parameters() writer.add_scalar('train_loss', cost, iteration + 1) loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt, iteration) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('val_loss', valid_loss, iteration + 1) writer.add_scalar('accuracy', current_accuracy, iteration + 1) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy model.save_checkpoints(iteration, 'best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED model.save_checkpoints(iteration, 'best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: model.save_checkpoints(iteration + 1, opt.model_name) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open('./saved_models/{}/log_dataset.txt'.format(opt.experiment_name), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """部分参数初始化""" learning_rate = 1e-4 label2num, num2label = label_num('all_labels.txt') num_classes = len(label2num) print('训练类别数:{}'.format(num_classes)) print('训练集标签列表:\n{}'.format(num2label.values())) print('-' * 80) class VGGNet(nn.Module): def __init__(self, num_classes=num_classes): super(VGGNet, self).__init__() net = models.vgg16(pretrained=True) net.classifier = nn.Sequential() self.features = net self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 128), nn.ReLU(True), nn.Dropout(), nn.Linear(128, num_classes), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x #--------------------训练过程--------------------------------- model = VGGNet() if torch.cuda.is_available(): model.cuda() params = [{'params': md.parameters()} for md in model.children() if md in [model.classifier]] optimizer = optim.Adam(model.parameters(), lr=learning_rate) loss_func = nn.CrossEntropyLoss() Loss_list = [] Accuracy_list = [] """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print('continue to train, start_iter: {}'.format(start_iter)) except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter num2label = opt.num2label while(True): # train part # training----------------------------- image_tensors, labels = train_dataset.get_batch() batch_x = image_tensors.to(device) #labels = [num2label[x] for x in labels]#将汉字转换回标签 batch_y = torch.from_numpy(np.asarray(labels, dtype=np.int8)).to(device) train_loss = 0. train_acc = 0. out = model(batch_x) loss = loss_func(out, batch_y.long()) train_loss += loss.item() pred = torch.max(out, 1)[1] train_correct = (pred == batch_y).sum() train_acc += train_correct.item() optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 0.5e+2 == 0: print('Step{}:'.format(i + 1)) print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len( labels)), train_acc / (len(labels)))) # save model per 1e+5 iter. if (i + 1) % 5e+2 == 0: torch.save( model.state_dict(), './saved_models/{}/iter_{}.pth'.format(opt.experiment_name, i+1)) if i == opt.num_iter: torch.save( model.state_dict(), './saved_models/{}/iter_{}.pth'.format(opt.experiment_name, i+1)) print('end the training') break i += 1 # evaluation-------------------------------- if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log model.eval() eval_loss = 0. eval_acc = 0. length_of_data = 0 for image_tensors, labels in valid_loader: batch_x = image_tensors.to(device) batch_y = torch.from_numpy(np.asarray(labels, dtype=np.int8)).to(device) length_of_data += len(labels) #batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(batch_y, volatile=True).cuda() out = model(batch_x) loss = loss_func(out, batch_y.long()) eval_loss += loss.item() pred = torch.max(out, 1)[1] num_correct = (pred == batch_y).sum() eval_acc += num_correct.item() print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (length_of_data), eval_acc / (length_of_data))) Loss_list.append(eval_loss / (len(labels))) Accuracy_list.append(100 * eval_acc / (len(labels))) x1 = np.arange(0, 100).reshape(1,-1) x2 = np.arange(0, 100).reshape(1,-1) y1 = np.array(Accuracy_list).reshape(1,-1) y2 = np.array(Loss_list).reshape(1,-1) plt.figure() plt.subplot(2, 1, 1) plt.plot(x1, y1, 'o-') plt.title('Test accuracy vs. epoches') plt.ylabel('Test accuracy') plt.subplot(2, 1, 2) plt.plot(x2, y2, '.-') plt.xlabel('Test loss vs. epoches') plt.ylabel('Test loss') plt.show() plt.savefig("accuracy_loss.jpg") sys.exit()
def train(opt): """ training pipeline for our character recognition model """ if not opt.data_filtering_off: print( "Filtering the images containing characters which are not in opt.character" ) print( "Filtering the images whose label is longer than opt.batch_max_length" ) opt.select_data = opt.select_data.split("-") opt.batch_ratio = opt.batch_ratio.split("-") train_dataset = Batch_Balanced_Dataset(opt) # Logging the experiment, so that we can refer to the performance of previous runs log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a") # Using params from user input to collation function for dataloader AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # Defining our validation dataloader valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") log.close() # Using either CTC or Attention for char predictions if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) # Runnning our OCR model in grayscale or RGB if opt.rgb: opt.input_channel = 3 # Defining our model using user inputs model = Model(opt) print( "model input parameters", opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction, ) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # Putting model in training mode model.train() # Using finetuning saved model from previous runs if opt.saved_model != "": print(f"loading pretrained model from {opt.saved_model}") if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") # print(model) # Sending model to cpu or gpu, depending upon the avialbility model.to(device) # Setting up loss functions in the case of either CTC or Attention if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print("Trainable params num : ", sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # Setup of optimizer to be used if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # print(opt) with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file: opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) opt_file.write(opt_log) # Training iteration starts here start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass # Setting up initial metrics results and initializing the timer start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if "CTC" in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}" loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_accuracy.pth", ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_norm_ED.pth", ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}" print(loss_model_log) log.write(loss_model_log + "\n") # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[s]")] pred = pred[:pred.find("[s]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" print(predicted_result_log) log.write(predicted_result_log + "\n") # save model per 1e+5 iter. if (iteration + 1) % 1e5 == 0: torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth", ) if (iteration + 1) == opt.num_iter: print("end the training") sys.exit() iteration += 1