Exemplo n.º 1
0
        results['Avg. PSNR1'].append(float('%.2f'%(avg_psnr1 / len(val_data_loader))))
        results['Avg. PSNR2'].append(float('%.2f'%(avg_psnr2 / len(val_data_loader))))


def checkpoint(epoch):
    model_out_g_path = "LapSRN_model_epoch_g_{}.pth".format(epoch)
    state_g = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch':epoch, 'lr':lr}
    torch.save(state_g, model_out_g_path, _use_new_zipfile_serialization=False)
    print("Checkpoint saved to {}".format(model_out_g_path))


if os.path.exists(opt.pre_model):
    model = LapSRN().cuda()
    checkpoints = torch.load(opt.pre_model)
    model.load_state_dict(checkpoints['model'])
    model.train()
    epoch_continue = checkpoints['epoch']
    
    optimizer = optim.Adagrad(model.parameters())


    for epoch in range(epoch_continue + 1, opt.nEpochs + 1):

        if epoch == epoch_continue + 1:
            optimizer.load_state_dict(checkpoints['optimizer'])
            lr = checkpoints['lr']

        else:
            optimizer = optim.Adagrad(model.parameters(), lr = lr, weight_decay = 1e-5)
            
            
Exemplo n.º 2
0
    model_out_r_path = "LapSRN_model_epoch_r_{}.pth".format(epoch)
    state_r = {'model': model_r.state_dict(), 'optimizer': optimizer_r.state_dict(), 'epoch':epoch, 'lr':lr_r}
    torch.save(state_r, model_out_r_path, _use_new_zipfile_serialization=False)
    
    model_out_g_path = "LapSRN_model_epoch_g_{}.pth".format(epoch)
    state_g = {'model': model_g.state_dict(), 'optimizer': optimizer_g.state_dict(), 'epoch':epoch, 'lr':lr_g}
    torch.save(state_g, model_out_g_path, _use_new_zipfile_serialization=False)
    
    print("Checkpoint saved to {} and {}".format(model_out_r_path, model_out_g_path))


if os.path.exists(opt.pre_model_r):
    model_r = LapSRN().to(device)
    checkpoints_r = torch.load(opt.pre_model_r)
    model_r.load_state_dict(checkpoints_r['model'])
    model_r.train()
    epoch_continue_r = checkpoints_r['epoch']
    optimizer_r = optim.Adagrad(model_r.parameters())


    model_g = LapSRN().to(device)
    checkpoints_g = torch.load(opt.pre_model_g)
    model_g.load_state_dict(checkpoints_g['model'])
    model_g.train()
    epoch_continue_g = checkpoints_g['epoch']
    optimizer_g = optim.Adagrad(model_g.parameters())

    for epoch in range(epoch_continue_g + 1, opt.nEpochs + 1):

        if epoch == epoch_continue_g + 1:
            optimizer_r.load_state_dict(checkpoints_r['optimizer'])