def denoise(fname, plot=False, stopping_mode="AMNS"): start_time = datetime.now() torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True dtype = torch.cuda.FloatTensor dtype = torch.cuda.FloatTensor # dtype = torch.DoubleTensor imsize =-1 plot = False sigma = 25 sigma_ = sigma/255. OVERLAP = 16 patch_size = 128 patch_stride = 64 # Add synthetic noise orig_img_pil = crop_image(get_image(fname, imsize)[0], d=32) orig_img_np = pil_to_np(orig_img_pil) np.random.seed(7) orig_img_noisy_pil, orig_img_noisy_np = get_noisy_image(orig_img_np, sigma_) if plot: plot_image_grid([orig_img_np, orig_img_noisy_np], 4, 6); regions_n_y = orig_img_np.shape[1]//128 regions_n_x = orig_img_np.shape[2]//128 print("Splitting image of shape {} in ({}, {}) regions".format(orig_img_np.shape, regions_n_y, regions_n_x)) noisy_regions = get_regions(orig_img_noisy_np, regions_n_y, regions_n_x, OVERLAP) clean_regions = get_regions(orig_img_np, regions_n_y, regions_n_x, OVERLAP) denoised = [[var_to_np(denoise_region(noisy_region, clean_region)[0]) for noisy_region, clean_region in zip(noisy_row, clean_row)] for noisy_row, clean_row in zip(noisy_regions, clean_regions)] out = image_from_regions(denoised, OVERLAP) print("Patched PSNR: {:.4f}".format(compare_psnr(orig_img_np, out))) return out
if len(sys.argv) > 1: stopping_mode = sys.argv[1] IMAGES = ["data/denoising/" + image for image in [ 'house.png', # 'F16.png', # 'lena.png', # 'baboon.png', # 'kodim03.png', # 'kodim01.png', # 'peppers.png', # 'kodim02.png', # 'kodim12.png' ]] psnrs = [] for fname in IMAGES: img_np = pil_to_np(crop_image(get_image(fname, -1)[0], d=32)) run1 = denoise(fname, False, stopping_mode) run2 = denoise(fname, False, stopping_mode) psnr1, psnr2, psnr_avg = [compare_psnr(i, img_np) for i in [run1, run2, 0.5 * (run1 + run2)]] print("Run 1: {}\nRun 2: {}\n PSNR of Average: {}".format(psnr1, psnr2, psnr_avg)) psnrs.append(psnr_avg) print("Average PSNR over test set: {}".format(np.mean(psnrs)))
def main(img: int = 0, num_iter: int = 40000, lr: float = 3e-4, gpu: int = 0, seed: int = 42, save: bool = True): np.random.seed(seed) torch.manual_seed(seed) os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True dtype = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.FloatTensor global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers, \ img_mean, sample_count, recons, uncerts, uncerts_ale, loss_last, roll_back imsize = (256, 256) PLOT = True timestamp = int(time.time()) save_path = '/media/fastdata/laves/unsure' os.mkdir(f'{save_path}/{timestamp}') # denoising if img == 0: fname = '../NORMAL-4951060-8.jpeg' imsize = (256, 256) elif img == 1: fname = '../BACTERIA-1351146-0006.png' imsize = (256, 256) elif img == 2: fname = '../081_HC.png' imsize = (256, 256) elif img == 3: fname = '../CNV-9997680-30.png' imsize = (256, 256) else: assert False if fname == '../NORMAL-4951060-8.jpeg': # Add Gaussian noise to simulate speckle img_pil = crop_image(get_image(fname, imsize)[0], d=32) img_np = pil_to_np(img_pil) print(img_np.shape) p_sigma = 0.1 img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma) elif fname == '../BACTERIA-1351146-0006.png': # Add Poisson noise to simulate low dose X-ray img_pil = crop_image(get_image(fname, imsize)[0], d=32) img_np = pil_to_np(img_pil) print(img_np.shape) #p_lambda = 50.0 #img_noisy_pil, img_noisy_np = get_noisy_image_poisson(img_np, p_lambda) # for lam > 20, poisson can be approximated with Gaussian p_sigma = 0.1 img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma) elif fname == '../081_HC.png': # Add Gaussian noise to simulate speckle img_pil = crop_image(get_image(fname, imsize)[0], d=32) img_np = pil_to_np(img_pil) print(img_np.shape) p_sigma = 0.1 img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma) elif fname == '../CNV-9997680-30.png': # Add Gaussian noise to simulate speckle img_pil = crop_image(get_image(fname, imsize)[0], d=32) img_np = pil_to_np(img_pil) print(img_np.shape) p_sigma = 0.1 img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma) else: assert False if PLOT: q = plot_image_grid([img_np, img_noisy_np], 4, 6) out_pil = np_to_pil(q) out_pil.save(f'{save_path}/{timestamp}/input.png', 'PNG') INPUT = 'noise' pad = 'reflection' OPT_OVER = 'net' # 'net,input' reg_noise_std = 1. / 10. LR = lr roll_back = False # To solve the oscillation of model training input_depth = 32 show_every = 100 exp_weight = 0.99 mse = torch.nn.MSELoss() img_noisy_torch = np_to_torch(img_noisy_np).type(dtype) LOSSES = {} RECONS = {} UNCERTS = {} UNCERTS_ALE = {} PSNRS = {} SSIMS = {} # # SGD OPTIMIZER = 'adamw' weight_decay = 0 LOSS = 'mse' figsize = 4 NET_TYPE = 'skip' skip_n33d = 128 skip_n33u = 128 skip_n11 = 4 num_scales = 5 upsample_mode = 'bilinear' dropout_mode_down = 'None' dropout_p_down = 0.0 dropout_mode_up = 'None' dropout_p_up = dropout_p_down dropout_mode_skip = 'None' dropout_p_skip = dropout_p_down dropout_mode_output = 'None' dropout_p_output = dropout_p_down net_input = get_noise( input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() out_avg = None last_net = None mc_iter = 1 def closure_dip(): global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers,\ img_mean, sample_count, recons, uncerts, loss_last if reg_noise_std > 0: net_input = net_input_saved + (noise.normal_() * reg_noise_std) out = net(net_input) out[:, :1] = out[:, :1].sigmoid() _loss = mse(out[:, :1], img_noisy_torch) _loss.backward() # Smoothing if out_avg is None: out_avg = out.detach() else: out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) losses.append(mse(out_avg[:, :1], img_noisy_torch).item()) _out = out.detach().cpu().numpy()[0, :1] _out_avg = out_avg.detach().cpu().numpy()[0, :1] psnr_noisy = compare_psnr(img_noisy_np, _out) psnr_gt = compare_psnr(img_np, _out) psnr_gt_sm = compare_psnr(img_np, _out_avg) ssim_noisy = compare_ssim(img_noisy_np[0], _out[0]) ssim_gt = compare_ssim(img_np[0], _out[0]) ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0]) psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm]) ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm]) if PLOT and i % show_every == 0: print( f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}' ) out_np = _out psnr_noisy = compare_psnr(img_noisy_np, out_np) psnr_gt = compare_psnr(img_np, out_np) if sample_count != 0: psnr_mean = compare_psnr(img_np, img_mean / sample_count) else: psnr_mean = 0 print('###################') recons.append(out_np) i += 1 return _loss if '../NORMAL-4951060-8.jpeg': net = get_net(input_depth, NET_TYPE, pad, skip_n33d=skip_n33d, skip_n33u=skip_n33u, skip_n11=skip_n11, num_scales=num_scales, n_channels=1, upsample_mode=upsample_mode, dropout_mode_down=dropout_mode_down, dropout_p_down=dropout_p_down, dropout_mode_up=dropout_mode_up, dropout_p_up=dropout_p_up, dropout_mode_skip=dropout_mode_skip, dropout_p_skip=dropout_p_skip, dropout_mode_output=dropout_mode_output, dropout_p_output=dropout_p_output).type(dtype) else: assert False net.apply(init_normal) losses = [] recons = [] uncerts = [] uncerts_ale = [] psnrs = [] ssims = [] img_mean = 0 sample_count = 0 i = 0 psnr_noisy_last = 0 loss_last = 1e16 parameters = get_params(OPT_OVER, net, net_input) out_avg = None optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay) optimize(optimizer, closure_dip, num_iter) LOSSES['dip'] = losses RECONS['dip'] = recons UNCERTS['dip'] = uncerts UNCERTS_ALE['dip'] = uncerts_ale PSNRS['dip'] = psnrs SSIMS['dip'] = ssims to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['dip']] q = plot_image_grid(to_plot, factor=13) out_pil = np_to_pil(q) out_pil.save(f'{save_path}/{timestamp}/dip_recons.png', 'PNG') ## SGLD weight_decay = 1e-4 LOSS = 'mse' input_depth = 32 param_noise_sigma = 2 NET_TYPE = 'skip' skip_n33d = 128 skip_n33u = 128 skip_n11 = 4 num_scales = 5 upsample_mode = 'bilinear' dropout_mode_down = 'None' dropout_p_down = 0.0 dropout_mode_up = 'None' dropout_p_up = dropout_p_down dropout_mode_skip = 'None' dropout_p_skip = dropout_p_down dropout_mode_output = 'None' dropout_p_output = dropout_p_down net_input = get_noise( input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() mc_iter = 25 def add_noise(model): for n in [x for x in model.parameters() if len(x.size()) == 4]: noise = torch.randn(n.size()) * param_noise_sigma * LR noise = noise.type(dtype) n.data = n.data + noise def closure_sgld(): global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers, img_mean, sample_count, recons, uncerts, loss_last add_noise(net) if reg_noise_std > 0: net_input = net_input_saved + (noise.normal_() * reg_noise_std) out = net(net_input) out[:, :1] = out[:, :1].sigmoid() _loss = mse(out[:, :1], img_noisy_torch) _loss.backward() # Smoothing if out_avg is None: out_avg = out.detach() else: out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) losses.append(mse(out_avg[:, :1], img_noisy_torch).item()) _out = out.detach().cpu().numpy()[0, :1] _out_avg = out_avg.detach().cpu().numpy()[0, :1] psnr_noisy = compare_psnr(img_noisy_np, _out) psnr_gt = compare_psnr(img_np, _out) psnr_gt_sm = compare_psnr(img_np, _out_avg) ssim_noisy = compare_ssim(img_noisy_np[0], _out[0]) ssim_gt = compare_ssim(img_np[0], _out[0]) ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0]) psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm]) ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm]) if PLOT and i % show_every == 0: print( f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}' ) out_np = _out recons.append(out_np) out_np_var = np.var(np.array(recons[-mc_iter:]), axis=0)[:1] print('mean epi', out_np_var.mean()) print('###################') uncerts.append(out_np_var) i += 1 return _loss if '../NORMAL-4951060-8.jpeg': net = get_net(input_depth, NET_TYPE, pad, skip_n33d=skip_n33d, skip_n33u=skip_n33u, skip_n11=skip_n11, num_scales=num_scales, n_channels=1, upsample_mode=upsample_mode, dropout_mode_down=dropout_mode_down, dropout_p_down=dropout_p_down, dropout_mode_up=dropout_mode_up, dropout_p_up=dropout_p_up, dropout_mode_skip=dropout_mode_skip, dropout_p_skip=dropout_p_skip, dropout_mode_output=dropout_mode_output, dropout_p_output=dropout_p_output).type(dtype) else: assert False net.apply(init_normal) losses = [] recons = [] uncerts = [] uncerts_ale = [] psnrs = [] ssims = [] img_mean = 0 sample_count = 0 i = 0 psnr_noisy_last = 0 loss_last = 1e10 out_avg = None last_net = None parameters = get_params(OPT_OVER, net, net_input) optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay) optimize(optimizer, closure_sgld, num_iter) LOSSES['sgld'] = losses RECONS['sgld'] = recons UNCERTS['sgld'] = uncerts UNCERTS_ALE['sgld'] = uncerts_ale PSNRS['sgld'] = psnrs SSIMS['sgld'] = ssims to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['sgld']] q = plot_image_grid(to_plot, factor=13) out_pil = np_to_pil(q) out_pil.save(f'{save_path}/{timestamp}/sgld_recons.png', 'PNG') errs = img_noisy_torch.cpu() - torch.tensor(RECONS['sgld'][-1]) uncerts_epi = torch.tensor(UNCERTS['sgld'][-1]).unsqueeze(0) uncerts = uncerts_epi uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21) fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001) ax.set_title( f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}') plt.tight_layout() fig.savefig(f'{save_path}/{timestamp}/sgld_calib.png') ## SGLD + NLL LOSS = 'nll' net_input = get_noise( input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() def closure_sgldnll(): global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers,\ img_mean, sample_count, recons, uncerts, uncerts_ale, loss_last add_noise(net) if reg_noise_std > 0: net_input = net_input_saved + (noise.normal_() * reg_noise_std) out = net(net_input) out[:, :1] = out[:, :1].sigmoid() _loss = gaussian_nll(out[:, :1], out[:, 1:], img_noisy_torch) _loss.backward() out[:, 1:] = torch.exp(-out[:, 1:]) # aleatoric uncertainty # Smoothing if out_avg is None: out_avg = out.detach() else: out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) with torch.no_grad(): mse_loss = mse(out_avg[:, :1], img_noisy_torch).item() losses.append(mse_loss) _out = out.detach().cpu().numpy()[0, :1] _out_avg = out_avg.detach().cpu().numpy()[0, :1] psnr_noisy = compare_psnr(img_noisy_np, _out) psnr_gt = compare_psnr(img_np, _out) psnr_gt_sm = compare_psnr(img_np, _out_avg) ssim_noisy = compare_ssim(img_noisy_np[0], _out[0]) ssim_gt = compare_ssim(img_np[0], _out[0]) ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0]) psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm]) ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm]) if PLOT and i % show_every == 0: print( f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}' ) out_np = _out recons.append(out_np) out_np_ale = out.detach().cpu().numpy()[0, 1:] out_np_var = np.var(np.array(recons[-mc_iter:]), axis=0)[:1] print('mean epi', out_np_var.mean()) print('mean ale', out_np_ale.mean()) print('###################') uncerts.append(out_np_var) uncerts_ale.append(out_np_ale) i += 1 return _loss if '../NORMAL-4951060-8.jpeg': net = get_net(input_depth, NET_TYPE, pad, skip_n33d=skip_n33d, skip_n33u=skip_n33u, skip_n11=skip_n11, num_scales=num_scales, n_channels=2, upsample_mode=upsample_mode, dropout_mode_down=dropout_mode_down, dropout_p_down=dropout_p_down, dropout_mode_up=dropout_mode_up, dropout_p_up=dropout_p_up, dropout_mode_skip=dropout_mode_skip, dropout_p_skip=dropout_p_skip, dropout_mode_output=dropout_mode_output, dropout_p_output=dropout_p_output).type(dtype) else: assert False net.apply(init_normal) losses = [] recons = [] uncerts = [] uncerts_ale = [] psnrs = [] ssims = [] img_mean = 0 sample_count = 0 i = 0 psnr_noisy_last = 0 loss_last = 1e6 out_avg = None last_net = None parameters = get_params(OPT_OVER, net, net_input) optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay) optimize(optimizer, closure_sgldnll, num_iter) LOSSES['sgldnll'] = losses RECONS['sgldnll'] = recons UNCERTS['sgldnll'] = uncerts UNCERTS_ALE['sgldnll'] = uncerts_ale PSNRS['sgldnll'] = psnrs SSIMS['sgldnll'] = ssims to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['sgldnll']] q = plot_image_grid(to_plot, factor=13) out_pil = np_to_pil(q) out_pil.save(f'{save_path}/{timestamp}/sgldnll_recons.png', 'PNG') errs = img_noisy_torch.cpu() - torch.tensor(RECONS['sgldnll'][-1]) uncerts_epi = torch.tensor(UNCERTS['sgldnll'][-1]).unsqueeze(0) uncerts_ale = torch.tensor(UNCERTS_ALE['sgldnll'][-1]).unsqueeze(0) uncerts = uncerts_epi + uncerts_ale uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21) fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001) ax.set_title( f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}') plt.tight_layout() fig.savefig(f'{save_path}/{timestamp}/sgldnll_calib.png') errs = torch.tensor(img_np).unsqueeze(0) - torch.tensor( RECONS['sgldnll'][-1]) uncerts_epi = torch.tensor(UNCERTS['sgldnll'][-1]).unsqueeze(0) uncerts_ale = torch.tensor(UNCERTS_ALE['sgldnll'][-1]).unsqueeze(0) uncerts = uncerts_epi + uncerts_ale uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21) fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001) ax.set_title( f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}') plt.tight_layout() fig.savefig(f'{save_path}/{timestamp}/sgldnll_calib2.png') ## MCDIP OPTIMIZER = 'adamw' weight_decay = 1e-4 LOSS = 'nll' input_depth = 32 figsize = 4 NET_TYPE = 'skip' skip_n33d = 128 skip_n33u = 128 skip_n11 = 4 num_scales = 5 upsample_mode = 'bilinear' dropout_mode_down = '2d' dropout_p_down = 0.3 dropout_mode_up = '2d' dropout_p_up = dropout_p_down dropout_mode_skip = 'None' dropout_p_skip = dropout_p_down dropout_mode_output = 'None' dropout_p_output = dropout_p_down net_input = get_noise( input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() mc_iter = 25 def closure_mcdip(): global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers,\ img_mean, sample_count, recons, uncerts, uncerts_ale, loss_last if reg_noise_std > 0: net_input = net_input_saved + (noise.normal_() * reg_noise_std) out = net(net_input) out[:, :1] = out[:, :1].sigmoid() _loss = gaussian_nll(out[:, :1], out[:, 1:], img_noisy_torch) _loss.backward() out[:, 1:] = torch.exp(-out[:, 1:]) # aleatoric uncertainty # Smoothing if out_avg is None: out_avg = out.detach() else: out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) losses.append(mse(out_avg[:, :1], img_noisy_torch).item()) _out = out.detach().cpu().numpy()[0, :1] _out_avg = out_avg.detach().cpu().numpy()[0, :1] psnr_noisy = compare_psnr(img_noisy_np, _out) psnr_gt = compare_psnr(img_np, _out) psnr_gt_sm = compare_psnr(img_np, _out_avg) ssim_noisy = compare_ssim(img_noisy_np[0], _out[0]) ssim_gt = compare_ssim(img_np[0], _out[0]) ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0]) psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm]) ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm]) if PLOT and i % show_every == 0: print( f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}' ) img_list = [] aleatoric_list = [] with torch.no_grad(): net_input = net_input_saved + (noise.normal_() * reg_noise_std) for _ in range(mc_iter): img = net(net_input) img[:, :1] = img[:, :1].sigmoid() img[:, 1:] = torch.exp(-img[:, 1:]) img_list.append(torch_to_np(img[:1])) aleatoric_list.append(torch_to_np(img[:, 1:])) img_list_np = np.array(img_list) out_np = np.mean(img_list_np, axis=0)[:1] out_np_ale = np.mean(aleatoric_list, axis=0)[:1] out_np_var = np.var(img_list_np, axis=0)[:1] psnr_noisy = compare_psnr(img_noisy_np, out_np) psnr_gt = compare_psnr(img_np, out_np) print('mean epi', out_np_var.mean()) print('mean ale', out_np_ale.mean()) print('###################') recons.append(out_np) uncerts.append(out_np_var) uncerts_ale.append(out_np_ale) i += 1 return _loss if '../NORMAL-4951060-8.jpeg': net = get_net(input_depth, NET_TYPE, pad, skip_n33d=skip_n33d, skip_n33u=skip_n33u, skip_n11=skip_n11, num_scales=num_scales, n_channels=2, upsample_mode=upsample_mode, dropout_mode_down=dropout_mode_down, dropout_p_down=dropout_p_down, dropout_mode_up=dropout_mode_up, dropout_p_up=dropout_p_up, dropout_mode_skip=dropout_mode_skip, dropout_p_skip=dropout_p_skip, dropout_mode_output=dropout_mode_output, dropout_p_output=dropout_p_output).type(dtype) else: assert False net.apply(init_normal) losses = [] recons = [] uncerts = [] uncerts_ale = [] psnrs = [] ssims = [] img_mean = 0 sample_count = 0 i = 0 psnr_noisy_last = 0 loss_last = 1e16 out_avg = None last_net = None parameters = get_params(OPT_OVER, net, net_input) optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay) optimize(optimizer, closure_mcdip, num_iter) LOSSES['mcdip'] = losses RECONS['mcdip'] = recons UNCERTS['mcdip'] = uncerts UNCERTS_ALE['mcdip'] = uncerts_ale PSNRS['mcdip'] = psnrs SSIMS['mcdip'] = ssims # In[75]: to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['mcdip']] q = plot_image_grid(to_plot, factor=13) out_pil = np_to_pil(q) out_pil.save(f'{save_path}/{timestamp}/mcdip_recons.png', 'PNG') # In[85]: errs = img_noisy_torch.cpu() - torch.tensor(RECONS['mcdip'][-1]) uncerts_epi = torch.tensor(UNCERTS['mcdip'][-1]).unsqueeze(0) uncerts_ale = torch.tensor(UNCERTS_ALE['mcdip'][-1]).unsqueeze(0) uncerts = uncerts_epi + uncerts_ale uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21) fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001) ax.set_title( f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}') plt.tight_layout() fig.savefig(f'{save_path}/{timestamp}/mcdip_calib.png') # In[86]: errs = torch.tensor(img_np).unsqueeze(0) - torch.tensor( RECONS['mcdip'][-1]) uncerts_epi = torch.tensor(UNCERTS['mcdip'][-1]).unsqueeze(0) uncerts_ale = torch.tensor(UNCERTS_ALE['mcdip'][-1]).unsqueeze(0) uncerts = uncerts_epi + uncerts_ale uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21) fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001) ax.set_title( f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}') plt.tight_layout() fig.savefig(f'{save_path}/{timestamp}/mcdip_calib2.png') fig, ax0 = plt.subplots(1, 1) for key, loss in LOSSES.items(): ax0.plot(range(len(loss)), loss, label=key) ax0.set_title('MSE') ax0.set_xlabel('iteration') ax0.set_ylabel('mse loss') ax0.set_ylim(0, 0.03) ax0.grid(True) ax0.legend() plt.tight_layout() plt.savefig(f'{save_path}/{timestamp}/losses.png') plt.show() fig, axs = plt.subplots(1, 3, constrained_layout=True) labels = ["psnr_noisy", "psnr_gt", "psnr_gt_sm"] for key, psnr in PSNRS.items(): psnr = np.array(psnr) for i in range(psnr.shape[1]): axs[i].plot(range(psnr.shape[0]), psnr[:, i], label=key) axs[i].set_title(labels[i]) axs[i].set_xlabel('iteration') axs[i].set_ylabel('psnr') axs[i].legend() plt.savefig(f'{save_path}/{timestamp}/psnrs.png') plt.show() fig, axs = plt.subplots(1, 3, constrained_layout=True) labels = ["ssim_noisy", "ssim_gt", "ssim_gt_sm"] for key, ssim in SSIMS.items(): ssim = np.array(ssim) for i in range(ssim.shape[1]): axs[i].plot(range(ssim.shape[0]), ssim[:, i], label=key) axs[i].set_title(labels[i]) axs[i].set_xlabel('iteration') axs[i].legend() axs[i].set_ylabel('ssim') plt.savefig(f'{save_path}/{timestamp}/ssims.png') plt.show() # save stuff for plotting if save: np.savez(f"{save_path}/{timestamp}/save.npz", noisy_img=img_noisy_np, losses=LOSSES, recons=RECONS, uncerts=UNCERTS, uncerts_ale=UNCERTS_ALE, psnrs=PSNRS, ssims=SSIMS)
def denoise(fname, plot=False, stopping_mode="AMNS"): """Add AWGN with sigma=25 to the given image and denoise it. Args: fname: Path to the image. mode: Stopping mode to use. either "AMNS" or "SMNS" Returns: A tuple with the denoised image in numpy format as the first element, and a history of the PSNR in the second element. """ dtype = torch.cuda.FloatTensor sigma = 25 sigma_ = sigma / 255. np.random.seed(7) img_pil = crop_image(get_image(fname, imsize)[0], d=32) img_np = pil_to_np(img_pil) img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_) if plot: plot_image_grid([img_np, img_noisy_np], 4, 6) INPUT = 'noise' # 'meshgrid' pad = 'reflection' OPT_OVER = 'net' # 'net,input' reg_noise_std = 1. / 30. # set to 1./20. for sigma=50 if stopping_mode == "AMNS": target_method_noise_std = predict_method_noise_std( orig_img_noisy_np, sigma / 255) * 255 elif stopping_mode == "SMNS": target_method_noise_std = 24.45 else: raise ValueError("Unknown stopping mode {}".format(stopping_mode)) print("Predicted method noise std: {:.4f}".format(target_method_noise_std)) LR = 0.01 exp_weight = 0.99 # Exponential averaging coefficient OPTIMIZER = 'adam' # 'LBFGS' show_every = 500 num_iter = 4000 input_depth = 32 figsize = 4 net = get_net(input_depth, 'skip', pad, skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, upsample_mode='bilinear').type(dtype) # net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() net_input = get_noise( input_depth, INPUT, (img_np.shape[1], img_np.shape[2])).type(dtype).detach() # Compute number of parameters s = sum([np.prod(list(p.size())) for p in net.parameters()]) print('Number of params: %d' % s) # Loss mse = torch.nn.MSELoss().type(dtype) img_noisy_torch = np_to_torch(img_noisy_np).type(dtype) net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() out_avg = None max_out = None max_psnr = 0 last_net = None psrn_noisy_last = 0 i = 0 psnr_history = [] overfit_counter = -75 def closure(): nonlocal i, out_avg, psrn_noisy_last, last_net, psnr_history, overfit_counter, max_out, max_psnr if reg_noise_std > 0: net_input = net_input_saved + (noise.normal_() * reg_noise_std) out = net(net_input) # Smoothing if exp_weight is not None: if out_avg is None: out_avg = out.detach() else: out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) total_loss = mse(out, img_noisy_torch) total_loss.backward() psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) psrn_gt = compare_psnr(img_np, out.detach().cpu().numpy()[0]) psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) method_noise_mse = np.sqrt( compare_mse( img_noisy_np - out_avg.detach().cpu().type(torch.FloatTensor).numpy()[0], np.zeros(img_np.shape, dtype=np.float32)) * 255**2) psnr_history.append((psrn_gt_sm, method_noise_mse)) print( 'Iteration %05d Loss %f PSNR_noisy: %f PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\r', end='') if plot and i % show_every == 0: out_np = torch_to_np(out) plot_image_grid( [np.clip(out_np, 0, 1), np.clip(torch_to_np(out_avg), 0, 1)], factor=figsize, nrow=1) if method_noise_mse < 24.45: overfit_counter += 1 if overfit_counter == 0: raise StopIteration() if psrn_gt_sm > max_psnr: max_out = out_avg max_psnr = psrn_gt_sm # Backtracking if i % show_every: if psrn_noisy - psrn_noisy_last < -5: print('Falling back to previous checkpoint.') for new_param, net_param in zip(last_net, net.parameters()): net_param.data.copy_(new_param.cuda()) return total_loss * 0 else: last_net = [x.data.cpu() for x in net.parameters()] psrn_noisy_last = psrn_noisy i += 1 return total_loss p = get_params(OPT_OVER, net, net_input) try: optimize(OPTIMIZER, p, closure, LR, num_iter) except StopIteration: pass return out_avg, psnr_history
def denoise(fname, plot=False, stopping_mode="AMNS"): """Add AWGN with sigma=25 to the given image and denoise it. Args: fname: Path to the image. mode: Stopping mode to use. either "AMNS", "SMNS", or "static". Returns: A tuple with the denoised image in numpy format as the first element, and a history of the PSNR in the second element. """ start_time = datetime.now() dtype = torch.cuda.FloatTensor imsize = -1 sigma = 25 sigma_ = sigma/255. MAX_LEVEL = 5 # Fix the random seed to that the noisy image are identical everytime. # This is mandatory as in the report we compute PSNR by averaging two runs. # If the runs used images corrupted with different noise, the noise would # cancel out, artificially inflating PSNR. np.random.seed(7) orig_img_pil = crop_image(get_image(fname, imsize)[0], d=32) orig_img_np = pil_to_np(orig_img_pil) # Comment the following line and uncomment the next when ground truth is not known orig_img_noisy_pil, orig_img_noisy_np = get_noisy_image(orig_img_np, sigma_) # orig_img_noisy_pil, orig_img_noisy_np = orig_img_pil, orig_img_np if plot: plot_image_grid([orig_img_noisy_np], 4, 5) # # Set up parameters and net input_depth = 3 if "snail" in fname else 32 INPUT = 'noise' OPT_OVER = 'net' KERNEL_TYPE = 'lanczos2' tv_weight = 0.0 OPTIMIZER = 'adam' LR = 0.01 weight_decay = 0.0 show_every = 100 RAMPUP_DURATION = 70 figsize = 10 def get_reg_noise_std(sigma): return sigma/(60*25)+1/60 reg_noise_std = get_reg_noise_std(sigma) if stopping_mode == "AMNS": target_method_noise_std = predict_method_noise_std(orig_img_noisy_np, sigma/255) * 255 elif stopping_mode == "SMNS": target_method_noise_std = 24.45 elif stopping_mode is None: target_method_noise_std = -1 print("Target method noise: {:.4f}".format(target_method_noise_std)) def get_phase_duration(level, phase): return 2 if level <= MAX_LEVEL - 1: if phase == 'trans': return 70 elif phase == 'stab': return 50 else: if phase == 'trans': return 100 elif phase == 'stab': # Use a fixed number of iterations when no stopping_mode # otherwise use arbitrarily large number such that early stopping # kicks in before it is reached. return 650 if stopping_mode is None else 2000000 net = SkipNetwork( input_channels=input_depth, skip_channels=[4, 4, 4, 4, 4], down_channels=[128, 128, 128, 128, 128], norm_fun="BatchNorm" ).type(dtype) exp_weight = 0.99 mse = torch.nn.MSELoss().type(dtype) s = sum(np.prod(list(p.size())) for p in net.parameters()) print('Number of params for D: %d' % s) last_out = None out_avg = None psrn_noisy_last = 0 psnr_history = [] overfit_counter = -25 def closure(i, j, max_iter, cur_level, phase, image_target): """Innermost loop of the optimization procedure.""" nonlocal out_avg, psrn_noisy_last, psnr_history, overfit_counter # note: j and max_iter are relative to the current phase # i counts the total number of iterations since the beginning of execution if reg_noise_std > 0: # Adapt regularization noise amplitude to current level. # It seems like at lower resolutions, reg. noise is too strong when using values from DIP net_input.data = net_input_saved + (noise.normal_() * (reg_noise_std * 10**(-(MAX_LEVEL - cur_level)))) out = net(net_input) # If at last level, start computing exponential moving average if exp_weight is not None and cur_level == MAX_LEVEL: if out_avg is None: out_avg = out.detach() else: out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) if cur_level == MAX_LEVEL: # Measure PSNR psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) psrn_gt = compare_psnr(orig_img_np, out.detach().cpu().numpy()[0]) psrn_gt_avg = compare_psnr(orig_img_np, out_avg.detach().cpu().numpy()[0]) method_noise_mse = np.sqrt(compare_mse(orig_img_noisy_np - out_avg.detach().cpu().type(torch.FloatTensor).numpy()[0], np.zeros(orig_img_np.shape, dtype=np.float32))*255**2) if method_noise_mse < target_method_noise_std: overfit_counter += 1 if overfit_counter == 0: raise StopIteration() psnr_history.append((psrn_gt_avg, method_noise_mse)) if plot and (i % show_every == 0 or j == max_iter - 1): print("i:{} j:{}/{} phase:{}\n".format(i, j+1, max_iter, phase)) out_np = var_to_np(out) img = np.clip(out_np, 0, 1) plot_image_grid([img], factor=figsize, nrow=2) total_loss = mse(out, image_target) total_loss.backward() if cur_level == MAX_LEVEL: print('Iteration %05d Loss %f Noise_stddev %f PSNR_noisy: %f PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), method_noise_mse, psrn_noisy, psrn_gt, psrn_gt_avg), '\r', end='') else: print('Iteration %05d Loss %f' % (i, total_loss.item()), '\r', end='') if i%100 == 0: print("") return total_loss, out i = 0 # Global iteration count # Init fixed random code vector. orig_noise = get_noise( input_depth, INPUT, ( int(orig_img_pil.size[1]), int(orig_img_pil.size[0]) ) ).type(dtype).detach() img_noisy_np = None # Iterate over each resolution level for cur_level in range(1, MAX_LEVEL + 1): net.grow() net = net.type(dtype) print("Increased network size") s = sum(np.prod(list(p.size())) for p in net.parameters()) print('Number of params: %d' % s) # Downsample z if cur_level != MAX_LEVEL: net_input = nn.AvgPool2d(kernel_size=2**(MAX_LEVEL - cur_level))(orig_noise) else: net_input = orig_noise # Save a copy of z as z will be perturbed by normal noise at each iteration net_input_saved = net_input.data.clone() img_noisy_pil = orig_img_noisy_pil.resize( ( orig_img_noisy_pil.size[0] // (2**(MAX_LEVEL - cur_level)), orig_img_noisy_pil.size[1] // (2**(MAX_LEVEL - cur_level)) ), Image.ANTIALIAS ) # Save the downsampled noisy image from the previous level for when we need to interpolate it # with the current resolution if cur_level != 1: prev_img_noisy_np = img_noisy_np img_noisy_np = pil_to_np(img_noisy_pil) img_noisy_var = np_to_var(img_noisy_np).type(dtype) noise = net_input.data.clone() # Skip transition and stabilization phases if we're at first level for phase in ["trans", "stab"] if cur_level != 1 else ["stab"]: # Re-create a new optimizer after each flush()/grow() calls, as we need to let the optimizer know about # New or removed model parameters. optimizer = torch.optim.Adam(net.parameters(), lr=LR, weight_decay=weight_decay) # Flush the network before starting the stabilization phase if phase == "stab": net.flush() net = net.type(dtype) for j in range(get_phase_duration(cur_level, phase)): # Increase alpha smoothly from 0 at the first iteration to 1 at the last iteration alpha = min(j / RAMPUP_DURATION, 1.0) LR_rampup = np.sin((alpha + 1.5) * np.pi)/2 + 0.5 set_lr(optimizer, LR*LR_rampup) if phase == "trans": set_lr(optimizer, LR*LR_rampup) net.update_alpha(alpha) img_noisy_var = np_to_var(interpolate_lr(img_noisy_np, prev_img_noisy_np, alpha)).type(dtype) optimizer.zero_grad() try: _, last_out = closure(i, j, get_phase_duration(cur_level, phase), cur_level, phase, img_noisy_var) except StopIteration: break i += 1 optimizer.step() print("finished, time: {}".format(datetime.now() - start_time)) return out_avg, psnr_history