def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + ' ') log.close() """ model configuration """ # if 'CTC' in opt.Prediction: if opt.baiduCTC: CTC_converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: CTC_converter = CTCLabelConverter(opt.character) # else: Attn_converter = AttnLabelConverter(opt.character) opt.num_class_ctc = len(CTC_converter.character) opt.num_class_attn = len(Attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class_ctc, opt.num_class_attn, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): # print(name) if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() print("Model:") print(model) # print(summary(model, (1, opt.imgH, opt.imgW,1))) """ setup loss """ if opt.baiduCTC: # need to install warpctc. see our guideline. if opt.label_smooth: criterion_major_path = SmoothCTCLoss(num_classes=opt.num_class_ctc, weight=0.05) else: criterion_major_path = CTCLoss() #criterion_major_path = CTCLoss(average_frames=False, reduction="mean", blank=0) else: criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device) # else: # criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager #criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device) criterion_guide_path = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) loss_avg_major_path = Averager() loss_avg_guide_path = Averager() # filter that only require gradient decent guide_parameters = [] major_parameters = [] guide_model_part_names = [ "Transformation", "FeatureExtraction", "SequenceModeling_Attn", "Attention" ] major_model_part_names = ["SequenceModeling_CTC", "CTC"] for name, param in model.named_parameters(): if param.requires_grad: if name.split(".")[1] in guide_model_part_names: guide_parameters.append(param) elif name.split(".")[1] in major_model_part_names: major_parameters.append(param) # print(name) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] if opt.continue_training: guide_parameters = [] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer_ctc = AdamW(major_parameters, lr=opt.lr) if not opt.continue_training: optimizer_attn = AdamW(guide_parameters, lr=opt.lr) scheduler_ctc = get_linear_schedule_with_warmup( optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter) scheduler_attn = get_linear_schedule_with_warmup( optimizer_attn, num_warmup_steps=10000, num_training_steps=opt.num_iter) start_iter = 0 if opt.saved_model != '' and (not opt.continue_training): print(f'loading pretrained model from {opt.saved_model}') checkpoint = torch.load(opt.saved_model) start_iter = checkpoint['start_iter'] + 1 if not opt.adam: optimizer_ctc.load_state_dict( checkpoint['optimizer_ctc_state_dict']) if not opt.continue_training: optimizer_attn.load_state_dict( checkpoint['optimizer_attn_state_dict']) scheduler_ctc.load_state_dict( checkpoint['scheduler_ctc_state_dict']) scheduler_attn.load_state_dict( checkpoint['scheduler_attn_state_dict']) print(scheduler_ctc.get_lr()) print(scheduler_attn.get_lr()) if opt.FT: model.load_state_dict(checkpoint['model_state_dict'], strict=False) else: model.load_state_dict(checkpoint['model_state_dict']) if opt.continue_training: model.load_state_dict(torch.load(opt.saved_model)) # print("Optimizer:") # print(optimizer) # scheduler_ctc = get_linear_schedule_with_warmup( optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter, last_epoch=start_iter - 1) scheduler_attn = get_linear_schedule_with_warmup( optimizer_attn, num_warmup_steps=10000, num_training_steps=opt.num_iter, last_epoch=start_iter - 1) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options ------------- ' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)} ' opt_log += '--------------------------------------- ' print(opt_log) opt_file.write(opt_log) """ start training """ start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter - 1 if opt.continue_training: start_iter = 0 while (True): # train part image_tensors, labels = train_dataset.get_batch() iteration += 1 if iteration < start_iter: continue image = image_tensors.to(device) # print(image.size()) text_attn, length_attn = Attn_converter.encode( labels, batch_max_length=opt.batch_max_length) #print("1") text_ctc, length_ctc = CTC_converter.encode( labels, batch_max_length=opt.batch_max_length) #print("2") #if iteration == start_iter : # writer.add_graph(model, (image, text_attn)) batch_size = image.size(0) preds_major, preds_guide = model(image, text_attn[:, :-1]) #print("10") preds_size = torch.IntTensor([preds_major.size(1)] * batch_size) if opt.baiduCTC: preds_major = preds_major.permute(1, 0, 2) # to use CTCLoss format if opt.label_smooth: cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size, length_ctc, batch_size) else: cost_ctc = criterion_major_path( preds_major, text_ctc, preds_size, length_ctc) / batch_size else: preds_major = preds_major.log_softmax(2).permute(1, 0, 2) cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size, length_ctc) #print("3") # preds = model(image, text[:, :-1]) # align with Attention.forward target = text_attn[:, 1:] # without [GO] Symbol if not opt.continue_training: cost_attn = criterion_guide_path( preds_guide.view(-1, preds_guide.shape[-1]), target.contiguous().view(-1)) optimizer_attn.zero_grad() cost_attn.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_( guide_parameters, opt.grad_clip) # gradient clipping with 5 (Default) optimizer_attn.step() optimizer_ctc.zero_grad() cost_ctc.backward() torch.nn.utils.clip_grad_norm_( major_parameters, opt.grad_clip) # gradient clipping with 5 (Default) optimizer_ctc.step() scheduler_ctc.step() scheduler_attn.step() #print("4") loss_avg_major_path.add(cost_ctc) if not opt.continue_training: loss_avg_guide_path.add(cost_attn) if (iteration + 1) % 100 == 0: writer.add_scalar("Loss/train_ctc", loss_avg_major_path.val(), (iteration + 1) // 100) loss_avg_major_path.reset() if not opt.continue_training: writer.add_scalar("Loss/train_attn", loss_avg_guide_path.val(), (iteration + 1) // 100) loss_avg_guide_path.reset() # validation part if ( iteration + 1 ) % opt.valInterval == 0: #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion_major_path, valid_loader, CTC_converter, opt) model.train() writer.add_scalar("Loss/valid", valid_loss, (iteration + 1) // opt.valInterval) writer.add_scalar("Metrics/accuracy", current_accuracy, (iteration + 1) // opt.valInterval) writer.add_scalar("Metrics/norm_ED", current_norm_ED, (iteration + 1) // opt.valInterval) # loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' # loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # training loss and validation loss if not opt.continue_training: loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Train loss attn: {loss_avg_guide_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' else: loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_major_path.reset() if not opt.continue_training: loss_avg_guide_path.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'{fol_ckpt}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'{fol_ckpt}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log} {current_model_log} {best_model_log}' print(loss_model_log) log.write(loss_model_log + ' ') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line} {head} {dashed_line} ' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): # if 'Attn' in opt.Prediction: # gt = gt[:gt.find('[s]')] # pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f} {str(pred == gt)} ' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + ' ') # save model per 1e+5 iter. if (iteration + 1) % 1e+3 == 0 and (not opt.continue_training): # print(scheduler_ctc.get_lr()) # print(scheduler_attn.get_lr()) torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_attn_state_dict': optimizer_attn.state_dict(), 'optimizer_ctc_state_dict': optimizer_ctc.state_dict(), 'start_iter': iteration, 'scheduler_ctc_state_dict': scheduler_ctc.state_dict(), 'scheduler_attn_state_dict': scheduler_attn.state_dict(), }, f'{fol_ckpt}/current_model.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def test(opt): """ model configuration """ # if 'CTC' in opt.Prediction: if opt.baiduCTC: converter_ctc = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter_ctc = CTCLabelConverter(opt.character) # else: converter_attn = AttnLabelConverter(opt.character) opt.num_class_ctc = len(converter_ctc.character) opt.num_class_attn = len(converter_attn.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) # print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, # opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, # opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False) opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') """ setup loss """ # if 'CTC' in opt.Prediction: criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device) # else: criterion_attn = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) benchmark_all_eval(model, criterion_ctc, criterion_attn, evaluation_loader, converter_ctc, converter_attn, opt) else: log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _, _, acc_attn, _, _, _ = validation_ctc_and_attn( model, criterion_ctc, criterion_attn, evaluation_loader, converter_ctc, converter_attn, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') print(f'{acc_attn:0.3f}') log.write(f'{accuracy_by_best_model:0.3f} ') log.close()
def demo(opt): """ model configuration """ if opt.guide_training : from model_guide import Model else : from model import Model if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else : converter = CTCLabelConverter(opt.character) if opt.Prediction == 'Attn' : converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) opt.num_class_ctc = opt.num_class opt.num_class_attn = opt.num_class_ctc + 1 if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict = False) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() data = pd.DataFrame() with torch.no_grad(): ind = 0 for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: if opt.guide_training : preds = model.module.inference(image, text_for_pred) else : preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) # Select max probabilty (greedy decoding) then decode index to character if opt.baiduCTC: if (opt.beam_search): preds_index = preds else : _, preds_index = preds.max(2) preds_index = preds_index.view(-1) else: _, preds_index = preds.max(2) preds_str = converter.decode(preds_index.data, preds_size.data,opt.beam_search) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] filename = img_name label = pred conf = round(confidence_score.item(),3) img = cv2.imread(filename) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_pil = Image.fromarray(img) img_buffer = io.BytesIO() img_pil.save(img_buffer, format="PNG") imgStr = base64.b64encode(img_buffer.getvalue()).decode("utf-8") data.loc[ind, 'img'] = '<img src="data:image/png;base64,{0:s}">'.format(imgStr) data.loc[ind, 'id'] = filename data.loc[ind, 'label'] = label data.loc[ind, 'conf'] = conf ind+=1 print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close() html_all = data.to_html(escape=False) if opt.is_save : text_file = open("result.html", "w") text_file.write(html_all) text_file.close()
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 # opt.select_data = opt.select_data.split('-')#[MJ,ST] # opt.batch_ratio = opt.batch_ratio.split('-')#[0.5,0.5] # train_dataset = Batch_Balanced_Dataset(opt) # log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) # train_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/train", # "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/traindata.mat",opt) # train_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr", # "/home/ldl/桌面/论文/文本识别/data/paddleocr/label/train.txt",6625*100) # train_dataset_chinese = TextRecognition(4068*75,opt.charalength,opt.chinesefile) # train_dataset_english = TextRecognition(4068*25,opt.charalength,opt.englishfile) # train_dataset = ConcatDataset([train_dataset_chinese,train_dataset_english]) train_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/gt') train_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/gt') train_dataset = ConcatDataset([train_dataset_xunfeichn,train_dataset_xunfeieng]) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid) # valid_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/test", # "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/testdata.mat",opt) # valid_dataset_chinese = TextRecognition(1001,opt.charalength,opt.chinesefile) # valid_dataset_english = TextRecognition(1001,opt.charalength,opt.englishfile) # valid_dataset = ConcatDataset([valid_dataset_chinese,valid_dataset_english]) valid_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/gt') valid_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/img', annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/gt') valid_dataset = ConcatDataset([valid_dataset_xunfeichn,valid_dataset_xunfeieng]) # valid_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/images", # "/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/test.txt",6625, # split='jpg') valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid) # log.write(valid_dataset_log) print('-' * 80) # log.write('-' * 80 + '\n') # log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU if opt.num_gpu > 1: model = torch.nn.DataParallel(model).to(device) else: model.to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") # print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) print(f"Train dataset length {len(train_dataset)}") print(f"Train dataset length {len(valid_dataset)}") # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer # if opt.adam: # optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) # else: # optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) optimizer = optim.Adam(filtered_parameters,lr=0.0001) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 # if opt.saved_model != '': # try: # start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) # print(f'continue to train, start_iter: {start_iter}') # except: # pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter epoch = 0 train_iter_loader = iter(train_loader) while(True): # train part try: image_tensors, labels = train_iter_loader.next() # if len(labels)>80: # print(labels) print("{:4}".format(iteration),end='\r') except StopIteration: epoch += 1 print(f"epoch:{epoch}") # if epoch >= 1: # break # train_loader = torch.utils.data.DataLoader( # train_dataset, batch_size=opt.batch_size, # shuffle=True, # 'True' to check training progress with validation function. # num_workers=int(opt.workers), # collate_fn=AlignCollate_valid) train_iter_loader = iter(train_loader) continue image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * preds.size(0)) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) try: cost = criterion(preds, text, preds_size, length) except Exception: print(preds.shape,preds_size.shape) raise '' else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() print("validation") with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy >= best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED >= best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+4 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') # train_dataset (image, label) train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) print("Model:") print(model) total_num, true_grad_num, false_grad_num = calculate_model_params(model) print("Total parameters: ", total_num) print("Number of parameters requires grad: ", true_grad_num) print("Number of parameters do not require grad: ", false_grad_num) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if isinstance(model, torch.nn.DataParallel): model = model.module # load pretrained model if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_pretrained_networks() elif opt.continue_train: model.load_checkpoint(opt.model_name) else: raise Exception('Something went wrong!') """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() log_dir = f'./saved_models/{opt.exp_name}' writer = SummaryWriter(log_dir) # """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) model.optimize_parameters() writer.add_scalar('train_loss', cost, iteration + 1) loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt, iteration) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('val_loss', valid_loss, iteration + 1) writer.add_scalar('accuracy', current_accuracy, iteration + 1) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy model.save_checkpoints(iteration, 'best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED model.save_checkpoints(iteration, 'best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: model.save_checkpoints(iteration + 1, opt.model_name) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1