def test(args):
    if args.model_type == "MIR":
        model = MIRNet()
    elif args.model_type == "KPN":
        model = MIRNet_kpn()
    else:
        print(" Model type not valid")
        return
    # summary(model,[[3,128,128],[0]])
    # exit()
    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(args.noise_dir)['siddplus_valid_noisy_srgb']
    all_clean_imgs = scipy.io.loadmat(args.gt)['siddplus_valid_gt_srgb']
    # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png"))
    # clean_path = [ i.replace("noisy","clean") for i in noisy_path]
    i_imgs, _,_,_ = all_noisy_imgs.shape
    psnrs = []
    ssims = []
    # print(noisy_path)
    for i_img in range(i_imgs):
        noise = transforms.ToTensor()(Image.fromarray(all_noisy_imgs[i_img])).unsqueeze(0)
        noise = noise.to(device)
        begin = time.time()
        pred = model(noise)
        pred = pred.detach().cpu()
        gt = transforms.ToTensor()((Image.fromarray(all_clean_imgs[i_img])))
        gt = gt.unsqueeze(0)
        psnr_t = calculate_psnr(pred, gt)
        ssim_t = calculate_ssim(pred, gt)
        psnrs.append(psnr_t)
        ssims.append(ssim_t)
        print(i_img, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ", str(ssim_t))
        if args.save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
            image_name = str(i_img) + "_"
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) + " :  SSIM : " + str(ssim_t), fontsize=25)
            plt.savefig(os.path.join(args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0)
    print("   AVG   :  PSNR : "+ str(np.mean(psnrs))+" :  SSIM : "+ str(np.mean(ssims)))
Example #2
0
def test(args):
    model = MIRNet()

    # summary(model,[[3,128,128],[0]])
    # exit()
    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png"))
    test_img = glob.glob(
        "/vinai/tampm2/cityscapes_noise/gtFine/val/*/*_gtFine_color.png")
    if not os.path.exists(args.save_img):
        os.makedirs(args.save_img)
    for i in range(len(test_img)):
        # print(noisy_path[i])
        img_path = os.path.join(
            args.noise_dir,
            test_img[i].split("/")[-1].replace("_gtFine_color",
                                               "_leftImg8bit"))
        print(img_path)
        image_noise = load_data(img_path)
        image_noises1 = split_tensor(image_noise)
        preds1 = []
        for image1 in image_noises1:
            image1 = image1.unsqueeze(0)
            image_noises2 = split_tensor(image1)
            preds2 = []
            for image2 in image_noises2:
                image2 = image2.unsqueeze(0)
                image_noises3 = split_tensor(image2)
                preds3 = []
                for image3 in image_noises3:
                    image3 = image3.unsqueeze(0).to(device)
                    print(image3.size())
                    pred3 = model(image3)
                    pred3 = pred3.detach().cpu().squeeze(0)
                    preds3.append(pred3)

                pred2 = merge_tensor(preds3)
                preds2.append(pred2)
            pred1 = merge_tensor(preds2)
            preds1.append(pred1)
        pred = merge_tensor(preds1)
        pred = trans(pred)
        name_img = img_path.split("/")[-1].split(".")[0]
        pred.save(args.save_img + "/" + name_img + ".png")
Example #3
0
def test(args):
    model = MIRNet()

    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(
        args.noise_dir)['BenchmarkNoisyBlocksSrgb']
    mat_re = np.zeros_like(all_noisy_imgs)
    i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape

    for i_img in range(i_imgs):
        for i_block in range(i_blocks):
            noise = transforms.ToTensor()(Image.fromarray(
                all_noisy_imgs[i_img][i_block])).unsqueeze(0)
            noise = noise.to(device)
            begin = time.time()
            pred = model(noise)
            pred = pred.detach().cpu()
            mat_re[i_img][i_block] = np.array(trans(pred[0]))

    return mat_re
Example #4
0
def convert_torch_to_onnx():
    img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png'
    model_path = '../../denoiser/pretrained_models/denoising/sidd_rgb.pth'
    output_path = 'models/denoiser_rgb.onnx'

    input_node_names = ['input_image']
    output_nodel_names = ['output_image']

    # torch_model = DenoiseNet()
    # load_checkpoint(torch_model, model_path, 'cpu')
    checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest')
    state_dict = checkpoint['state_dict']
    torch_model = MIRNet()
    torch_model.load_state_dict(state_dict)

    img = imageio.imread(img_path)
    # img = img[0:256,0:256,:]
    print(img.shape)
    img = np.asarray(img, dtype=np.float32) / 255.

    img_tensor = torch.from_numpy(img)
    img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)

    print('Test forward pass')
    s = time.time()
    with torch.no_grad():
        enhanced_img_tensor = torch_model(img_tensor)
        enhanced_img_tensor = torch.clamp(enhanced_img_tensor, 0, 1)
        enhanced_img = (enhanced_img_tensor.permute(0, 2, 3, 1).squeeze(0)
                        .cpu().detach().numpy())
        enhanced_img = np.clip(enhanced_img * 255, 0, 255).astype('uint8')
        imageio.imwrite('../img/denoised.jpg', enhanced_img)
    print('- Time: ', time.time() - s)

    print('Export to onnx format')
    s = time.time()
    torch2onnx(torch_model, img_tensor, output_path, input_node_names,
               output_nodel_names, keep_initializers=False,
               verify_after_export=True)
    print('- Time: ', time.time() - s)
Example #5
0
def test(args):
    model = MIRNet()
    save_img = args.save_img
    checkpoint_dir = "checkpoints/mir"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    noisy_path = sorted(glob.glob(args.noise_dir + "/2_*.png"))
    clean_path = [i.replace("noisy", "clean") for i in noisy_path]
    print(noisy_path)
    for i in range(len(noisy_path)):
        noise = transforms.ToTensor()(Image.open(
            noisy_path[i]).convert('RGB'))[:, 0:args.image_size,
                                           0:args.image_size].unsqueeze(0)
        noise = noise.to(device)
        begin = time.time()
        print(noise.size())
        pred = model(noise)
        pred = pred.detach().cpu()
        gt = transforms.ToTensor()(Image.open(
            clean_path[i]).convert('RGB'))[:, 0:args.image_size,
                                           0:args.image_size]
        gt = gt.unsqueeze(0)
        psnr_t = calculate_psnr(pred, gt)
        ssim_t = calculate_ssim(pred, gt)
        print(i, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ", str(ssim_t))
        if save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
            image_name = noisy_path[i].split("/")[-1].split(".")[0]
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) +
                         " :  SSIM : " + str(ssim_t),
                         fontsize=25)
            plt.savefig(os.path.join(
                args.save_img, image_name + "_" + args.checkpoint + '.png'),
                        pad_inches=0)
Example #6
0
def test(args):
    if args.model_type == "MIR":
        model = MIRNet(in_channels=args.n_colors,
                       out_channels=args.out_channels)
    elif args.model_type == "KPN":
        model = MIRNet_kpn(in_channels=args.n_colors,
                           out_channels=args.out_channels)
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(
        args.noise_dir)['BenchmarkNoisyBlocksRaw']
    mat_re = np.zeros_like(all_noisy_imgs)
    i_imgs, i_blocks, _, _ = all_noisy_imgs.shape

    for i_img in range(i_imgs):
        for i_block in range(i_blocks):
            noise = transforms.ToTensor()(pack_raw(
                all_noisy_imgs[i_img][i_block])).unsqueeze(0)
            noise = noise.to(device)
            begin = time.time()
            pred = model(noise)
            pred = np.array(pred.detach().cpu()[0]).transpose(1, 2, 0)
            pred = unpack_raw(pred)
            mat_re[i_img][i_block] = np.array(pred)

    return mat_re
Example #7
0
# from denoiser.utils import load_checkpoint
from model.MIRNet import MIRNet
from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint

img_path = '/home/dell/Downloads/FullTest/noisy/2_1.png'
model_path = '../../denoiser/pretrained_models/denoising/sidd_rgb.pth'
output_path = 'models/denoiser_rgb.onnx'

input_node_names = ['input_image']
output_nodel_names = ['output_image']

# torch_model = DenoiseNet()
# load_checkpoint(torch_model, model_path, 'cpu')
checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest')
state_dict = checkpoint['state_dict']
torch_model = MIRNet()
print(torch_model)
exit(0)
torch_model.load_state_dict(state_dict)
torch_model.eval()
img = imageio.imread(img_path)
img = img[0:256, 0:256, :]
print(img.shape)
img = np.asarray(img, dtype=np.float32) / 255.

img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)

print('Test forward pass')
s = time.time()
# with torch.no_grad():
Example #8
0
def train(args):
    torch.set_num_threads(args.num_workers)
    torch.manual_seed(0)
    if args.data_type == 'rgb':
        data_set = SingleLoader(noise_dir=args.noise_dir,
                                gt_dir=args.gt_dir,
                                image_size=args.image_size)
    elif args.data_type == 'raw':
        data_set = SingleLoader_raw(noise_dir=args.noise_dir,
                                    gt_dir=args.gt_dir,
                                    image_size=args.image_size)
    else:
        print("Data type not valid")
        exit()
    data_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_func = losses.CharbonnierLoss().to(device)
    # loss_func = losses.AlginLoss().to(device)
    adaptive = robust_loss.adaptive.AdaptiveLossFunction(
        num_dims=3 * args.image_size**2, float_dtype=np.float32, device=device)
    checkpoint_dir = args.checkpoint
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if args.model_type == "MIR":
        model = MIRNet(in_channels=args.n_colors,
                       out_channels=args.out_channels).to(device)
    elif args.model_type == "KPN":
        model = MIRNet_kpn(in_channels=args.n_colors,
                           out_channels=args.out_channels).to(device)
    else:
        print(" Model type not valid")
        return
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    optimizer.zero_grad()
    average_loss = MovingAverage(args.save_every)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               [2, 4, 6, 8, 10, 12, 14, 16],
                                               0.8)
    if args.restart:
        start_epoch = 0
        global_step = 0
        best_loss = np.inf
        print('=> no checkpoint file to be loaded.')
    else:
        try:
            checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda',
                                         'latest')
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_iter']
            best_loss = checkpoint['best_loss']
            state_dict = checkpoint['state_dict']
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = "model."+ k  # remove `module.`
            #     new_state_dict[name] = v
            model.load_state_dict(state_dict)
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint (epoch {}, global_step {})'.format(
                start_epoch, global_step))
        except:
            start_epoch = 0
            global_step = 0
            best_loss = np.inf
            print('=> no checkpoint file to be loaded.')
    eps = 1e-4
    for epoch in range(start_epoch, args.epoch):
        for step, (noise, gt) in enumerate(data_loader):
            noise = noise.to(device)
            gt = gt.to(device)
            pred = model(noise)
            # print(pred.size())
            loss = loss_func(pred, gt)
            # bs = gt.size()[0]
            # diff = noise - gt
            # loss = torch.sqrt((diff * diff) + (eps * eps))
            # loss = loss.view(bs,-1)
            # loss = adaptive.lossfun(loss)
            # loss = torch.mean(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            average_loss.update(loss)
            if global_step % args.save_every == 0:
                print(len(average_loss._cache))
                if average_loss.get_value() < best_loss:
                    is_best = True
                    best_loss = average_loss.get_value()
                else:
                    is_best = False

                save_dict = {
                    'epoch': epoch,
                    'global_iter': global_step,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }
                save_checkpoint(save_dict, is_best, checkpoint_dir,
                                global_step)
            if global_step % args.loss_every == 0:
                print(global_step, "PSNR  : ", calculate_psnr(pred, gt))
                print(average_loss.get_value())
            global_step += 1
        print('Epoch {} is finished.'.format(epoch))
        scheduler.step()
Example #9
0
def test_multi(args):
    model = MIRNet()

    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:

    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = "model." + k  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)

    mat_folders = glob.glob(os.path.join(args.noise_dir, '*'))

    trans = transforms.ToPILImage()
    if not os.path.exists(args.save_img):
        os.makedirs(args.save_img)
    for mat_folder in mat_folders:
        save_mat_folder = os.path.join(args.save_img,
                                       mat_folder.split("/")[-1])
        for mat_file in glob.glob(os.path.join(mat_folder, '*')):
            mat_contents = sio.loadmat(mat_file)
            sub_image, y_gb, x_gb = mat_contents['image'], mat_contents[
                'y_gb'][0][0], mat_contents['x_gb'][0][0]
            image_noise = transforms.ToTensor()(
                Image.fromarray(sub_image)).unsqueeze(0)
            image_noise_batch = image_noise.to(device)

            pred = model(image_noise_batch)
            pred = np.array(trans(pred[0].cpu()))
            if args.save_img != '':
                if not os.path.exists(save_mat_folder):
                    os.makedirs(save_mat_folder)
                # mat_contents['image'] = pred
                # print(mat_contents)
                print("save : ",
                      os.path.join(save_mat_folder,
                                   mat_file.split("/")[-1]))
                data = {
                    "image": pred,
                    "y_gb": mat_contents['y_gb'][0][0],
                    "x_gb": mat_contents['x_gb'][0][0],
                    "y_lc": mat_contents['y_lc'][0][0],
                    "x_lc": mat_contents['x_lc'][0][0],
                    'size': mat_contents['size'][0][0],
                    "H": mat_contents['H'][0][0],
                    "W": mat_contents['W'][0][0]
                }
                # print(data)
                sio.savemat(
                    os.path.join(save_mat_folder,
                                 mat_file.split("/")[-1]), data)
Example #10
0
from model.MIRNet import MIRNet
import torch
import tensorflow as tf
import onnx
from onnx_tf.backend import prepare
import os
from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint

checkpoint = load_checkpoint("../checkpoints/mir/", False, 'latest')
state_dict = checkpoint['state_dict']
model = MIRNet()
model.load_state_dict(state_dict)
model.eval()

# Converting model to ONNX
print('===> Converting model to ONNX.')
try:
    for _ in model.modules():
        _.training = False

    sample_input = torch.randn(1, 3, 256, 256)

    input_nodes = ['input']
    output_nodes = ['output']

    torch.onnx.export(model,
                      args=sample_input,
                      f="model.onnx",
                      export_params=True,
                      input_names=input_nodes,
                      output_names=output_nodes,