def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction, SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).cuda() # load model if opt.saved_model != '': print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) opt.name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = CTCLoss(reduction='sum') else: criterion = torch.nn.CrossEntropyLoss( ignore_index=0).cuda() # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) print(accuracy_by_best_model) with open('./result/{0}/log_evaluation.txt'.format(opt.name), 'a') as log: log.write(str(accuracy_by_best_model) + '\n')
def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) elif 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 2 """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') log.write(f'{accuracy_by_best_model:0.3f}\n') log.close()
def dataloader(self, opt): src_train_data = opt.src_train_data src_select_data = opt.src_select_data src_batch_ratio = opt.src_batch_ratio src_train_dataset = Batch_Balanced_Dataset(opt, src_train_data, src_select_data, src_batch_ratio) tar_train_data = opt.tar_train_data tar_select_data = opt.tar_select_data tar_batch_ratio = opt.tar_batch_ratio tar_train_dataset = Batch_Balanced_Dataset(opt, tar_train_data, tar_select_data, tar_batch_ratio) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) return src_train_dataset, tar_train_dataset, valid_loader
def load_model(self, root_dir): ## --- Load Prediction component --- ## opt = self.opt if 'CTC' in opt.Prediction: from libs.clova_ai_recognition.utils import CTCLabelConverter self.converter = CTCLabelConverter(opt.character) else: from libs.clova_ai_recognition.utils import AttnLabelConverter self.converter = AttnLabelConverter(opt.character) self.opt.num_class = len(self.converter.character) ## ---- Data loader ---- ## from dataset import AlignCollate self.AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) if opt.rgb: opt.input_channel = 3 ## --- Load model --- # model_dir = os.path.join(root_dir, '_'.join(sorted(self.langs))) if not os.path.isdir(model_dir): exit('ERROR: Model folder %s not found.' % model_dir) opt.saved_model = os.path.join(model_dir, 'best_accuracy.pth') from model import Model self.model = torch.nn.DataParallel(Model(opt)).to(self.device) self.model.load_state_dict(torch.load(opt.saved_model, map_location=self.device)) self.model.eval() return
def predictAllImagesInFolder(self, src_path): opt = self.opts AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=src_path, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=torch.cuda.is_available()) results = [] for image_tensors, image_path_list in demo_loader: preds_str = self.predict(image_tensors) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) results.append(f'{os.path.basename(img_name)},{pred}') return results
def init_model(args): """ model configuration """ if 'CTC' in args.Prediction: converter = CTCLabelConverter(args.character) else: converter = AttnLabelConverter(args.character) args.num_class = len(converter.character) if args.rgb: args.input_channel = 3 model = Model(args) print('model input parameters', args.imgH, args.imgW, args.num_fiducial, args.input_channel, args.output_channel, args.hidden_size, args.num_class, args.batch_max_length, args.Transformation, args.FeatureExtraction, args.SequenceModeling, args.Prediction) try: model = torch.nn.DataParallel(model).to(device) except RuntimeError: raise RuntimeError(device) # load model print('loading pretrained model from %s' % args.saved_model) model.load_state_dict(torch.load(args.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=args.imgH, imgW=args.imgW, keep_ratio_with_pad=args.PAD) model.eval() return converter, model, AlignCollate_demo
def dataset_preparation(opt): if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() return train_dataset, valid_loader
def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): """ evaluation with 10 benchmark evaluation datasets """ # The evaluation datasets, dataset order is same with Table 1 in our paper. eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] # # To easily compute the total accuracy of our paper. # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] if calculate_infer_time: evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. else: evaluation_batch_size = opt.batch_size list_accuracy = [] total_forward_time = 0 total_evaluation_data_number = 0 total_correct_number = 0 log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') dashed_line = '-' * 80 print(dashed_line) log.write(dashed_line + '\n') for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=evaluation_batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation( model, criterion, evaluation_loader, converter, opt) list_accuracy.append(f'{accuracy_by_best_model:0.3f}') total_forward_time += infer_time total_evaluation_data_number += len(eval_data) total_correct_number += accuracy_by_best_model * length_of_data log.write(eval_data_log) print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}') log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n') print(dashed_line) log.write(dashed_line + '\n') averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 total_accuracy = total_correct_number / total_evaluation_data_number params_num = sum([np.prod(p.size()) for p in model.parameters()]) evaluation_log = 'accuracy: ' for name, accuracy in zip(eval_data_list, list_accuracy): evaluation_log += f'{name}: {accuracy}\t' evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' print(evaluation_log) log.write(evaluation_log + '\n') log.close() return None
def predict(images_arr): # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo images = [Image.fromarray(crop) for crop in images_arr] AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = np_RawDataset(images=images, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) stt_c = 0 image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) stt = 0 out = [] for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] confidence_score = pred_max_prob.cumprod(dim=0)[-1] pred = pred.replace("@", "-") out.append(pred) return out
def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): """ evaluation with 10 benchmark evaluation datasets """ list_accuracy = [] Total_forward_time = 0 Total_evaluation_data_number = 0 # The evaluation datasets, dataset order is same with Table 1 in our paper. eval_data_list = [ 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' ] if calculate_infer_time: evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. else: evaluation_batch_size = opt.batch_size print('-' * 80) for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) eval_data = hierarchical_dataset(root=eval_data_path, opt=opt) print('-' * 80) Total_evaluation_data_number += len(eval_data) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=evaluation_batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, infer_time = validation( model, criterion, evaluation_loader, converter, opt) Total_forward_time += infer_time list_accuracy.append(f'{accuracy_by_best_model:0.3f}') averaged_forward_time = Total_forward_time / Total_evaluation_data_number * 1000 params_num = sum([np.prod(p.size()) for p in model.parameters()]) evaluation_log = 'accuracy: ' for name, accuracy in zip(eval_data_list, list_accuracy): evaluation_log += f'{name}: {accuracy}\t' evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}, # parameters: {params_num/1e6:0.3f}' print(evaluation_log) with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log: log.write(evaluation_log + '\n') return None
def main(self, image, image_path, bboxes): # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=self.imgH, imgW=self.imgW, keep_ratio_with_pad=self.PAD) dataset = self.cropedDataset(image=image, image_path=image_path, bboxes=bboxes) data_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=len(dataset), shuffle=False, num_workers=int(self.workers), collate_fn=AlignCollate_demo, pin_memory=True) result_str = self.extract_text(data_loader) return result_str
def predict(self, img): self.eval() batch_size = 1 with torch.no_grad(): AlignCollate_demo = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD) transform_PIL = transforms.ToPILImage() image = [transform_PIL(img)] image = AlignCollate_demo((image, ""))[0] length_for_pred = torch.IntTensor([self.opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor( batch_size, self.opt.batch_max_length + 1).fill_(0).to(device) print(image.shape) cv2.imshow("", tensor2im(image[0])) if 'CTC' in self.opt.Prediction: preds = self(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = self.converter.decode(preds_index.data, preds_size.data) else: preds = self(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = self.converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) pred_max_prob = preds_max_prob[0] pred = preds_str[0] if 'Attn' in self.opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] confidence_score = pred_max_prob.cumprod(dim=0)[-1] return pred, confidence_score
def train(): model = ReSeg(2, pretrained=False, use_coordinates=True, usegpu=False) batch_size = 2 criterion1 = DiceLoss() criterion2 = DiscriminativeLoss(0.5, 1.5, 2) optimizer = optim.Adam(model.parameters()) dst = SegDataset(list_path=list_path, img_root=img_root, height=448, width=448, number_of_instances=number_of_instances, semantic_ann_npy=semantic_ann_npy, instances_ann_npy=instances_ann_npy, transform=x_transform) ac = AlignCollate(2, 100, 448, 448) trainloader = torch.utils.data.DataLoader(dst, batch_size=1, collate_fn=ac) train_model(model, criterion1, criterion2, optimizer, trainloader)
def original_demo(model, converter, length_for_pred, text_for_pred): opt = option() AlignCollate_demo = AlignCollate(imgH=opt['imgH'], imgW=opt['imgW'], keep_ratio_with_pad=opt['PAD']) demo_data = RawDataset(root=opt['image_folder'], opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt['batch_size'], shuffle=False, num_workers=int(opt['workers']), collate_fn=AlignCollate_demo, pin_memory=True) print(demo_loader) # predict with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # 最大長予測用 #torch.cuda.synchronize(device) if 'CTC' == opt['Prediction']: print('kotti') preds = model(image, text_for_pred).log_softmax(2) # 最大確率を選択し、インデックスを文字にデコードします preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # 最大確率を選択し、インデックスを文字にデコードします _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' == opt['Prediction']: pred = pred[:pred.find('[s]')] # 文の終わりトークン([s])の後の剪定 print(f'{img_name}\t{pred}')
def test(opt): """ model configuration """ converter, model = model_configuration(opt) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') log.write(f'{accuracy_by_best_model:0.3f}\n') log.close()
def train(opt): plotDir = os.path.join(opt.exp_dir,opt.exp_name,'plots') if not os.path.exists(plotDir): os.makedirs(plotDir) lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size*2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDisV1(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for currModel in [model, ocrModel, disModel]: for name, param in currModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) if not opt.ocrFixed: ocrModel.train() else: ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth")))>0: opt.saved_dis_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth"))[-1] #loading pre-trained model if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '' and opt.saved_dis_model != 'None': print(f'loading pretrained discriminator model from {opt.saved_dis_model}') if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() styleRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() loss_avg_ocrRecon_1 = Averager() loss_avg_ocrRecon_2 = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_styRecon = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim=='adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.optim=='adam': ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.optim=='adam': dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter cntr=0 while(True): # train part if opt.lr_policy !="None": scheduler.step() ocr_scheduler.step() dis_scheduler.step() image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch() # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) cntr+=1 disCnt = int(image_tensors_all.size(0)/2) image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) batch_size = image.size(0) ##-----------------------------------## #generate text(labels) from ocr.forward if opt.ocrFixed: # ocrModel.eval() length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = ocrModel(image, text_for_pred) preds = preds[:, :text_for_loss.shape[1] - 1, :] preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(labels_1): pred_EOS = pred.find('[s]') labels_1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # ocrModel.train() else: labels_1 = labels_gt ##-----------------------------------## text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator images_recon_1, images_recon_2, style = model(image, text_1, text_2) if 'CTC' in opt.Prediction: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1) ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2) # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 ) else: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1[:, :-1]) # align with Attention.forward target_ocr = text_1[:, 1:] # without [GO] Symbol ocrCost_train = ocrCriterion(preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1)) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False) # align with Attention.forward target_1 = text_1[:, 1:] # without [GO] Symbol preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1)) ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1)) # ocrCost = 0.5*(ocrCost_1+ocrCost_2) if not opt.ocrFixed: #training OCR ocrModel.zero_grad() ocrCost_train.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() #if ocr is fixed; ignore this loss loss_avg_ocr.add(ocrCost_train) else: loss_avg_ocr.add(torch.tensor(0.0)) #Domain discriminator: Dis update disModel.zero_grad() disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) disCost.backward() # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) # #[Style Encoder] + [Word Generator] update #Adversarial loss disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2)) #Input reconstruction loss recCost = recCriterion(images_recon_1,image) #Pair style reconstruction loss if opt.styleReconWeight == 0.0: styleRecCost = torch.tensor(0.0) else: if opt.styleDetach: styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach()) else: styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style) #OCR Content cost ocrCost = 0.5*(ocrCost_1+ocrCost_2) cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.styleReconWeight*styleRecCost model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #Individual losses loss_avg_ocrRecon_1.add(opt.ocrWeight*0.5*ocrCost_1) loss_avg_ocrRecon_2.add(opt.ocrWeight*0.5*ocrCost_2) loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(opt.reconWeight*recCost) loss_avg_styRecon.add(opt.styleReconWeight*styleRecCost) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_gt[trImgCntr]+'.png')) save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: model.eval() ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() disModel.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res( iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() if not opt.ocrFixed: ocrModel.train() else: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.train() # ocrModel.module.Prediction.eval() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' #plotting lib.plot.plot(os.path.join(plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Synth-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Dis-Loss'), valid_loss[2].item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon1-Loss'), valid_loss[3].item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon2-Loss'), valid_loss[4].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Gen-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(plotDir,'Valid-ImgRecon1-Loss'), valid_loss[6].item()) lib.plot.plot(os.path.join(plotDir,'Valid-StyRecon2-Loss'), valid_loss[7].item()) lib.plot.plot(os.path.join(plotDir,'Orig-OCR-WordAccuracy'), current_accuracy[0]) lib.plot.plot(os.path.join(plotDir,'Recon-OCR-WordAccuracy'), current_accuracy[1]) lib.plot.plot(os.path.join(plotDir,'Pair-OCR-WordAccuracy'), current_accuracy[2]) lib.plot.plot(os.path.join(plotDir,'Orig-OCR-CharAccuracy'), current_norm_ED[0]) lib.plot.plot(os.path.join(plotDir,'Recon-OCR-CharAccuracy'), current_norm_ED[1]) lib.plot.plot(os.path.join(plotDir,'Pair-OCR-CharAccuracy'), current_norm_ED[2]) # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] if not opt.ocrFixed: torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] if not opt.ocrFixed: torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: # gt_ocr = gt_ocr[:gt_ocr.find('[s]')] pred_ocr = pred_ocr[:pred_ocr.find('[s]')] # gt_1 = gt_1[:gt_1.find('[s]')] pred_1 = pred_1[:pred_1.find('[s]')] # gt_2 = gt_2[:gt_2.find('[s]')] pred_2 = pred_2[:pred_2.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() loss_avg_ocrRecon_1.reset() loss_avg_ocrRecon_2.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_styRecon.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def demoToTxt1(image_folder, saved_model, txtFile): # sensitive parser = argparse.ArgumentParser() parser.add_argument('--image_folder', default=image_folder, help='path to image_folder which contains text images') parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser.add_argument('--batch_size', type=int, default=100, help='input batch size') parser.add_argument('--saved_model', default=saved_model, help="path to saved_model to evaluation") """ Data processing """ parser.add_argument('--batch_max_length', type=int, default=20, help='maximum-label-length') parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') parser.add_argument('--rgb', action='store_true', help='use rgb input') parser.add_argument('--character', type=str, default='0123456789', help='character label') parser.add_argument('--sensitive', default=True, help='for sensitive character mode') parser.add_argument('--PAD', default=False, action='store_true', help='whether to keep ratio then pad for image resize') """ Model Architecture """ parser.add_argument('--Transformation', default='TPS', type=str, help='Transformation stage. None|TPS') parser.add_argument('--FeatureExtraction', default='ResNet', type=str, help='FeatureExtraction stage. VGG|RCNN|ResNet') parser.add_argument('--SequenceModeling', default='BiLSTM', type=str, help='SequenceModeling stage. None|BiLSTM') parser.add_argument('--Prediction', default='CTC', type=str, help='Prediction stage. CTC|Attn') parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') parser.add_argument( '--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') parser.add_argument( '--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') opt = parser.parse_args() """ vocab / character number configuration """ if opt.sensitive: opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' # opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True opt.num_gpu = torch.cuda.device_count() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() saved_file = open(txtFile, 'w') for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) print(f'{img_name}\t{pred}') saved_file.write(f'{img_name}\t{pred}\n')
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) print(f'{img_name}\t{pred}')
def demo(opt): """ model configuration """ lists = [] #목적지라고 생각하는 사진에서 인식한 text를 담을 배열 converter = AttnLabelConverter(opt.character) #ATTN opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) #model.py의 Model import print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) #파라미터 값 정보 출력 model = torch.nn.DataParallel(model).to(device) #GPU로 데이터 병렬 처리 진행 # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) #모델의 매개변수를 불러옴 AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data1 = RawDataset(root=opt.image_folder1, opt=opt) # use RawDataset 간판탐지결과 demo_data2 = RawDataset(root=opt.image_folder2, opt=opt) # use RawDataset 구글맵문자열탐지결과 demo_loader1 = torch.utils.data.DataLoader(demo_data1, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) demo_loader2 = torch.utils.data.DataLoader(demo_data2, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader1: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) #ATTn preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') #이어서 쓸수 있게 열고 dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') #테이블 양식 출력 log.write( f'{dashed_line}\n{head}\n{dashed_line}\n') #txt에 테이블 양식 저장 preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence score 값을 계산 confidence_score = pred_max_prob.cumprod(dim=0)[-1] lists.append(pred) print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}' ) #구한 값을 출력 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) #구한 값을 txt에 저장 log.close() #파일 닫기 with torch.no_grad(): for image_tensors, image_path_list in demo_loader2: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) #ATTn preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') #이어서 쓸수 있게 열고 dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') #테이블 양식 출력 log.write( f'{dashed_line}\n{head}\n{dashed_line}\n') #txt에 테이블 양식 저장 preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # confidence score 값을 계산 confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}' ) #구한 값을 출력 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) #구한 값을 txt에 저장 if pred in lists: print(pred + "은(는) 알맞은 목적지입니다.") else: print(pred + "은(는) 알맞은 목적지가 아닙니다.") log.close() #파일 닫기
def train(opt): os.makedirs(opt.log, exist_ok=True) writer = SummaryWriter(opt.log) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ ctc_converter = CTCLabelConverter(opt.character) attn_converter = AttnLabelConverter(opt.character) opt.num_class = len(attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) """ setup loss """ loss_avg = Averager() ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device) attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter pbar = tqdm(range(opt.num_iter)) for iteration in pbar: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) ctc_text, ctc_length = ctc_converter.encode( labels, batch_max_length=opt.batch_max_length) attn_text, attn_length = attn_converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds, refiner = model(image, attn_text[:, :-1]) refiner_size = torch.IntTensor([refiner.size(1)] * batch_size) refiner = refiner.log_softmax(2).permute(1, 0, 2) refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length) total_loss = opt.lambda_ctc * refiner_loss target = attn_text[:, 1:] # without [GO] Symbol for pred in preds: total_loss += opt.lambda_attn * attn_loss( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) model.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(total_loss) if loss_avg.val() <= 0.6: opt.grad_clip = 2 if loss_avg.val() <= 0.3: opt.grad_clip = 1 preds = (p.cpu() for p in preds) refiner = refiner.cpu() image = image.cpu() torch.cuda.empty_cache() writer.add_scalar('train_loss', loss_avg.val(), iteration) pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format( iteration, opt.num_iter, loss_avg.val())) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, attn_loss, valid_loader, attn_converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('Val_loss', valid_loss) pbar.set_description(loss_log) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' # print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' or 'Transformer' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' log.write(predicted_result_log + '\n') # save model per 1e+3 iter. if (iteration + 1) % 1e+3 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/SCATTER_STR.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def detect_ocr(config, image, timestamp,save_img): detection_list,img,boxes = Detection_txt(config,image,config.net) # print(detection_list) t = time.time() device = config.device print("config device", device) model = config.model converter = config.converter # 32 * 100 AlignCollate_demo = AlignCollate(imgH=config.imgH, imgW=config.imgW, keep_ratio_with_pad=config.PAD) # demo_data = RawDataset(root=image, opt=config) # use RawDataset demo_data = RawDataset_wPosition(root=detection_list, opt=config) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=config.batch_size, shuffle=False, num_workers=int(config.workers), collate_fn=AlignCollate_demo, pin_memory=True) #(< PIL.Image.Image image mode=L size=398x120 at 0x7F376DAF30B8 >, './demo_image/demo_12.png') # predict model.eval() with torch.no_grad(): log = open(f'{config.logfilepath}', 'a') dashed_line = '-' * 80 head = f'{"coordinates":25s}\t{"predicted_labels":25s}\tconfidence score' if save_img: print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') pred_list = [] new_boxes = [] for image_tensors, coordinate_list in demo_loader: batch_size = image_tensors.size(0) # print(image_tensors.shape) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([config.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, config.batch_max_length + 1).fill_(0).to(device) preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for coordinate, pred, pred_max_prob in zip(coordinate_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) if pred_EOS == 0: confidence_score = 0.0 else: pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1].item() coordinate = list(coordinate) pred_list.append([coordinate,pred,confidence_score]) if save_img: print(f'{coordinate}\t{pred:25s}\t{confidence_score:0.4f}') log.write(f'{coordinate}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close() recog_time = time.time() - t config.recog_time = config.recog_time+recog_time # print("\nrun time (recognition) : {:.2f} , {:.2f} s".format(recog_time,config.recog_time)) if save_img: saveResult(img, boxes, pred_list, config.result_folder, config.res_imagefileName) return pred_list, timestamp
def demo(args): """Open csv file wherein you are going to write the Predicted Words""" data = pd.read_csv('../data/craft_output/data.csv') """ model configuration """ if 'CTC' in args.Prediction: converter = CTCLabelConverter(args.character) else: converter = AttnLabelConverter(args.character) args.num_class = len(converter.character) if args.rgb: args.input_channel = 3 model = Model(args) print('model input parameters', args.imgH, args.imgW, args.num_fiducial, args.input_channel, args.output_channel, args.hidden_size, args.num_class, args.batch_max_length, args.Transformation, args.FeatureExtraction, args.SequenceModeling, args.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % args.saved_model) model.load_state_dict(torch.load(args.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=args.imgH, imgW=args.imgW, keep_ratio_with_pad=args.PAD) demo_data = RawDataset(root=args.image_folder, args=args) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([args.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, args.batch_max_length + 1).fill_(0).to(device) if 'CTC' in args.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) dashed_line = '-' * 80 head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): start = '../data/crop_img/' path = os.path.relpath(img_name, start) folder = os.path.dirname(path) image_name=os.path.basename(path) file_name='_'.join(image_name.split('_')[:-8]) txt_file=os.path.join(start, folder, file_name) log = open(f'{txt_file}_log_demo_result.txt', 'a') if 'Attn' in args.Prediction: pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}') log.write(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}\n') log.close()
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + ' ') log.close() """ model configuration """ # if 'CTC' in opt.Prediction: if opt.baiduCTC: CTC_converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: CTC_converter = CTCLabelConverter(opt.character) # else: Attn_converter = AttnLabelConverter(opt.character) opt.num_class_ctc = len(CTC_converter.character) opt.num_class_attn = len(Attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class_ctc, opt.num_class_attn, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): # print(name) if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() print("Model:") print(model) # print(summary(model, (1, opt.imgH, opt.imgW,1))) """ setup loss """ if opt.baiduCTC: # need to install warpctc. see our guideline. if opt.label_smooth: criterion_major_path = SmoothCTCLoss(num_classes=opt.num_class_ctc, weight=0.05) else: criterion_major_path = CTCLoss() #criterion_major_path = CTCLoss(average_frames=False, reduction="mean", blank=0) else: criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device) # else: # criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager #criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device) criterion_guide_path = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) loss_avg_major_path = Averager() loss_avg_guide_path = Averager() # filter that only require gradient decent guide_parameters = [] major_parameters = [] guide_model_part_names = [ "Transformation", "FeatureExtraction", "SequenceModeling_Attn", "Attention" ] major_model_part_names = ["SequenceModeling_CTC", "CTC"] for name, param in model.named_parameters(): if param.requires_grad: if name.split(".")[1] in guide_model_part_names: guide_parameters.append(param) elif name.split(".")[1] in major_model_part_names: major_parameters.append(param) # print(name) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] if opt.continue_training: guide_parameters = [] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer_ctc = AdamW(major_parameters, lr=opt.lr) if not opt.continue_training: optimizer_attn = AdamW(guide_parameters, lr=opt.lr) scheduler_ctc = get_linear_schedule_with_warmup( optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter) scheduler_attn = get_linear_schedule_with_warmup( optimizer_attn, num_warmup_steps=10000, num_training_steps=opt.num_iter) start_iter = 0 if opt.saved_model != '' and (not opt.continue_training): print(f'loading pretrained model from {opt.saved_model}') checkpoint = torch.load(opt.saved_model) start_iter = checkpoint['start_iter'] + 1 if not opt.adam: optimizer_ctc.load_state_dict( checkpoint['optimizer_ctc_state_dict']) if not opt.continue_training: optimizer_attn.load_state_dict( checkpoint['optimizer_attn_state_dict']) scheduler_ctc.load_state_dict( checkpoint['scheduler_ctc_state_dict']) scheduler_attn.load_state_dict( checkpoint['scheduler_attn_state_dict']) print(scheduler_ctc.get_lr()) print(scheduler_attn.get_lr()) if opt.FT: model.load_state_dict(checkpoint['model_state_dict'], strict=False) else: model.load_state_dict(checkpoint['model_state_dict']) if opt.continue_training: model.load_state_dict(torch.load(opt.saved_model)) # print("Optimizer:") # print(optimizer) # scheduler_ctc = get_linear_schedule_with_warmup( optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter, last_epoch=start_iter - 1) scheduler_attn = get_linear_schedule_with_warmup( optimizer_attn, num_warmup_steps=10000, num_training_steps=opt.num_iter, last_epoch=start_iter - 1) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options ------------- ' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)} ' opt_log += '--------------------------------------- ' print(opt_log) opt_file.write(opt_log) """ start training """ start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter - 1 if opt.continue_training: start_iter = 0 while (True): # train part image_tensors, labels = train_dataset.get_batch() iteration += 1 if iteration < start_iter: continue image = image_tensors.to(device) # print(image.size()) text_attn, length_attn = Attn_converter.encode( labels, batch_max_length=opt.batch_max_length) #print("1") text_ctc, length_ctc = CTC_converter.encode( labels, batch_max_length=opt.batch_max_length) #print("2") #if iteration == start_iter : # writer.add_graph(model, (image, text_attn)) batch_size = image.size(0) preds_major, preds_guide = model(image, text_attn[:, :-1]) #print("10") preds_size = torch.IntTensor([preds_major.size(1)] * batch_size) if opt.baiduCTC: preds_major = preds_major.permute(1, 0, 2) # to use CTCLoss format if opt.label_smooth: cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size, length_ctc, batch_size) else: cost_ctc = criterion_major_path( preds_major, text_ctc, preds_size, length_ctc) / batch_size else: preds_major = preds_major.log_softmax(2).permute(1, 0, 2) cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size, length_ctc) #print("3") # preds = model(image, text[:, :-1]) # align with Attention.forward target = text_attn[:, 1:] # without [GO] Symbol if not opt.continue_training: cost_attn = criterion_guide_path( preds_guide.view(-1, preds_guide.shape[-1]), target.contiguous().view(-1)) optimizer_attn.zero_grad() cost_attn.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_( guide_parameters, opt.grad_clip) # gradient clipping with 5 (Default) optimizer_attn.step() optimizer_ctc.zero_grad() cost_ctc.backward() torch.nn.utils.clip_grad_norm_( major_parameters, opt.grad_clip) # gradient clipping with 5 (Default) optimizer_ctc.step() scheduler_ctc.step() scheduler_attn.step() #print("4") loss_avg_major_path.add(cost_ctc) if not opt.continue_training: loss_avg_guide_path.add(cost_attn) if (iteration + 1) % 100 == 0: writer.add_scalar("Loss/train_ctc", loss_avg_major_path.val(), (iteration + 1) // 100) loss_avg_major_path.reset() if not opt.continue_training: writer.add_scalar("Loss/train_attn", loss_avg_guide_path.val(), (iteration + 1) // 100) loss_avg_guide_path.reset() # validation part if ( iteration + 1 ) % opt.valInterval == 0: #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion_major_path, valid_loader, CTC_converter, opt) model.train() writer.add_scalar("Loss/valid", valid_loss, (iteration + 1) // opt.valInterval) writer.add_scalar("Metrics/accuracy", current_accuracy, (iteration + 1) // opt.valInterval) writer.add_scalar("Metrics/norm_ED", current_norm_ED, (iteration + 1) // opt.valInterval) # loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' # loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # training loss and validation loss if not opt.continue_training: loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Train loss attn: {loss_avg_guide_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' else: loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_major_path.reset() if not opt.continue_training: loss_avg_guide_path.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'{fol_ckpt}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'{fol_ckpt}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log} {current_model_log} {best_model_log}' print(loss_model_log) log.write(loss_model_log + ' ') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line} {head} {dashed_line} ' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): # if 'Attn' in opt.Prediction: # gt = gt[:gt.find('[s]')] # pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f} {str(pred == gt)} ' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + ' ') # save model per 1e+5 iter. if (iteration + 1) % 1e+3 == 0 and (not opt.continue_training): # print(scheduler_ctc.get_lr()) # print(scheduler_attn.get_lr()) torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_attn_state_dict': optimizer_attn.state_dict(), 'optimizer_ctc_state_dict': optimizer_ctc.state_dict(), 'start_iter': iteration, 'scheduler_ctc_state_dict': scheduler_ctc.state_dict(), 'scheduler_attn_state_dict': scheduler_attn.state_dict(), }, f'{fol_ckpt}/current_model.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') custom_output.write( f'{img_name}\t{pred}\t{confidence_score:0.4f}\n') log.close()
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #import ipdb;ipdb.set_trace() train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") #print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf-8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) #import ipdb;ipdb.set_trace() if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf-8") as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') #pred = pred.encode('utf-8') #gt = gt.encode('utf-8') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): if opt.use_tb: tb_dir = f'/home_hongdo/{getpass.getuser()}/tb/{opt.experiment_name}' print('tensorboard : ', tb_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) writer = SummaryWriter(log_dir=tb_dir) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) # log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') log = open(f'{save_dir}/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 # sekim for transfer learning model = Model(opt, 38) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) # sekim change last layer in_feature = model.module.Prediction.generator.in_features model.module.Prediction.attention_cell.rnn = nn.LSTMCell( 256 + opt.num_class, 256).to(device) model.module.Prediction.generator = nn.Linear(in_feature, opt.num_class).to(device) print(model.module.Prediction.generator) print("Model:") print(model) model.train() """ setup loss """ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: with open(f'{save_dir}/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print("-------------------------------------------------") print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'{save_dir}/{opt.experiment_name}/log_train.txt', 'a') as log: # with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() if opt.use_tb: writer.add_scalar('OCR_loss/train_loss', loss_avg.val(), i) writer.add_scalar('OCR_loss/validation_loss', valid_loss, i) current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') torch.save( model.state_dict(), f'{save_dir}/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') torch.save( model.state_dict(), f'{save_dir}/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') torch.save(model.state_dict(), f'{save_dir}/{opt.experiment_name}/iter_{i + 1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size*2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDis() print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # Synthesizer weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Recognizer weight initialization for name, param in ocrModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Discriminator weight initialization for name, param in disModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() ocrModel = torch.nn.DataParallel(ocrModel).to(device) ocrModel.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.saved_synth_model != '': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("Model:") print(model) if opt.saved_ocr_model != '': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) # ocrModel.eval() #as we can't call RNN.backward in eval mode print("OCRModel:") print(ocrModel) if opt.saved_dis_model != '': print(f'loading pretrained discriminator model from {opt.saved_dis_model}') if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) # ocrModel.eval() #as we can't call RNN.backward in eval mode print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() # loss averager loss_avg = Averager() loss_avg_ocr = Averager() ##---------- loss_avg_dis = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training # filter that only require gradient decent ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.adam: ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training # filter that only require gradient decent dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.adam: dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("DisOptimizer:") print(dis_optimizer) """ final options """ # print(opt) with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '': try: start_iter = int(opt.saved_synth_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter while(True): # train part image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch() # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### disCnt = int(image_tensors_all.size(0)/2) image_tensors, image_tensors_real, labels_1, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) batch_size = image.size(0) images_recon_1, images_recon_2, _ = model(image, text_1, text_2) if 'CTC' in opt.Prediction: #ocr training preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #dis training #Check: Using alternate real images disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) #synth training preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost = 0.5*(ocrCriterion(preds_1, text_1, preds_size_1, length_1) + ocrCriterion(preds_2, text_2, preds_size_2, length_2)) #gen training disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2)) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) recCost = recCriterion(images_recon_1,image) cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost disModel.zero_grad() disCost.backward() torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #training OCR ocrModel.zero_grad() ocrCost_train.backward() torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr.add(ocrCost_train) #START HERE # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: model.eval() ocrModel.eval() disModel.eval() with torch.no_grad(): # valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( # model, criterion, valid_loader, converter, opt) valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_adv( iteration, model, ocrModel, disModel, recCriterion, ocrCriterion, valid_loader, converter, opt) model.train() ocrModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_{iteration+1}.pth')) torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_{iteration+1}_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_{iteration+1}_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt, tb): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] print("~~~~~~~~~~~~Gradient Descent~~~~~~~~~~~~~") #print(model.parameters()) #print(model.) for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Filtered parameters for gradient descent: \n', len(filtered_parameters)) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward print(preds[0][0]) target = text[:, 1:] # without [GO] Symbol print(target[0]) cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' tb.add_scalar('Training Loss vs Iteration', loss_avg.val(), i) # Record to Tensorboard tb.add_scalar('Validation Loss vs Iteration', valid_loss, i) # Record to Tensorboard loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' tb.add_scalar('Current Accuracy vs Iteration', current_accuracy, i) # Record to Tensorboard tb.add_scalar('Current Norm ED vs Iteration', current_norm_ED, i) # Record to Tensorboard # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = FontDataset(opt=opt) # use FontDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() total_cnt = 0 acc1_cnt = 0 acck_cnt = 0 with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds, alphas = model(image, text_for_pred, is_train=False) if opt.batch_max_length == 1: # select top_k probabilty (greedy decoding) then decode index to character k = opt.topk preds = F.softmax(preds, dim=2) topk = preds.topk(k) topk_id = topk[1] topk_prob = topk[0] topk_id = topk_id.detach().cpu()[:, 0, :].unsqueeze(dim=1).numpy() # (batch_size, topk) # concat 3(['s']) to the end of ids topk_s = np.ones_like(topk_id) * 3 topk_id = np.concatenate((topk_id, topk_s), axis=1) topk_chars = converter.decode(topk_id, length_for_pred) topk_probs = topk_prob.detach().cpu()[:, 0, :] # (batch_size, topk) else: raise ValueError if opt.batch_max_length == 1: log = open(f'./log_demo_result.csv', 'a', encoding='utf-8') # topk_probs = F.softmax(topk_probs, dim=-1) for img_name, pred, pred_max_prob in zip(image_path_list, topk_chars, topk_probs): if 'Attn' in opt.Prediction: pred = [p[:p.find('[s]')] for p in pred] # prune after "end of sentence" token ([s]) # print(img_name, end='') log.write(img_name) for pred_char, pred_prob in zip(pred, pred_max_prob): # print(','+pred_char, end='') # print(',%.4f' % pred_prob, end='') log.write(','+pred_char) log.write(',%.4f' % pred_prob) # print() log.write('\n') if img_name == pred[0]: acc1_cnt += 1 if img_name in pred: acck_cnt += 1 total_cnt += 1 log.close() else: raise ValueError return total_cnt, acc1_cnt, acck_cnt