def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt, mode='Train') log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt, mode='Val') # valid_dataset = RawDataset(root=opt.vAlignCollatealid_data, opt=opt) # use RawDataset valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, # shuffle=True, # 'True' to check training progress with validation function. shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter start_time = time.time() while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) if iteration % 20 == 0: print(f'#{iteration}: loss:{cost}', flush=True) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_loss = float("inf") i = start_iter if opt.early_stopping_param == 'accuracy': early_stopping_param = opt.early_stopping_param elif opt.early_stopping_param == 'loss': early_stopping_param = opt.early_stopping_param elif opt.early_stopping_param == 'None': early_stopping_param = None else: raise Exception( 'early_stopping_param is neither accuracy, loss or None') current_patience = opt.early_stopping_patience while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i > (opt.ignore_x_vals * opt.valInterval) and i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if early_stopping_param == 'accuracy': current_patience = opt.early_stopping_patience elif early_stopping_param == 'accuracy': #if it enters here, it means that the accuracy didn't upgrade current_patience -= 1 if valid_loss < best_loss: best_loss = valid_loss if early_stopping_param == 'loss': current_patience = opt.early_stopping_patience elif early_stopping_param == 'loss': #if it enters here, it means that the loss didn't upgrade current_patience -= 1 if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' if early_stopping_param is not None: loss_model_log += f'\n{"Early Stopping Param":17s}: {early_stopping_param}, {"Patience"}: {current_patience}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') if current_patience == 0: print( "Stopped the training due to early stopping patience") sys.exit() # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ training pipeline for our character recognition model """ if not opt.data_filtering_off: print( "Filtering the images containing characters which are not in opt.character" ) print( "Filtering the images whose label is longer than opt.batch_max_length" ) opt.select_data = opt.select_data.split("-") opt.batch_ratio = opt.batch_ratio.split("-") train_dataset = Batch_Balanced_Dataset(opt) # Logging the experiment, so that we can refer to the performance of previous runs log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a") # Using params from user input to collation function for dataloader AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # Defining our validation dataloader valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") log.close() # Using either CTC or Attention for char predictions if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) # Runnning our OCR model in grayscale or RGB if opt.rgb: opt.input_channel = 3 # Defining our model using user inputs model = Model(opt) print( "model input parameters", opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction, ) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # Putting model in training mode model.train() # Using finetuning saved model from previous runs if opt.saved_model != "": print(f"loading pretrained model from {opt.saved_model}") if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") # print(model) # Sending model to cpu or gpu, depending upon the avialbility model.to(device) # Setting up loss functions in the case of either CTC or Attention if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print("Trainable params num : ", sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # Setup of optimizer to be used if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # print(opt) with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file: opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) opt_file.write(opt_log) # Training iteration starts here start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass # Setting up initial metrics results and initializing the timer start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if "CTC" in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}" loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_accuracy.pth", ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_norm_ED.pth", ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}" print(loss_model_log) log.write(loss_model_log + "\n") # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[s]")] pred = pred[:pred.find("[s]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" print(predicted_result_log) log.write(predicted_result_log + "\n") # save model per 1e+5 iter. if (iteration + 1) % 1e5 == 0: torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth", ) if (iteration + 1) == opt.num_iter: print("end the training") sys.exit() iteration += 1
def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): """ evaluation with 10 benchmark evaluation datasets """ # The evaluation datasets, dataset order is same with Table 1 in our paper. eval_data_list = [ 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' ] # # To easily compute the total accuracy of our paper. # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] if calculate_infer_time: evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. else: evaluation_batch_size = opt.batch_size list_accuracy = [] total_forward_time = 0 total_evaluation_data_number = 0 total_correct_number = 0 log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') dashed_line = '-' * 80 print(dashed_line) log.write(dashed_line + '\n') for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=evaluation_batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation( model, criterion, evaluation_loader, converter, opt) list_accuracy.append(f'{accuracy_by_best_model:0.3f}') total_forward_time += infer_time total_evaluation_data_number += len(eval_data) total_correct_number += accuracy_by_best_model * length_of_data log.write(eval_data_log) print( f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}' ) log.write( f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n' ) print(dashed_line) log.write(dashed_line + '\n') averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 total_accuracy = total_correct_number / total_evaluation_data_number params_num = sum([np.prod(p.size()) for p in model.parameters()]) evaluation_log = 'accuracy: ' for name, accuracy in zip(eval_data_list, list_accuracy): evaluation_log += f'{name}: {accuracy}\t' evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' print(evaluation_log) log.write(evaluation_log + '\n') log.close() return None
def detect_ocr(config, image, timestamp, save_img): detection_list, img, boxes = Detection_txt(config, image, config.net) # print(detection_list) t = time.time() device = config.device model = config.model converter = config.converter # 32 * 100 AlignCollate_demo = AlignCollate(imgH=config.imgH, imgW=config.imgW, keep_ratio_with_pad=config.PAD) # demo_data = RawDataset(root=image, opt=config) # use RawDataset demo_data = RawDataset_wPosition(root=detection_list, opt=config) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=config.batch_size, shuffle=False, num_workers=int(config.workers), collate_fn=AlignCollate_demo, pin_memory=True) #(< PIL.Image.Image image mode=L size=398x120 at 0x7F376DAF30B8 >, './demo_image/demo_12.png') # predict model.eval() with torch.no_grad(): log = open(f'{config.logfilepath}', 'a') dashed_line = '-' * 80 head = f'{"coordinates":25s}\t{"predicted_labels":25s}\tconfidence score' if save_img: print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') pred_list = [] new_boxes = [] for image_tensors, coordinate_list in demo_loader: batch_size = image_tensors.size(0) # print(image_tensors.shape) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([config.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor( batch_size, config.batch_max_length + 1).fill_(0).to(device) preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for coordinate, pred, pred_max_prob in zip(coordinate_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) if pred_EOS == 0: confidence_score = 0.0 else: pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1].item() coordinate = list(coordinate) pred_list.append([coordinate, pred, confidence_score]) if save_img: print(f'{coordinate}\t{pred:25s}\t{confidence_score:0.4f}') log.write( f'{coordinate}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close() recog_time = time.time() - t config.recog_time = config.recog_time + recog_time # print("\nrun time (recognition) : {:.2f} , {:.2f} s".format(recog_time,config.recog_time)) if save_img: saveResult(img, boxes, pred_list, config.result_folder, config.res_imagefileName) return pred_list, timestamp
def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') log.write(f'{accuracy_by_best_model:0.3f}\n') log.close()
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) temp = [] for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) temp.append(pred) print(f'{img_name}\t{pred}') return temp
def _textRecognition(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict char_list = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" csv_filename = os.path.join(opt.output_dirpath, "text_information.csv") Header = ["bookID", "prediction"] with open(csv_filename, 'w') as f: writer = csv.DictWriter(f, fieldnames=Header) writer.writeheader() model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) #log = open(f'./text_information.csv', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip( image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) #print("{}:{}".format(pred_max_prob,len(pred_max_prob))) try: confidence_score = pred_max_prob.cumprod(dim=0)[-1] except: deleteImageAndText(opt.book_img_dirpath, img_name) continue #confidence_score = pred_max_prob.cumprod(dim=0)[-1] pred = pred[0] if confidence_score < 0.5: pred = "Unreadble" deleteImageAndText(opt.book_img_dirpath, img_name) continue elif not (pred in char_list): pred = "Undefined" # extract the name part of the image filename = os.path.basename(img_name) print( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') writer.writerow({"bookID": filename, "prediction": pred})
def demo(opt): inputimage = opt.input_image boxesscv = opt.boxescsv bboxes = parse_csv(inputimage, boxesscv) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'{opt.output_folder}result.csv', 'w') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_index, (pred, pred_max_prob) in enumerate( zip(preds_str, preds_max_prob)): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] for pts in bboxes[img_index]: x, y = pts log.write(f'{x},{y},') log.write(f'{pred}\n') log.close() # copy log to local output folder os.system(f'cp {opt.output_folder}result.csv /input/output') shutil.make_archive('per_word_visual', 'zip', '/input/output')
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict acc = 0;avgconf = [0.0, 0.0] # correct, wrong sample_num =0 model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): try: if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') if os.path.split(img_name)[-1].split("_")[0]==pred: acc+=1.0 avgconf[0]+=confidence_score else: avgconf[1] += confidence_score except: avgconf[1] += confidence_score print("!!!!!",img_name) log.close() sample_num+=len(image_path_list) acc/=sample_num avgconf[0]/=sample_num avgconf[1]/=sample_num print(f"Sample num:{sample_num}\nAccuracy:{acc}\nCorrect samples Average Confidence:{avgconf[0]}\nWrong samples Average Confidence:{avgconf[1]}")
def train(opt): plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots') if not os.path.exists(plotDir): os.makedirs(plotDir) lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size * 2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGenV4(opt) ocrModel = Model(opt) disModel = MsImageDisV1(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for currModel in [model, ocrModel, disModel]: for name, param in currModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) if not opt.ocrFixed: ocrModel.train() else: ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.modelFolderFlag: if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))) > 0: opt.saved_synth_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))[-1] if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_dis.pth"))) > 0: opt.saved_dis_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_dis.pth"))[-1] #loading pre-trained model if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '' and opt.saved_dis_model != 'None': print( f'loading pretrained discriminator model from {opt.saved_dis_model}' ) if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() styleRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() loss_avg_ocrRecon_1 = Averager() loss_avg_ocrRecon_2 = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_styRecon = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim == 'adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.optim == 'adam': ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.optim == 'adam': dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int( opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer, opt) ocr_scheduler = get_scheduler(ocr_optimizer, opt) dis_scheduler = get_scheduler(dis_optimizer, opt) start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter cntr = 0 while (True): # train part if opt.lr_policy != "None": scheduler.step() ocr_scheduler.step() dis_scheduler.step() image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch( ) # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) cntr += 1 disCnt = int(image_tensors_all.size(0) / 2) image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[ disCnt:disCnt + disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) batch_size = image.size(0) ##-----------------------------------## #generate text(labels) from ocr.forward if opt.ocrFixed: # ocrModel.eval() length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = ocrModel(image, text_for_pred) preds = preds[:, :text_for_loss.shape[1] - 1, :] preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(labels_1): pred_EOS = pred.find('[s]') labels_1[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) # ocrModel.train() else: labels_1 = labels_gt ##-----------------------------------## text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode( labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator images_recon_1, images_recon_2, style = model(image, text_1, text_2) if 'CTC' in opt.Prediction: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1) ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2) # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 ) else: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel( image, text_1[:, :-1]) # align with Attention.forward target_ocr = text_1[:, 1:] # without [GO] Symbol ocrCost_train = ocrCriterion( preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1)) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False) # align with Attention.forward target_1 = text_1[:, 1:] # without [GO] Symbol preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1)) ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1)) # ocrCost = 0.5*(ocrCost_1+ocrCost_2) if not opt.ocrFixed: #training OCR ocrModel.zero_grad() ocrCost_train.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() #if ocr is fixed; ignore this loss loss_avg_ocr.add(ocrCost_train) else: loss_avg_ocr.add(torch.tensor(0.0)) #Domain discriminator: Dis update disModel.zero_grad() disCost = opt.disWeight * 0.5 * ( disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) disCost.backward() # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) # #[Style Encoder] + [Word Generator] update #Adversarial loss disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) + disModel.module.calc_gen_loss(images_recon_2)) #Input reconstruction loss recCost = recCriterion(images_recon_1, image) #Pair style reconstruction loss if opt.styleReconWeight == 0.0: styleRecCost = torch.tensor(0.0) else: if opt.styleDetach: styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style.detach()) else: styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style) #OCR Content cost ocrCost = 0.5 * (ocrCost_1 + ocrCost_2) cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #Individual losses loss_avg_ocrRecon_1.add(opt.ocrWeight * 0.5 * ocrCost_1) loss_avg_ocrRecon_2.add(opt.ocrWeight * 0.5 * ocrCost_2) loss_avg_gen.add(opt.disWeight * disGenCost) loss_avg_imgRecon.add(opt.reconWeight * recCost) loss_avg_styRecon.add(opt.styleReconWeight * styleRecCost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages', str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image( tensor2im(image[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_input_' + labels_gt[trImgCntr] + '.png')) save_image( tensor2im(images_recon_1[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_recon_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_pair_' + labels_2[trImgCntr] + '.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: model.eval() ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() disModel.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res( iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() if not opt.ocrFixed: ocrModel.train() else: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.train() # ocrModel.module.Prediction.eval() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' #plotting lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'), valid_loss[2].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon1-Loss'), valid_loss[3].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon2-Loss'), valid_loss[4].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'), valid_loss[6].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-StyRecon2-Loss'), valid_loss[7].item()) lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-WordAccuracy'), current_accuracy[0]) lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-WordAccuracy'), current_accuracy[1]) lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-WordAccuracy'), current_accuracy[2]) lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-CharAccuracy'), current_norm_ED[0]) lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-CharAccuracy'), current_norm_ED[1]) lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-CharAccuracy'), current_norm_ED[2]) # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip( labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: # gt_ocr = gt_ocr[:gt_ocr.find('[s]')] pred_ocr = pred_ocr[:pred_ocr.find('[s]')] # gt_1 = gt_1[:gt_1.find('[s]')] pred_1 = pred_1[:pred_1.find('[s]')] # gt_2 = gt_2[:gt_2.find('[s]')] pred_2 = pred_2[:pred_2.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() loss_avg_ocrRecon_1.reset() loss_avg_ocrRecon_2.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_styRecon.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_synth.pth')) if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def demo(opt): """ model configuration """ if 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) elif 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # load model if opt.saved_model != '': print('loading pretrained model from %s' % opt.saved_model) checkpoint = torch.load(opt.saved_model) if type(checkpoint) == dict: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) del checkpoint torch.cuda.empty_cache() model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() dict_gt = {} with open('gt.txt', 'r') as gt_file: gt = gt_file.readlines() for line in gt: key = line.split(', "')[0] value = line.split(', "')[1].replace('"\n', '').lower() dict_gt[key] = value for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'Transformer' in opt.SequenceModeling: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) elif 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Transformer' in opt.SequenceModeling: pred = pred[:pred.find('</s>')] elif 'Attn' in opt.Prediction: # prune after "end of sentence" token ([s]) pred = pred[:pred.find('[s]')] raw_img = cv2.imread(img_name) raw_img = cv2.resize(raw_img, (200, 64)) tmp_img = np.zeros((128, 200, 3), np.uint8) tmp_img.fill(255) tmp_img[:64, :200] = raw_img raw_img = tmp_img font = cv2.FONT_HERSHEY_SIMPLEX bottomLeftCornerOfText = (5, 90) fontScale = 1 fontColor = (0, 0, 255) lineType = 2 if pred == dict_gt[img_name.split('/')[-1]]: cv2.putText(raw_img, pred, (5, 90), font, fontScale, (0, 255, 0), lineType) raw_img = raw_img[:96, :200] cv2.imwrite('./trash/true/' + img_name.split('/')[-1], raw_img) else: cv2.putText(raw_img, pred, (5, 90), font, fontScale, (0, 0, 255), lineType) cv2.putText(raw_img, dict_gt[img_name.split('/')[-1]], (5, 125), font, fontScale, (0, 255, 0), lineType) cv2.imwrite('./trash/false/' + img_name.split('/')[-1], raw_img) print(f'{img_name}\t{pred}')
min_point = points_sorted[i][0] max_point = points_sorted[i][1] #print("Cropped image") mask_file = result_folder + filename + "_" + str( order_sorted[i]) + "_" + str(i) + '.jpg' #print(mask_file) crop_image = rgb_img[int(min_point[1]):int(max_point[1]), int(min_point[0]):int(max_point[0])] #plt.imshow(crop_image) #plt.show() cv2.imwrite(mask_file, crop_image) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo #result_folder = './intermediate_result/' AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=result_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) print("Starting text classification") model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device)
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open('./saved_models/{}/log_dataset.txt'.format(opt.experiment_name), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """部分参数初始化""" learning_rate = 1e-4 label2num, num2label = label_num('all_labels.txt') num_classes = len(label2num) print('训练类别数:{}'.format(num_classes)) print('训练集标签列表:\n{}'.format(num2label.values())) print('-' * 80) class VGGNet(nn.Module): def __init__(self, num_classes=num_classes): super(VGGNet, self).__init__() net = models.vgg16(pretrained=True) net.classifier = nn.Sequential() self.features = net self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 128), nn.ReLU(True), nn.Dropout(), nn.Linear(128, num_classes), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x #--------------------训练过程--------------------------------- model = VGGNet() if torch.cuda.is_available(): model.cuda() params = [{'params': md.parameters()} for md in model.children() if md in [model.classifier]] optimizer = optim.Adam(model.parameters(), lr=learning_rate) loss_func = nn.CrossEntropyLoss() Loss_list = [] Accuracy_list = [] """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print('continue to train, start_iter: {}'.format(start_iter)) except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter num2label = opt.num2label while(True): # train part # training----------------------------- image_tensors, labels = train_dataset.get_batch() batch_x = image_tensors.to(device) #labels = [num2label[x] for x in labels]#将汉字转换回标签 batch_y = torch.from_numpy(np.asarray(labels, dtype=np.int8)).to(device) train_loss = 0. train_acc = 0. out = model(batch_x) loss = loss_func(out, batch_y.long()) train_loss += loss.item() pred = torch.max(out, 1)[1] train_correct = (pred == batch_y).sum() train_acc += train_correct.item() optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 0.5e+2 == 0: print('Step{}:'.format(i + 1)) print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len( labels)), train_acc / (len(labels)))) # save model per 1e+5 iter. if (i + 1) % 5e+2 == 0: torch.save( model.state_dict(), './saved_models/{}/iter_{}.pth'.format(opt.experiment_name, i+1)) if i == opt.num_iter: torch.save( model.state_dict(), './saved_models/{}/iter_{}.pth'.format(opt.experiment_name, i+1)) print('end the training') break i += 1 # evaluation-------------------------------- if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log model.eval() eval_loss = 0. eval_acc = 0. length_of_data = 0 for image_tensors, labels in valid_loader: batch_x = image_tensors.to(device) batch_y = torch.from_numpy(np.asarray(labels, dtype=np.int8)).to(device) length_of_data += len(labels) #batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(batch_y, volatile=True).cuda() out = model(batch_x) loss = loss_func(out, batch_y.long()) eval_loss += loss.item() pred = torch.max(out, 1)[1] num_correct = (pred == batch_y).sum() eval_acc += num_correct.item() print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (length_of_data), eval_acc / (length_of_data))) Loss_list.append(eval_loss / (len(labels))) Accuracy_list.append(100 * eval_acc / (len(labels))) x1 = np.arange(0, 100).reshape(1,-1) x2 = np.arange(0, 100).reshape(1,-1) y1 = np.array(Accuracy_list).reshape(1,-1) y2 = np.array(Loss_list).reshape(1,-1) plt.figure() plt.subplot(2, 1, 1) plt.plot(x1, y1, 'o-') plt.title('Test accuracy vs. epoches') plt.ylabel('Test accuracy') plt.subplot(2, 1, 2) plt.plot(x2, y2, '.-') plt.xlabel('Test loss vs. epoches') plt.ylabel('Test loss') plt.show() plt.savefig("accuracy_loss.jpg") sys.exit()
def demoToTxt2(image_folder, saved_model, txtFile): # sensitive parser = argparse.ArgumentParser() parser.add_argument('--image_folder', default=image_folder, help='path to image_folder which contains text images') parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser.add_argument('--batch_size', type=int, default=100, help='input batch size') parser.add_argument('--saved_model', default=saved_model, help="path to saved_model to evaluation") """ Data processing """ parser.add_argument('--batch_max_length', type=int, default=20, help='maximum-label-length') parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') parser.add_argument('--rgb', action='store_true', help='use rgb input') parser.add_argument('--character', type=str, default='0123456789', help='character label') parser.add_argument('--sensitive', default=True, help='for sensitive character mode') parser.add_argument('--PAD', default=False, action='store_true', help='whether to keep ratio then pad for image resize') """ Model Architecture """ parser.add_argument('--Transformation', default='TPS', type=str, help='Transformation stage. None|TPS') parser.add_argument('--FeatureExtraction', default='ResNet', type=str, help='FeatureExtraction stage. VGG|RCNN|ResNet') parser.add_argument('--SequenceModeling', default='BiLSTM', type=str, help='SequenceModeling stage. None|BiLSTM') parser.add_argument('--Prediction', default='Attn', type=str, help='Prediction stage. CTC|Attn') parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') parser.add_argument( '--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') parser.add_argument( '--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') opt = parser.parse_args() """ vocab / character number configuration """ if opt.sensitive: opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' # opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True opt.num_gpu = torch.cuda.device_count() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() saved_file = open(txtFile, 'w') for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) print(f'{img_name}\t{pred}') saved_file.write(f'{img_name}\t{pred}\n')
def runDeepTextNet(segmentedImagesList): opt = argparse.Namespace(FeatureExtraction='ResNet', PAD=False, Prediction='Attn', SequenceModeling='BiLSTM', Transformation='TPS', batch_max_length=25, batch_size=192, character='0123456789abcdefghijklmnopqrstuvwxyz', hidden_size=256, image_folder='demo_image/', imgH=32, imgW=100, input_channel=1, num_class=38, num_fiducial=20, num_gpu=0, output_channel=512, rgb=False, saved_model='TPS-ResNet-BiLSTM-Attn.pth', sensitive=False, workers=4) model = Model(opt) model = torch.nn.DataParallel(model).to('cpu') directory = "TPS-ResNet-BiLSTM-Attn.pth" model.load_state_dict(torch.load(directory, map_location='cpu')) converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=segmentedImagesList, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() out_preds_texts = [] for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] # print(pred) out_preds_texts.append(pred) # print(out_preds_texts) sentence_out = [' '.join(out_preds_texts)] return (sentence_out)
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) #modified # result_df = pd.DataFrame(columns=['video_id', 'word_id' ,'ocr_text']) result_df = pd.read_csv(opt.out_df_path) # result_df.append({'video_id':1, 'word_id': 1, 'ocr_text':"sdf"}, ignore_index=True) ################### # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) # log = open(f'./log_demo_result.txt', 'a') # dashed_line = '-' * 80 # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' # print(f'{dashed_line}\n{head}\n{dashed_line}') # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # # calculate confidence score (= multiply of pred_max_prob) # confidence_score = pred_max_prob.cumprod(dim=0)[-1] # print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') # log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') #modified img_name_path = Path(img_name) # result_df = result_df.append({'video_id': img_name_path.parent.stem , 'word_id': img_name_path.stem , 'ocr_text': pred}, ignore_index=True) result_df.loc[(result_df.video_id == img_name_path.parent.stem) & (result_df.word_id == int(img_name_path.stem)), 'ocr_text'] = pred ################## # log.close() result_df.to_csv(opt.out_df_path, index=False)
def index(): model, converter, length_for_pred, text_for_pred, opt = loader() start_time = time.time() AlignCollate_demo = AlignCollate(imgH=opt['imgH'], imgW=opt['imgW'], keep_ratio_with_pad=opt['PAD']) demo_data = RawDataset(root=opt['image_folder'], opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt['batch_size'], shuffle=False, num_workers=int(opt['workers']), collate_fn=AlignCollate_demo, pin_memory=True) get_data = time.time() - start_time # predict with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # 最大長予測用 # torch.cuda.synchronize(device) if 'CTC' in opt['Prediction']: preds = model(image, text_for_pred) #.log_softmax(2) preds = preds.log_softmax(2) # 最大確率を選択し、インデックスを文字にデコードします preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # 最大確率を選択し、インデックスを文字にデコードします _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt['Prediction']: pred = pred[:pred.find('[s]')] # 文の終わりトークン([s])の後の剪定 print(f'{img_name}\t{pred}') forward_time = time.time() - start_time only_infer_time = forward_time - get_data print('*' * 80) print('get_dta_time:{:.5f}[sec]'.format(get_data)) print('only_infer_time:{:.5f}[sec]'.format(only_infer_time)) print('total_time:{:.5f}[sec]'.format(forward_time)) print('*' * 80) img_name = [i[9:] for i in image_path_list] items = {} for path, pred in zip(img_name, preds_str): items[path] = pred return render_template('index.html', images=items)
def demo(opt): """Open csv file wherein you are going to write the Predicted Words""" data = pd.read_csv('/content/data.csv') """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) dashed_line = '-' * 80 # head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score' # print(f'{dashed_line}\n{head}\n{dashed_line}') # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): start = '/content/Result/Crop Words/' path = os.path.relpath(img_name, start) folder = os.path.dirname(path) image_name = os.path.basename(path) file_name = '_'.join(image_name.split('_')[:-8]) txt_file = os.path.join(start, folder, file_name) log = open(f'{txt_file}_log_demo_result_vgg.txt', 'a') if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] imgcropped = cv2.imread(img_name) cv2_imshow(imgcropped) print(f'{pred:25s}\t {confidence_score:0.4f}\n') # print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}') log.write( f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}\n' ) log.close()
def demo(opt): """ model configuration """ if opt.guide_training : from model_guide import Model else : from model import Model if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else : converter = CTCLabelConverter(opt.character) if opt.Prediction == 'Attn' : converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) opt.num_class_ctc = opt.num_class opt.num_class_attn = opt.num_class_ctc + 1 if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict = False) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() data = pd.DataFrame() with torch.no_grad(): ind = 0 for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: if opt.guide_training : preds = model.module.inference(image, text_for_pred) else : preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) # Select max probabilty (greedy decoding) then decode index to character if opt.baiduCTC: if (opt.beam_search): preds_index = preds else : _, preds_index = preds.max(2) preds_index = preds_index.view(-1) else: _, preds_index = preds.max(2) preds_str = converter.decode(preds_index.data, preds_size.data,opt.beam_search) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] filename = img_name label = pred conf = round(confidence_score.item(),3) img = cv2.imread(filename) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_pil = Image.fromarray(img) img_buffer = io.BytesIO() img_pil.save(img_buffer, format="PNG") imgStr = base64.b64encode(img_buffer.getvalue()).decode("utf-8") data.loc[ind, 'img'] = '<img src="data:image/png;base64,{0:s}">'.format(imgStr) data.loc[ind, 'id'] = filename data.loc[ind, 'label'] = label data.loc[ind, 'conf'] = conf ind+=1 print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close() html_all = data.to_html(escape=False) if opt.is_save : text_file = open("result.html", "w") text_file.write(html_all) text_file.close()
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Bert' in opt.Prediction: converter = TransformerConverter(opt.character, opt.batch_max_length) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) opt.alphabet_size = len(opt.character) + 2 # +2 for [UNK]+[EOS] if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # mkdir result experiment_name = os.path.join('./result', opt.image_folder.split('/')[-2]) if not os.path.exists(experiment_name): os.makedirs(experiment_name) result = {} # predict model.eval() for idx, (image_tensors, image_path_list) in enumerate(demo_loader): batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) elif 'Bert' in opt.Prediction: with torch.no_grad(): pad_mask = None preds = model(image, pad_mask) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds[1].max(2) length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size) preds_str = converter.decode(preds_index, length_for_pred) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print(f'{idx}/{len(demo_data) / opt.batch_size}') for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s]) # for show # write in json name = f'{img_name}'.split('/')[-1].replace('gt', 'res').split('.')[0] value = [{"transcription": f'{pred}'}] result[name] = value with open(f'{experiment_name}/result.json', 'w') as f: json.dump(result, f) print("writed finish...")
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) print(opt.num_class) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device))) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_data = LmdbDataset(root=opt.image_folder, opt=opt, mode='Val') # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True, drop_last=True) log = open(f'./log_demo_result.txt', 'a') # predict model.eval() fail_count, sample_count = 0, 0 record_count = 1 with torch.no_grad(): for image_tensors, image_path_list, original_images, indexes in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) dashed_line = '-' * 80 # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' # print(f'{dashed_line}\n{head}\n{dashed_line}') # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for image_tensor, gt, pred, pred_max_prob, original_image, lmdb_key in zip( image_tensors, image_path_list, preds_str, preds_max_prob, original_images, indexes): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] if pred_max_prob.shape[0] > 0: # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] else: confidence_score = 0.0 compare_gt = "".join(x.upper() for x in gt if x.isalnum()) compare_pred = "".join(x.upper() for x in pred if x.isalnum()) if compare_gt != compare_pred: fail_count += 1 # print(f'{gt:25s}\t{pred:25s}\tFail\t{confidence_score:0.4f}\t{record_count}\n') im = to_pil_image(image_tensor) try: # im.save(os.path.join('result', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpeg')) original_image.save( os.path.join( 'result', 'fail', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpg')) except Exception as e: print( f'Error: {e} {lmdb_key}_{compare_pred}_{compare_gt}' ) exit(1) else: # print(f'{gt:25s}\t{pred:25s}\tSuccess\t{confidence_score:0.4f}') im = to_pil_image(image_tensor) try: # im.save(os.path.join('result', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpeg')) original_image.save( os.path.join( 'result', 'success', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpg')) except Exception as e: print( f'Error: {e} {lmdb_key}_{compare_pred}_{compare_gt}' ) exit(1) sample_count += 1 record_count += 1 log.close() print( f'total accuracy: {(sample_count-fail_count)/sample_count:.2f} number of sample: {sample_count}' )
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size * 2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDis(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Recognizer weight initialization for name, param in ocrModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Discriminator weight initialization for name, param in disModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) ocrModel.train() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() #loading pre-trained model if opt.saved_ocr_model != '': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '': print( f'loading pretrained discriminator model from {opt.saved_dis_model}' ) if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() styleRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.adam: ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.adam: dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '': try: start_iter = int( opt.saved_synth_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter # cntr=0 while (True): # train part image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch( ) # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) # cntr+=1 disCnt = int(image_tensors_all.size(0) / 2) image_tensors, image_tensors_real, labels_1, labels_2 = image_tensors_all[:disCnt], image_tensors_all[ disCnt:disCnt + disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode( labels_2, batch_max_length=opt.batch_max_length) batch_size = image.size(0) images_recon_1, images_recon_2, style = model(image, text_1, text_2) if 'CTC' in opt.Prediction: #ocr training preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #dis training #Check: Using alternate real images disCost = opt.disWeight * 0.5 * ( disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) #synth training preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost = 0.5 * ( ocrCriterion(preds_1, text_1, preds_size_1, length_1) + ocrCriterion(preds_2, text_2, preds_size_2, length_2)) #gen training disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) + disModel.module.calc_gen_loss(images_recon_2)) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) recCost = recCriterion(images_recon_1, image) styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style.detach()) cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost disModel.zero_grad() disCost.backward() torch.nn.utils.clip_grad_norm_( disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #training OCR ocrModel.zero_grad() ocrCost_train.backward() torch.nn.utils.clip_grad_norm_( ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr.add(ocrCost_train) #START HERE # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages', str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image( tensor2im(image[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_input_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(images_recon_1[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_recon_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_pair_' + labels_2[trImgCntr] + '.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: model.eval() ocrModel.eval() disModel.eval() with torch.no_grad(): # valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( # model, criterion, valid_loader, converter, opt) valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw( iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() ocrModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip( labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_{iteration+1}.pth')) torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_{iteration+1}_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_{iteration+1}_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ # CTCLoss converter_ctc = CTCLabelConverter(opt.character) # Attention converter_atten = AttnLabelConverter(opt.character) opt.num_class_ctc = len(converter_ctc.character) opt.num_class_atten = len(converter_atten.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p_: p_.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # use fp16 to train model = model.to(device) if opt.fp16: with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: log.write('==> Enable fp16 training' + '\n') print('==> Enable fp16 training') model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # data parallel for multi-GPU if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model).to(device) model.train() # for i in model.module.Prediction_atten: # i.to(device) # for i in model.module.Feat_Extraction.scr: # i.to(device) if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device) criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() """ final options """ writer = SummaryWriter(f'./saved_models/{opt.exp_name}') # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter # image_tensors, labels = train_dataset.get_batch() while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) batch_size = image.size(0) text_ctc, length_ctc = converter_ctc.encode( labels, batch_max_length=opt.batch_max_length) text_atten, length_atten = converter_atten.encode( labels, batch_max_length=opt.batch_max_length) # type tuple; (tensor, list); text_atten[:, :-1]:align with Attention.forward preds_ctc, preds_atten = model(image, text_atten[:, :-1]) # CTC Loss preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size) # _, preds_index = preds_ctc.max(2) # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data) preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2) cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size, length_ctc) # Attention Loss # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten] # # select max probabilty (greedy decoding) then decode index to character # preds_index_atten = [i.max(2)[1] for i in preds_atten] # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten] # preds_str_atten2 = preds_str_atten # preds_str_atten = [] # for i in preds_str_atten2: # prune after "end of sentence" token ([s]) # temp = [] # for j in i: # j = j[:j.find('[s]')] # temp.append(j) # preds_str_atten.append(temp) # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i] target = text_atten[:, 1:] # without [GO] Symbol # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1)) for index, pred in enumerate(preds_atten): if index == 0: cost_atten = 1.0 * criterion_atten( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) else: cost_atten += 1.0 * criterion_atten( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in # preds_atten] # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1)) cost = cost_ctc + cost_atten writer.add_scalar('loss', cost.item(), global_step=iteration + 1) # cost = cost_ctc # cost = cost_atten if (iteration + 1) % 100 == 0: print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format( iteration + 1, cost.item(), loss_avg.val()), end='\n') else: print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format( iteration + 1, cost.item(), loss_avg.val()), end='') sys.stdout.flush() if cost < 0.001: print(f'iter: {iteration + 1}\tloss: {cost}') # aaaaaa = 0 # model.zero_grad() optimizer.zero_grad() if torch.isnan(cost): print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN') sys.exit() elif torch.isinf(cost): print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF') sys.exit() else: if opt.fp16: with amp.scale_loss(cost, optimizer) as scaled_loss: scaled_loss.backward() else: cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) writer.add_scalar('loss_avg', loss_avg.val(), global_step=iteration + 1) # if loss_avg.val() <= 0.6: # opt.grad_clip = 2 # if loss_avg.val() <= 0.3: # opt.grad_clip = 1 # validation part if iteration == 0 or ( iteration + 1 ) % opt.valInterval == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion_atten, valid_loader, converter_atten, opt) model.train() writer.add_scalar('accuracy', current_accuracy, global_step=iteration + 1) # training loss and validation loss loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() # if (iteration + 1) % opt.valInterval == 0: # print(f'iter: {iteration + 1}\tloss: {cost}') iteration += 1
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open('./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = "image_path" + "\t" + "predicted_labels" + "\t" + "confidence score" print(dashed_line + "\n" + head + "\n" + dashed_line) log.write(dashed_line + "\n" + head + "\n" + dashed_line) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(img_name, pred) #(img_name +"\t"+ pred +"\t"+ confidence_score) log.write(img_name + "\t" + pred + "\n") log.close()
def test(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length + 2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = LmdbDataset(root=opt.test_data, opt=opt) test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. sampler=test_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False) print('-' * 80) opt.num_class = len(converter.character) ocrModel = ModelV1(opt).to(device) ## Loading pre-trained files print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage) ocrModel.load_state_dict(checkpoint) evalCntr = 0 fCntr = 0 c1_s1_input_correct = 0.0 c1_s1_input_ed_correct = 0.0 # pdb.set_trace() for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader): print(vCntr) image_input_tensors = image_input_tensors.to(device) text_gt, length_gt = converter.encode( labels_gt, batch_max_length=opt.batch_max_length) with torch.no_grad(): currBatchSize = image_input_tensors.shape[0] # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device) #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(image_input_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel( image_input_tensors, text_gt[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_1): pred_EOS = pred.find('[s]') preds_str_gt_1[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) for trImgCntr in range(image_input_tensors.shape[0]): #ocr accuracy # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): c1_s1_input_gt = labels_gt[trImgCntr] c1_s1_input_ocr = preds_str_gt_1[trImgCntr] if c1_s1_input_gt == c1_s1_input_ocr: c1_s1_input_correct += 1 # ICDAR2019 Normalized Edit Distance if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0: c1_s1_input_ed_correct += 0 elif len(c1_s1_input_gt) > len(c1_s1_input_ocr): c1_s1_input_ed_correct += 1 - edit_distance( c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt) else: c1_s1_input_ed_correct += 1 - edit_distance( c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr) evalCntr += 1 fCntr += 1 avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr) avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr) # if not(opt.realVaData): with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'), 'a') as log: # training loss and validation loss loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}' print(loss_log) log.write(loss_log + "\n")
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # model = model.to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device))) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() fail_count, sample_count = 0, 0 with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] if pred_max_prob.shape[0] > 0: # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] else: confidence_score = 0.0 gt = img_name.split('_L_')[1] gt = gt.split('.')[0] pred = pred.split('.')[0] # if img_name.find('1_225427_L_대전출입국관리사무소_L_21.png') >=0 : # print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') if gt.split('(')[0] != pred.split('(')[0]: fail_count += 1 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) else: print( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') pass sample_count += 1 log.close() print(f'total accuracy: {(sample_count-fail_count)/sample_count:.2f}')