Example #1
0
def train(opt):
    model = Model.Basemodel(opt, device)
    
    # weight initialization
    for name, param, in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initializaed')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
                
        except Exception as e :
            if 'weight' in name:
                param.data.fill_(1)
            continue
            
    # load pretrained model
    if opt.saved_model != '':
        base_path = './models'
        print(f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}')
        
        try :
            model.load_state_dict(torch.load(os.path.join(base_path, opt.saved_model)))
            print('loading complete ')    
        except Exception as e:
            print(e)
            print('coud not find model')
            
    #data parallel for multi GPU
    model = torch.nn.DataParallel(model, device_ids=[0]).to(device)
    model.train() 
    
    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p : p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Tranable params : ', sum(params_num))

    loss_avg = utils.Averager()
    loss_avg_glyph = utils.Averager()
    
    # optimizer
    
    optimizer = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps)
#     optimizer = torch.optim.Adam(filtered_parameters, lr=0.0001)
#     optimizer = SWA(base_opt)
#     optimizer = torch.optim.AdamW(filtered_parameters)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience = 2, factor= 0.5 )
#     optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1)

    # opt log
    with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '---------------------Options-----------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log +=f'{str(k)} : {str(v)}\n'
        opt_log +='---------------------------------------------\n'
        opt_file.write(opt_log)
        
    #start training
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    swa_count = 0
    
    for n_epoch, epoch in enumerate(range(opt.num_epoch)):
        for n_iter, data_point in enumerate(data_loader):
            
            image, labels = data_point 
            image = image.to(device)
            try:
                target, length = converter.encode(labels, batch_max_length = opt.max_length)
                batch_size = image.size(0)
            except Exception as e:
                print(f'{e}')
                continue

            logits, glyphs, embedding_ids = model(image, (target, length), is_train = True)
            
            recognition_loss = model.module.decoder.recognition_loss(logits.view(-1, opt.num_classes+2), target.view(-1))
            generation_loss = model.module.generator.glyph_loss(glyphs, target, length, embedding_ids, opt)
            
            cost = recognition_loss + generation_loss

            loss_avg.add(recognition_loss)
            loss_avg_glyph.add(generation_loss)
            
            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) #gradient clipping with 5
            optimizer.step()
            
            #validation
            if (n_iter % opt.val_interval == 0) & (n_iter!=0)   :
#                 & (n_iter!=0)
                elapsed_time = time.time() - start_time
                with open(f'./models/{opt.experiment_name}/log_train.txt', 'a') as log:
                    model.eval()
                    with torch.no_grad():
                        valid_loss_recog, valid_loss_glyph, current_accuracy, current_norm_ED, preds, confidence_score, labels,                         infer_time, length_of_data = evaluate.validation_efifstr(model, valid_loader, converter, opt)
                    model.train()

                    present_time = time.localtime()
                    loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.998)}]\n'+                    f'Train recognition loss : {loss_avg.val():0.5f}, Glyph loss : {loss_avg_glyph.val():0.5f}\nValid recogntion loss : {valid_loss_recog:0.5f}, Glyph loss : {valid_loss_glyph:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}'
                    loss_avg.reset()
                    loss_avg_glyph.reset()

                    current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}'

                    #keep the best
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth')

                    if current_norm_ED > best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_norm_ED.pth')

                    best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                    loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                    print(loss_model_log)
                    log.write(loss_model_log+'\n')

                    dashed_line = '-'*80
                    head = f'{"Ground Truth":25s} | {"Prediction" :25s}| Confidence Score & T/F'
                    predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                    random_idx  = np.random.choice(range(len(labels)), size= 5, replace=False)
                    for gt, pred, confidence in zip(list(np.asarray(labels)[random_idx]), list(np.asarray(preds)[random_idx]), list(np.asarray(confidence_score)[random_idx])):
                        gt = gt[: gt.find('[s]')]
                        pred = pred[: pred.find('[s]')]

                        predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                    predicted_result_log += f'{dashed_line}'
                    print(predicted_result_log)
                    log.write(predicted_result_log+'\n')

#                 # Stochastic weight averaging
#                 optimizer.update_swa()
#                 swa_count+=1
#                 if swa_count % 3 ==0:
#                     optimizer.swap_swa_sgd()
#                     torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/swa_{swa_count}.pth')

        if (n_epoch) % 5 ==0:
            torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/{n_epoch}.pth')