Пример #1
0
        print("===> Avg. PSNR2: {:.4f} dB".format(avg_psnr2 / len(val_data_loader)))
        results['Avg. PSNR1'].append(float('%.2f'%(avg_psnr1 / len(val_data_loader))))
        results['Avg. PSNR2'].append(float('%.2f'%(avg_psnr2 / len(val_data_loader))))


def checkpoint(epoch):
    model_out_g_path = "LapSRN_model_epoch_g_{}.pth".format(epoch)
    state_g = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch':epoch, 'lr':lr}
    torch.save(state_g, model_out_g_path, _use_new_zipfile_serialization=False)
    print("Checkpoint saved to {}".format(model_out_g_path))


if os.path.exists(opt.pre_model):
    model = LapSRN().cuda()
    checkpoints = torch.load(opt.pre_model)
    model.load_state_dict(checkpoints['model'])
    model.train()
    epoch_continue = checkpoints['epoch']
    
    optimizer = optim.Adagrad(model.parameters())


    for epoch in range(epoch_continue + 1, opt.nEpochs + 1):

        if epoch == epoch_continue + 1:
            optimizer.load_state_dict(checkpoints['optimizer'])
            lr = checkpoints['lr']

        else:
            optimizer = optim.Adagrad(model.parameters(), lr = lr, weight_decay = 1e-5)
            
    filenames = os.listdir(args.img_dir)
    image_filenames = [os.path.join(args.img_dir, x) for x in filenames \
                       if is_image_file(x)]
    image_filenames = sorted(image_filenames)

    model = LapSRN(img_channels=1,
                   upscale_factor=args.upscale_factor,
                   n_feat=10,
                   n_recursive=1,
                   local_residual='ns').to(device)
    if args.cuda:
        ckpt = torch.load(args.model)
    else:
        ckpt = torch.load(args.model, map_location='cpu')
    model.load_state_dict(ckpt['model'])

    res = {}

    for i, f in enumerate(image_filenames):
        # Read test image.
        img = Image.open(f).convert('RGB')
        width, height = img.size[0], img.size[1]

        # Crop test image so that it has size that can be downsampled by the upscale factor.
        pad_width = width % args.upscale_factor
        pad_height = height % args.upscale_factor
        width -= pad_width
        height -= pad_height
        img = img.crop((0, 0, width, height))
Пример #3
0
def checkpoint(epoch):
    model_out_r_path = "LapSRN_model_epoch_r_{}.pth".format(epoch)
    state_r = {'model': model_r.state_dict(), 'optimizer': optimizer_r.state_dict(), 'epoch':epoch, 'lr':lr_r}
    torch.save(state_r, model_out_r_path, _use_new_zipfile_serialization=False)
    
    model_out_g_path = "LapSRN_model_epoch_g_{}.pth".format(epoch)
    state_g = {'model': model_g.state_dict(), 'optimizer': optimizer_g.state_dict(), 'epoch':epoch, 'lr':lr_g}
    torch.save(state_g, model_out_g_path, _use_new_zipfile_serialization=False)
    
    print("Checkpoint saved to {} and {}".format(model_out_r_path, model_out_g_path))


if os.path.exists(opt.pre_model_r):
    model_r = LapSRN().to(device)
    checkpoints_r = torch.load(opt.pre_model_r)
    model_r.load_state_dict(checkpoints_r['model'])
    model_r.train()
    epoch_continue_r = checkpoints_r['epoch']
    optimizer_r = optim.Adagrad(model_r.parameters())


    model_g = LapSRN().to(device)
    checkpoints_g = torch.load(opt.pre_model_g)
    model_g.load_state_dict(checkpoints_g['model'])
    model_g.train()
    epoch_continue_g = checkpoints_g['epoch']
    optimizer_g = optim.Adagrad(model_g.parameters())

    for epoch in range(epoch_continue_g + 1, opt.nEpochs + 1):

        if epoch == epoch_continue_g + 1: