def compare_pois_l2(s, c, saved_epoch, img_ind, add2exp='_400/',\ path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/'): exp_name_pois = 'pois' + add2exp + 's{}c{}'.format(s, c) exp_name_l2 = 'l2' + add2exp + 's{}c{}'.format(s, c) model_pois = PoisNet(output_features=c, stages=s) model_l2 = UDNet(output_features=c, stages=s) path2dataset = path2valdata BSDSval = BSDS500(path2dataset + 'val/', get_name=True) gt, noisy, file_name = BSDSval[img_ind] split = file_name.split('_') name, maxval = split[0], split[2][7:] maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10 gt.unsqueeze_(0) noisy.unsqueeze_(0) state = th.load('./PoisDenoiser/networks/PoisNet/models/'\ +exp_name_pois+'/state_{}.pth'.format(saved_epoch),\ map_location=lambda storage,loc:storage) model_pois.load_state_dict(state['model_state_dict']) estim_pois = model_pois(noisy, noisy).detach() state = th.load('./PoisDenoiser/networks/PoisNet/models/'\ +exp_name_l2+'/state_{}.pth'.format(saved_epoch),\ map_location=lambda storage,loc:storage) model_l2.load_state_dict(state['model_state_dict']) stdn = th.Tensor([5]) estim_l2 = model_l2(noisy, stdn, noisy).detach() psnr_noisy = psnr(gt, noisy) psnr_est_pois = psnr(gt, estim_pois) psnr_est_l2 = psnr(gt, estim_l2) gt_title = 'clear ({} in BSDS val)'.format(img_ind) noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format( maxval, psnr_noisy) estim_pois_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\ .format('pois', saved_epoch, psnr_est_pois) estim_l2_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\ .format('l2', saved_epoch, psnr_est_l2) show_images([noisy, gt, estim_pois, estim_l2], \ [noisy_title, gt_title, estim_pois_title, estim_l2_title])
params = OrderedDict(kernel_size=opt.kernel_size,input_channels=input_channels,\ output_features=output_features,rbf_mixtures=opt.rbf_mixtures,\ rbf_precision=opt.rbf_precision,stages=opt.stages,pad=opt.pad,\ padType=opt.padType,convWeightSharing=opt.convWeightSharing,\ scale_f=opt.scale_f,scale_t=opt.scale_t,normalizedWeights=\ opt.normalizedWeights,zeroMeanWeights=opt.zeroMeanWeights,rbf_start=\ opt.rbf_start,rbf_end=opt.rbf_end,data_min=opt.data_min,data_max=\ opt.data_max,data_step=opt.data_step,alpha=opt.alpha,clb=opt.clb,\ cub=opt.cub) model = UDNet(*params.values()) if opt.initModelPath != '': state = th.load(opt.initModelPath, map_location=lambda storage, loc: storage) model.load_state_dict(state['model_state_dict']) opt.resume = False #criterion = nn.MSELoss(size_average=True,reduce=True) criterion = PSNRLoss(peakval=opt.cub) optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-04) if opt.cuda: model = model.cuda() criterion = criterion.cuda() start = 0
def compare_pois_l2_pois_w_prox(s, c, saved_epoch, img_ind, add2exp='_400/',\ path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/'): exp_name_pois = 'pois' + add2exp + 's{}c{}'.format(s, c) exp_name_poisprox = 'pois_w_prox' + add2exp + 's{}c{}'.format(s, c) exp_name_l2 = 'l2' + add2exp + 's{}c{}'.format(s, c) model_pois = PoisNet(output_features=c, stages=s) model_poisprox = PoisNet(output_features=c, stages=s, prox_param=True) model_l2 = UDNet(output_features=c, stages=s) path2dataset = path2valdata BSDSval = BSDS500(path2dataset + 'val/', get_name=True) gt, noisy, file_name = BSDSval[img_ind] split = file_name.split('_') name, maxval = split[0], split[2][7:] maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10 gt.unsqueeze_(0) noisy.unsqueeze_(0) state = th.load('./PoisDenoiser/networks/PoisNet/models/'\ +exp_name_pois+'/state_{}.pth'.format(saved_epoch),\ map_location=lambda storage,loc:storage) model_pois.load_state_dict(state['model_state_dict']) estim_pois = model_pois(noisy, noisy).detach() state = th.load('./PoisDenoiser/networks/PoisNet/models/'\ +exp_name_poisprox+'/state_{}.pth'.format(saved_epoch),\ map_location=lambda storage,loc:storage) model_poisprox.load_state_dict(state['model_state_dict']) estim_poisprox = model_poisprox(noisy, noisy).detach() state = th.load('./PoisDenoiser/networks/PoisNet/models/'\ +exp_name_l2+'/state_{}.pth'.format(saved_epoch),\ map_location=lambda storage,loc:storage) model_l2.load_state_dict(state['model_state_dict']) stdn = th.Tensor([5]) estim_l2 = model_l2(noisy, stdn, noisy).detach() psnr_noisy = psnr(gt, noisy) psnr_est_pois = psnr(gt, estim_pois) psnr_est_poisprox = psnr(gt, estim_poisprox) psnr_est_l2 = psnr(gt, estim_l2) gt_title = 'clear ({} in BSDS val)'.format(img_ind) noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format( maxval, psnr_noisy) estim_pois_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\ .format('pois', saved_epoch, psnr_est_pois) estim_poisprox_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\ .format('poisprox', saved_epoch, psnr_est_poisprox) estim_l2_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\ .format('l2', saved_epoch, psnr_est_l2) images = [noisy, gt, estim_pois, estim_poisprox, estim_l2] titles = [ noisy_title, gt_title, estim_pois_title, estim_poisprox_title, estim_l2_title ] fontsize = 15 images_corrected_dims = [] for i, img in enumerate(images): if img.dim() == 4: img = img[0] img = img[0] if img.size()[0] == 1 else img.permute(1, 2, 0) images_corrected_dims.append(img) images = images_corrected_dims figsize = (20, 10) fig, ax = plt.subplots(2, 3, figsize=figsize) fig.patch.set_facecolor('white') ax[0, 0].imshow(images[0], cmap='gray') ax[0, 0].set_axis_off() ax[0, 0].set_title(noisy_title, fontsize=fontsize) ax[0, 1].imshow(images[1], cmap='gray') ax[0, 1].set_axis_off() ax[0, 1].set_title(gt_title, fontsize=fontsize) ax[1, 0].imshow(images[2], cmap='gray') ax[1, 0].set_axis_off() ax[1, 0].set_title(estim_pois_title, fontsize=fontsize) ax[1, 1].imshow(images[3], cmap='gray') ax[1, 1].set_axis_off() ax[1, 1].set_title(estim_poisprox_title, fontsize=fontsize) ax[1, 2].imshow(images[4], cmap='gray') ax[1, 2].set_axis_off() ax[1, 2].set_title(estim_l2_title, fontsize=fontsize) ax[0, 2].remove() fig.tight_layout()