Exemplo n.º 1
0
def compare_pois_l2(s, c, saved_epoch, img_ind, add2exp='_400/',\
    path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/'):

    exp_name_pois = 'pois' + add2exp + 's{}c{}'.format(s, c)
    exp_name_l2 = 'l2' + add2exp + 's{}c{}'.format(s, c)

    model_pois = PoisNet(output_features=c, stages=s)
    model_l2 = UDNet(output_features=c, stages=s)

    path2dataset = path2valdata
    BSDSval = BSDS500(path2dataset + 'val/', get_name=True)

    gt, noisy, file_name = BSDSval[img_ind]
    split = file_name.split('_')
    name, maxval = split[0], split[2][7:]
    maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10

    gt.unsqueeze_(0)
    noisy.unsqueeze_(0)

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_pois+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_pois.load_state_dict(state['model_state_dict'])
    estim_pois = model_pois(noisy, noisy).detach()

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_l2+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_l2.load_state_dict(state['model_state_dict'])
    stdn = th.Tensor([5])
    estim_l2 = model_l2(noisy, stdn, noisy).detach()

    psnr_noisy = psnr(gt, noisy)
    psnr_est_pois = psnr(gt, estim_pois)
    psnr_est_l2 = psnr(gt, estim_l2)

    gt_title = 'clear ({} in BSDS val)'.format(img_ind)
    noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format(
        maxval, psnr_noisy)
    estim_pois_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('pois', saved_epoch, psnr_est_pois)
    estim_l2_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('l2', saved_epoch, psnr_est_l2)

    show_images([noisy, gt, estim_pois, estim_l2], \
                [noisy_title, gt_title, estim_pois_title, estim_l2_title])
Exemplo n.º 2
0
def do_inference_return_lists(noisy, gt, s, c, exp_name, saved_epoch, model_type='pois',\
    app=None,  prox_param=False, sharing_weights=True):

    if model_type == 'pois':
        model = PoisNet(output_features=c,
                        stages=s,
                        prox_param=prox_param,
                        convWeightSharing=sharing_weights)
    elif model_type == 'l2':
        model = UDNet(output_features=c,
                      stages=s,
                      alpha=prox_param,
                      convWeightSharing=sharing_weights)

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)

    model.load_state_dict(state['model_state_dict'])

    if model_type == 'pois':
        estim = model(noisy, noisy).detach()
    elif model_type == 'l2':
        stdn = th.Tensor([1])
        estim = model(noisy, stdn, noisy).detach()

    psnr_est = psnr(gt, estim)

    if app is not None:
        estim_title = '{} (epoch={}),\nPSNR: {:.3f} dB'.format(
            app, saved_epoch, psnr_est)
    else:
        estim_title = 'estim (epoch={}),\nPSNR: {:.3f} dB'.format(
            saved_epoch, psnr_est)

    return estim, estim_title
Exemplo n.º 3
0
output_features = opt.channels 

prox_param = opt.prox_param
sharing_weights = not opt.no_sharing_weights

print(prox_param)
print(sharing_weights)


model_type = opt.model_type
print(model_type)
if model_type == 'pois':
    model = PoisNet(stages=stages, output_features=output_features,\
        prox_param=prox_param, convWeightSharing=sharing_weights).cuda()
elif model_type == 'l2':
    model = UDNet(stages=stages, output_features=output_features, \
        alpha=prox_param, convWeightSharing=sharing_weights).cuda()

lr0 = opt.lr0
optimizer = Adam(model.parameters(), lr=lr0)

milestones = list(map(int, opt.milestones.strip('[]').split(','))) \
            if opt.milestones != '' else []
    
scheduler = MultiStepLR(optimizer, milestones=milestones, \
                                   gamma=0.1)
loss = opt.loss
if loss == 'MSE':
    criterion = MSELoss().cuda()
elif loss == 'pois':
    criterion = poisLLHLoss
elif loss == 'L1':
    train_set.stdn)

print('===> Building model')

# Parameters that we need to specify in order to initialize our model
params = OrderedDict(kernel_size=opt.kernel_size,input_channels=input_channels,\
         output_features=output_features,rbf_mixtures=opt.rbf_mixtures,\
         rbf_precision=opt.rbf_precision,stages=opt.stages,pad=opt.pad,\
         padType=opt.padType,convWeightSharing=opt.convWeightSharing,\
         scale_f=opt.scale_f,scale_t=opt.scale_t,normalizedWeights=\
         opt.normalizedWeights,zeroMeanWeights=opt.zeroMeanWeights,rbf_start=\
         opt.rbf_start,rbf_end=opt.rbf_end,data_min=opt.data_min,data_max=\
         opt.data_max,data_step=opt.data_step,alpha=opt.alpha,clb=opt.clb,\
         cub=opt.cub)

model = UDNet(*params.values())

if opt.initModelPath != '':
    state = th.load(opt.initModelPath,
                    map_location=lambda storage, loc: storage)
    model.load_state_dict(state['model_state_dict'])
    opt.resume = False

#criterion = nn.MSELoss(size_average=True,reduce=True)
criterion = PSNRLoss(peakval=opt.cub)

optimizer = optim.Adam(model.parameters(),
                       lr=opt.lr,
                       betas=(0.9, 0.999),
                       eps=1e-04)
Exemplo n.º 5
0
def compare_pois_l2_pois_w_prox(s, c, saved_epoch, img_ind, add2exp='_400/',\
    path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/'):

    exp_name_pois = 'pois' + add2exp + 's{}c{}'.format(s, c)
    exp_name_poisprox = 'pois_w_prox' + add2exp + 's{}c{}'.format(s, c)
    exp_name_l2 = 'l2' + add2exp + 's{}c{}'.format(s, c)

    model_pois = PoisNet(output_features=c, stages=s)
    model_poisprox = PoisNet(output_features=c, stages=s, prox_param=True)
    model_l2 = UDNet(output_features=c, stages=s)

    path2dataset = path2valdata
    BSDSval = BSDS500(path2dataset + 'val/', get_name=True)

    gt, noisy, file_name = BSDSval[img_ind]
    split = file_name.split('_')
    name, maxval = split[0], split[2][7:]
    maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10

    gt.unsqueeze_(0)
    noisy.unsqueeze_(0)

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_pois+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_pois.load_state_dict(state['model_state_dict'])
    estim_pois = model_pois(noisy, noisy).detach()

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_poisprox+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_poisprox.load_state_dict(state['model_state_dict'])
    estim_poisprox = model_poisprox(noisy, noisy).detach()

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name_l2+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)
    model_l2.load_state_dict(state['model_state_dict'])
    stdn = th.Tensor([5])
    estim_l2 = model_l2(noisy, stdn, noisy).detach()

    psnr_noisy = psnr(gt, noisy)
    psnr_est_pois = psnr(gt, estim_pois)
    psnr_est_poisprox = psnr(gt, estim_poisprox)
    psnr_est_l2 = psnr(gt, estim_l2)

    gt_title = 'clear ({} in BSDS val)'.format(img_ind)
    noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format(
        maxval, psnr_noisy)
    estim_pois_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('pois', saved_epoch, psnr_est_pois)
    estim_poisprox_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('poisprox', saved_epoch, psnr_est_poisprox)
    estim_l2_title = '{} (epoch={}),\nPSNR: {:.3f} dB'\
    .format('l2', saved_epoch, psnr_est_l2)

    images = [noisy, gt, estim_pois, estim_poisprox, estim_l2]
    titles = [
        noisy_title, gt_title, estim_pois_title, estim_poisprox_title,
        estim_l2_title
    ]

    fontsize = 15

    images_corrected_dims = []
    for i, img in enumerate(images):
        if img.dim() == 4:
            img = img[0]

        img = img[0] if img.size()[0] == 1 else img.permute(1, 2, 0)
        images_corrected_dims.append(img)

    images = images_corrected_dims

    figsize = (20, 10)
    fig, ax = plt.subplots(2, 3, figsize=figsize)
    fig.patch.set_facecolor('white')

    ax[0, 0].imshow(images[0], cmap='gray')
    ax[0, 0].set_axis_off()
    ax[0, 0].set_title(noisy_title, fontsize=fontsize)
    ax[0, 1].imshow(images[1], cmap='gray')
    ax[0, 1].set_axis_off()
    ax[0, 1].set_title(gt_title, fontsize=fontsize)
    ax[1, 0].imshow(images[2], cmap='gray')
    ax[1, 0].set_axis_off()
    ax[1, 0].set_title(estim_pois_title, fontsize=fontsize)
    ax[1, 1].imshow(images[3], cmap='gray')
    ax[1, 1].set_axis_off()
    ax[1, 1].set_title(estim_poisprox_title, fontsize=fontsize)
    ax[1, 2].imshow(images[4], cmap='gray')
    ax[1, 2].set_axis_off()
    ax[1, 2].set_title(estim_l2_title, fontsize=fontsize)
    ax[0, 2].remove()

    fig.tight_layout()
Exemplo n.º 6
0
def do_inference(s, c, exp_name, saved_epoch, model_type='pois', \
    img_ind=None, img_from_train_dataset=False, clear_ind=0, app=None,\
    path2valdata='./DATASETS/BSDS500/BSDS500_validation_MAXVALs_01_2/', prox_param=False, sharing_weights=True, do_VST=False):

    if model_type == 'pois':
        model = PoisNet(output_features=c,
                        stages=s,
                        prox_param=prox_param,
                        convWeightSharing=sharing_weights)
    elif model_type == 'l2':
        model = UDNet(output_features=c,
                      stages=s,
                      alpha=prox_param,
                      convWeightSharing=sharing_weights)

    path2dataset = path2valdata
    BSDSval = BSDS500(path2dataset+'val/', get_name=True, do_VST_4_visual=do_VST) \
        if not img_from_train_dataset \
        else BSDS500(path2dataset+'train/', get_name=True, do_VST_4_visual=do_VST)

    if img_ind is None:
        img_ind = 80

    if do_VST:
        gt, noisy, noisy_initial, file_name = BSDSval[img_ind]
    else:
        gt, noisy, file_name = BSDSval[img_ind]

    split = file_name.split('_')
    name, maxval = split[0], split[2][7:]
    maxval = (int(maxval[0]) * 10 + int(maxval[-1])) / 10

    gt.unsqueeze_(0)
    noisy.unsqueeze_(0)

    state = th.load('./PoisDenoiser/networks/PoisNet/models/'\
        +exp_name+'/state_{}.pth'.format(saved_epoch),\
                   map_location=lambda storage,loc:storage)

    model.load_state_dict(state['model_state_dict'])

    if model_type == 'pois':
        estim = model(noisy, noisy).detach()
    elif model_type == 'l2':
        stdn = th.Tensor([1])
        estim = model(noisy, stdn, noisy).detach()

    psnr_noisy = psnr(gt, noisy_initial) if do_VST else psnr(gt, noisy)
    psnr_est = psnr(
        gt, VST_backward_unbiased_exact(estim)) if do_VST else psnr(gt, estim)

    gt_title = 'clear ({} in BSDS val)'.format(img_ind)
    noisy_title = 'noisy (max val={}) \nPSNR: {:.2f} dB'.format(
        maxval, psnr_noisy)
    if app is not None:
        estim_title = '{} (epoch={}),\nPSNR: {:.3f} dB'.format(
            app, saved_epoch, psnr_est)
    else:
        estim_title = 'estim (epoch={}),\nPSNR: {:.3f} dB'.format(
            saved_epoch, psnr_est)

    if clear_ind == 0:
        show_images([gt, noisy, estim], [gt_title, noisy_title, estim_title])
    else:
        show_images([noisy, gt, estim], [noisy_title, gt_title, estim_title])
print('===> Building model')

Lmodel = nn.ModuleList()

params = OrderedDict(kernel_size=opt.kernel_size,input_channels=input_channels,\
         output_features=output_features,rbf_mixtures=opt.rbf_mixtures,\
         rbf_precision=opt.rbf_precision,stages=1,pad=opt.pad,\
         padType=opt.padType,convWeightSharing=opt.convWeightSharing,\
         scale_f=opt.scale_f,scale_t=opt.scale_t,normalizedWeights=\
         opt.normalizedWeights,zeroMeanWeights=opt.zeroMeanWeights,rbf_start=\
         opt.rbf_start,rbf_end=opt.rbf_end,data_min=opt.data_min,data_max=\
         opt.data_max,data_step=opt.data_step,alpha=opt.alpha,clb=opt.clb,\
         cub=opt.cub)

for i in range(opt.stages):
    Lmodel.append(UDNet(*params.values()))

tic()
for stage in range(opt.stages):
    training_data_loader = DataLoader(dataset=train_set,num_workers=opt.threads,\
                                batch_size=opt.batchSize,shuffle=True)
    testing_data_loader = DataLoader(dataset=test_set,num_workers=opt.threads,\
                                batch_size=opt.testBatchSize,shuffle=False)

    #criterion = nn.MSELoss(size_average=True,reduce=True)
    criterion = PSNRLoss(peakval=opt.cub)
    stagePath = os.path.join(dirPath, 'stage' + str(stage + 1))
    os.makedirs(stagePath, exist_ok=True)
    smodel = Lmodel[stage]
    optimizer = optim.Adam(smodel.parameters(),
                           lr=opt.lr,