def test(args): if args.model_type == "MIR": model = MIRNet() elif args.model_type == "KPN": model = MIRNet_kpn() else: print(" Model type not valid") return # summary(model,[[3,128,128],[0]]) # exit() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format(start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat(args.noise_dir)['siddplus_valid_noisy_srgb'] all_clean_imgs = scipy.io.loadmat(args.gt)['siddplus_valid_gt_srgb'] # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png")) # clean_path = [ i.replace("noisy","clean") for i in noisy_path] i_imgs, _,_,_ = all_noisy_imgs.shape psnrs = [] ssims = [] # print(noisy_path) for i_img in range(i_imgs): noise = transforms.ToTensor()(Image.fromarray(all_noisy_imgs[i_img])).unsqueeze(0) noise = noise.to(device) begin = time.time() pred = model(noise) pred = pred.detach().cpu() gt = transforms.ToTensor()((Image.fromarray(all_clean_imgs[i_img]))) gt = gt.unsqueeze(0) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) psnrs.append(psnr_t) ssims.append(ssim_t) print(i_img, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if args.save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = str(i_img) + "_" plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join(args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0) print(" AVG : PSNR : "+ str(np.mean(psnrs))+" : SSIM : "+ str(np.mean(ssims)))
def test(args): model = MIRNet() # summary(model,[[3,128,128],[0]]) # exit() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png")) test_img = glob.glob( "/vinai/tampm2/cityscapes_noise/gtFine/val/*/*_gtFine_color.png") if not os.path.exists(args.save_img): os.makedirs(args.save_img) for i in range(len(test_img)): # print(noisy_path[i]) img_path = os.path.join( args.noise_dir, test_img[i].split("/")[-1].replace("_gtFine_color", "_leftImg8bit")) print(img_path) image_noise = load_data(img_path) image_noises1 = split_tensor(image_noise) preds1 = [] for image1 in image_noises1: image1 = image1.unsqueeze(0) image_noises2 = split_tensor(image1) preds2 = [] for image2 in image_noises2: image2 = image2.unsqueeze(0) image_noises3 = split_tensor(image2) preds3 = [] for image3 in image_noises3: image3 = image3.unsqueeze(0).to(device) print(image3.size()) pred3 = model(image3) pred3 = pred3.detach().cpu().squeeze(0) preds3.append(pred3) pred2 = merge_tensor(preds3) preds2.append(pred2) pred1 = merge_tensor(preds2) preds1.append(pred1) pred = merge_tensor(preds1) pred = trans(pred) name_img = img_path.split("/")[-1].split(".")[0] pred.save(args.save_img + "/" + name_img + ".png")
def test(args): model = MIRNet() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['BenchmarkNoisyBlocksSrgb'] mat_re = np.zeros_like(all_noisy_imgs) i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape for i_img in range(i_imgs): for i_block in range(i_blocks): noise = transforms.ToTensor()(Image.fromarray( all_noisy_imgs[i_img][i_block])).unsqueeze(0) noise = noise.to(device) begin = time.time() pred = model(noise) pred = pred.detach().cpu() mat_re[i_img][i_block] = np.array(trans(pred[0])) return mat_re
def convert_torch_to_onnx(): img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png' model_path = '../../denoiser/pretrained_models/denoising/sidd_rgb.pth' output_path = 'models/denoiser_rgb.onnx' input_node_names = ['input_image'] output_nodel_names = ['output_image'] # torch_model = DenoiseNet() # load_checkpoint(torch_model, model_path, 'cpu') checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest') state_dict = checkpoint['state_dict'] torch_model = MIRNet() torch_model.load_state_dict(state_dict) img = imageio.imread(img_path) # img = img[0:256,0:256,:] print(img.shape) img = np.asarray(img, dtype=np.float32) / 255. img_tensor = torch.from_numpy(img) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) print('Test forward pass') s = time.time() with torch.no_grad(): enhanced_img_tensor = torch_model(img_tensor) enhanced_img_tensor = torch.clamp(enhanced_img_tensor, 0, 1) enhanced_img = (enhanced_img_tensor.permute(0, 2, 3, 1).squeeze(0) .cpu().detach().numpy()) enhanced_img = np.clip(enhanced_img * 255, 0, 255).astype('uint8') imageio.imwrite('../img/denoised.jpg', enhanced_img) print('- Time: ', time.time() - s) print('Export to onnx format') s = time.time() torch2onnx(torch_model, img_tensor, output_path, input_node_names, output_nodel_names, keep_initializers=False, verify_after_export=True) print('- Time: ', time.time() - s)
def test(args): model = MIRNet() save_img = args.save_img checkpoint_dir = "checkpoints/mir" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) noisy_path = sorted(glob.glob(args.noise_dir + "/2_*.png")) clean_path = [i.replace("noisy", "clean") for i in noisy_path] print(noisy_path) for i in range(len(noisy_path)): noise = transforms.ToTensor()(Image.open( noisy_path[i]).convert('RGB'))[:, 0:args.image_size, 0:args.image_size].unsqueeze(0) noise = noise.to(device) begin = time.time() print(noise.size()) pred = model(noise) pred = pred.detach().cpu() gt = transforms.ToTensor()(Image.open( clean_path[i]).convert('RGB'))[:, 0:args.image_size, 0:args.image_size] gt = gt.unsqueeze(0) psnr_t = calculate_psnr(pred, gt) ssim_t = calculate_ssim(pred, gt) print(i, " UP : PSNR : ", str(psnr_t), " : SSIM : ", str(ssim_t)) if save_img != '': if not os.path.exists(args.save_img): os.makedirs(args.save_img) plt.figure(figsize=(15, 15)) plt.imshow(np.array(trans(pred[0]))) plt.title("denoise KPN DGF " + args.model_type, fontsize=25) image_name = noisy_path[i].split("/")[-1].split(".")[0] plt.axis("off") plt.suptitle(image_name + " UP : PSNR : " + str(psnr_t) + " : SSIM : " + str(ssim_t), fontsize=25) plt.savefig(os.path.join( args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0)
def test(args): if args.model_type == "MIR": model = MIRNet(in_channels=args.n_colors, out_channels=args.out_channels) elif args.model_type == "KPN": model = MIRNet_kpn(in_channels=args.n_colors, out_channels=args.out_channels) else: print(" Model type not valid") return checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) all_noisy_imgs = scipy.io.loadmat( args.noise_dir)['BenchmarkNoisyBlocksRaw'] mat_re = np.zeros_like(all_noisy_imgs) i_imgs, i_blocks, _, _ = all_noisy_imgs.shape for i_img in range(i_imgs): for i_block in range(i_blocks): noise = transforms.ToTensor()(pack_raw( all_noisy_imgs[i_img][i_block])).unsqueeze(0) noise = noise.to(device) begin = time.time() pred = model(noise) pred = np.array(pred.detach().cpu()[0]).transpose(1, 2, 0) pred = unpack_raw(pred) mat_re[i_img][i_block] = np.array(pred) return mat_re
# from denoiser.utils import load_checkpoint from model.MIRNet import MIRNet from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png' model_path = '../../denoiser/pretrained_models/denoising/sidd_rgb.pth' output_path = 'models/denoiser_rgb.onnx' input_node_names = ['input_image'] output_nodel_names = ['output_image'] # torch_model = DenoiseNet() # load_checkpoint(torch_model, model_path, 'cpu') checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest') state_dict = checkpoint['state_dict'] torch_model = MIRNet() print(torch_model) exit(0) torch_model.load_state_dict(state_dict) torch_model.eval() img = imageio.imread(img_path) img = img[0:256, 0:256, :] print(img.shape) img = np.asarray(img, dtype=np.float32) / 255. img_tensor = torch.from_numpy(img) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) print('Test forward pass') s = time.time() # with torch.no_grad():
def train(args): torch.set_num_threads(args.num_workers) torch.manual_seed(0) if args.data_type == 'rgb': data_set = SingleLoader(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) elif args.data_type == 'raw': data_set = SingleLoader_raw(noise_dir=args.noise_dir, gt_dir=args.gt_dir, image_size=args.image_size) else: print("Data type not valid") exit() data_loader = DataLoader(data_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loss_func = losses.CharbonnierLoss().to(device) # loss_func = losses.AlginLoss().to(device) adaptive = robust_loss.adaptive.AdaptiveLossFunction( num_dims=3 * args.image_size**2, float_dtype=np.float32, device=device) checkpoint_dir = args.checkpoint if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if args.model_type == "MIR": model = MIRNet(in_channels=args.n_colors, out_channels=args.out_channels).to(device) elif args.model_type == "KPN": model = MIRNet_kpn(in_channels=args.n_colors, out_channels=args.out_channels).to(device) else: print(" Model type not valid") return optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer.zero_grad() average_loss = MovingAverage(args.save_every) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2, 4, 6, 8, 10, 12, 14, 16], 0.8) if args.restart: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') else: try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] state_dict = checkpoint['state_dict'] # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = "model."+ k # remove `module.` # new_state_dict[name] = v model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') eps = 1e-4 for epoch in range(start_epoch, args.epoch): for step, (noise, gt) in enumerate(data_loader): noise = noise.to(device) gt = gt.to(device) pred = model(noise) # print(pred.size()) loss = loss_func(pred, gt) # bs = gt.size()[0] # diff = noise - gt # loss = torch.sqrt((diff * diff) + (eps * eps)) # loss = loss.view(bs,-1) # loss = adaptive.lossfun(loss) # loss = torch.mean(loss) optimizer.zero_grad() loss.backward() optimizer.step() average_loss.update(loss) if global_step % args.save_every == 0: print(len(average_loss._cache)) if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), } save_checkpoint(save_dict, is_best, checkpoint_dir, global_step) if global_step % args.loss_every == 0: print(global_step, "PSNR : ", calculate_psnr(pred, gt)) print(average_loss.get_value()) global_step += 1 print('Epoch {} is finished.'.format(epoch)) scheduler.step()
def test_multi(args): model = MIRNet() checkpoint_dir = args.checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = "model." + k # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) # except: # print('=> no checkpoint file to be loaded.') # model.load_state_dict(state_dict) # exit(1) model.eval() model = model.to(device) trans = transforms.ToPILImage() torch.manual_seed(0) mat_folders = glob.glob(os.path.join(args.noise_dir, '*')) trans = transforms.ToPILImage() if not os.path.exists(args.save_img): os.makedirs(args.save_img) for mat_folder in mat_folders: save_mat_folder = os.path.join(args.save_img, mat_folder.split("/")[-1]) for mat_file in glob.glob(os.path.join(mat_folder, '*')): mat_contents = sio.loadmat(mat_file) sub_image, y_gb, x_gb = mat_contents['image'], mat_contents[ 'y_gb'][0][0], mat_contents['x_gb'][0][0] image_noise = transforms.ToTensor()( Image.fromarray(sub_image)).unsqueeze(0) image_noise_batch = image_noise.to(device) pred = model(image_noise_batch) pred = np.array(trans(pred[0].cpu())) if args.save_img != '': if not os.path.exists(save_mat_folder): os.makedirs(save_mat_folder) # mat_contents['image'] = pred # print(mat_contents) print("save : ", os.path.join(save_mat_folder, mat_file.split("/")[-1])) data = { "image": pred, "y_gb": mat_contents['y_gb'][0][0], "x_gb": mat_contents['x_gb'][0][0], "y_lc": mat_contents['y_lc'][0][0], "x_lc": mat_contents['x_lc'][0][0], 'size': mat_contents['size'][0][0], "H": mat_contents['H'][0][0], "W": mat_contents['W'][0][0] } # print(data) sio.savemat( os.path.join(save_mat_folder, mat_file.split("/")[-1]), data)
from model.MIRNet import MIRNet import torch import tensorflow as tf import onnx from onnx_tf.backend import prepare import os from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest') state_dict = checkpoint['state_dict'] model = MIRNet() model.load_state_dict(state_dict) model.eval() # Converting model to ONNX print('===> Converting model to ONNX.') try: for _ in model.modules(): _.training = False sample_input = torch.randn(1, 3, 256, 256) input_nodes = ['input'] output_nodes = ['output'] torch.onnx.export(model, args=sample_input, f="model.onnx", export_params=True, input_names=input_nodes, output_names=output_nodes,