def train(opt, log): """dataset preparation""" # train dataset. for convenience if opt.select_data == "label": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", ] elif opt.select_data == "synth": select_data = ["MJ", "ST"] elif opt.select_data == "synth_SA": select_data = ["MJ", "ST", "SA"] opt.batch_ratio = "0.4-0.4-0.2" # same ratio with SCATTER paper. elif opt.select_data == "mix": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", "MJ", "ST", ] elif opt.select_data == "mix_SA": select_data = [ "1.SVT", "2.IIIT", "3.IC13", "4.IC15", "5.COCO", "6.RCTW17", "7.Uber", "8.ArT", "9.LSVT", "10.MLT19", "11.ReCTS", "MJ", "ST", "SA", ] else: select_data = opt.select_data.split("-") # set batch_ratio for each data. if opt.batch_ratio: batch_ratio = opt.batch_ratio.split("-") else: batch_ratio = [round(1 / len(select_data), 3)] * len(select_data) train_loader = Batch_Balanced_Dataset(opt, opt.train_data, select_data, batch_ratio, log) if opt.semi != "None": select_data_unlabel = ["U1.Book32", "U2.TextVQA", "U3.STVQA"] batch_ratio_unlabel = [round(1 / len(select_data_unlabel), 3) ] * len(select_data_unlabel) dataset_root_unlabel = "data_CVPR2021/training/unlabel/" train_loader_unlabel_semi = Batch_Balanced_Dataset( opt, dataset_root_unlabel, select_data_unlabel, batch_ratio_unlabel, log, learn_type="semi", ) AlignCollate_valid = AlignCollate(opt, mode="test") valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt, mode="test") valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=False, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") """ model configuration """ if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.sos_token_index = converter.dict["[SOS]"] opt.eos_token_index = converter.dict["[EOS]"] opt.num_class = len(converter.character) model = Model(opt) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != "": fine_tuning_log = f"### loading pretrained model from {opt.saved_model}\n" if "MoCo" in opt.saved_model or "MoCo" in opt.self_pre: pretrained_state_dict_qk = torch.load(opt.saved_model) pretrained_state_dict = {} for name in pretrained_state_dict_qk: if "encoder_q" in name: rename = name.replace("encoder_q.", "") pretrained_state_dict[rename] = pretrained_state_dict_qk[ name] else: pretrained_state_dict = torch.load(opt.saved_model) for name, param in model.named_parameters(): try: param.data.copy_(pretrained_state_dict[name].data ) # load from pretrained model if opt.FT == "freeze": param.requires_grad = False # Freeze fine_tuning_log += f"pretrained layer (freezed): {name}\n" else: fine_tuning_log += f"pretrained layer: {name}\n" except: fine_tuning_log += f"non-pretrained layer: {name}\n" print(fine_tuning_log) log.write(fine_tuning_log + "\n") # print("Model:") # print(model) log.write(repr(model) + "\n") """ setup loss """ if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: # ignore [PAD] token criterion = torch.nn.CrossEntropyLoss( ignore_index=converter.dict["[PAD]"]).to(device) if "Pseudo" in opt.semi: criterion_SemiSL = PseudoLabelLoss(opt, converter, criterion) elif "MeanT" in opt.semi: criterion_SemiSL = MeanTeacherLoss(opt, student_for_init_teacher=model) # loss averager train_loss_avg = Averager() semi_loss_avg = Averager() # semi supervised loss avg # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print(f"Trainable params num: {sum(params_num)}") log.write(f"Trainable params num: {sum(params_num)}\n") # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optimizer == "sgd": optimizer = torch.optim.SGD( filtered_parameters, lr=opt.lr, momentum=opt.sgd_momentum, weight_decay=opt.sgd_weight_decay, ) elif opt.optimizer == "adadelta": optimizer = torch.optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) elif opt.optimizer == "adam": optimizer = torch.optim.Adam(filtered_parameters, lr=opt.lr) print("Optimizer:") print(optimizer) log.write(repr(optimizer) + "\n") if "super" in opt.schedule: if opt.optimizer == "sgd": cycle_momentum = True else: cycle_momentum = False scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=opt.lr, cycle_momentum=cycle_momentum, div_factor=20, final_div_factor=1000, total_steps=opt.num_iter, ) print("Scheduler:") print(scheduler) log.write(repr(scheduler) + "\n") """ final options """ # print(opt) opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): if str(k) == "character" and len(str(v)) > 500: opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n" else: opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) log.write(opt_log) log.close() """ start training """ start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass start_time = time.time() best_score = -1 # training loop for iteration in tqdm( range(start_iter + 1, opt.num_iter + 1), total=opt.num_iter, position=0, leave=True, ): if "MeanT" in opt.semi: image_tensors, image_tensors_ema, labels = train_loader.get_batch_ema( ) else: image_tensors, labels = train_loader.get_batch() image = image_tensors.to(device) labels_index, labels_length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) # default recognition loss part if "CTC" in opt.Prediction: preds = model(image) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) loss = criterion(preds_log_softmax, labels_index, preds_size, labels_length) else: preds = model(image, labels_index[:, :-1]) # align with Attention.forward target = labels_index[:, 1:] # without [SOS] Symbol loss = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) # semi supervised part (SemiSL) if "Pseudo" in opt.semi: image_unlabel, _ = train_loader_unlabel_semi.get_batch_two_images() image_unlabel = image_unlabel.to(device) loss_SemiSL = criterion_SemiSL(image_unlabel, model) loss = loss + loss_SemiSL semi_loss_avg.add(loss_SemiSL) elif "MeanT" in opt.semi: ( image_tensors_unlabel, image_tensors_unlabel_ema, ) = train_loader_unlabel_semi.get_batch_two_images() image_unlabel = image_tensors_unlabel.to(device) student_input = torch.cat([image, image_unlabel], dim=0) image_ema = image_tensors_ema.to(device) image_unlabel_ema = image_tensors_unlabel_ema.to(device) teacher_input = torch.cat([image_ema, image_unlabel_ema], dim=0) loss_SemiSL = criterion_SemiSL( student_input=student_input, student_logit=preds, student=model, teacher_input=teacher_input, iteration=iteration, ) loss = loss + loss_SemiSL semi_loss_avg.add(loss_SemiSL) model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() train_loss_avg.add(loss) if "super" in opt.schedule: scheduler.step() else: adjust_learning_rate(optimizer, iteration, opt) # validation part. # To see training progress, we also conduct validation when 'iteration == 1' if iteration % opt.val_interval == 0 or iteration == 1: # for validation log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_score, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # keep best score (accuracy or norm ED) model on valid dataset # Do not use this on test datasets. It would be an unfair comparison # (training should be done without referring test set). if current_score > best_score: best_score = current_score torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_score.pth", ) # validation log: loss, lr, score (accuracy or norm ED), time. lr = optimizer.param_groups[0]["lr"] elapsed_time = time.time() - start_time valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f}" valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n" valid_log += f'{"Current_score":17s}: {current_score:0.2f}, Current_lr: {lr:0.7f}\n' valid_log += f'{"Best_score":17s}: {best_score:0.2f}, Infer_time: {infer_time:0.1f}, Elapsed_time: {elapsed_time:0.1f}' # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[EOS]")] pred = pred[:pred.find("[EOS]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" valid_log = f"{valid_log}\n{predicted_result_log}" print(valid_log) log.write(valid_log + "\n") opt.writer.add_scalar("train/train_loss", float(f"{train_loss_avg.val():0.5f}"), iteration) opt.writer.add_scalar("train/semi_loss", float(f"{semi_loss_avg.val():0.5f}"), iteration) opt.writer.add_scalar("train/lr", float(f"{lr:0.7f}"), iteration) opt.writer.add_scalar("train/elapsed_time", float(f"{elapsed_time:0.1f}"), iteration) opt.writer.add_scalar("valid/valid_loss", float(f"{valid_loss:0.5f}"), iteration) opt.writer.add_scalar("valid/current_score", float(f"{current_score:0.2f}"), iteration) opt.writer.add_scalar("valid/best_score", float(f"{best_score:0.2f}"), iteration) train_loss_avg.reset() semi_loss_avg.reset() """ Evaluation at the end of training """ print("Start evaluation on benchmark testset") """ keep evaluation model and result logs """ os.makedirs(f"./result/{opt.exp_name}", exist_ok=True) os.makedirs(f"./evaluation_log", exist_ok=True) saved_best_model = f"./saved_models/{opt.exp_name}/best_score.pth" # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/') model.load_state_dict(torch.load(f"{saved_best_model}")) opt.eval_type = "benchmark" model.eval() with torch.no_grad(): total_accuracy, eval_data_list, accuracy_list = benchmark_all_eval( model, criterion, converter, opt) opt.writer.add_scalar("test/total_accuracy", float(f"{total_accuracy:0.2f}"), iteration) for eval_data, accuracy in zip(eval_data_list, accuracy_list): accuracy = float(accuracy) opt.writer.add_scalar(f"test/{eval_data}", float(f"{accuracy:0.2f}"), iteration) print( f'finished the experiment: {opt.exp_name}, "CUDA_VISIBLE_DEVICES" was {opt.CUDA_VISIBLE_DEVICES}' )
def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) # return model AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) # _, accuracy_by_best_model, _, _, _, _, _, _ = validation( # model, criterion, evaluation_loader, converter, opt) for i, (image_tensors, labels) in enumerate(evaluation_loader): # batch_size = image_tensors.size(0) # text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) target_layer = model.module.FeatureExtraction.ConvNet.layer4[-1] # model.eval() # handle =model.module.Transformation.register_forward_hook(hook) # model(image_tensors,text_for_pred) input_tensor = image_tensors # print(labels) # print(input_tensor.shape,'input_tensor.shape') # handle.remove() # print(input_tensor) # Create an input tensor image for your model.. # input_tensor=image_tensors # Note: input_tensor can be a batch tensor with several images! # print(labels) # Construct the CAM object once, and then re-use it on many images: cam = EigenCAM(model=model, target_layer=target_layer, use_cuda=opt.use_cuda) # If target_category is None, the highest scoring category # will be used for every image in the batch. # target_category can also be an integer, or a list of different integers # for every image in the batch. text_for_loss, length_for_loss = converter.encode( labels, batch_max_length=opt.batch_max_length) target_category = text_for_loss # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing. grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) # In this example grayscale_cam has only one image in the batch: grayscale_cam = grayscale_cam[0, :] loader = transforms.ToPILImage() sourc_image = loader(image_tensors[0]) sourc_image = cv2.cvtColor(np.asarray(sourc_image), cv2.COLOR_RGB2BGR) # rgb_img=loader(input_tensor[0].cpu()) # rgb_img.save('rgb_visual.bmp') # rgb_img2=cv2.imread('rgb_visual.bmp') rgb_img2 = np.float32(sourc_image) / 255 visualization = show_cam_on_image(rgb_img2, grayscale_cam) sourc_image = cv2.resize(sourc_image, (0, 0), fx=5, fy=5, interpolation=cv2.INTER_CUBIC) visualization = cv2.resize(visualization, (0, 0), fx=5, fy=5, interpolation=cv2.INTER_CUBIC) cat_image = np.vstack((sourc_image, visualization)) cv2.imwrite('visual_image2/' + str(i) + '_2.bmp', cat_image)
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' print(loss_log) log.write(loss_log + '\n') loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' print(current_model_log) log.write(current_model_log + '\n') # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # show some predicted results print('-' * 80) print(f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F') log.write(f'{"Ground Truth":25s} | {"Prediction":25s} | {"Confidence Score"}\n') print('-' * 80) for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] print(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}') log.write(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n') print('-' * 80) # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt, show_number=2, amp=False): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8") AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=min(32, opt.batch_size), shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), prefetch_factor=512, collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) if opt.saved_model != '': pretrained_dict = torch.load(opt.saved_model) if opt.new_prediction: model.Prediction = nn.Linear( model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight'])) model = torch.nn.DataParallel(model).to(device) print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(pretrained_dict, strict=False) else: model.load_state_dict(pretrained_dict) if opt.new_prediction: model.module.Prediction = nn.Linear( model.module.SequenceModeling_output, opt.num_class) for name, param in model.module.Prediction.named_parameters(): if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) model = model.to(device) else: # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue model = torch.nn.DataParallel(model).to(device) model.train() print("Model:") print(model) count_parameters(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # freeze some layers try: if opt.freeze_FeatureFxtraction: for param in model.module.FeatureExtraction.parameters(): param.requires_grad = False if opt.freeze_SequenceModeling: for param in model.module.SequenceModeling.parameters(): param.requires_grad = False except: pass # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim == 'adam': #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer = optim.Adam(filtered_parameters) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter scaler = GradScaler() t1 = time.time() while (True): # train part optimizer.zero_grad(set_to_none=True) if amp: with autocast(): image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) scaler.scale(cost).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) scaler.step(optimizer) scaler.update() else: image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() loss_avg.add(cost) # validation part if (i % opt.valInterval == 0) and (i != 0): print('training time: ', time.time() - t1) t1 = time.time() elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\ infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' #show_number = min(show_number, len(labels)) start = random.randint(0, len(labels) - show_number) for gt, pred, confidence in zip( labels[start:start + show_number], preds[start:start + show_number], confidence_score[start:start + show_number]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') print('validation time: ', time.time() - t1) t1 = time.time() # save model per 1e+4 iter. if (i + 1) % 1e+4 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): print(opt.local_rank) opt.device = torch.device('cuda:{}'.format(opt.local_rank)) device = opt.device """ dataset preparation """ train_dataset = Batch_Balanced_Dataset(opt) valid_loader = train_dataset.getValDataloader() print('-' * 80) """ model configuration """ if 'CTC' == opt.Prediction: converter = CTCLabelConverter(opt.character, opt) elif 'Attn' == opt.Prediction: converter = AttnLabelConverter(opt.character, opt) elif 'CTC_Attn' == opt.Prediction: converter = CTCLabelConverter(opt.character, opt), AttnLabelConverter( opt.character, opt) opt.num_class = len(opt.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) model.to(opt.device) print(model) print('model input parameters', opt.rgb, opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue """ setup loss """ if 'CTC' == opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) elif 'Attn' == opt.Prediction: criterion = torch.nn.CrossEntropyLoss( ignore_index=0).to(device), torch.nn.MSELoss( reduction="sum").to(device) # ignore [GO] token = ignore index 0 elif 'CTC_Attn' == opt.Prediction: criterion = torch.nn.CTCLoss( zero_infinity=True).to(device), torch.nn.CrossEntropyLoss( ignore_index=0).to(device), torch.nn.MSELoss( reduction='sum').to(device) # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) if opt.local_rank == 0: print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.sgd: optimizer = optim.SGD(filtered_parameters, lr=opt.lr, momentum=0.9, weight_decay=opt.weight_decay) elif opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) if opt.local_rank == 0: print("Optimizer:") print(optimizer) if opt.sync_bn: model = apex.parallel.convert_syncbn_model(model) if opt.amp > 1: model, optimizer = amp.initialize(model, optimizer, opt_level="O" + str(opt.amp), keep_batchnorm_fp32=True, loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O" + str(opt.amp)) # data parallel for multi-GPU model = DDP(model) if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') try: model.load_state_dict( torch.load(opt.continue_model, map_location=torch.device( 'cuda', torch.cuda.current_device()))) except: traceback.print_exc() print(f'COPYING pretrained model from {opt.continue_model}') pretrained_dict = torch.load(opt.continue_model, map_location=torch.device( 'cuda', torch.cuda.current_device())) model_dict = model.state_dict() pretrained_dict2 = dict() for k, v in pretrained_dict.items(): if opt.Prediction == 'Attn': if 'module.Prediction_attn.' in k: k = k.replace('module.Prediction_attn.', 'module.Prediction.') if k in model_dict and model_dict[k].shape == v.shape: pretrained_dict2[k] = v model_dict.update(pretrained_dict2) model.load_state_dict(model_dict) model.train() """ final options """ with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' opt_log += str(model) print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter ct = opt.batch_mul model.zero_grad() dist.barrier() while (True): # train part start = time.time() image, labels, pos = train_dataset.sync_get_batch() end = time.time() data_t = end - start start = time.time() batch_size = image.size(0) if 'CTC' == opt.Prediction: text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True elif 'Attn' == opt.Prediction: text, length = converter.encode( labels, batch_max_length=opt.batch_max_length) preds = model(image, text[:, :-1]) # align with Attention.forward preds_attn = preds[0] preds_alpha = preds[1] target = text[:, 1:] # without [GO] Symbol cost = criterion[0](preds_attn.view(-1, preds_attn.shape[-1]), target.contiguous().view(-1)) if opt.posreg_w > 0.001: cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[1]) print('attn_cost = ', cost, 'pos_cost = ', cost_pos * opt.posreg_w) cost += opt.posreg_w * cost_pos else: print('attn_cost = ', cost_attn) elif 'CTC_Attn' == opt.Prediction: text_ctc, length_ctc = converter[0].encode( labels, batch_max_length=opt.batch_max_length) text_attn, length_attn = converter[1].encode( labels, batch_max_length=opt.batch_max_length) """ ctc prediction and loss """ #should input text_attn here preds = model(image, text_attn[:, :-1]) preds_ctc = preds[0].log_softmax(2) preds_ctc_size = torch.IntTensor([preds_ctc.size(1)] * batch_size).to(device) preds_ctc = preds_ctc.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost_ctc = criterion[0](preds_ctc, text_ctc, preds_ctc_size, length_ctc) torch.backends.cudnn.enabled = True """ attention prediction and loss """ preds_attn = preds[1][0] # align with Attention.forward preds_alpha = preds[1][1] target = text_attn[:, 1:] # without [GO] Symbol cost_attn = criterion[1](preds_attn.view(-1, preds_attn.shape[-1]), target.contiguous().view(-1)) cost = opt.ctc_attn_loss_ratio * cost_ctc + ( 1 - opt.ctc_attn_loss_ratio) * cost_attn if opt.posreg_w > 0.001: cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[2]) cost += opt.posreg_w * cost_pos cost_ctc = reduce_tensor(cost_ctc) cost_attn = reduce_tensor(cost_attn) cost_pos = reduce_tensor(cost_pos) if opt.local_rank == 0: print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn, 'pos_cost = ', cost_pos * opt.posreg_w) else: cost_ctc = reduce_tensor(cost_ctc) cost_attn = reduce_tensor(cost_attn) if opt.local_rank == 0: print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn) cost /= opt.batch_mul if opt.amp > 0: with amp.scale_loss(cost, optimizer) as scaled_loss: scaled_loss.backward() else: cost.backward() """ https://github.com/davidlmorton/learning-rate-schedules/blob/master/increasing_batch_size_without_increasing_memory.ipynb """ ct -= 1 if ct == 0: if opt.amp > 0: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), opt.grad_clip) else: torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() model.zero_grad() ct = opt.batch_mul else: continue train_t = time.time() - start cost = reduce_tensor(cost) loss_avg.add(cost) if opt.local_rank == 0: print('iter', i, 'loss =', cost, ', data_t=', data_t, ',train_t=', train_t, ', batchsz=', opt.batch_mul * opt.batch_size) sys.stdout.flush() # validation part if (i > 0 and i % opt.valInterval == 0) or (i == 0 and opt.continue_model != ''): elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): if 'CTC_Attn' in opt.Prediction: # we only count for attention accuracy, because ctc is used to help attention valid_loss, current_accuracy_ctc, current_accuracy, current_norm_ED_ctc, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion[1], valid_loader, converter[1], opt, converter[0]) elif 'Attn' in opt.Prediction: valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion[0], valid_loader, converter, opt) else: valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:10], labels[:10]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' if 'CTC_Attn' in opt.Prediction: valid_log += f' ctc_accuracy: {current_accuracy_ctc:0.3f}, ctc_norm_ED: {current_norm_ED_ctc:0.2f}' current_accuracy = max(current_accuracy, current_accuracy_ctc) current_norm_ED = min(current_norm_ED, current_norm_ED_ctc) if opt.local_rank == 0: print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) torch.save( model, f'./saved_models/{opt.experiment_name}/best_accuracy.model' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) torch.save( model, f'./saved_models/{opt.experiment_name}/best_norm_ED.model' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per iter. if (i + 1) % opt.save_interval == 0 and opt.local_rank == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() if opt.prof_iter > 0 and i > opt.prof_iter: sys.exit() i += 1
def train(opt): plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots') if not os.path.exists(plotDir): os.makedirs(plotDir) lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size * 2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGenV4(opt) ocrModel = Model(opt) disModel = MsImageDisV1(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for currModel in [model, ocrModel, disModel]: for name, param in currModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) if not opt.ocrFixed: ocrModel.train() else: ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.modelFolderFlag: if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))) > 0: opt.saved_synth_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_synth.pth"))[-1] if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_dis.pth"))) > 0: opt.saved_dis_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_dis.pth"))[-1] #loading pre-trained model if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '' and opt.saved_dis_model != 'None': print( f'loading pretrained discriminator model from {opt.saved_dis_model}' ) if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 if opt.imgReconLoss == 'l1': recCriterion = torch.nn.L1Loss() elif opt.imgReconLoss == 'ssim': recCriterion = ssim elif opt.imgReconLoss == 'ms-ssim': recCriterion = msssim if opt.styleLoss == 'l1': styleRecCriterion = torch.nn.L1Loss() elif opt.styleLoss == 'triplet': styleRecCriterion = torch.nn.TripletMarginLoss( margin=opt.tripletMargin, p=1) #for validation; check only positive pairs styleTestRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() loss_avg_ocrRecon_1 = Averager() loss_avg_ocrRecon_2 = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_styRecon = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim == 'adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.optim == 'adam': ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.optim == 'adam': dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int( opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer, opt) ocr_scheduler = get_scheduler(ocr_optimizer, opt) dis_scheduler = get_scheduler(dis_optimizer, opt) start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter cntr = 0 while (True): # train part if opt.lr_policy != "None": scheduler.step() ocr_scheduler.step() dis_scheduler.step() image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch( ) # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) cntr += 1 disCnt = int(image_tensors_all.size(0) / 2) image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[ disCnt:disCnt + disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image_hole_tensors, image_mask_tensors = genRandomMasks(image_tensors) image = image_tensors.to(device) image_hole_tensors = image_hole_tensors.to(device) image_mask_tensors = image_mask_tensors.to(device) image_real = image_tensors_real.to(device) batch_size = image.size(0) ##-----------------------------------## #generate text(labels) from ocr.forward if opt.ocrFixed: # ocrModel.eval() length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = ocrModel(image, text_for_pred) preds = preds[:, :text_for_loss.shape[1] - 1, :] preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(labels_1): pred_EOS = pred.find('[s]') labels_1[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) # ocrModel.train() else: labels_1 = labels_gt ##-----------------------------------## text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode( labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator # images_recon_1, images_recon_2, style = model(image, text_1, text_2) images_recon_1, images_recon_2, style = model(image_hole_tensors, text_1, text_2) if 'CTC' in opt.Prediction: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1) ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2) # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 ) else: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel( image, text_1[:, :-1]) # align with Attention.forward target_ocr = text_1[:, 1:] # without [GO] Symbol ocrCost_train = ocrCriterion( preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1)) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False) # align with Attention.forward target_1 = text_1[:, 1:] # without [GO] Symbol preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1)) ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1)) # ocrCost = 0.5*(ocrCost_1+ocrCost_2) if not opt.ocrFixed: #training OCR ocrModel.zero_grad() ocrCost_train.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() #if ocr is fixed; ignore this loss loss_avg_ocr.add(ocrCost_train) else: loss_avg_ocr.add(torch.tensor(0.0)) #Domain discriminator: Dis update disModel.zero_grad() disCost = opt.disWeight * 0.5 * ( disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) disCost.backward() # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) # #[Style Encoder] + [Word Generator] update #Adversarial loss disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) + disModel.module.calc_gen_loss(images_recon_2)) #Input image reconstruction loss if opt.imgReconLoss == 'ssim': recCost = -1 * recCriterion(images_recon_1, image, val_range=2) elif opt.imgReconLoss == 'ms-ssim': recCost = -1 * recCriterion( images_recon_1, image, val_range=2, normalize='relu') else: recCost = 0.5 * (recCriterion(image_mask_tensors * images_recon_1, image_mask_tensors * image) + recCriterion( (1 - image_mask_tensors) * images_recon_1, (1 - image_mask_tensors) * image)) #Pair style reconstruction loss if opt.styleReconWeight == 0.0: styleRecCost = torch.tensor(0.0) else: if opt.styleLoss == 'l1': if opt.styleDetach: styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style.detach()) else: styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style) elif opt.styleLoss == 'triplet': if opt.styleDetach: styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style.detach(), model(image_real, None, None, styleFlag=True)) else: styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style, model(image_real, None, None, styleFlag=True)) #OCR Content cost ocrCost = 0.5 * (opt.ocrWeight_1 * ocrCost_1 + opt.ocrWeight_2 * ocrCost_2) cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #Individual losses loss_avg_ocrRecon_1.add(opt.ocrWeight * 0.5 * ocrCost_1) loss_avg_ocrRecon_2.add(opt.ocrWeight * 0.5 * ocrCost_2) loss_avg_gen.add(opt.disWeight * disGenCost) loss_avg_imgRecon.add(opt.reconWeight * recCost) loss_avg_styRecon.add(opt.styleReconWeight * styleRecCost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages', str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image( tensor2im(image[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_input_' + labels_gt[trImgCntr] + '.png')) save_image( tensor2im(images_recon_1[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_recon_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_pair_' + labels_2[trImgCntr] + '.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: model.eval() ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() disModel.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res( iteration, model, ocrModel, disModel, recCriterion, styleTestRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() if not opt.ocrFixed: ocrModel.train() else: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.train() # ocrModel.module.Prediction.eval() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' #plotting lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) lib.plot.plot(os.path.join(plotDir, 'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'), valid_loss[2].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon1-Loss'), valid_loss[3].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon2-Loss'), valid_loss[4].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'), valid_loss[6].item()) lib.plot.plot(os.path.join(plotDir, 'Valid-StyRecon2-Loss'), valid_loss[7].item()) lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-WordAccuracy'), current_accuracy[0]) lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-WordAccuracy'), current_accuracy[1]) lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-WordAccuracy'), current_accuracy[2]) lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-CharAccuracy'), current_norm_ED[0]) lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-CharAccuracy'), current_norm_ED[1]) lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-CharAccuracy'), current_norm_ED[2]) # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip( labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: # gt_ocr = gt_ocr[:gt_ocr.find('[s]')] pred_ocr = pred_ocr[:pred_ocr.find('[s]')] # gt_1 = gt_1[:gt_1.find('[s]')] pred_1 = pred_1[:pred_1.find('[s]')] # gt_2 = gt_2[:gt_2.find('[s]')] pred_2 = pred_2[:pred_2.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() loss_avg_ocrRecon_1.reset() loss_avg_ocrRecon_2.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_styRecon.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_synth.pth')) if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding='utf-16') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) elif 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding='utf-16') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":10s} | {"Prediction":10s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction or 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:10s} | {pred:10s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) #log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+4 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size * 2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDis(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Recognizer weight initialization for name, param in ocrModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # Discriminator weight initialization for name, param in disModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) ocrModel.train() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() #loading pre-trained model if opt.saved_ocr_model != '': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '': print( f'loading pretrained discriminator model from {opt.saved_dis_model}' ) if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() styleRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.adam: ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.adam: dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '': try: start_iter = int( opt.saved_synth_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter # cntr=0 while (True): # train part image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch( ) # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) # cntr+=1 disCnt = int(image_tensors_all.size(0) / 2) image_tensors, image_tensors_real, labels_1, labels_2 = image_tensors_all[:disCnt], image_tensors_all[ disCnt:disCnt + disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) text_1, length_1 = converter.encode( labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode( labels_2, batch_max_length=opt.batch_max_length) batch_size = image.size(0) images_recon_1, images_recon_2, style = model(image, text_1, text_2) if 'CTC' in opt.Prediction: #ocr training preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #dis training #Check: Using alternate real images disCost = opt.disWeight * 0.5 * ( disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) #synth training preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost = 0.5 * ( ocrCriterion(preds_1, text_1, preds_size_1, length_1) + ocrCriterion(preds_2, text_2, preds_size_2, length_2)) #gen training disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) + disModel.module.calc_gen_loss(images_recon_2)) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) recCost = recCriterion(images_recon_1, image) styleRecCost = styleRecCriterion( model(images_recon_2, None, None, styleFlag=True), style.detach()) cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost disModel.zero_grad() disCost.backward() torch.nn.utils.clip_grad_norm_( disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #training OCR ocrModel.zero_grad() ocrCost_train.backward() torch.nn.utils.clip_grad_norm_( ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr.add(ocrCost_train) #START HERE # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages', str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image( tensor2im(image[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_input_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(images_recon_1[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_recon_' + labels_1[trImgCntr] + '.png')) save_image( tensor2im(images_recon_2[trImgCntr].detach()), os.path.join( opt.exp_dir, opt.exp_name, 'trainImages', str(iteration), str(trImgCntr) + '_pair_' + labels_2[trImgCntr] + '.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: model.eval() ocrModel.eval() disModel.eval() with torch.no_grad(): # valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( # model, criterion, valid_loader, converter, opt) valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw( iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() ocrModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip( labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_{iteration+1}.pth')) torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_{iteration+1}_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir, opt.exp_name, 'iter_{iteration+1}_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def train(opt): os.makedirs(opt.log, exist_ok=True) writer = SummaryWriter(opt.log) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ ctc_converter = CTCLabelConverter(opt.character) attn_converter = AttnLabelConverter(opt.character) opt.num_class = len(attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) """ setup loss """ loss_avg = Averager() ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device) attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter pbar = tqdm(range(opt.num_iter)) for iteration in pbar: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) ctc_text, ctc_length = ctc_converter.encode( labels, batch_max_length=opt.batch_max_length) attn_text, attn_length = attn_converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds, refiner = model(image, attn_text[:, :-1]) refiner_size = torch.IntTensor([refiner.size(1)] * batch_size) refiner = refiner.log_softmax(2).permute(1, 0, 2) refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length) total_loss = opt.lambda_ctc * refiner_loss target = attn_text[:, 1:] # without [GO] Symbol for pred in preds: total_loss += opt.lambda_attn * attn_loss( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) model.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(total_loss) if loss_avg.val() <= 0.6: opt.grad_clip = 2 if loss_avg.val() <= 0.3: opt.grad_clip = 1 preds = (p.cpu() for p in preds) refiner = refiner.cpu() image = image.cpu() torch.cuda.empty_cache() writer.add_scalar('train_loss', loss_avg.val(), iteration) pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format( iteration, opt.num_iter, loss_avg.val())) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, attn_loss, valid_loader, attn_converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('Val_loss', valid_loss) pbar.set_description(loss_log) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' # print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' or 'Transformer' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' log.write(predicted_result_log + '\n') # save model per 1e+3 iter. if (iteration + 1) % 1e+3 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/SCATTER_STR.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def train(opt): """ 准备训练和验证的数据集 """ transform = transforms.Compose([ ToTensor(), ]) train_dataset = LmdbDataset(opt.train_data, opt=opt, transform=transform) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), ) valid_dataset = LmdbDataset(root=opt.valid_data, opt=opt, transform=transform) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), ) print('-' * 80) """ 模型的配置 """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # 权重初始化 for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue model = model.to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part for image_tensors, labels in train_loader: image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length ) # text: [index, index, ..., index], length: [10, 8] batch_size = image.size(0) if 'CTC' in opt.Prediction: # set xx = model(image, text) torch.Size([100, 63, 7]), xx.log_softmax(2)[0][0] = xx[0][0].log_softmax(-1) preds = model(image, text).log_softmax(2) # torch.Size([100, 63, 12]) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute( 1, 0, 2 ) # to use CTCLoss format # 100 * 63 * 7 -> 63 * 100 * 7 # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion( preds, text, preds_size, length ) # preds.shape: torch.Size([63, 100, 7]), 其中63是序列特征,100是batch_size, 7是输出类别数量; text.shape: torch.Size([1000]), 表示1000个字符 # preds_size:[63, 63, ..., 63] 100,数组中的63表示序列的长度 length: [10, 10, ..., 10] 100,数组中的每个10表示每个标签的长度,意思就是每一张图片有10个字符 torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open( f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt, AMP, WdB, ralph_path, train_data_path, train_data_list, test_data_path, test_data_list, experiment_name, train_batch_size, val_batch_size, workers, lr, valInterval, num_iter, wdbprj, continue_model='', finetune=''): HVD3P = pO.HVD or pO.DDP os.makedirs(f'./saved_models/{experiment_name}', exist_ok=True) # if OnceExecWorker and WdB: # wandb.init(project=wdbprj, name=experiment_name) # wandb.config.update(opt) # load supplied ralph with open(ralph_path, 'r') as f: ralph_train = json.load(f) print('[4] IN TRAIN; BEFORE MAKING DATASET') train_dataset = ds_load.myLoadDS(train_data_list, train_data_path, ralph=ralph_train) valid_dataset = ds_load.myLoadDS(test_data_list, test_data_path, ralph=ralph_train) # SAVE RALPH FOR LATER USE # with open(f'./saved_models/{experiment_name}/ralph.json', 'w+') as f: # json.dump(train_dataset.ralph, f) print('[5] DATASET DONE LOADING') if OnceExecWorker: print(pO) print('Alphabet :', len(train_dataset.alph), train_dataset.alph) for d in [train_dataset, valid_dataset]: print('Dataset Size :', len(d.fns)) # print('Max LbW : ',max(list(map(len,d.tlbls))) ) # print('#Chars : ',sum([len(x) for x in d.tlbls])) # print('Sample label :',d.tlbls[-1]) # print("Dataset :", sorted(list(map(len,d.tlbls))) ) print('-' * 80) if opt.num_gpu > 1: workers = workers * (1 if HVD3P else opt.num_gpu) if HVD3P: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=opt.world_size, rank=opt.rank) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=opt.world_size, rank=opt.rank) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True if not HVD3P else False, pin_memory=True, num_workers=int(workers), sampler=train_sampler if HVD3P else None, worker_init_fn=WrkSeeder, collate_fn=ds_load.SameTrCollate) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=val_batch_size, pin_memory=True, num_workers=int(workers), sampler=valid_sampler if HVD3P else None) model = OrigamiNet() model.apply(init_bn) # load finetune ckpt if finetune != '': model = load_finetune(model, finetune) model.train() if OnceExecWorker: import pprint [print(k, model.lreszs[k]) for k in sorted(model.lreszs.keys())] biparams = list( dict(filter(lambda kv: 'bias' in kv[0], model.named_parameters())).values()) nonbiparams = list( dict(filter(lambda kv: 'bias' not in kv[0], model.named_parameters())).values()) if not pO.DDP: model = model.to(device) else: model.cuda(opt.rank) optimizer = optim.Adam(model.parameters(), lr=lr) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=10**(-1 / 90000)) # if OnceExecWorker and WdB: # wandb.watch(model, log="all") if pO.HVD: hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) # optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.fp16) if pO.DDP and opt.rank != 0: random.seed() np.random.seed() # if AMP: # model, optimizer = amp.initialize(model, optimizer, opt_level = "O1") if pO.DP: model = torch.nn.DataParallel(model) elif pO.DDP: model = pDDP(model, device_ids=[opt.rank], output_device=opt.rank, find_unused_parameters=False) model_ema = ModelEma(model) if continue_model != '': if OnceExecWorker: print(f'loading pretrained model from {continue_model}') checkpoint = torch.load( continue_model, map_location=f'cuda:{opt.rank}' if HVD3P else None) model.load_state_dict(checkpoint['model'], strict=True) optimizer.load_state_dict(checkpoint['optimizer']) model_ema._load_checkpoint(continue_model, f'cuda:{opt.rank}' if HVD3P else None) criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True).to(device) converter = CTCLabelConverter(train_dataset.ralph.values()) if OnceExecWorker: with open(f'./saved_models/{experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' opt_log += gin.operative_config_str() opt_file.write(opt_log) # if WdB: # wandb.config.gin_str = gin.operative_config_str().splitlines() print(optimizer) print(opt_log) start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 best_CER = 1e+6 i = 0 gAcc = 1 epoch = 1 btReplay = False and AMP max_batch_replays = 3 if HVD3P: train_sampler.set_epoch(epoch) titer = iter(train_loader) while (True): start_time = time.time() model.zero_grad() train_loss = Metric(pO, 'train_loss') for j in trange(valInterval, leave=False, desc='Training'): # Load a batch try: image_tensors, labels, fnames = next(titer) except StopIteration: epoch += 1 if HVD3P: train_sampler.set_epoch(epoch) titer = iter(train_loader) image_tensors, labels, fnames = next(titer) # log filenames # fnames = [f'{i}___{fname}' for fname in fnames] # with open(f'./saved_models/{experiment_name}/filelog.txt', 'a+') as f: # f.write('\n'.join(fnames) + '\n') # Move to device image = image_tensors.to(device) text, length = converter.encode(labels) batch_size = image.size(0) replay_batch = True maxR = 3 while replay_batch and maxR > 0: maxR -= 1 # Forward pass preds = model(image, text).float() preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2).log_softmax(2) if i == 0 and OnceExecWorker: print('Model inp : ', image.dtype, image.size()) print('CTC inp : ', preds.dtype, preds.size(), preds_size[0]) # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size, length.to(device)).mean() / gAcc torch.backends.cudnn.enabled = True train_loss.update(cost) # cost tracking? # with open(f'./saved_models/{experiment_name}/steplog.txt', 'a+') as f: # f.write(f'Step {i} cost: {cost}\n') optimizer.zero_grad() default_optimizer_step = optimizer.step # added for batch replay # Backward and step if not AMP: cost.backward() replay_batch = False else: # with amp.scale_loss(cost, optimizer) as scaled_loss: # scaled_loss.backward() # if pO.HVD: optimizer.synchronize() # if optimizer.step is default_optimizer_step or not btReplay: # replay_batch = False # elif maxR>0: # optimizer.step() pass if btReplay: pass #amp._amp_state.loss_scalers[0]._loss_scale = mx_sc if (i + 1) % gAcc == 0: if pO.HVD and AMP: with optimizer.skip_synchronize(): optimizer.step() else: optimizer.step() model.zero_grad() model_ema.update(model, num_updates=i / 2) if (i + 1) % (gAcc * 2) == 0: lr_scheduler.step() i += 1 # validation part if True: elapsed_time = time.time() - start_time start_time = time.time() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, ted, bleu, preds, labels, infer_time = validation( model_ema.ema, criterion, valid_loader, converter, opt, pO) model.train() v_time = time.time() - start_time if OnceExecWorker: if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED checkpoint = { 'model': model.state_dict(), 'state_dict_ema': model_ema.ema.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save( checkpoint, f'./saved_models/{experiment_name}/best_norm_ED.pth') if ted < best_CER: best_CER = ted if current_accuracy > best_accuracy: best_accuracy = current_accuracy out = f'[{i}] Loss: {train_loss.avg:0.5f} time: ({elapsed_time:0.1f},{v_time:0.1f})' out += f' vloss: {valid_loss:0.3f}' out += f' CER: {ted:0.4f} NER: {current_norm_ED:0.4f} lr: {lr_scheduler.get_lr()[0]:0.5f}' out += f' bAcc: {best_accuracy:0.1f}, bNER: {best_norm_ED:0.4f}, bCER: {best_CER:0.4f}, B: {bleu*100:0.2f}' print(out) with open(f'./saved_models/{experiment_name}/log_train.txt', 'a') as log: log.write(out + '\n') # if WdB: # wandb.log({'lr': lr_scheduler.get_lr()[0], 'It':i, 'nED': current_norm_ED, 'B':bleu*100, # 'tloss':train_loss.avg, 'AnED': best_norm_ED, 'CER':ted, 'bestCER':best_CER, 'vloss':valid_loss}) if DEBUG: print( f'[!!!] Iteration check. Value of i: {i} | Value of num_iter: {num_iter}' ) # Change i == num_iter to i >= num_iter # Add num_iter > 0 condition if num_iter > 0 and i >= num_iter: print('end the training') #sys.exit() break
def train(opt): """ dataset preparation """ if opt.select_data == 'baidu': train_set = BAIDUset(opt, opt.train_csv) train_loader = torch.utils.data.DataLoader( train_set, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), collate_fn=BaiduCollate(opt.imgH, opt.imgW, keep_ratio=False)) val_set = BAIDUset(opt, opt.val_csv) valid_loader = torch.utils.data.DataLoader( val_set, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), collate_fn=BaiduCollate(opt.imgH, opt.imgW, keep_ratio=False), pin_memory=True) else: opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Bert' in opt.Prediction: converter = TransformerConverter(opt.character, opt.max_seq) elif 'SRN' in opt.Prediction: converter = SRNConverter(opt.character, opt.SRN_PAD) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).cuda() model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() elif 'Bert' in opt.Prediction: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() elif 'SRN' in opt.Prediction: criterion = cal_performance else: criterion = torch.nn.CrossEntropyLoss( ignore_index=0).cuda() # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.ranger: optimizer = Ranger(filtered_parameters, lr=opt.lr) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) lrScheduler = lr_scheduler.MultiStepLR(optimizer, [5, 20, 30], gamma=0.5) # 减小学习速率 """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter if opt.select_data == 'baidu': train_iter = iter(train_loader) step_per_epoch = len(train_set) / opt.batch_size print('一代有多少step:', step_per_epoch) else: step_per_epoch = train_dataset.nums_samples / opt.batch_size print('一代有多少step:', step_per_epoch) while (True): # try: # train part for p in model.parameters(): p.requires_grad = True if opt.select_data == 'baidu': try: image_tensors, labels = train_iter.next() except: train_iter = iter(train_loader) image_tensors, labels = train_iter.next() else: image_tensors, labels = train_dataset.get_batch() image = image_tensors.cuda() text, length = converter.encode(labels) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) elif 'Bert' in opt.Prediction: pad_mask = None # print(image.shape) preds = model(image, pad_mask) cost = criterion(preds[0].view(-1, preds[0].shape[-1]), text.contiguous().view(-1)) + \ criterion(preds[1].view(-1, preds[1].shape[-1]), text.contiguous().view(-1)) elif 'SRN' in opt.Prediction: preds = model(image, None) cost, n_correct = criterion(preds, text) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) if i % opt.disInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) start_time = time.time() # validation part if i % opt.valInterval == 0 and i > start_iter: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print( f'pred: {pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'pred: {pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n' ) valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % opt.saveInterval == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() if i > 0 and i % step_per_epoch == 0: # 调整学习速率 lrScheduler.step() i += 1
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ # CTCLoss converter_ctc = CTCLabelConverter(opt.character) # Attention converter_atten = AttnLabelConverter(opt.character) opt.num_class_ctc = len(converter_ctc.character) opt.num_class_atten = len(converter_atten.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p_: p_.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # use fp16 to train model = model.to(device) if opt.fp16: with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: log.write('==> Enable fp16 training' + '\n') print('==> Enable fp16 training') model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # data parallel for multi-GPU if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model).to(device) model.train() # for i in model.module.Prediction_atten: # i.to(device) # for i in model.module.Feat_Extraction.scr: # i.to(device) if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device) criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() """ final options """ writer = SummaryWriter(f'./saved_models/{opt.exp_name}') # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter # image_tensors, labels = train_dataset.get_batch() while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) batch_size = image.size(0) text_ctc, length_ctc = converter_ctc.encode( labels, batch_max_length=opt.batch_max_length) text_atten, length_atten = converter_atten.encode( labels, batch_max_length=opt.batch_max_length) # type tuple; (tensor, list); text_atten[:, :-1]:align with Attention.forward preds_ctc, preds_atten = model(image, text_atten[:, :-1]) # CTC Loss preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size) # _, preds_index = preds_ctc.max(2) # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data) preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2) cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size, length_ctc) # Attention Loss # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten] # # select max probabilty (greedy decoding) then decode index to character # preds_index_atten = [i.max(2)[1] for i in preds_atten] # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten] # preds_str_atten2 = preds_str_atten # preds_str_atten = [] # for i in preds_str_atten2: # prune after "end of sentence" token ([s]) # temp = [] # for j in i: # j = j[:j.find('[s]')] # temp.append(j) # preds_str_atten.append(temp) # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i] target = text_atten[:, 1:] # without [GO] Symbol # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1)) for index, pred in enumerate(preds_atten): if index == 0: cost_atten = 1.0 * criterion_atten( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) else: cost_atten += 1.0 * criterion_atten( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in # preds_atten] # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1)) cost = cost_ctc + cost_atten writer.add_scalar('loss', cost.item(), global_step=iteration + 1) # cost = cost_ctc # cost = cost_atten if (iteration + 1) % 100 == 0: print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format( iteration + 1, cost.item(), loss_avg.val()), end='\n') else: print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format( iteration + 1, cost.item(), loss_avg.val()), end='') sys.stdout.flush() if cost < 0.001: print(f'iter: {iteration + 1}\tloss: {cost}') # aaaaaa = 0 # model.zero_grad() optimizer.zero_grad() if torch.isnan(cost): print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN') sys.exit() elif torch.isinf(cost): print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF') sys.exit() else: if opt.fp16: with amp.scale_loss(cost, optimizer) as scaled_loss: scaled_loss.backward() else: cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) writer.add_scalar('loss_avg', loss_avg.val(), global_step=iteration + 1) # if loss_avg.val() <= 0.6: # opt.grad_clip = 2 # if loss_avg.val() <= 0.3: # opt.grad_clip = 1 # validation part if iteration == 0 or ( iteration + 1 ) % opt.valInterval == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion_atten, valid_loader, converter_atten, opt) model.train() writer.add_scalar('accuracy', current_accuracy, global_step=iteration + 1) # training loss and validation loss loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() # if (iteration + 1) % opt.valInterval == 0: # print(f'iter: {iteration + 1}\tloss: {cost}') iteration += 1
def train(opt): """ training pipeline for our character recognition model """ if not opt.data_filtering_off: print( "Filtering the images containing characters which are not in opt.character" ) print( "Filtering the images whose label is longer than opt.batch_max_length" ) opt.select_data = opt.select_data.split("-") opt.batch_ratio = opt.batch_ratio.split("-") train_dataset = Batch_Balanced_Dataset(opt) # Logging the experiment, so that we can refer to the performance of previous runs log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a") # Using params from user input to collation function for dataloader AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # Defining our validation dataloader valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, ) log.write(valid_dataset_log) print("-" * 80) log.write("-" * 80 + "\n") log.close() # Using either CTC or Attention for char predictions if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) # Runnning our OCR model in grayscale or RGB if opt.rgb: opt.input_channel = 3 # Defining our model using user inputs model = Model(opt) print( "model input parameters", opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction, ) # weight initialization for name, param in model.named_parameters(): if "localization_fc2" in name: print(f"Skip {name} as it is already initialized") continue try: if "bias" in name: init.constant_(param, 0.0) elif "weight" in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if "weight" in name: param.data.fill_(1) continue # Putting model in training mode model.train() # Using finetuning saved model from previous runs if opt.saved_model != "": print(f"loading pretrained model from {opt.saved_model}") if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") # print(model) # Sending model to cpu or gpu, depending upon the avialbility model.to(device) # Setting up loss functions in the case of either CTC or Attention if "CTC" in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print("Trainable params num : ", sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # Setup of optimizer to be used if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) # print(opt) with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file: opt_log = "------------ Options -------------\n" args = vars(opt) for k, v in args.items(): opt_log += f"{str(k)}: {str(v)}\n" opt_log += "---------------------------------------\n" print(opt_log) opt_file.write(opt_log) # Training iteration starts here start_iter = 0 if opt.saved_model != "": try: start_iter = int(opt.saved_model.split("_")[-1].split(".")[0]) print(f"continue to train, start_iter: {start_iter}") except: pass # Setting up initial metrics results and initializing the timer start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if "CTC" in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: model.eval() with torch.no_grad(): ( valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data, ) = validation(model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}" loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_accuracy.pth", ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/best_norm_ED.pth", ) best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}" print(loss_model_log) log.write(loss_model_log + "\n") # show some predicted results dashed_line = "-" * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if "Attn" in opt.Prediction: gt = gt[:gt.find("[s]")] pred = pred[:pred.find("[s]")] predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" predicted_result_log += f"{dashed_line}" print(predicted_result_log) log.write(predicted_result_log + "\n") # save model per 1e+5 iter. if (iteration + 1) % 1e5 == 0: torch.save( model.state_dict(), f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth", ) if (iteration + 1) == opt.num_iter: print("end the training") sys.exit() iteration += 1