def main(args): r"""Performs the main training loop """ # Load dataset print('> Loading dataset ...') dataset_train = Dataset(train=True, gray_mode=args.gray, shuffle=True) dataset_val = Dataset(train=False, gray_mode=args.gray, shuffle=False) loader_train = DataLoader(dataset=dataset_train, num_workers=6, \ batch_size=args.batch_size, shuffle=True) print("\t# of training samples: %d\n" % int(len(dataset_train))) # Init loggers if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) logger = init_logger(args) # Create model if not args.gray: in_ch = 3 else: in_ch = 1 net = FFDNet(num_input_channels=in_ch) # Initialize model with He init net.apply(weights_init_kaiming) # Define loss criterion = nn.MSELoss(size_average=False) # Move to GPU device_ids = [0] model = nn.DataParallel(net, device_ids=device_ids).cuda() criterion.cuda() # Optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) # Resume training or start anew if args.resume_training: resumef = os.path.join(args.log_dir, 'ckpt.pth') if os.path.isfile(resumef): checkpoint = torch.load(resumef) print("> Resuming previous training") model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) new_epoch = args.epochs new_milestone = args.milestone current_lr = args.lr args = checkpoint['args'] training_params = checkpoint['training_params'] start_epoch = training_params['start_epoch'] args.epochs = new_epoch args.milestone = new_milestone args.lr = current_lr print("=> loaded checkpoint '{}' (epoch {})"\ .format(resumef, start_epoch)) print("=> loaded parameters :") print("==> checkpoint['optimizer']['param_groups']") print("\t{}".format(checkpoint['optimizer']['param_groups'])) print("==> checkpoint['training_params']") for k in checkpoint['training_params']: print("\t{}, {}".format(k, checkpoint['training_params'][k])) argpri = vars(checkpoint['args']) print("==> checkpoint['args']") for k in argpri: print("\t{}, {}".format(k, argpri[k])) args.resume_training = False else: raise Exception("Couldn't resume training with checkpoint {}".\ format(resumef)) else: start_epoch = 0 training_params = {} training_params['step'] = 0 training_params['current_lr'] = 0 training_params['no_orthog'] = args.no_orthog # Training for epoch in range(start_epoch, args.epochs): # Learning rate value scheduling according to args.milestone if epoch > args.milestone[1]: current_lr = args.lr / 1000. training_params['no_orthog'] = True elif epoch > args.milestone[0]: current_lr = args.lr / 10. else: current_lr = args.lr # set learning rate in optimizer for param_group in optimizer.param_groups: param_group["lr"] = current_lr print('learning rate %f' % current_lr) # train for i, data in enumerate(loader_train, 0): # Pre-training step model.train() model.zero_grad() optimizer.zero_grad() # inputs: noise and noisy image img_train = data noise = torch.zeros(img_train.size()) stdn = np.random.uniform(args.noiseIntL[0], args.noiseIntL[1], \ size=noise.size()[0]) for nx in range(noise.size()[0]): sizen = noise[0, :, :, :].size() noise[nx, :, :, :] = torch.FloatTensor(sizen).\ normal_(mean=0, std=stdn[nx]) imgn_train = img_train + noise # Create input Variables img_train = Variable(img_train.cuda(), volatile=True) imgn_train = Variable(imgn_train.cuda(), volatile=True) noise = Variable(noise.cuda()) stdn_var = Variable(torch.cuda.FloatTensor(stdn), volatile=True) # Evaluate model and optimize it out_train = model(imgn_train, stdn_var) loss = criterion(out_train, noise) / (imgn_train.size()[0] * 2) loss.backward() optimizer.step() # Results model.eval() out_train = torch.clamp(imgn_train - model(imgn_train, stdn_var), 0., 1.) psnr_train = batch_psnr(out_train, img_train, 1.) # PyTorch v0.4.0: loss.data[0] --> loss.item() if training_params['step'] % args.save_every == 0: # Apply regularization by orthogonalizing filters if not training_params['no_orthog']: model.apply(svd_orthogonalization) # Log the scalar values writer.add_scalar('loss', loss.data[0], training_params['step']) writer.add_scalar('PSNR on training data', psnr_train, \ training_params['step']) print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %\ (epoch+1, i+1, len(loader_train), loss.data[0], psnr_train)) training_params['step'] += 1 # The end of each epoch model.eval() # Validation psnr_val = 0 for valimg in dataset_val: img_val = torch.unsqueeze(valimg, 0) noise = torch.FloatTensor(img_val.size()).\ normal_(mean=0, std=args.val_noiseL) imgn_val = img_val + noise img_val, imgn_val = Variable(img_val.cuda()), Variable( imgn_val.cuda()) sigma_noise = Variable(torch.cuda.FloatTensor([args.val_noiseL])) out_val = torch.clamp(imgn_val - model(imgn_val, sigma_noise), 0., 1.) psnr_val += batch_psnr(out_val, img_val, 1.) psnr_val /= len(dataset_val) print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val)) writer.add_scalar('PSNR on validation data', psnr_val, epoch) writer.add_scalar('Learning rate', current_lr, epoch) # Log val images try: if epoch == 0: # Log graph of the model writer.add_graph( model, (imgn_val, sigma_noise), ) # Log validation images for idx in range(2): imclean = utils.make_grid(img_val.data[idx].clamp(0., 1.), \ nrow=2, normalize=False, scale_each=False) imnsy = utils.make_grid(imgn_val.data[idx].clamp(0., 1.), \ nrow=2, normalize=False, scale_each=False) writer.add_image('Clean validation image {}'.format(idx), imclean, epoch) writer.add_image('Noisy validation image {}'.format(idx), imnsy, epoch) for idx in range(2): imrecons = utils.make_grid(out_val.data[idx].clamp(0., 1.), \ nrow=2, normalize=False, scale_each=False) writer.add_image('Reconstructed validation image {}'.format(idx), \ imrecons, epoch) # Log training images imclean = utils.make_grid(img_train.data, nrow=8, normalize=True, \ scale_each=True) writer.add_image('Training patches', imclean, epoch) except Exception as e: logger.error("Couldn't log results: {}".format(e)) # save model and checkpoint training_params['start_epoch'] = epoch + 1 torch.save(model.state_dict(), os.path.join(args.log_dir, 'net.pth')) save_dict = { \ 'state_dict': model.state_dict(), \ 'optimizer' : optimizer.state_dict(), \ 'training_params': training_params, \ 'args': args\ } torch.save(save_dict, os.path.join(args.log_dir, 'ckpt.pth')) if epoch % args.save_every_epochs == 0: torch.save(save_dict, os.path.join(args.log_dir, \ 'ckpt_e{}.pth'.format(epoch+1))) del save_dict
def test_ffdnet(**args): r"""Denoises an input image with FFDNet """ # Init logger logger = init_logger_ipol() # Check if input exists and if it is RGB try: rgb_den = is_rgb(args['input']) except: raise Exception('Could not open the input image') # Open image as a CxHxW torch.Tensor if rgb_den: in_ch = 3 model_fn = 'models/net_rgb.pth' imorig = cv2.imread(args['input']) # from HxWxC to CxHxW, RGB image imorig = (cv2.cvtColor(imorig, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) else: # from HxWxC to CxHxW grayscale image (C=1) in_ch = 1 model_fn = 'logs/net.pth' imorig = cv2.imread(args['input'], cv2.IMREAD_GRAYSCALE) imorig_copy = imorig.copy() imorig = np.expand_dims(imorig, 0) imorig = np.expand_dims(imorig, 0) # Handle odd sizes expanded_h = False expanded_w = False sh_im = imorig.shape if sh_im[2] % 2 == 1: expanded_h = True imorig = np.concatenate((imorig, \ imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2) if sh_im[3] % 2 == 1: expanded_w = True imorig = np.concatenate((imorig, \ imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3) imorig = normalize(imorig) imorig = torch.Tensor(imorig) # Absolute path to model file model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \ model_fn) # Create model print('Loading model ...\n') net = FFDNet(num_input_channels=in_ch) # Load saved weights if args['cuda']: state_dict = torch.load(model_fn) device_ids = [0] model = nn.DataParallel(net, device_ids=device_ids).cuda() else: state_dict = torch.load(model_fn, map_location='cpu') # CPU mode: remove the DataParallel wrapper state_dict = remove_dataparallel_wrapper(state_dict) model = net model.load_state_dict(state_dict) # Sets the model in evaluation mode (e.g. it removes BN) model.eval() # Sets data type according to CPU or GPU modes if args['cuda']: dtype = torch.cuda.FloatTensor else: dtype = torch.FloatTensor # Add noise if args['add_noise']: noise = torch.FloatTensor(imorig.size()).\ normal_(mean=0, std=args['noise_sigma']) imnoisy = imorig + noise else: imnoisy = imorig.clone() # Test mode with torch.no_grad(): # PyTorch v0.4.0 imorig, imnoisy = Variable(imorig.type(dtype)), \ Variable(imnoisy.type(dtype)) nsigma = Variable(torch.FloatTensor([args['noise_sigma']]).type(dtype)) # Measure runtime start_t = time.time() # Estimate noise and subtract it to the input image im_noise_estim = model(imnoisy, nsigma) outim = torch.clamp(imnoisy - im_noise_estim, 0., 1.) stop_t = time.time() if expanded_h: imorig = imorig[:, :, :-1, :] outim = outim[:, :, :-1, :] imnoisy = imnoisy[:, :, :-1, :] if expanded_w: imorig = imorig[:, :, :, :-1] outim = outim[:, :, :, :-1] imnoisy = imnoisy[:, :, :, :-1] # Compute PSNR and log it if rgb_den: print("### RGB denoising ###") else: print("### Grayscale denoising ###") if args['add_noise']: psnr = batch_psnr(outim, imorig, 1.) psnr_noisy = batch_psnr(imnoisy, imorig, 1.) print("----------PSNR noisy {0:0.2f}dB".format(psnr_noisy)) print("----------PSNR denoised {0:0.2f}dB".format(psnr)) else: logger.info("\tNo noise was added, cannot compute PSNR") print("----------Runtime {0:0.4f}s".format(stop_t - start_t)) # Compute difference diffout = 2 * (outim - imorig) + .5 diffnoise = 2 * (imnoisy - imorig) + .5 # Save images if not args['dont_save_results']: noisyimg = variable_to_cv2_image(imnoisy) outimg = variable_to_cv2_image(outim) cv2.imwrite( "bfffd/noisy-" + str(int(args['noise_sigma'] * 255)) + '-' + args['input'], noisyimg) cv2.imwrite( "bfffd/ffdnet-" + str(int(args['noise_sigma'] * 255)) + '-' + args['input'], outimg) if args['add_noise']: cv2.imwrite("noisy_diff.png", variable_to_cv2_image(diffnoise)) cv2.imwrite("ffdnet_diff.png", variable_to_cv2_image(diffout)) (score, diff) = compare_ssim(noisyimg, imorig_copy, full=True) (score2, diff) = compare_ssim(outimg, imorig_copy, full=True) print("----------Noisy ssim: {0:0.4f}".format(score)) print("----------Denoisy ssim: {0:0.4f}".format(score2))
# cuda = True cuda = False # Only Grayscale model is provided here, the RGB model is even better. rgb_den = False B = 2048. # B = 256. ######## in_ch = 1 # This model was trained on the Waterloo Exploration Database, sigma: [0,20], 50 epochs. see Fig.2 of the paper # It cost about 12 hours in a GPU server (Nvidia RTX 2080Ti with 11GB graphic memory) model_fn = 'mat/net-gray-v2.pth' # Absolute path to model file model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_fn) # Create model print('Loading model ...\n') net = FFDNet(num_input_channels=in_ch) # Load saved weights if cuda: state_dict = torch.load(model_fn) device_ids = [0] model = nn.DataParallel(net, device_ids=device_ids).cuda() # Sets data type according to CPU or GPU modes dtype = torch.cuda.FloatTensor else: state_dict = torch.load(model_fn, map_location='cpu') # CPU mode: remove the DataParallel wrapper state_dict = remove_dataparallel_wrapper(state_dict) model = net dtype = torch.FloatTensor model.load_state_dict(state_dict)
def test_ffdnet(**args): r"""Denoises an input image with FFDNet """ # Init logger logger = init_logger_ipol() # Check if input exists and if it is RGB try: rgb_den = is_rgb(args['input']) except: raise Exception('Could not open the input image') # Measure runtime start_t = time.time() # Open image as a CxHxW torch.Tensor if rgb_den: in_ch = 3 model_fn = 'net_rgb.pth' imorig = Image.open(args['input']) imorig = np.array(imorig, dtype=np.float32).transpose(2, 0, 1) else: # from HxWxC to CxHxW grayscale image (C=1) in_ch = 1 model_fn = 'models/net_gray.pth' # imorig = cv2.imread(args['input'], cv2.IMREAD_GRAYSCALE) imorig = np.expand_dims(imorig, 0) imorig = np.expand_dims(imorig, 0) # Handle odd sizes expanded_h = False expanded_w = False sh_im = imorig.shape if sh_im[2] % 2 == 1: expanded_h = True imorig = np.concatenate((imorig, \ imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2) if sh_im[3] % 2 == 1: expanded_w = True imorig = np.concatenate((imorig, \ imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3) imorig = normalize(imorig) imorig = torch.Tensor(imorig) # Absolute path to model file model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \ model_fn) # Create model print('Loading model ...\n') net = FFDNet(num_input_channels=in_ch) # Load saved weights if args['cuda']: state_dict = torch.load(model_fn) #device_ids = [0,1,2,3] #model = nn.DataParallel(net, device_ids=device_ids).cuda() #state_dict = remove_dataparallel_wrapper(state_dict) model = net else: state_dict = torch.load(model_fn, map_location='cpu') # CPU mode: remove the DataParallel wrapper state_dict = remove_dataparallel_wrapper(state_dict) model = net model.load_state_dict(state_dict) # Sets the model in evaluation mode (e.g. it removes BN) model.eval() # Sets data type according to CPU or GPU modes if args['cuda']: dtype = torch.cuda.FloatTensor else: dtype = torch.FloatTensor # Test mode with torch.no_grad(): # PyTorch v0.4.0 imorig = Variable(imorig.type(dtype)) nsigma = Variable(torch.FloatTensor([args['noise_sigma']]).type(dtype)) # # Measure runtime # start_t = time.time() # Estimate noise and subtract it to the input image im_noise_estim = model(imorig, nsigma) stop_t = time.time() # log time if rgb_den: print("### RGB denoising ###") else: print("### Grayscale denoising ###") print("\tRuntime {0:0.4f}s".format(stop_t - start_t)) # Save noises noise = variable_to_numpy(imorig.to(3) - im_noise_estim).transpose(1, 2, 0) filename = args['input'].split('/')[-1].split('.')[0] if args['save_path']: sio.savemat( './output_noise/' + args['save_path'] + '/' + filename + '.mat', {'Noisex': noise}) else: sio.savemat('./output_noise/' + filename + '.mat', {'Noisex': noise})
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" r"""Denoises an input image with FFDNet """ # Init logger logger = init_logger_ipol() # Check if input exists and if it is RGB # Absolute path to model file model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \ 'models/net_rgb.pth') # Create model print('Loading model ...\n') net = FFDNet(num_input_channels=3, test_mode=True) model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \ model_fn) # Load saved weights print(model_fn) state_dict = torch.load(model_fn) device_ids = [0] model = nn.DataParallel(net, device_ids=device_ids).cuda() model.load_state_dict(state_dict) # Sets the model in evaluation mode (e.g. it removes BN) model.eval() # Sets data type according to CPU or GPU modes dtype = torch.cuda.FloatTensor