def train(opt): model = Model.Basemodel(opt, device) # weight initialization for name, param, in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initializaed') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e : if 'weight' in name: param.data.fill_(1) continue # load pretrained model if opt.saved_model != '': base_path = './models' print(f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}') try : model.load_state_dict(torch.load(os.path.join(base_path, opt.saved_model))) print('loading complete ') except Exception as e: print(e) print('coud not find model') #data parallel for multi GPU model = torch.nn.DataParallel(model, device_ids=[0]).to(device) model.train() # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p : p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Tranable params : ', sum(params_num)) loss_avg = utils.Averager() loss_avg_glyph = utils.Averager() # optimizer optimizer = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps) # optimizer = torch.optim.Adam(filtered_parameters, lr=0.0001) # optimizer = SWA(base_opt) # optimizer = torch.optim.AdamW(filtered_parameters) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience = 2, factor= 0.5 ) # optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1) # opt log with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '---------------------Options-----------------\n' args = vars(opt) for k, v in args.items(): opt_log +=f'{str(k)} : {str(v)}\n' opt_log +='---------------------------------------------\n' opt_file.write(opt_log) #start training start_time = time.time() best_accuracy = -1 best_norm_ED = -1 swa_count = 0 for n_epoch, epoch in enumerate(range(opt.num_epoch)): for n_iter, data_point in enumerate(data_loader): image, labels = data_point image = image.to(device) try: target, length = converter.encode(labels, batch_max_length = opt.max_length) batch_size = image.size(0) except Exception as e: print(f'{e}') continue logits, glyphs, embedding_ids = model(image, (target, length), is_train = True) recognition_loss = model.module.decoder.recognition_loss(logits.view(-1, opt.num_classes+2), target.view(-1)) generation_loss = model.module.generator.glyph_loss(glyphs, target, length, embedding_ids, opt) cost = recognition_loss + generation_loss loss_avg.add(recognition_loss) loss_avg_glyph.add(generation_loss) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) #gradient clipping with 5 optimizer.step() #validation if (n_iter % opt.val_interval == 0) & (n_iter!=0) : # & (n_iter!=0) elapsed_time = time.time() - start_time with open(f'./models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss_recog, valid_loss_glyph, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = evaluate.validation_efifstr(model, valid_loader, converter, opt) model.train() present_time = time.localtime() loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.998)}]\n'+ f'Train recognition loss : {loss_avg.val():0.5f}, Glyph loss : {loss_avg_glyph.val():0.5f}\nValid recogntion loss : {valid_loss_recog:0.5f}, Glyph loss : {valid_loss_glyph:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}' loss_avg.reset() loss_avg_glyph.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}' #keep the best if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log+'\n') dashed_line = '-'*80 head = f'{"Ground Truth":25s} | {"Prediction" :25s}| Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' random_idx = np.random.choice(range(len(labels)), size= 5, replace=False) for gt, pred, confidence in zip(list(np.asarray(labels)[random_idx]), list(np.asarray(preds)[random_idx]), list(np.asarray(confidence_score)[random_idx])): gt = gt[: gt.find('[s]')] pred = pred[: pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log+'\n') # # Stochastic weight averaging # optimizer.update_swa() # swa_count+=1 # if swa_count % 3 ==0: # optimizer.swap_swa_sgd() # torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/swa_{swa_count}.pth') if (n_epoch) % 5 ==0: torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/{n_epoch}.pth')