예제 #1
0
def compare_pois_l2(s, c, saved_epoch, img_ind, add2exp='_400/',\
    path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/'):

    exp_name_pois = 'pois' + add2exp + 's{}c{}'.format(s, c)
    exp_name_l2 = 'l2' + add2exp + 's{}c{}'.format(s, c)

    model_pois = PoisNet(output_features=c, stages=s)
    model_l2 = UDNet(output_features=c, stages=s)

    path2dataset = path2valdata
    BSDSval = BSDS500(path2dataset + 'val/', get_name=True)

    gt, noisy, file_name = BSDSval[img_ind]
    split = file_name.split('_')
    name, maxval = split[0], split[2][7:]
    maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10

    gt.unsqueeze_(0)
    noisy.unsqueeze_(0)

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_pois+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_pois.load_state_dict(state['model_state_dict'])
    estim_pois = model_pois(noisy, noisy).detach()

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_l2+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_l2.load_state_dict(state['model_state_dict'])
    stdn = th.Tensor([5])
    estim_l2 = model_l2(noisy, stdn, noisy).detach()

    psnr_noisy = psnr(gt, noisy)
    psnr_est_pois = psnr(gt, estim_pois)
    psnr_est_l2 = psnr(gt, estim_l2)

    gt_title = 'clear ({} in BSDS val)'.format(img_ind)
    noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format(
        maxval, psnr_noisy)
    estim_pois_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('pois', saved_epoch, psnr_est_pois)
    estim_l2_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('l2', saved_epoch, psnr_est_l2)

    show_images([noisy, gt, estim_pois, estim_l2], \
                [noisy_title, gt_title, estim_pois_title, estim_l2_title])
params = OrderedDict(kernel_size=opt.kernel_size,input_channels=input_channels,\
         output_features=output_features,rbf_mixtures=opt.rbf_mixtures,\
         rbf_precision=opt.rbf_precision,stages=opt.stages,pad=opt.pad,\
         padType=opt.padType,convWeightSharing=opt.convWeightSharing,\
         scale_f=opt.scale_f,scale_t=opt.scale_t,normalizedWeights=\
         opt.normalizedWeights,zeroMeanWeights=opt.zeroMeanWeights,rbf_start=\
         opt.rbf_start,rbf_end=opt.rbf_end,data_min=opt.data_min,data_max=\
         opt.data_max,data_step=opt.data_step,alpha=opt.alpha,clb=opt.clb,\
         cub=opt.cub)

model = UDNet(*params.values())

if opt.initModelPath != '':
    state = th.load(opt.initModelPath,
                    map_location=lambda storage, loc: storage)
    model.load_state_dict(state['model_state_dict'])
    opt.resume = False

#criterion = nn.MSELoss(size_average=True,reduce=True)
criterion = PSNRLoss(peakval=opt.cub)

optimizer = optim.Adam(model.parameters(),
                       lr=opt.lr,
                       betas=(0.9, 0.999),
                       eps=1e-04)

if opt.cuda:
    model = model.cuda()
    criterion = criterion.cuda()

start = 0
예제 #3
0
def compare_pois_l2_pois_w_prox(s, c, saved_epoch, img_ind, add2exp='_400/',\
    path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/'):

    exp_name_pois = 'pois' + add2exp + 's{}c{}'.format(s, c)
    exp_name_poisprox = 'pois_w_prox' + add2exp + 's{}c{}'.format(s, c)
    exp_name_l2 = 'l2' + add2exp + 's{}c{}'.format(s, c)

    model_pois = PoisNet(output_features=c, stages=s)
    model_poisprox = PoisNet(output_features=c, stages=s, prox_param=True)
    model_l2 = UDNet(output_features=c, stages=s)

    path2dataset = path2valdata
    BSDSval = BSDS500(path2dataset + 'val/', get_name=True)

    gt, noisy, file_name = BSDSval[img_ind]
    split = file_name.split('_')
    name, maxval = split[0], split[2][7:]
    maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10

    gt.unsqueeze_(0)
    noisy.unsqueeze_(0)

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_pois+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_pois.load_state_dict(state['model_state_dict'])
    estim_pois = model_pois(noisy, noisy).detach()

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_poisprox+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_poisprox.load_state_dict(state['model_state_dict'])
    estim_poisprox = model_poisprox(noisy, noisy).detach()

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_l2+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_l2.load_state_dict(state['model_state_dict'])
    stdn = th.Tensor([5])
    estim_l2 = model_l2(noisy, stdn, noisy).detach()

    psnr_noisy = psnr(gt, noisy)
    psnr_est_pois = psnr(gt, estim_pois)
    psnr_est_poisprox = psnr(gt, estim_poisprox)
    psnr_est_l2 = psnr(gt, estim_l2)

    gt_title = 'clear ({} in BSDS val)'.format(img_ind)
    noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format(
        maxval, psnr_noisy)
    estim_pois_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('pois', saved_epoch, psnr_est_pois)
    estim_poisprox_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('poisprox', saved_epoch, psnr_est_poisprox)
    estim_l2_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('l2', saved_epoch, psnr_est_l2)

    images = [noisy, gt, estim_pois, estim_poisprox, estim_l2]
    titles = [
        noisy_title, gt_title, estim_pois_title, estim_poisprox_title,
        estim_l2_title
    ]

    fontsize = 15

    images_corrected_dims = []
    for i, img in enumerate(images):
        if img.dim() == 4:
            img = img[0]

        img = img[0] if img.size()[0] == 1 else img.permute(1, 2, 0)
        images_corrected_dims.append(img)

    images = images_corrected_dims

    figsize = (20, 10)
    fig, ax = plt.subplots(2, 3, figsize=figsize)
    fig.patch.set_facecolor('white')

    ax[0, 0].imshow(images[0], cmap='gray')
    ax[0, 0].set_axis_off()
    ax[0, 0].set_title(noisy_title, fontsize=fontsize)
    ax[0, 1].imshow(images[1], cmap='gray')
    ax[0, 1].set_axis_off()
    ax[0, 1].set_title(gt_title, fontsize=fontsize)
    ax[1, 0].imshow(images[2], cmap='gray')
    ax[1, 0].set_axis_off()
    ax[1, 0].set_title(estim_pois_title, fontsize=fontsize)
    ax[1, 1].imshow(images[3], cmap='gray')
    ax[1, 1].set_axis_off()
    ax[1, 1].set_title(estim_poisprox_title, fontsize=fontsize)
    ax[1, 2].imshow(images[4], cmap='gray')
    ax[1, 2].set_axis_off()
    ax[1, 2].set_title(estim_l2_title, fontsize=fontsize)
    ax[0, 2].remove()

    fig.tight_layout()