def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) elif 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 2 """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') log.write(f'{accuracy_by_best_model:0.3f}\n') log.close()
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) elif "Transformer" in opt.Prediction: converter = TransformerLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): #if 'localization_fc2' in name: if 'localization_fc2' in name or 'decoder' in name or 'self_attn' in name or 'Seq2Seq' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) elif "Transformer" in opt.Prediction: criterion = torch.nn.CrossEntropyLoss(ignore_index=2).to(device) # ignore [PAD] token = ignore index 1 else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: print("use Adadelta") optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction or "Transformer" in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, # 'True' to check training progress with validation function. shuffle=True, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) elif 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue """ setup loss """ if 'Transformer' in opt.SequenceModeling: criterion = transformer_loss elif 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: # ignore [GO] token = ignore index 0 criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) elif 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer = optim.Adam(filtered_parameters, betas=(0.9, 0.98), eps=1e-09) optimizer_schedule = ScheduledOptim(optimizer, opt.d_model, opt.n_warmup_steps) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") if opt.load_weights != '' and check_isfile(opt.load_weights): # load pretrained weights but ignore layers that don't match in size checkpoint = torch.load(opt.load_weights, pickle_module=pickle) if type(checkpoint) == dict: pretrain_dict = checkpoint['state_dict'] else: pretrain_dict = checkpoint model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) print("Loaded pretrained weights from '{}'".format(opt.load_weights)) del checkpoint torch.cuda.empty_cache() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') checkpoint = torch.load(opt.continue_model) print(checkpoint.keys()) model.load_state_dict(checkpoint['state_dict']) start_iter = checkpoint['step'] + 1 print('continue to train start_iter: ', start_iter) if 'optimizer' in checkpoint.keys(): optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() if 'best_accuracy' in checkpoint.keys(): best_accuracy = checkpoint['best_accuracy'] if 'best_norm_ED' in checkpoint.keys(): best_norm_ED = checkpoint['best_norm_ED'] del checkpoint torch.cuda.empty_cache() # data parallel for multi-GPU model = torch.nn.DataParallel(model).cuda() model.train() print("Model size:", count_num_param(model), 'M') if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer_schedule.n_current_steps = start_iter for i in tqdm(range(start_iter, opt.num_iter)): for p in model.parameters(): p.requires_grad = True cpu_images, cpu_texts = train_dataset.get_batch() image = cpu_images.cuda() if 'Transformer' in opt.SequenceModeling: text, length, text_pos = converter.encode(cpu_texts, opt.batch_max_length) elif 'CTC' in opt.Prediction: text, length = converter.encode(cpu_texts) else: text, length = converter.encode(cpu_texts, opt.batch_max_length) batch_size = image.size(0) if 'Transformer' in opt.SequenceModeling: preds = model(image, text, tgt_pos=text_pos) target = text[:, 1:] # without <s> Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) elif 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) else: preds = model(image, text) target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim: optimizer_schedule.step_and_update_lr() elif 'Transformer' in opt.SequenceModeling: optimizer.step() else: # gradient clipping with 5 (Default) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() loss_avg.add(cost) # validation part if i > 0 and (i + 1) % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], gts[:5]): if 'Transformer' in opt.SequenceModeling: pred = pred[:pred.find('</s>')] gt = gt[:gt.find('</s>')] elif 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i+1}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy state_dict = model.module.state_dict() save_checkpoint( { 'best_accuracy': best_accuracy, 'state_dict': state_dict, }, False, f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED state_dict = model.module.state_dict() save_checkpoint( { 'best_norm_ED': best_norm_ED, 'state_dict': state_dict, }, False, f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) # torch.save( # model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1000 iter. if (i + 1) % 1000 == 0: state_dict = model.module.state_dict() optimizer_state_dict = optimizer.state_dict() save_checkpoint( { 'state_dict': state_dict, 'optimizer': optimizer_state_dict, 'step': i, 'best_accuracy': best_accuracy, 'best_norm_ED': best_norm_ED, }, False, f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
def test(opt): """ model configuration """ if 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) elif 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # load model if opt.saved_model != '': print('loading pretrained model from %s' % opt.saved_model) checkpoint = torch.load(opt.saved_model) if type(checkpoint) == dict: model.load_state_dict(checkpoint['state_dict']) else: model = torch.nn.DataParallel(model).cuda() model.load_state_dict(checkpoint) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) #parallel model model = torch.nn.DataParallel(model).cuda() # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'Transformer' in opt.SequenceModeling: # ignore PAD token = ignore index 2 criterion = torch.nn.CrossEntropyLoss(ignore_index=2).cuda() elif 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: # ignore [GO] token = ignore index 0 criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() """ evaluation """ model.eval() if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) print(accuracy_by_best_model) with open( './result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log: log.write(str(accuracy_by_best_model) + '\n')
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = TransformerLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a', encoding='utf-16') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close()
def demo(opt): """ model configuration """ if 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) elif 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # load model if opt.saved_model != '': print('loading pretrained model from %s' % opt.saved_model) checkpoint = torch.load(opt.saved_model) if type(checkpoint) == dict: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) del checkpoint torch.cuda.empty_cache() model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() dict_gt = {} with open('gt.txt', 'r') as gt_file: gt = gt_file.readlines() for line in gt: key = line.split(', "')[0] value = line.split(', "')[1].replace('"\n', '').lower() dict_gt[key] = value for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'Transformer' in opt.SequenceModeling: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) elif 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Transformer' in opt.SequenceModeling: pred = pred[:pred.find('</s>')] elif 'Attn' in opt.Prediction: # prune after "end of sentence" token ([s]) pred = pred[:pred.find('[s]')] raw_img = cv2.imread(img_name) raw_img = cv2.resize(raw_img, (200, 64)) tmp_img = np.zeros((128, 200, 3), np.uint8) tmp_img.fill(255) tmp_img[:64, :200] = raw_img raw_img = tmp_img font = cv2.FONT_HERSHEY_SIMPLEX bottomLeftCornerOfText = (5, 90) fontScale = 1 fontColor = (0, 0, 255) lineType = 2 if pred == dict_gt[img_name.split('/')[-1]]: cv2.putText(raw_img, pred, (5, 90), font, fontScale, (0, 255, 0), lineType) raw_img = raw_img[:96, :200] cv2.imwrite('./trash/true/' + img_name.split('/')[-1], raw_img) else: cv2.putText(raw_img, pred, (5, 90), font, fontScale, (0, 0, 255), lineType) cv2.putText(raw_img, dict_gt[img_name.split('/')[-1]], (5, 125), font, fontScale, (0, 255, 0), lineType) cv2.imwrite('./trash/false/' + img_name.split('/')[-1], raw_img) print(f'{img_name}\t{pred}')