def define_G(opt): opt_net = opt['netG'] net_type = opt_net['net_type'] # ---------------------------------------- # denoising task # ---------------------------------------- # ---------------------------------------- # DnCNN # ---------------------------------------- if net_type == 'dncnn': from models.network_dncnn import DnCNN as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], # total number of conv layers act_mode=opt_net['act_mode']) # ---------------------------------------- # Flexible DnCNN # ---------------------------------------- elif net_type == 'fdncnn': from models.network_dncnn import FDnCNN as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], # total number of conv layers act_mode=opt_net['act_mode']) # ---------------------------------------- # FFDNet # ---------------------------------------- elif net_type == 'ffdnet': from models.network_ffdnet import FFDNet as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], act_mode=opt_net['act_mode']) # ---------------------------------------- # others # ---------------------------------------- # ---------------------------------------- # super-resolution task # ---------------------------------------- # ---------------------------------------- # SRMD # ---------------------------------------- elif net_type == 'srmd': from models.network_srmd import SRMD as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # super-resolver prior of DPSR # ---------------------------------------- elif net_type == 'dpsr': from models.network_dpsr import MSRResNet_prior as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # modified SRResNet v0.0 # ---------------------------------------- elif net_type == 'msrresnet0': from models.network_msrresnet import MSRResNet0 as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # modified SRResNet v0.1 # ---------------------------------------- elif net_type == 'msrresnet1': from models.network_msrresnet import MSRResNet1 as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # RRDB # ---------------------------------------- elif net_type == 'rrdb': # RRDB from models.network_rrdb import RRDB as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # IMDB # ---------------------------------------- elif net_type == 'imdn': # IMDB from models.network_imdn import IMDN as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # USRNet # ---------------------------------------- elif net_type == 'usrnet': # USRNet from models.network_usrnet import USRNet as net netG = net(n_iter=opt_net['n_iter'], h_nc=opt_net['h_nc'], in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], act_mode=opt_net['act_mode'], downsample_mode=opt_net['downsample_mode'], upsample_mode=opt_net['upsample_mode'] ) # ---------------------------------------- # others # ---------------------------------------- # TODO else: raise NotImplementedError('netG [{:s}] is not found.'.format(net_type)) # ---------------------------------------- # initialize weights # ---------------------------------------- if opt['is_train']: init_weights(netG, init_type=opt_net['init_type'], init_bn_type=opt_net['init_bn_type'], gain=opt_net['init_gain']) return netG
def main(): # ---------------------------------------- # Preparation # ---------------------------------------- noise_level_img = 15 # noise level for noisy image noise_level_model = noise_level_img # noise level for model model_name = 'fdncnn_gray' # 'fdncnn_gray' | 'fdncnn_color' | 'fdncnn_color_clip' | 'fdncnn_gray_clip' testset_name = 'bsd68' # test set, 'bsd68' | 'cbsd68' | 'set12' need_degradation = True # default: True x8 = False # default: False, x8 to boost performance show_img = False # default: Falsedefault: False task_current = 'dn' # 'dn' for denoising | 'sr' for super-resolution sf = 1 # unused for denoising if 'color' in model_name: n_channels = 3 # 3 for color image else: n_channels = 1 # 1 for grayscale image if 'clip' in model_name: use_clip = True # clip the intensities into range of [0, 1] else: use_clip = False model_pool = 'model_zoo' # fixed testsets = 'testsets' # fixed results = 'results' # fixed result_name = testset_name + '_' + model_name border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM model_path = os.path.join(model_pool, model_name + '.pth') # ---------------------------------------- # L_path, E_path, H_path # ---------------------------------------- L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images H_path = L_path # H_path, for High-quality images E_path = os.path.join(results, result_name) # E_path, for Estimated images util.mkdir(E_path) if H_path == L_path: need_degradation = True logger_name = result_name utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name + '.log')) logger = logging.getLogger(logger_name) need_H = True if H_path is not None else False device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ---------------------------------------- # load model # ---------------------------------------- from models.network_dncnn import FDnCNN as net model = net(in_nc=n_channels + 1, out_nc=n_channels, nc=64, nb=20, act_mode='R') model.load_state_dict(torch.load(model_path), strict=True) model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) logger.info('Model path: {:s}'.format(model_path)) number_parameters = sum(map(lambda x: x.numel(), model.parameters())) logger.info('Params number: {}'.format(number_parameters)) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format( model_name, noise_level_img, noise_level_model)) logger.info(L_path) L_paths = util.get_image_paths(L_path) H_paths = util.get_image_paths(H_path) if need_H else None for idx, img in enumerate(L_paths): # ------------------------------------ # (1) img_L # ------------------------------------ img_name, ext = os.path.splitext(os.path.basename(img)) # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) img_L = util.imread_uint(img, n_channels=n_channels) img_L = util.uint2single(img_L) if need_degradation: # degradation process np.random.seed(seed=0) # for reproducibility img_L += np.random.normal(0, noise_level_img / 255., img_L.shape) if use_clip: img_L = util.uint2single(util.single2uint(img_L)) util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format( noise_level_img)) if show_img else None img_L = util.single2tensor4(img_L) noise_level_map = torch.ones( (1, 1, img_L.size(2), img_L.size(3)), dtype=torch.float).mul_(noise_level_model / 255.) img_L = torch.cat((img_L, noise_level_map), dim=1) img_L = img_L.to(device) # ------------------------------------ # (2) img_E # ------------------------------------ if not x8: img_E = model(img_L) else: img_E = utils_model.test_mode(model, img_L, mode=3) img_E = util.tensor2uint(img_E) if need_H: # -------------------------------- # (3) img_H # -------------------------------- img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) img_H = img_H.squeeze() # -------------------------------- # PSNR and SSIM # -------------------------------- psnr = util.calculate_psnr(img_E, img_H, border=border) ssim = util.calculate_ssim(img_E, img_H, border=border) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format( img_name + ext, psnr, ssim)) util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None # ------------------------------------ # save results # ------------------------------------ util.imsave(img_E, os.path.join(E_path, img_name + ext)) if need_H: ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) logger.info( 'Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'. format(result_name, ave_psnr, ave_ssim))
def denoising(noise_im, clean_im, LR=1e-2, sigma=5, rho=1, eta=0.5, total_step=20, prob1_iter=500, result_root=None, f=None): input_depth = 3 latent_dim = 3 en_net = Encoder(input_depth, latent_dim, down_sample_norm='batchnorm', up_sample_norm='batchnorm').cuda() de_net = Decoder(latent_dim, input_depth, down_sample_norm='batchnorm', up_sample_norm='batchnorm').cuda() model = net(3, 3, nc=64, nb=20, act_mode='R') model_path = '/home/dihan/KAIR/model_zoo/dncnn_color_blind.pth' model.load_state_dict(torch.load(model_path), strict=True) model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.cuda() noise_im_torch = np_to_torch(noise_im) noise_im_torch = noise_im_torch.cuda() with torch.no_grad(): r_dncnn_np = torch_to_np(model(noise_im_torch)) psnr_dncnn = compare_psnr(clean_im.transpose(1, 2, 0), r_dncnn_np.transpose(1, 2, 0), 1) ssim_dncnn = compare_ssim(r_dncnn_np.transpose(1, 2, 0), clean_im.transpose(1, 2, 0), multichannel=True, data_range=1) print('PSNR_DNCNN: {}, SSIM_DNCNN: {}'.format(psnr_dncnn, ssim_dncnn), file=f, flush=True) parameters = [p for p in en_net.parameters() ] + [p for p in de_net.parameters()] optimizer = torch.optim.Adam(parameters, lr=LR) l2_loss = torch.nn.MSELoss(reduction='sum').cuda() i0 = np_to_torch(noise_im).cuda() Y = torch.zeros_like(noise_im_torch).cuda() i0_til_torch = np_to_torch(noise_im).cuda() diff_original_np = noise_im.astype(np.float32) - clean_im.astype( np.float32) diff_original_name = 'Original_dis.png' save_hist(diff_original_np, result_root + diff_original_name) best_psnr = 0 best_ssim = 0 for i in range(total_step): ############################### sub-problem 1 ################################# for i_1 in range(prob1_iter): optimizer.zero_grad() mean, log_var = en_net(noise_im_torch) z = sample_z(mean, log_var) out = de_net(z) total_loss = 0.5 * l2_loss(out, noise_im_torch) total_loss += kl_loss(mean, log_var, i0, sigma) total_loss += (rho / 2) * l2_loss(i0 + Y, i0_til_torch) total_loss.backward() optimizer.step() with torch.no_grad(): i0 = ((1 / sigma**2) * mean + rho * (i0_til_torch - Y)) / ((1 / sigma**2) + rho) with torch.no_grad(): ############################### sub-problem 2 ################################# i0_til_torch = model(i0 + Y) ############################### sub-problem 3 ################################# Y = Y + eta * (i0 - i0_til_torch) ############################################################################### i0_np = torch_to_np(i0) Y_np = torch_to_np(Y) denoise_obj_pil = np_to_pil((i0_np + Y_np).clip(0, 1)) Y_norm_np = np.sqrt((Y_np * Y_np).sum(0)) i0_pil = np_to_pil(i0_np) mean_np = torch_to_np(mean) mean_pil = np_to_pil(mean_np) out_np = torch_to_np(out) out_pil = np_to_pil(out_np) diff_np = mean_np - clean_im denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png' Y_name = 'Y_{:04d}'.format(i) + '.png' i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png' mean_i_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png' out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png' diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png' denoise_obj_pil.save(result_root + denoise_obj_name) save_heatmap(Y_norm_np, result_root + Y_name) i0_pil.save(result_root + i0_name) mean_pil.save(result_root + mean_i_name) out_pil.save(result_root + out_name) save_hist(diff_np, result_root + diff_name) i0_til_np = torch_to_np(i0_til_torch).clip(0, 1) psnr = compare_psnr(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), 1) ssim = compare_ssim(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), multichannel=True, data_range=1) i0_til_pil = np_to_pil(i0_til_np) i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png')) print('Iteration: %02d, VAE Loss: %f, PSNR: %f, SSIM: %f' % (i, total_loss.item(), psnr, ssim), file=f, flush=True) if best_psnr < psnr: best_psnr = psnr best_ssim = ssim else: break return i0_til_np, best_psnr, best_ssim
def main(): # ---------------------------------------- # Preparation # ---------------------------------------- model_name = 'dncnn3' # 'dncnn3'- can be used for blind Gaussian denoising, JPEG deblocking (quality factor 5-100) and super-resolution (x234) # important! testset_name = 'bsd68' # test set, low-quality grayscale/color JPEG images n_channels = 1 # set 1 for grayscale image, set 3 for color image x8 = False # default: False, x8 to boost performance testsets = 'testsets' # fixed results = 'results' # fixed result_name = testset_name + '_' + model_name # fixed L_path = os.path.join( testsets, testset_name ) # L_path, for Low-quality grayscale/Y-channel JPEG images E_path = os.path.join(results, result_name) # E_path, for Estimated images util.mkdir(E_path) model_pool = 'model_zoo' # fixed model_path = os.path.join(model_pool, model_name + '.pth') logger_name = result_name utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name + '.log')) logger = logging.getLogger(logger_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ---------------------------------------- # load model # ---------------------------------------- from models.network_dncnn import DnCNN as net model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R') model.load_state_dict(torch.load(model_path), strict=True) model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) logger.info('Model path: {:s}'.format(model_path)) number_parameters = sum(map(lambda x: x.numel(), model.parameters())) logger.info('Params number: {}'.format(number_parameters)) logger.info(L_path) L_paths = util.get_image_paths(L_path) for idx, img in enumerate(L_paths): # ------------------------------------ # (1) img_L # ------------------------------------ img_name, ext = os.path.splitext(os.path.basename(img)) logger.info('{:->4d}--> {:>10s}'.format(idx + 1, img_name + ext)) img_L = util.imread_uint(img, n_channels=n_channels) img_L = util.uint2single(img_L) if n_channels == 3: ycbcr = util.rgb2ycbcr(img_L, False) img_L = ycbcr[..., 0:1] img_L = util.single2tensor4(img_L) img_L = img_L.to(device) # ------------------------------------ # (2) img_E # ------------------------------------ if not x8: img_E = model(img_L) else: img_E = utils_model.test_mode(model, img_L, mode=3) img_E = util.tensor2single(img_E) if n_channels == 3: ycbcr[..., 0] = img_E img_E = util.ycbcr2rgb(ycbcr) img_E = util.single2uint(img_E) # ------------------------------------ # save results # ------------------------------------ util.imsave(img_E, os.path.join(E_path, img_name + '.png'))
def main(args): # Model name if args['model'] is None: model_name = 'dncnn_50' # 'dncnn_25' | 'dncnn_50' | 'dncnn_gray_blind' | 'dncnn_color_blind' | 'dncnn3' else: model_name = args['model'] # RGB or Gray Scale mode if args['color'] == 'rgb': model_name = 'dncnn_color_blind' # Error Handling for unavailable model try: model_path = os.path.join(model_pool, model_name + '.pth') print("Using model %s" % (model_name)) print("---------------------------------") except: print('Model not found') exit() # Disabled now - Found to reduce quality of output x8 = False # default: False, x8 to boost performance if args['type'] == None: task_current = 'dn' # 'dn' for denoising | 'sr' for super-resolution else: task_current = args['type'] sf = 1 # unused for denoising # Identify if model used is color if 'color' in model_name: n_channels = 3 # fixed, 1 for grayscale image, 3 for color image else: n_channels = 1 # fixed for grayscale image if model_name in ['dncnn_gray_blind', 'dncnn_color_blind', 'dncnn3']: nb = 20 # fixed else: nb = 17 # fixed border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM need_H = False # Load Model model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=nb, act_mode='R') model.load_state_dict(torch.load(model_path), strict=True) model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) # Load Image if (args['input'] != None) and (args['batch'] == None): # Single image img = args['input'] predict(img, n_channels, model, x8) elif (args['input'] == None) and (args['batch'] != None): # Batch Mode # Load each image from directory and predict mypath = args['batch'] # Check if path exist and load list of files if os.path.exists(mypath): files = [ f for f in listdir(mypath) if os.path.isfile(os.path.join(mypath, f)) ] print("Found %d files" % (len(files))) # Predict for each file and save results for item in files: try: filepath = os.path.join(mypath, item) predict(filepath, n_channels, model, x8) except: print("Error with %s" % (item)) else: print("Path does not exist") exit()