def test(args): model = MIRNet() # summary(model,[[3,128,128],[0]]) # exit() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png")) test_img = glob.glob( "/vinai/tampm2/cityscapes_noise/gtFine/val/*/*_gtFine_color.png") if not os.path.exists(args.save_img): os.makedirs(args.save_img) for i in range(len(test_img)): # print(noisy_path[i]) img_path = os.path.join( args.noise_dir, test_img[i].split("/")[-1].replace("_gtFine_color", "_leftImg8bit")) print(img_path) image_noise = load_data(img_path) image_noises1 = split_tensor(image_noise) preds1 = [] for image1 in image_noises1: image1 = image1.unsqueeze(0) image_noises2 = split_tensor(image1) preds2 = [] for image2 in image_noises2: image2 = image2.unsqueeze(0) image_noises3 = split_tensor(image2) preds3 = [] for image3 in image_noises3: image3 = image3.unsqueeze(0).to(device) print(image3.size()) pred3 = model(image3) pred3 = pred3.detach().cpu().squeeze(0) preds3.append(pred3) pred2 = merge_tensor(preds3) preds2.append(pred2) pred1 = merge_tensor(preds2) preds1.append(pred1) pred = merge_tensor(preds1) pred = trans(pred) name_img = img_path.split("/")[-1].split(".")[0] pred.save(args.save_img + "/" + name_img + ".png")
def test(args): model = MIRNet() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['BenchmarkNoisyBlocksSrgb'] mat_re = np.zeros_like(all_noisy_imgs) i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape for i_img in range(i_imgs): for i_block in range(i_blocks): noise = transforms.ToTensor()(Image.fromarray( all_noisy_imgs[i_img][i_block])).unsqueeze(0) noise = noise.to(device) begin = time.time() pred = model(noise) pred = pred.detach().cpu() mat_re[i_img][i_block] = np.array(trans(pred[0])) return mat_re
def test(args): model = MIRNet_DGF() # summary(model,[[3,128,128],[0]]) # exit() if args.data_type == 'rgb': load_data = load_data_split elif args.data_type == 'filter': load_data = load_data_filter checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format(start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat(args.noise_dir)['siddplus_valid_noisy_srgb'] all_clean_imgs = scipy.io.loadmat(args.gt)['siddplus_valid_gt_srgb'] # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png")) # clean_path = [ i.replace("noisy","clean") for i in noisy_path] i_imgs, _,_,_ = all_noisy_imgs.shape psnrs = [] ssims = [] # print(noisy_path) for i_img in range(i_imgs): noise = transforms.ToTensor()(Image.fromarray(all_noisy_imgs[i_img])) image_noise, image_noise_hr = load_data(noise, args.burst_length) image_noise_hr = image_noise_hr.to(device) burst_noise = image_noise.to(device) begin = time.time() _, pred = model(burst_noise,image_noise_hr) pred = pred.detach().cpu() gt = transforms.ToTensor()((Image.fromarray(all_clean_imgs[i_img]))) gt = gt.unsqueeze(0) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnrs.append(psnr_t) ssims.append(ssim_t) print(i_img, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = str(i_img) + "_" plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join(args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) print(" AVG : PSNR : "+ str(np.mean(psnrs))+" : SSIM : "+ str(np.mean(ssims)))
def convert_torch_to_onnx(): img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png' output_path = 'models/denoiser_rgb.onnx' input_node_names = ['input_image'] output_nodel_names = ['output_image'] # torch_model = DenoiseNet() # load_checkpoint(torch_model, model_path, 'cpu') checkpoint = load_checkpoint("../checkpoints/kpn_att_repeat_new/", False, 'latest') state_dict = checkpoint['state_dict'] torch_model = Att_KPN_DGF( color=True, burst_length=4, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False ) torch_model.load_state_dict(state_dict) img = imageio.imread(img_path) img = np.asarray(img, dtype=np.float32) / 255. img_tensor = torch.from_numpy(img) img_tensor = img_tensor.permute(2, 0, 1) image_noise, image_noise_hr = load_data(img_tensor, 4) # begin = time.time() # print(image_noise_batch.size()) b, N, c, h, w = image_noise.size() feedData = image_noise.view(b, -1, h, w) print('Test forward pass') s = time.time() print("feedData :",feedData.size()) print("image_noise ",image_noise[:, 0:4, ...].size()) print("image_noise_hr ",image_noise_hr.size()) with torch.no_grad(): _,enhanced_img_tensor = torch_model(feedData, image_noise[:, 0:4, ...],image_noise_hr) enhanced_img_tensor = torch.clamp(enhanced_img_tensor, 0, 1) enhanced_img = (enhanced_img_tensor.permute(0, 2, 3, 1).squeeze(0) .cpu().detach().numpy()) enhanced_img = np.clip(enhanced_img * 255, 0, 255).astype('uint8') imageio.imwrite('../img/denoised.jpg', enhanced_img) print('- Time: ', time.time() - s) print('Export to onnx format') s = time.time() # torch2onnx(torch_model, img_tensor, output_path, input_node_names, torch2onnx(torch_model, (feedData, image_noise[:, 0:4, ...],image_noise_hr), output_path, input_node_names, output_nodel_names, keep_initializers=False, verify_after_export=True) print('- Time: ', time.time() - s)
def test(args): model = MIRNet() save_img = args.save_img checkpoint_dir = "checkpoints/mir" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) noisy_path = sorted(glob.glob(args.noise_dir + "/2_*.png")) clean_path = [i.replace("noisy", "clean") for i in noisy_path] print(noisy_path) for i in range(len(noisy_path)): noise = transforms.ToTensor()(Image.open( noisy_path[i]).convert('RGB'))[:, 0:args.image_size, 0:args.image_size].unsqueeze(0) noise = noise.to(device) begin = time.time() print(noise.size()) pred = model(noise) pred = pred.detach().cpu() gt = transforms.ToTensor()(Image.open( clean_path[i]).convert('RGB'))[:, 0:args.image_size, 0:args.image_size] gt = gt.unsqueeze(0) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) print(i, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = noisy_path[i].split("/")[-1].split(".")[0] plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0)
def eval(): model = Network(True).cuda() model.load_state_dict(load_checkpoint('./noise_models', best_or_latest='best')) model.eval() from torchvision.transforms import transforms from PIL import Image img = Image.open('./003/1.jpg') trans_tensor = transforms.ToTensor() trans_srgb = transforms.ToPILImage() img = trans_tensor(img).unsqueeze(0).cuda() pred = model(img).squeeze() print('min:', torch.min(pred), 'max:', torch.max(pred)) pred = pred / torch.max(pred) pred = pred.cpu() trans_srgb(pred).save('./003/1_pred.png', quality=100) print('OK!')
def test(args): if args.model_type == "DGF": model = MIRNet_DGF(n_colors=args.n_colors, out_channels=args.out_channels) elif args.model_type == "noise": model = MIRNet_noise(n_colors=args.n_colors, out_channels=args.out_channels) else: print(" Model type not valid") return checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['BenchmarkNoisyBlocksRaw'] mat_re = np.zeros_like(all_noisy_imgs) i_imgs, i_blocks, _, _ = all_noisy_imgs.shape for i_img in range(i_imgs): for i_block in range(i_blocks): noise = transforms.ToTensor()(pack_raw( all_noisy_imgs[i_img][i_block])) image_noise, image_noise_hr = load_data(noise, args.burst_length) image_noise_hr = image_noise_hr.to(device) burst_noise = image_noise.to(device) begin = time.time() _, pred = model(burst_noise, image_noise_hr) pred = np.array(pred.detach().cpu()[0]).transpose(1, 2, 0) pred = unpack_raw(pred) mat_re[i_img][i_block] = np.array(pred) return mat_re
def convert_torch_to_onnx(): img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png' model_path = '../../denoiser/pretrained_models/denoising/sidd_rgb.pth' output_path = 'models/denoiser_rgb.onnx' input_node_names = ['input_image'] output_nodel_names = ['output_image'] # torch_model = DenoiseNet() # load_checkpoint(torch_model, model_path, 'cpu') checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest') state_dict = checkpoint['state_dict'] torch_model = MIRNet() torch_model.load_state_dict(state_dict) img = imageio.imread(img_path) # img = img[0:256,0:256,:] print(img.shape) img = np.asarray(img, dtype=np.float32) / 255. img_tensor = torch.from_numpy(img) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) print('Test forward pass') s = time.time() with torch.no_grad(): enhanced_img_tensor = torch_model(img_tensor) enhanced_img_tensor = torch.clamp(enhanced_img_tensor, 0, 1) enhanced_img = (enhanced_img_tensor.permute(0, 2, 3, 1).squeeze(0) .cpu().detach().numpy()) enhanced_img = np.clip(enhanced_img * 255, 0, 255).astype('uint8') imageio.imwrite('../img/denoised.jpg', enhanced_img) print('- Time: ', time.time() - s) print('Export to onnx format') s = time.time() torch2onnx(torch_model, img_tensor, output_path, input_node_names, output_nodel_names, keep_initializers=False, verify_after_export=True) print('- Time: ', time.time() - s)
def eval(config, args): train_config = config['training'] arch_config = config['architecture'] use_cache = train_config['use_cache'] print('Eval Process......') checkpoint_dir = train_config['checkpoint_dir'] if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # the path for saving eval images eval_dir = train_config['eval_dir'] if not os.path.exists(eval_dir): os.mkdir(eval_dir) files = os.listdir(eval_dir) for f in files: os.remove(os.path.join(eval_dir, f)) # dataset and dataloader data_set = TrainDataSet(train_config['dataset_configs'], img_format='.bmp', degamma=True, color=False, blind=arch_config['blind_est'], train=False) data_loader = DataLoader(data_set, batch_size=1, shuffle=False, num_workers=args.num_workers) dataset_config = read_config(train_config['dataset_configs'], _configspec_path())['dataset_configs'] # model here model = KPN(color=False, burst_length=dataset_config['burst_length'], blind_est=arch_config['blind_est'], kernel_size=list(map(int, arch_config['kernel_size'].split())), sep_conv=arch_config['sep_conv'], channel_att=arch_config['channel_att'], spatial_att=arch_config['spatial_att'], upMode=arch_config['upMode'], core_bias=arch_config['core_bias']) if args.cuda: model = model.cuda() if args.mGPU: model = nn.DataParallel(model) # load trained model ckpt = load_checkpoint(checkpoint_dir, args.checkpoint) model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # data_loader = iter(data_loader) burst_length = dataset_config['burst_length'] data_length = burst_length if arch_config['blind_est'] else burst_length + 1 patch_size = dataset_config['patch_size'] trans = transforms.ToPILImage() with torch.no_grad(): psnr = 0.0 ssim = 0.0 for i, (burst_noise, gt, white_level) in enumerate(data_loader): if i < 100: # data = next(data_loader) if args.cuda: burst_noise = burst_noise.cuda() gt = gt.cuda() white_level = white_level.cuda() pred_i, pred = model(burst_noise, burst_noise[:, 0:burst_length, ...], white_level) pred_i = sRGBGamma(pred_i) pred = sRGBGamma(pred) gt = sRGBGamma(gt) burst_noise = sRGBGamma(burst_noise / white_level) psnr_t = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1)) ssim_t = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1)) psnr_noisy = calculate_psnr( burst_noise[:, 0, ...].unsqueeze(1), gt.unsqueeze(1)) psnr += psnr_t ssim += ssim_t pred = torch.clamp(pred, 0.0, 1.0) if args.cuda: pred = pred.cpu() gt = gt.cpu() burst_noise = burst_noise.cpu() trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join( eval_dir, '{}_noisy_{:.2f}dB.png'.format(i, psnr_noisy)), quality=100) trans(pred.squeeze()).save(os.path.join( eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)), quality=100) trans(gt.squeeze()).save(os.path.join(eval_dir, '{}_gt.png'.format(i)), quality=100) print('{}-th image is OK, with PSNR: {:.2f}dB, SSIM: {:.4f}'. format(i, psnr_t, ssim_t)) else: break print('All images are OK, average PSNR: {:.2f}dB, SSIM: {:.4f}'.format( psnr / 100, ssim / 100))
def train(args): torch.set_num_threads(args.num_workers) torch.manual_seed(0) if args.data_type == 'rgb': data_set = SingleLoader(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) elif args.data_type == 'raw': data_set = SingleLoader_raw(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) else: print("Data type not valid") exit() data_loader = DataLoader(data_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loss_func = losses.CharbonnierLoss().to(device) # loss_func = losses.AlginLoss().to(device) adaptive = robust_loss.adaptive.AdaptiveLossFunction( num_dims=3 * args.image_size**2, float_dtype=np.float32, device=device) checkpoint_dir = args.checkpoint if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if args.model_type == "MIR": model = MIRNet(in_channels=args.n_colors, out_channels=args.out_channels).to(device) elif args.model_type == "KPN": model = MIRNet_kpn(in_channels=args.n_colors, out_channels=args.out_channels).to(device) else: print(" Model type not valid") return optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer.zero_grad() average_loss = MovingAverage(args.save_every) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2, 4, 6, 8, 10, 12, 14, 16], 0.8) if args.restart: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') else: try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] state_dict = checkpoint['state_dict'] # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = "model."+ k # remove `module.` # new_state_dict[name] = v model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') eps = 1e-4 for epoch in range(start_epoch, args.epoch): for step, (noise, gt) in enumerate(data_loader): noise = noise.to(device) gt = gt.to(device) pred = model(noise) # print(pred.size()) loss = loss_func(pred, gt) # bs = gt.size()[0] # diff = noise - gt # loss = torch.sqrt((diff * diff) + (eps * eps)) # loss = loss.view(bs,-1) # loss = adaptive.lossfun(loss) # loss = torch.mean(loss) optimizer.zero_grad() loss.backward() optimizer.step() average_loss.update(loss) if global_step % args.save_every == 0: print(len(average_loss._cache)) if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), } save_checkpoint(save_dict, is_best, checkpoint_dir, global_step) if global_step % args.loss_every == 0: print(global_step, "PSNR : ", calculate_psnr(pred, gt)) print(average_loss.get_value()) global_step += 1 print('Epoch {} is finished.'.format(epoch)) scheduler.step()
def train(num_workers, cuda, restart_train, mGPU): # torch.set_num_threads(num_threads) color = True batch_size = args.batch_size lr = 2e-4 lr_decay = 0.89125093813 n_epoch = args.epoch # num_workers = 8 save_freq = args.save_every loss_freq = args.loss_every lr_step_size = 100 burst_length = args.burst_length # checkpoint path checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # logs path logs_dir = "checkpoints/logs/" + args.checkpoint if not os.path.exists(logs_dir): os.makedirs(logs_dir) shutil.rmtree(logs_dir) log_writer = SummaryWriter(logs_dir) # dataset and dataloader data_set = SingleLoader_DGF(noise_dir=args.noise_dir,gt_dir=args.gt_dir,image_size=args.image_size,burst_length=burst_length) data_loader = DataLoader( data_set, batch_size=batch_size, shuffle=True, num_workers=num_workers ) # model here if args.model_type == "attKPN": model = Att_KPN_noise_DGF( color=color, burst_length=burst_length, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False ) elif args.model_type == "attWKPN": model = Att_Weight_KPN_noise_DGF( color=color, burst_length=burst_length, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False ) elif args.model_type == 'KPN': model = KPN_noise_DGF( color=color, burst_length=burst_length, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False ) else: print(" Model type not valid") return if cuda: model = model.cuda() if mGPU: model = nn.DataParallel(model) model.train() # loss function here loss_func = LossBasic() if args.wavelet_loss: print("Use wavelet loss") loss_func2 = WaveletLoss() # Optimizer here optimizer = optim.Adam( model.parameters(), lr=lr ) optimizer.zero_grad() # learning rate scheduler here scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=lr_decay) average_loss = MovingAverage(save_freq) if not restart_train: try: checkpoint = load_checkpoint(checkpoint_dir,cuda , best_or_latest=args.load_type) start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['lr_scheduler']) print('=> loaded checkpoint (epoch {}, global_step {})'.format(start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') else: start_epoch = 0 global_step = 0 best_loss = np.inf if os.path.exists(checkpoint_dir): pass # files = os.listdir(checkpoint_dir) # for f in files: # os.remove(os.path.join(checkpoint_dir, f)) else: os.mkdir(checkpoint_dir) print('=> training') for epoch in range(start_epoch, n_epoch): epoch_start_time = time.time() # decay the learning rate # print('='*20, 'lr={}'.format([param['lr'] for param in optimizer.param_groups]), '='*20) t1 = time.time() for step, (image_noise_hr,image_noise_lr, image_gt_hr, _) in enumerate(data_loader): # print(burst_noise.size()) # print(gt.size()) if cuda: burst_noise = image_noise_lr.cuda() gt = image_gt_hr.cuda() image_noise_hr = image_noise_hr.cuda() noise_gt = (image_noise_hr-image_gt_hr).cuda() else: burst_noise = image_noise_lr gt = image_gt_hr noise_gt = image_noise_hr - image_gt_hr # _, pred,noise = model(burst_noise,image_noise_hr) # print(pred.size()) # loss_basic = loss_func(pred, gt) loss_noise = loss_func(noise,noise_gt) loss = loss_basic + loss_noise if args.wavelet_loss: loss_wave = loss_func2(pred,gt) loss_wave_noise = loss_func2(noise,noise_gt) # print(loss_wave) loss = loss_basic + loss_wave + loss_noise + loss_wave_noise # backward optimizer.zero_grad() loss.backward() optimizer.step() # update the average loss average_loss.update(loss) # global_step if not color: pred = pred.unsqueeze(1) gt = gt.unsqueeze(1) if global_step %loss_freq ==0: # calculate PSNR print("burst_noise : ",burst_noise.size()) print("gt : ",gt.size()) psnr = calculate_psnr(pred, gt) ssim = calculate_ssim(pred, gt) # add scalars to tensorboardX log_writer.add_scalar('loss_basic', loss_basic, global_step) log_writer.add_scalar('loss_total', loss, global_step) log_writer.add_scalar('psnr', psnr, global_step) log_writer.add_scalar('ssim', ssim, global_step) # print print('{:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t|' ' loss: {:.4f}\t| PSNR: {:.2f}dB\t| SSIM: {:.4f}\t| time:{:.2f} seconds.' .format(global_step, epoch, step, loss_basic, loss, psnr, ssim, time.time()-t1)) t1 = time.time() if global_step % save_freq == 0: if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'lr_scheduler': scheduler.state_dict() } save_checkpoint( save_dict, is_best, checkpoint_dir, global_step, max_keep=10 ) global_step += 1 print('Epoch {} is finished, time elapsed {:.2f} seconds.'.format(epoch, time.time()-epoch_start_time)) lr_cur = [param['lr'] for param in optimizer.param_groups] if lr_cur[0] > 5e-6: scheduler.step() else: for param in optimizer.param_groups: param['lr'] = 5e-6
def test_multi(args): model = MIRNet() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = "model." + k # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) mat_folders = glob.glob(os.path.join(args.noise_dir, '*')) trans = transforms.ToPILImage() if not os.path.exists(args.save_img): os.makedirs(args.save_img) for mat_folder in mat_folders: save_mat_folder = os.path.join(args.save_img, mat_folder.split("/")[-1]) for mat_file in glob.glob(os.path.join(mat_folder, '*')): mat_contents = sio.loadmat(mat_file) sub_image, y_gb, x_gb = mat_contents['image'], mat_contents[ 'y_gb'][0][0], mat_contents['x_gb'][0][0] image_noise = transforms.ToTensor()( Image.fromarray(sub_image)).unsqueeze(0) image_noise_batch = image_noise.to(device) pred = model(image_noise_batch) pred = np.array(trans(pred[0].cpu())) if args.save_img != '': if not os.path.exists(save_mat_folder): os.makedirs(save_mat_folder) # mat_contents['image'] = pred # print(mat_contents) print("save : ", os.path.join(save_mat_folder, mat_file.split("/")[-1])) data = { "image": pred, "y_gb": mat_contents['y_gb'][0][0], "x_gb": mat_contents['x_gb'][0][0], "y_lc": mat_contents['y_lc'][0][0], "x_lc": mat_contents['x_lc'][0][0], 'size': mat_contents['size'][0][0], "H": mat_contents['H'][0][0], "W": mat_contents['W'][0][0] } # print(data) sio.savemat( os.path.join(save_mat_folder, mat_file.split("/")[-1]), data)
from model.KPN_DGF import KPN_DGF, Att_KPN_DGF, Att_Weight_KPN_DGF, Att_KPN_Wavelet_DGF import torch import tensorflow as tf import onnx from onnx_tf.backend import prepare import os from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint checkpoint = load_checkpoint("../checkpoints/kpn_att_repeat_new/", False, 'latest') state_dict = checkpoint['state_dict'] model = Att_KPN_DGF(color=True, burst_length=4, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) # model.load_state_dict(state_dict) model.eval() from torchsummary import summary summary(model, [(12, 256, 256), (4, 3, 256, 256), (3, 512, 512)], batch_size=1) exit() # Converting model to ONNX print('===> Converting model to ONNX.') try: for _ in model.modules(): _.training = False
def test(args): if args.model_type == "DGF": model = MIRNet_DGF(n_colors=args.n_colors, out_channels=args.out_channels) elif args.model_type == "noise": model = MIRNet_noise(n_colors=args.n_colors, out_channels=args.out_channels) else: print(" Model type not valid") return checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = "model." + k # remove `module.` # new_state_dict[name] = v model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['ValidationNoisyBlocksRaw'] all_clean_imgs = scipy.io.loadmat(args.gt_dir)['ValidationGtBlocksRaw'] # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png")) # clean_path = [ i.replace("noisy","clean") for i in noisy_path] i_imgs, i_blocks, _, _ = all_noisy_imgs.shape psnrs = [] ssims = [] # print(noisy_path) for i_img in range(i_imgs): for i_block in range(i_blocks): noise = transforms.ToTensor()(pack_raw( all_noisy_imgs[i_img][i_block])) image_noise, image_noise_hr = load_data(noise, args.burst_length) image_noise_hr = image_noise_hr.to(device) burst_noise = image_noise.to(device) begin = time.time() _, pred = model(burst_noise, image_noise_hr) pred = pred.detach().cpu() gt = transforms.ToTensor()( (pack_raw(all_clean_imgs[i_img][i_block]))) gt = gt.unsqueeze(0) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnrs.append(psnr_t) ssims.append(ssim_t) print(i_img, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = str(i_img) plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) print(" AVG : PSNR : " + str(np.mean(psnrs)) + " : SSIM : " + str(np.mean(ssims)))
def train(): log_writer = SummaryWriter('./logs') parser = argparse.ArgumentParser() parser.add_argument('--restart', '-r', action='store_true') args = parser.parse_args() config = read_config('kpn_specs/att_kpn_config.conf', 'kpn_specs/configspec.conf') train_config = config['training'] data_set = TrainDataSet( train_config['dataset_configs'], img_format='.bmp', degamma=True, color=True, blind=False ) data_loader = DataLoader( dataset=data_set, batch_size=32, shuffle=True, num_workers=4 ) loss_fn = nn.L1Loss() model = Network(True).cuda() model.train() optimizer = optim.Adam(model.parameters(), lr=5e-5) if not args.restart: model.load_state_dict(load_checkpoint('./noise_models', best_or_latest='best')) global_iter = 0 min_loss = np.inf loss_ave = MovingAverage(200) import os if not os.path.exists('./noise_models'): os.mkdir('./noise_models') for epoch in range(100): for step, (data, A, B) in enumerate(data_loader): feed = data[:, 0, ...].cuda() gt = data[:, -1, ...].cuda() # print(data.size()) pred = model(feed) loss = loss_fn(pred, gt) global_iter += 1 optimizer.zero_grad() loss.backward() optimizer.step() log_writer.add_scalar('loss', loss, global_iter) loss_ave.update(loss) if global_iter % 200 == 0: loss_t = loss_ave.get_value() min_loss = min(min_loss, loss_t) is_best = min_loss == loss_t save_checkpoint( model.state_dict(), is_best=is_best, checkpoint_dir='./noise_models', n_iter=global_iter ) print('{: 6d}, epoch {: 3d}, iter {: 4d}, loss {:.4f}'.format(global_iter, epoch, step, loss))
def train(args): # torch.set_num_threads(4) # torch.manual_seed(args.seed) # checkpoint = utility.checkpoint(args) data_set = SingleLoader(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) data_loader = DataLoader(data_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) loss_basic = BasicLoss() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint_dir = args.checkpoint if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) model = MWRN_lv3().to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [5, 10, 15, 20, 25, 30], 0.5) optimizer.zero_grad() average_loss = MovingAverage(args.save_every) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') DWT = common.DWT() param = [x for name, x in model.named_parameters()] clip_grad_D = 1e4 grad_norm_D = 0 for epoch in range(start_epoch, args.epoch): for step, (noise, gt) in enumerate(data_loader): noise = noise.to(device) gt = gt.to(device) x1 = DWT(gt).to(device) x2 = DWT(x1).to(device) x3 = DWT(x2).to(device) y1 = DWT(noise).to(device) y2 = DWT(y1).to(device) y3 = DWT(y2).to(device) lv3_out, img_lv3 = model(y3, None) scale_loss_lv3 = loss_basic(x3, img_lv3) loss = scale_loss_lv3 optimizer.zero_grad() loss.backward() total_norm_D = nn.utils.clip_grad_norm_(param, clip_grad_D) grad_norm_D = (grad_norm_D * (step / (step + 1)) + total_norm_D / (step + 1)) optimizer.step() average_loss.update(loss) if global_step % args.save_every == 0: print("Save : epoch ", epoch, " step : ", global_step, " with avg loss : ", average_loss.get_value(), ", best loss : ", best_loss) if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), } save_checkpoint(save_dict, is_best, checkpoint_dir, global_step) if global_step % args.loss_every == 0: print(global_step, ": ", average_loss.get_value()) global_step += 1 clip_grad_D = min(clip_grad_D, grad_norm_D) scheduler.step() print("Epoch : ", epoch, "end at step: ", global_step)
def test_multi(args): color = True burst_length = args.burst_length device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN_DGF( color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False ) elif args.model_type == "attKPN_Wave": model = Att_KPN_Wavelet_DGF( color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False ) elif args.model_type == "attWKPN": model = Att_Weight_KPN_DGF( color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False ) elif args.model_type == "KPN": model = KPN_DGF( color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False ) else: print(" Model type not valid") return checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len(os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format(checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir,cuda=device=='cuda',best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(state_dict) # else: # model.load_state_dict(ckpt['state_dict']) model.to(device) print('The model has been loaded from epoch {}, n_iter {}.'.format(ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png")) model.eval() torch.manual_seed(0) trans = transforms.ToPILImage() if not os.path.exists(args.save_img): os.makedirs(args.save_img) for i in range(len(noisy_path)): image_noise = transforms.ToTensor()(Image.open(noisy_path[i]).convert('RGB')) image_noise,image_noise_hr = load_data(image_noise,burst_length) image_noise_hr = image_noise_hr.to(device) # begin = time.time() image_noise_batch = image_noise.to(device) # print(image_noise_batch.size()) # burst_size = image_noise_batch.size()[1] burst_noise = image_noise_batch.to(device) # print(burst_noise.size()) # print(image_noise_hr.size()) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...],image_noise_hr) del pred_i print(pred.size()) pred = np.array(trans(pred[0].cpu())) print(pred.shape) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) # mat_contents['image'] = pred # print(mat_contents) print("save : ", os.path.join(args.save_img,noisy_path[i].split("/")[-1].split(".")[0]+'.mat')) data = {"Idenoised_crop": pred} # print(data) sio.savemat(os.path.join(args.save_img,noisy_path[i].split("/")[-1].split(".")[0]+'.mat'), data)
def validation(args): gkpn_model = Att_Weight_KPN(color=False, burst_length=8, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True) gkpn_model = nn.DataParallel(gkpn_model.cuda()) state = load_checkpoint('../models/att_weight_kpn/checkpoint', best_or_latest='best') gkpn_model.load_state_dict(state['state_dict']) gkpn_model.eval() wkpn_model = Att_Weight_KPN(color=False, burst_length=8, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False) wkpn_model = nn.DataParallel(wkpn_model.cuda()) state = load_checkpoint('../models/weight_kpn/checkpoint', best_or_latest='best') wkpn_model.load_state_dict(state['state_dict']) wkpn_model.eval() kpn_model = KPN(color=False, burst_length=8, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False) kpn_model = nn.DataParallel(kpn_model.cuda()) state = load_checkpoint('../models/kpn_aug/checkpoint', best_or_latest='best') kpn_model.load_state_dict(state['state_dict']) kpn_model.eval() noise_est = nn.DataParallel(Network(True).cuda()) state = load_checkpoint('../noise_models', best_or_latest='best') noise_est.load_state_dict(state['model']) noise_est.eval() trans = transforms.ToTensor() imgs = [] with torch.no_grad(): if os.path.exists(args.img): if os.path.isdir(args.img): files = os.listdir(args.img) # file_index = np.random.permutation(len(files))[:8] file_index = range(8) for index in file_index: img = Image.open(os.path.join(args.img, files[index])) img = trans(img) img += (0.02 + 0.01 * img) * torch.randn_like(img) imgs.append(img.clamp(0.0, 1.0)) else: raise ValueError('should be a burst of frames, not a image!') s_read, s_shot = torch.Tensor([[[args.read]] ]), torch.Tensor([[[args.shot]]]) noise_est = torch.sqrt( s_read**2 + s_shot * torch.max(torch.zeros_like(imgs[0]), imgs[0])) imgs.append(noise_est) imgs = torch.stack(imgs, dim=0).unsqueeze(0).cuda() # imgs = torch.stack(imgs, dim=0).unsqueeze(0).cuda() # h, w = imgs.size()[-2:] # noise_est = noise_est(imgs[:, 0, ...].expand(1, 3, h, w))[:, 0, ...].unsqueeze(1).unsqueeze(2) # imgs = torch.cat([imgs, 15*noise_est], dim=1) else: raise ValueError('The path for image is not existing.') b, N, c, h, w = imgs.size() res_wkpn = torch.zeros(c, h, w).cuda() res_gkpn = torch.zeros(c, h, w).cuda() res_kpn = torch.zeros(c, h, w).cuda() res_gkpn_pred_i = torch.zeros(8, c, h, w).cuda() res_gkpn_residual = torch.zeros(8, c, h, w).cuda() patch_size = 512 receptiveFiled = 120 imgs_pad = torch.zeros(b, N, c, h + 2 * receptiveFiled, w + 2 * receptiveFiled) imgs_pad[..., receptiveFiled:-receptiveFiled, receptiveFiled:-receptiveFiled] = imgs if not os.path.exists('./eval_images_real'): os.mkdir('./eval_images_real') filename = os.path.basename(args.img) filename = os.path.splitext(filename)[0] trans = transforms.ToPILImage() for channel in range(c): for i in range(0, h, patch_size): for j in range(0, w, patch_size): if i + patch_size <= h and j + patch_size <= w: # feed = imgs[..., i:i+patch_size, j:j+patch_size].contiguous() feed = imgs_pad[..., channel, i:i + patch_size + 2 * receptiveFiled, j:j + patch_size + 2 * receptiveFiled].contiguous() elif i + patch_size <= h: # feed = imgs[..., i:i + patch_size, j:].contiguous() feed = imgs_pad[..., channel, i:i + patch_size + 2 * receptiveFiled, j:].contiguous() elif j + patch_size <= w: # feed = imgs[..., i:, j:j+patch_size].contiguous() feed = imgs_pad[..., channel, i:, j:j + patch_size + 2 * receptiveFiled].contiguous() else: # feed = imgs[..., i:, j:].contiguous() feed = imgs_pad[..., channel, i:, j:].contiguous() hs, ws = feed.size()[-2:] hs -= 2 * receptiveFiled ws -= 2 * receptiveFiled feed = padding(feed, patch_size + 2 * receptiveFiled, patch_size + 2 * receptiveFiled) # _, pred = wkpn_model(feed.view(b, -1, patch_size+2*receptiveFiled, patch_size+2*receptiveFiled), feed[:, 0:8, ...]) # res_wkpn[channel, i:i+patch_size, j:j+patch_size] = pred[..., receptiveFiled:hs+receptiveFiled, receptiveFiled:ws+receptiveFiled].squeeze() pred_i, pred, residuals = gkpn_model( feed.view(b, -1, patch_size + 2 * receptiveFiled, patch_size + 2 * receptiveFiled), feed[:, 0:8, ...]) res_gkpn[channel, i:i + patch_size, j:j + patch_size] = pred[..., receptiveFiled:hs + receptiveFiled, receptiveFiled:ws + receptiveFiled].squeeze() res_gkpn_pred_i[:, channel, i:i + patch_size, j:j + patch_size] = pred_i[ ..., receptiveFiled:hs + receptiveFiled, receptiveFiled:ws + receptiveFiled].squeeze() res_gkpn_residual[:, channel, i:i + patch_size, j:j + patch_size] = residuals[ ..., receptiveFiled:hs + receptiveFiled, receptiveFiled:ws + receptiveFiled].squeeze() # _, pred = kpn_model(feed.view(b, -1, patch_size + 2 * receptiveFiled, patch_size + 2 * receptiveFiled), # feed[:, 0:8, ...]) # res_kpn[channel, i:i + patch_size, j:j + patch_size] = pred[..., receptiveFiled:hs + receptiveFiled, # receptiveFiled:ws + receptiveFiled].squeeze() print('{}, {} OK!'.format(i, j)) res_kpn = res_kpn.cpu().clamp(0.0, 1.0) res_wkpn = res_wkpn.cpu().clamp(0.0, 1.0) res_gkpn = res_gkpn.cpu().clamp(0.0, 1.0) # trans(res_kpn).save('./eval_images_real/{}_pred_kpn.png'.format(filename), quality=100) # trans(res_wkpn).save('./eval_images_real/{}_pred_wkpn.png'.format(filename), quality=100) trans(res_gkpn).save( './eval_images_real/{}_pred_gkpn.png'.format(filename), quality=100) trans(imgs[0, 0, ...].cpu()).save( './eval_images_real/{}_noisy.png'.format(filename), quality=100) print('OK!')
def eval(args): color = True print('Eval Process......') burst_length = 8 checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # the path for saving eval images eval_dir = "eval_img" if not os.path.exists(eval_dir): os.mkdir(eval_dir) files = os.listdir(eval_dir) for f in files: os.remove(os.path.join(eval_dir, f)) # dataset and dataloader data_set = MultiLoader(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) data_loader = DataLoader(data_set, batch_size=1, shuffle=False, num_workers=args.num_workers) # model here if args.model_type == "attKPN": model = Att_KPN_noise(color=color, burst_length=burst_length, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN_noise(color=color, burst_length=burst_length, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN_noise(color=color, burst_length=burst_length, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return if args.cuda: model = model.cuda() if args.mGPU: model = nn.DataParallel(model) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=args.cuda) model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # data_loader = iter(data_loader) trans = transforms.ToPILImage() with torch.no_grad(): psnr = 0.0 ssim = 0.0 for i, (burst_noise, gt) in enumerate(data_loader): if i < 100: # data = next(data_loader) if args.cuda: burst_noise = burst_noise.cuda() gt = gt.cuda() pred_i, pred = model(burst_noise) if not color: psnr_t = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1)) ssim_t = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1)) psnr_noisy = calculate_psnr( burst_noise[:, 0, ...].unsqueeze(1), gt.unsqueeze(1)) else: psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnr_noisy = calculate_psnr(burst_noise[:, 0, ...], gt) psnr += psnr_t ssim += ssim_t pred = torch.clamp(pred, 0.0, 1.0) if args.cuda: pred = pred.cpu() gt = gt.cpu() burst_noise = burst_noise.cpu() trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join( eval_dir, '{}_noisy_{:.2f}dB.png'.format(i, psnr_noisy)), quality=100) trans(pred.squeeze()).save(os.path.join( eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)), quality=100) trans(gt.squeeze()).save(os.path.join(eval_dir, '{}_gt.png'.format(i)), quality=100) print('{}-th image is OK, with PSNR: {:.2f}dB, SSIM: {:.4f}'. format(i, psnr_t, ssim_t)) else: break
def test_multi(dir, image_size, args): num_workers = 1 batch_size = 1 color = True burst_length = 8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return model2 = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda') state_dict = ckpt['state_dict'] new_state_dict = OrderedDict() # if not args.mGPU: for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # model.load_state_dict(ckpt['state_dict']) model.load_state_dict(new_state_dict) checkpoint_dir = "checkpoints/" + "kpn" if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model2.load_state_dict(new_state_dict) else: model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.to(device) model2.to(device) model.eval() model2.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) for i in range(10): image_noise = load_data(dir, image_size, burst_length) begin = time.time() image_noise_batch = image_noise.to(device) print(image_noise_batch.size()) burst_size = image_noise_batch.size()[1] burst_noise = image_noise_batch.to(device) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...]) pred_i2, pred2 = model2(feedData, burst_noise[:, 0:burst_length, ...]) pred = pred.detach().cpu() pred2 = pred2.detach().cpu() print("Time : ", time.time() - begin) print(pred_i.size()) print(pred.size()) if args.save_img != '': # print(np.array(trans(mf8[0]))) plt.figure(figsize=(10, 3)) plt.subplot(1, 3, 1) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise attKPN") plt.subplot(1, 3, 2) plt.imshow(np.array(trans(pred2[0]))) plt.title("denoise KPN") # plt.show() plt.subplot(1, 3, 3) plt.imshow(np.array(trans(image_noise[0][0]))) plt.title("noise ") image_name = str(i) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0)
data_loader = DataLoader(dataset=data_set, batch_size=1, shuffle=True, num_workers=8) if 'kpn_5x5' in val_dict: kpn_5x5 = KPN(color=color, burst_length=8, blind_est=False, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False).cuda() kpn_5x5 = nn.DataParallel(kpn_5x5) state = load_checkpoint('../models/kpn_aug/checkpoint', best_or_latest='best') kpn_5x5.load_state_dict(state['state_dict']) print('KPN with 5x5-sized kernel is loaded for iteration {}!'.format( state['global_iter'])) psnr_kpn_5x5, ssim_kpn_5x5 = [], [] if 'kpn_7x7' in val_dict: kpn_7x7 = KPN(color=color, burst_length=8, blind_est=False, kernel_size=[21], sep_conv=True, channel_att=False, spatial_att=False).cuda() kpn_7x7 = nn.DataParallel(kpn_7x7) state = load_checkpoint('../models/kpn_aug_15x15/checkpoint',
def test_multi(args): color = True burst_length = args.burst_length device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attKPN_Wave": model = Att_KPN_Wavelet_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) # else: # model.load_state_dict(ckpt['state_dict']) model.to(device) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['ValidationNoisyBlocksSrgb'] all_clean_imgs = scipy.io.loadmat(args.gt)['ValidationGtBlocksSrgb'] i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape psnrs = [] ssims = [] for i_img in range(i_imgs): for i_block in range(i_blocks): image_noise = transforms.ToTensor()(Image.fromarray( all_noisy_imgs[i_img][i_block])) image_noise = transforms.ToTensor()(Image.fromarray( all_noisy_imgs[i_img][i_block])) image_noise, image_noise_hr = load_data(image_noise, burst_length) image_noise_hr = image_noise_hr.to(device) # begin = time.time() image_noise = image_noise.to(device) # print(image_noise_batch.size()) # burst_size = image_noise.size()[1] # print(burst_noise.size()) # print(image_noise_hr.size()) if color: b, N, c, h, w = image_noise.size() feedData = image_noise.view(b, -1, h, w) else: feedData = image_noise # print(feedData.size()) pred_i, pred = model(feedData, image_noise[:, 0:burst_length, ...], image_noise_hr) del pred_i pred = pred.detach().cpu() # print("Time : ", time.time()-begin) gt = transforms.ToTensor()(Image.fromarray( all_clean_imgs[i_img][i_block])) gt = gt.unsqueeze(0) # print(pred_i.size()) # print(pred[0].size()) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnrs.append(psnr_t) ssims.append(ssim_t) print(i_img, " ", i_block, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = str(i_img) + "_" + str(i_block) plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) """ if args.save_img: # print(np.array(trans(mf8[0]))) plt.figure(figsize=(30, 9)) plt.subplot(1,3,1) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise DGF "+args.model_type, fontsize=26) plt.subplot(1,3,2) plt.imshow(np.array(trans(gt[0]))) plt.title("gt ", fontsize=26) plt.subplot(1,3,3) plt.imshow(np.array(trans(image_noise_hr[0]))) plt.title("noise ", fontsize=26) plt.axis("off") plt.suptitle(str(i)+" UP : PSNR : "+ str(psnr_t)+" : SSIM : "+ str(ssim_t), fontsize=26) plt.savefig("checkpoints/22_DGF_" + args.checkpoint+str(i)+'.png',pad_inches=0) """ print(" AVG : PSNR : " + str(np.mean(psnrs)) + " : SSIM : " + str(np.mean(ssims)))
def test_multi(args): color = True burst_length = args.burst_length device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attKPN_Wave": model = Att_KPN_Wavelet_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return checkpoint_dir = args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) # else: # model.load_state_dict(ckpt['state_dict']) model.to(device) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['BenchmarkNoisyBlocksSrgb'] mat_re = np.zeros_like(all_noisy_imgs) # all_clean_imgs = scipy.io.loadmat(args.gt)['siddplus_valid_gt_srgb'] i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape psnrs = [] ssims = [] for i_img in range(i_imgs): for i_block in range(i_blocks): image_noise = transforms.ToTensor()(Image.fromarray( all_noisy_imgs[i_img][i_block])) image_noise, image_noise_hr = load_data(image_noise, burst_length) image_noise_hr = image_noise_hr.to(device) # begin = time.time() image_noise_batch = image_noise.to(device) # print(image_noise_batch.size()) burst_size = image_noise_batch.size()[1] burst_noise = image_noise_batch.to(device) # print(burst_noise.size()) # print(image_noise_hr.size()) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...], image_noise_hr) # del pred_i pred = pred.detach().cpu() mat_re[i_img][i_block] = np.array(trans(pred[0])) return mat_re
def train(config, num_workers, num_threads, cuda, restart_train, mGPU): # torch.set_num_threads(num_threads) train_config = config['training'] arch_config = config['architecture'] batch_size = train_config['batch_size'] lr = train_config['learning_rate'] weight_decay = train_config['weight_decay'] decay_step = train_config['decay_steps'] lr_decay = train_config['lr_decay'] n_epoch = train_config['num_epochs'] use_cache = train_config['use_cache'] print('Configs:', config) # checkpoint path checkpoint_dir = train_config['checkpoint_dir'] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # logs path logs_dir = train_config['logs_dir'] if not os.path.exists(logs_dir): os.makedirs(logs_dir) shutil.rmtree(logs_dir) log_writer = SummaryWriter(logs_dir) # dataset and dataloader data_set = TrainDataSet(train_config['dataset_configs'], img_format='.bmp', degamma=True, color=False, blind=arch_config['blind_est']) data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) dataset_config = read_config(train_config['dataset_configs'], _configspec_path())['dataset_configs'] # model here model = KPN(color=False, burst_length=dataset_config['burst_length'], blind_est=arch_config['blind_est'], kernel_size=list(map(int, arch_config['kernel_size'].split())), sep_conv=arch_config['sep_conv'], channel_att=arch_config['channel_att'], spatial_att=arch_config['spatial_att'], upMode=arch_config['upMode'], core_bias=arch_config['core_bias']) if cuda: model = model.cuda() if mGPU: model = nn.DataParallel(model) model.train() # loss function here loss_func = LossFunc(coeff_basic=1.0, coeff_anneal=1.0, gradient_L1=True, alpha=arch_config['alpha'], beta=arch_config['beta']) # Optimizer here if train_config['optimizer'] == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr) elif train_config['optimizer'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) else: raise ValueError( "Optimizer must be 'sgd' or 'adam', but received {}.".format( train_config['optimizer'])) optimizer.zero_grad() # learning rate scheduler here scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=lr_decay) average_loss = MovingAverage(train_config['save_freq']) if not restart_train: try: checkpoint = load_checkpoint(checkpoint_dir, 'best') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['lr_scheduler']) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') else: start_epoch = 0 global_step = 0 best_loss = np.inf if os.path.exists(checkpoint_dir): pass # files = os.listdir(checkpoint_dir) # for f in files: # os.remove(os.path.join(checkpoint_dir, f)) else: os.mkdir(checkpoint_dir) print('=> training') burst_length = dataset_config['burst_length'] data_length = burst_length if arch_config['blind_est'] else burst_length + 1 patch_size = dataset_config['patch_size'] for epoch in range(start_epoch, n_epoch): epoch_start_time = time.time() # decay the learning rate lr_cur = [param['lr'] for param in optimizer.param_groups] if lr_cur[0] > 5e-6: scheduler.step() else: for param in optimizer.param_groups: param['lr'] = 5e-6 print( '=' * 20, 'lr={}'.format([param['lr'] for param in optimizer.param_groups]), '=' * 20) t1 = time.time() for step, (burst_noise, gt, white_level) in enumerate(data_loader): if cuda: burst_noise = burst_noise.cuda() gt = gt.cuda() # print('white_level', white_level, white_level.size()) # pred_i, pred = model(burst_noise, burst_noise[:, 0:burst_length, ...], white_level) # loss_basic, loss_anneal = loss_func(sRGBGamma(pred_i), sRGBGamma(pred), sRGBGamma(gt), global_step) loss = loss_basic + loss_anneal # backward optimizer.zero_grad() loss.backward() optimizer.step() # update the average loss average_loss.update(loss) # calculate PSNR psnr = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1)) ssim = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1)) # add scalars to tensorboardX log_writer.add_scalar('loss_basic', loss_basic, global_step) log_writer.add_scalar('loss_anneal', loss_anneal, global_step) log_writer.add_scalar('loss_total', loss, global_step) log_writer.add_scalar('psnr', psnr, global_step) log_writer.add_scalar('ssim', ssim, global_step) # print print( '{:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t| loss_anneal: {:.4f}\t|' ' loss: {:.4f}\t| PSNR: {:.2f}dB\t| SSIM: {:.4f}\t| time:{:.2f} seconds.' .format(global_step, epoch, step, loss_basic, loss_anneal, loss, psnr, ssim, time.time() - t1)) t1 = time.time() # global_step global_step += 1 if global_step % train_config['save_freq'] == 0: if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'lr_scheduler': scheduler.state_dict() } save_checkpoint(save_dict, is_best, checkpoint_dir, global_step, max_keep=train_config['ckpt_to_keep']) print('Epoch {} is finished, time elapsed {:.2f} seconds.'.format( epoch, time.time() - epoch_start_time))
def test_multi(args): color = True burst_length = args.burst_length device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attKPN_Wave": model = Att_KPN_Wavelet_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) # else: # model.load_state_dict(ckpt['state_dict']) model.to(device) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) noisy_path = sorted(glob.glob(args.noise_dir + "/*.png")) clean_path = [i.replace("noisy", "clean") for i in noisy_path] upscale_factor = int(math.sqrt(burst_length)) for i in range(len(noisy_path)): image_noise, image_noise_hr = load_data(noisy_path[i], burst_length) image_noise_hr = image_noise_hr.to(device) # begin = time.time() image_noise_batch = image_noise.to(device) # print(image_noise_batch.size()) burst_size = image_noise_batch.size()[1] burst_noise = image_noise_batch.to(device) # print(burst_noise.size()) # print(image_noise_hr.size()) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...], image_noise_hr) # del pred_i pred_i = pred_i.detach().cpu() print(pred_i.size()) pred_full = pixel_shuffle(pred_i, upscale_factor) pred_full = pred_full print(pred_full.size()) pred = pred.detach().cpu() # print("Time : ", time.time()-begin) gt = transforms.ToTensor()(Image.open(clean_path[i]).convert('RGB')) gt = gt.unsqueeze(0) # print(pred_i.size()) # print(pred[0].size()) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) print(i, " pixel_shuffle UP : PSNR : ", str(calculate_psnr(pred_full, gt)), " : SSIM : ", str(calculate_ssim(pred_full, gt))) print(i, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = noisy_path[i].split("/")[-1].split(".")[0] plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) """
# from convert.converter import torch2onnx, onnx2keras, keras2tflite # from denoiser.networks.denoising_rgb import DenoiseNet # from denoiser.utils import load_checkpoint from model.MIRNet import MIRNet from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png' model_path = '../../denoiser/pretrained_models/denoising/sidd_rgb.pth' output_path = 'models/denoiser_rgb.onnx' input_node_names = ['input_image'] output_nodel_names = ['output_image'] # torch_model = DenoiseNet() # load_checkpoint(torch_model, model_path, 'cpu') checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest') state_dict = checkpoint['state_dict'] torch_model = MIRNet() print(torch_model) exit(0) torch_model.load_state_dict(state_dict) torch_model.eval() img = imageio.imread(img_path) img = img[0:256, 0:256, :] print(img.shape) img = np.asarray(img, dtype=np.float32) / 255. img_tensor = torch.from_numpy(img) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) print('Test forward pass')
def test_multi(args): color = True burst_length = 8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_type == "attKPN": model = Att_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=True, spatial_att=True, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return # model2 = KPN( # color=color, # burst_length=burst_length, # blind_est=True, # kernel_size=[5], # sep_conv=False, # channel_att=False, # spatial_att=False, # upMode="bilinear", # core_bias=False # ) checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda', best_or_latest=args.load_type) state_dict = ckpt['state_dict'] # if not args.cuda: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) # else: # model.load_state_dict(ckpt['state_dict']) ############################################# # checkpoint_dir = "checkpoints/" + "kpn" # if not os.path.exists(checkpoint_dir) or len(os.listdir(checkpoint_dir)) == 0: # print('There is no any checkpoint file in path:{}'.format(checkpoint_dir)) # # load trained model # ckpt = load_checkpoint(checkpoint_dir,cuda=device=='cuda') # state_dict = ckpt['state_dict'] # new_state_dict = OrderedDict() # if not args.cuda: # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # # model.load_state_dict(ckpt['state_dict']) # model2.load_state_dict(new_state_dict) ########################################### print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.to(device) model.eval() # model2.eval() # model= save_dict['state_dict'] trans = transforms.ToPILImage() torch.manual_seed(0) noisy_path = sorted(glob.glob(args.noise_dir + "/*.png")) clean_path = [i.replace("noisy", "clean") for i in noisy_path] for i in range(len(noisy_path)): image_noise = load_data(noisy_path[i], burst_length) begin = time.time() image_noise_batch = image_noise.to(device) # print(image_noise.size()) # print(image_noise_batch.size()) burst_noise = image_noise_batch.to(device) if color: b, N, c, h, w = burst_noise.size() feedData = burst_noise.view(b, -1, h, w) else: feedData = burst_noise # print(feedData.size()) pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...]) del pred_i # pred_i2, pred2 = model2(feedData, burst_noise[:, 0:burst_length, ...]) # print("Time : ", time.time()-begin) pred = pred.detach().cpu() gt = transforms.ToTensor()(Image.open(clean_path[i]).convert('RGB')) # print(pred_i.size()) # print(pred.size()) # print(gt.size()) gt = gt.unsqueeze(0) _, _, h_hr, w_hr = gt.size() _, _, h_lr, w_lr = pred.size() gt_down = F.interpolate(gt, (h_lr, w_lr), mode='bilinear', align_corners=True) pred_up = F.interpolate(pred, (h_hr, w_hr), mode='bilinear', align_corners=True) # print("After interpolate") # print(pred_up.size()) # print(gt_down.size()) psnr_t_up = calculate_psnr(pred_up, gt) ssim_t_up = calculate_ssim(pred_up, gt) psnr_t_down = calculate_psnr(pred, gt_down) ssim_t_down = calculate_ssim(pred, gt_down) print(i, " UP : PSNR : ", str(psnr_t_up), " : SSIM : ", str(ssim_t_up), " : DOWN : PSNR : ", str(psnr_t_down), " : SSIM : ", str(ssim_t_down)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred_up[0]))) plt.title("denoise KPN split " + args.model_type, fontsize=25) image_name = noisy_path[i].split("/")[-1].split(".")[0] plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t_up) + " : SSIM : " + str(ssim_t_up), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) # print(np.array(trans(mf8[0]))) """
def eval(args): color = args.color print('Eval Process......') burst_length = args.burst_length # print(args.checkpoint) checkpoint_dir = "checkpoints/" + args.checkpoint if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # the path for saving eval images eval_dir = "eval_img" if not os.path.exists(eval_dir): os.mkdir(eval_dir) # dataset and dataloader data_set = SingleLoader_DGF(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size, burst_length=burst_length) data_loader = DataLoader(data_set, batch_size=1, shuffle=False, num_workers=args.num_workers) # model here if args.model_type == "attKPN": model = Att_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "attWKPN": model = Att_Weight_KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) elif args.model_type == "KPN": model = KPN_DGF(color=color, burst_length=burst_length, blind_est=True, kernel_size=[5], sep_conv=False, channel_att=False, spatial_att=False, upMode="bilinear", core_bias=False) else: print(" Model type not valid") return if args.cuda: model = model.cuda() if args.mGPU: model = nn.DataParallel(model) # load trained model ckpt = load_checkpoint(checkpoint_dir, cuda=args.cuda, best_or_latest=args.load_type) state_dict = ckpt['state_dict'] if not args.mGPU: new_state_dict = OrderedDict() if not args.cuda: for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # data_loader = iter(data_loader) trans = transforms.ToPILImage() with torch.no_grad(): psnr = 0.0 ssim = 0.0 torch.manual_seed(0) for i, (image_noise_hr, image_noise_lr, image_gt_hr) in enumerate(data_loader): if i < 100: # data = next(data_loader) if args.cuda: burst_noise = image_noise_lr.cuda() gt = image_gt_hr.cuda() else: burst_noise = image_noise_lr gt = image_gt_hr if color: b, N, c, h, w = image_noise_lr.size() feedData = image_noise_lr.view(b, -1, h, w) else: feedData = image_noise_lr pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...], image_noise_hr) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) print("PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) pred = torch.clamp(pred, 0.0, 1.0) if args.cuda: pred = pred.cpu() gt = gt.cpu() burst_noise = burst_noise.cpu() if args.save_img: trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join( eval_dir, '{}_noisy.png'.format(i)), quality=100) trans(pred.squeeze()).save(os.path.join( eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)), quality=100) trans(gt.squeeze()).save(os.path.join( eval_dir, '{}_gt.png'.format(i)), quality=100) else: break
def train(config, restart_training, num_workers, num_threads): torch.set_num_threads(num_threads) print("Using {} CPU threads".format(torch.get_num_threads())) # TODO: de-hardcode this one. N_CHANNEL = 3 train_config = config["training"] batch_size = train_config["batch_size"] lr = train_config["learning_rate"] w_decay = train_config["weight_decay"] step_size = train_config["decay_steps"] gamma = train_config["lr_decay"] betas = (train_config["beta1"], train_config["beta2"]) n_epochs = train_config["num_epochs"] dataset_configs = train_config["dataset_configs"] use_cache = train_config["use_cache"] print("Configs:", config) # create dir for model checkpoint_dir = train_config["checkpoint_dir"] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) logger = Logger(train_config["logs_dir"]) use_gpu = torch.cuda.is_available() num_gpu = list(range(torch.cuda.device_count())) print("Using On the fly TRAIN datasets") train_data = OnTheFlyDataset(train_config["dataset_configs"], im_size=(train_config["image_width"], train_config["image_height"]), use_cache=use_cache) train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers) model = get_model(config["architecture"]) l1_loss = nn.SmoothL1Loss() if use_gpu: ts = time.time() model = model.cuda() model = nn.DataParallel(model, device_ids=num_gpu) print("Finish cuda loading, time elapsed {}".format(time.time() - ts)) # for sanity check all_parameters = [ p for n, p in model.named_parameters() if p.requires_grad ] if train_config["optimizer"] == "adam": print("Using Adam.") optimizer = optim.Adam([ { 'params': all_parameters }, ], lr=lr, betas=betas, weight_decay=w_decay, amsgrad=True) elif train_config["optimizer"] == "sgd": print("Using SGD.") optimizer = optim.SGD([ { 'params': all_parameters }, ], lr=lr, momentum=betas[0], weight_decay=w_decay) else: raise ValueError( "Optimizer must be 'sgd' or 'adam', received '{}'".format( train_config["optimizer"])) scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) n_global_iter = 0 average_loss = MovingAverage(train_config["n_loss_average"]) best_loss = np.inf checkpoint_loaded = False if not restart_training: try: checkpoint = load_checkpoint(checkpoint_dir, 'best') start_epoch = checkpoint['epoch'] n_global_iter = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) checkpoint_loaded = True print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) except: start_epoch = 0 n_global_iter = 0 best_loss = np.inf print("=> load checkpoint failed, training from scratch") else: start_epoch = 0 print("=> training from scratch") for epoch in range(start_epoch, n_epochs): scheduler.step() ts = time.time() t4 = None t_generate_data = [] t_train_disc = [] t_train_gen = [] t_vis = [] t_save = [] for iter, batch in enumerate(train_loader): if t4 is not None: # collect information and print out average time. t0_old = t0 t0 = time.time() if t4 is not None: t_generate_data.append(t0 - t4) t_train_disc.append(t1 - t0_old) t_train_gen.append(t2 - t1) t_vis.append(t3 - t2) t_save.append(t4 - t3) N_report = 100 N_print = 1000 if (iter % N_report) == 0: t_generate_data = np.mean(t_generate_data) t_train_disc = np.mean(t_train_disc) t_train_gen = np.mean(t_train_gen) t_vis = np.mean(t_vis) t_save = np.mean(t_save) t_total = t_generate_data + t_train_disc + t_train_gen + t_vis + t_save if (iter % N_print) == 0: print("t_generate_data: {:0.4g} s ({:0.4g}%)".format( t_generate_data, t_generate_data / t_total * 100)) print("t_train_disc: {:0.4g} s ({:0.4g}%)".format( t_train_disc, t_train_disc / t_total * 100)) print("t_train_gen: {:0.4g} s ({:0.4g}%)".format( t_train_gen, t_train_gen / t_total * 100)) print("t_vis: {:0.4g} s ({:0.4g}%)".format( t_vis, t_vis / t_total * 100)) print("t_save: {:0.4g} s ({:0.4g}%)".format( t_save, t_save / t_total * 100)) logger.scalar_summary('Steps per sec', 1.0 / t_total, n_global_iter) t_generate_data = [] t_train_disc = [] t_train_gen = [] t_vis = [] t_save = [] should_vis = ((n_global_iter + 1) % train_config["vis_freq"]) == 0 if use_gpu: degraded_img = batch['degraded_img'].cuda() target_img = batch['original_img'].cuda() else: degraded_img = batch['degraded_img'] target_img = batch['original_img'] t1 = time.time() optimizer.zero_grad() # Run the input through the model. output_img = model(degraded_img) loss = l1_loss(output_img, target_img) loss.backward() optimizer.step() logger.scalar_summary('Loss', loss.data[0], n_global_iter) psnr = calculate_psnr(output_img, target_img) logger.scalar_summary('Train PSNR', psnr, n_global_iter) average_loss.update(loss.data[0]) t2 = time.time() if iter % 10 == 0: print("epoch{}, iter{}, loss: {}" \ .format(epoch, iter, loss.data[0])) n_global_iter += 1 if should_vis: exp = batch['vis_exposure'] if 'vis_exposure' in batch else None img = create_vis(degraded_img[:, :3, ...], target_img, output_img, exp) logger.image_summary("Train Images", img, n_global_iter) t3 = time.time() if (n_global_iter % train_config["save_freq"]) == 0: if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': n_global_iter, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), } save_checkpoint(save_dict, is_best, checkpoint_dir, n_global_iter) t4 = time.time() print("Finish epoch {}, time elapsed {}" \ .format(epoch, time.time() - ts))