Пример #1
0
def eval(args):
    color = args.color
    print('Eval Process......')
    burst_length = args.burst_length
    # print(args.checkpoint)
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # the path for saving eval images
    eval_dir = "eval_img"
    if not os.path.exists(eval_dir):
        os.mkdir(eval_dir)

    # dataset and dataloader
    data_set = SingleLoader_DGF(noise_dir=args.noise_dir,
                                gt_dir=args.gt_dir,
                                image_size=args.image_size,
                                burst_length=burst_length)
    data_loader = DataLoader(data_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.num_workers)

    # model here
    if args.model_type == "attKPN":
        model = Att_KPN_DGF(color=color,
                            burst_length=burst_length,
                            blind_est=True,
                            kernel_size=[5],
                            sep_conv=False,
                            channel_att=False,
                            spatial_att=False,
                            upMode="bilinear",
                            core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN_DGF(color=color,
                                   burst_length=burst_length,
                                   blind_est=True,
                                   kernel_size=[5],
                                   sep_conv=False,
                                   channel_att=False,
                                   spatial_att=False,
                                   upMode="bilinear",
                                   core_bias=False)
    elif args.model_type == "KPN":
        model = KPN_DGF(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=False,
                        spatial_att=False,
                        upMode="bilinear",
                        core_bias=False)
    else:
        print(" Model type not valid")
        return
    if args.cuda:
        model = model.cuda()

    if args.mGPU:
        model = nn.DataParallel(model)
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=args.cuda,
                           best_or_latest=args.load_type)

    state_dict = ckpt['state_dict']
    if not args.mGPU:
        new_state_dict = OrderedDict()
        if not args.cuda:
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(ckpt['state_dict'])
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.eval()

    # data_loader = iter(data_loader)
    trans = transforms.ToPILImage()

    with torch.no_grad():
        psnr = 0.0
        ssim = 0.0
        torch.manual_seed(0)
        for i, (image_noise_hr, image_noise_lr,
                image_gt_hr) in enumerate(data_loader):
            if i < 100:
                # data = next(data_loader)
                if args.cuda:
                    burst_noise = image_noise_lr.cuda()
                    gt = image_gt_hr.cuda()
                else:
                    burst_noise = image_noise_lr
                    gt = image_gt_hr
                if color:
                    b, N, c, h, w = image_noise_lr.size()
                    feedData = image_noise_lr.view(b, -1, h, w)
                else:
                    feedData = image_noise_lr
                pred_i, pred = model(feedData, burst_noise[:, 0:burst_length,
                                                           ...],
                                     image_noise_hr)

                psnr_t = calculate_psnr(pred, gt)
                ssim_t = calculate_ssim(pred, gt)
                print("PSNR : ", str(psnr_t), " :  SSIM : ", str(ssim_t))

                pred = torch.clamp(pred, 0.0, 1.0)

                if args.cuda:
                    pred = pred.cpu()
                    gt = gt.cpu()
                    burst_noise = burst_noise.cpu()
                if args.save_img:
                    trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join(
                        eval_dir, '{}_noisy.png'.format(i)),
                                                                 quality=100)
                    trans(pred.squeeze()).save(os.path.join(
                        eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)),
                                               quality=100)
                    trans(gt.squeeze()).save(os.path.join(
                        eval_dir, '{}_gt.png'.format(i)),
                                             quality=100)
            else:
                break
Пример #2
0
def test_multi(args):
    color = True
    burst_length = args.burst_length
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN_DGF(color=color,
                            burst_length=burst_length,
                            blind_est=True,
                            kernel_size=[5],
                            sep_conv=False,
                            channel_att=True,
                            spatial_att=True,
                            upMode="bilinear",
                            core_bias=False)
    elif args.model_type == "attKPN_Wave":
        model = Att_KPN_Wavelet_DGF(color=color,
                                    burst_length=burst_length,
                                    blind_est=True,
                                    kernel_size=[5],
                                    sep_conv=False,
                                    channel_att=True,
                                    spatial_att=True,
                                    upMode="bilinear",
                                    core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN_DGF(color=color,
                                   burst_length=burst_length,
                                   blind_est=True,
                                   kernel_size=[5],
                                   sep_conv=False,
                                   channel_att=True,
                                   spatial_att=True,
                                   upMode="bilinear",
                                   core_bias=False)
    elif args.model_type == "KPN":
        model = KPN_DGF(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=False,
                        spatial_att=False,
                        upMode="bilinear",
                        core_bias=False)
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=device == 'cuda',
                           best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']

    # if not args.cuda:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    # else:
    #     model.load_state_dict(ckpt['state_dict'])

    model.to(device)
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.eval()
    # model= save_dict['state_dict']
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(
        args.noise_dir)['ValidationNoisyBlocksSrgb']
    all_clean_imgs = scipy.io.loadmat(args.gt)['ValidationGtBlocksSrgb']
    i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape
    psnrs = []
    ssims = []
    for i_img in range(i_imgs):
        for i_block in range(i_blocks):
            image_noise = transforms.ToTensor()(Image.fromarray(
                all_noisy_imgs[i_img][i_block]))
            image_noise = transforms.ToTensor()(Image.fromarray(
                all_noisy_imgs[i_img][i_block]))
            image_noise, image_noise_hr = load_data(image_noise, burst_length)
            image_noise_hr = image_noise_hr.to(device)
            # begin = time.time()
            image_noise = image_noise.to(device)
            # print(image_noise_batch.size())
            # burst_size = image_noise.size()[1]
            # print(burst_noise.size())
            # print(image_noise_hr.size())
            if color:
                b, N, c, h, w = image_noise.size()
                feedData = image_noise.view(b, -1, h, w)
            else:
                feedData = image_noise
            # print(feedData.size())
            pred_i, pred = model(feedData, image_noise[:, 0:burst_length, ...],
                                 image_noise_hr)
            del pred_i
            pred = pred.detach().cpu()
            # print("Time : ", time.time()-begin)
            gt = transforms.ToTensor()(Image.fromarray(
                all_clean_imgs[i_img][i_block]))
            gt = gt.unsqueeze(0)
            # print(pred_i.size())
            # print(pred[0].size())
            psnr_t = calculate_psnr(pred, gt)
            ssim_t = calculate_ssim(pred, gt)
            psnrs.append(psnr_t)
            ssims.append(ssim_t)
            print(i_img, "  ", i_block, "   UP   :  PSNR : ", str(psnr_t),
                  " :  SSIM : ", str(ssim_t))
            if args.save_img != '':
                if not os.path.exists(args.save_img):
                    os.makedirs(args.save_img)
                plt.figure(figsize=(15, 15))
                plt.imshow(np.array(trans(pred[0])))
                plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
                image_name = str(i_img) + "_" + str(i_block)
                plt.axis("off")
                plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) +
                             " :  SSIM : " + str(ssim_t),
                             fontsize=25)
                plt.savefig(os.path.join(
                    args.save_img,
                    image_name + "_" + args.checkpoint + '.png'),
                            pad_inches=0)
        """
        if args.save_img:
            # print(np.array(trans(mf8[0])))
            plt.figure(figsize=(30, 9))
            plt.subplot(1,3,1)
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise DGF "+args.model_type, fontsize=26)
            plt.subplot(1,3,2)
            plt.imshow(np.array(trans(gt[0])))
            plt.title("gt ", fontsize=26)
            plt.subplot(1,3,3)
            plt.imshow(np.array(trans(image_noise_hr[0])))
            plt.title("noise ", fontsize=26)
            plt.axis("off")
            plt.suptitle(str(i)+"   UP   :  PSNR : "+ str(psnr_t)+" :  SSIM : "+ str(ssim_t), fontsize=26)
            plt.savefig("checkpoints/22_DGF_" + args.checkpoint+str(i)+'.png',pad_inches=0)
        """
    print("   AVG   :  PSNR : " + str(np.mean(psnrs)) + " :  SSIM : " +
          str(np.mean(ssims)))
Пример #3
0
def test_multi(args):
    color = True
    burst_length = args.burst_length
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN_DGF(
            color=color,
            burst_length=burst_length,
            blind_est=True,
            kernel_size=[5],
            sep_conv=False,
            channel_att=True,
            spatial_att=True,
            upMode="bilinear",
            core_bias=False
        )
    elif args.model_type == "attKPN_Wave":
        model = Att_KPN_Wavelet_DGF(
            color=color,
            burst_length=burst_length,
            blind_est=True,
            kernel_size=[5],
            sep_conv=False,
            channel_att=True,
            spatial_att=True,
            upMode="bilinear",
            core_bias=False
        )
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN_DGF(
            color=color,
            burst_length=burst_length,
            blind_est=True,
            kernel_size=[5],
            sep_conv=False,
            channel_att=True,
            spatial_att=True,
            upMode="bilinear",
            core_bias=False
        )
    elif args.model_type == "KPN":
        model = KPN_DGF(
            color=color,
            burst_length=burst_length,
            blind_est=True,
            kernel_size=[5],
            sep_conv=False,
            channel_att=False,
            spatial_att=False,
            upMode="bilinear",
            core_bias=False
        )
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,cuda=device=='cuda',best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']

    # if not args.cuda:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(state_dict)
    # else:
    #     model.load_state_dict(ckpt['state_dict'])

    model.to(device)
    print('The model has been loaded from epoch {}, n_iter {}.'.format(ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png"))
    model.eval()
    torch.manual_seed(0)
    trans = transforms.ToPILImage()
    if not os.path.exists(args.save_img):
        os.makedirs(args.save_img)
    for i in range(len(noisy_path)):
        image_noise = transforms.ToTensor()(Image.open(noisy_path[i]).convert('RGB'))
        image_noise,image_noise_hr = load_data(image_noise,burst_length)
        image_noise_hr = image_noise_hr.to(device)
        # begin = time.time()
        image_noise_batch = image_noise.to(device)
        # print(image_noise_batch.size())
        # burst_size = image_noise_batch.size()[1]
        burst_noise = image_noise_batch.to(device)
        # print(burst_noise.size())
        # print(image_noise_hr.size())
        if color:
            b, N, c, h, w = burst_noise.size()
            feedData = burst_noise.view(b, -1, h, w)
        else:
            feedData = burst_noise
        pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...],image_noise_hr)
        del pred_i
        print(pred.size())
        pred = np.array(trans(pred[0].cpu()))
        print(pred.shape)
        if args.save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            # mat_contents['image'] = pred
            # print(mat_contents)
            print("save : ", os.path.join(args.save_img,noisy_path[i].split("/")[-1].split(".")[0]+'.mat'))
            data = {"Idenoised_crop": pred}
            # print(data)
            sio.savemat(os.path.join(args.save_img,noisy_path[i].split("/")[-1].split(".")[0]+'.mat'), data)
Пример #4
0
from utils.training_util import save_checkpoint, MovingAverage, load_checkpoint

checkpoint = load_checkpoint("../checkpoints/kpn_att_repeat_new/", False,
                             'latest')
state_dict = checkpoint['state_dict']
model = Att_KPN_DGF(color=True,
                    burst_length=4,
                    blind_est=True,
                    kernel_size=[5],
                    sep_conv=False,
                    channel_att=True,
                    spatial_att=True,
                    upMode="bilinear",
                    core_bias=False)
# model.load_state_dict(state_dict)
model.eval()
from torchsummary import summary
summary(model, [(12, 256, 256), (4, 3, 256, 256), (3, 512, 512)], batch_size=1)
exit()
# Converting model to ONNX
print('===> Converting model to ONNX.')
try:
    for _ in model.modules():
        _.training = False

    sample_input1 = torch.randn(1, 12, 256, 256)
    sample_input2 = torch.randn(1, 4, 3, 256, 256)
    sample_input3 = torch.randn(1, 3, 512, 512)

    input_nodes = ['input']
    output_nodes = ['output']
Пример #5
0
def test_multi(args):
    color = True
    burst_length = args.burst_length
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN_DGF(color=color,
                            burst_length=burst_length,
                            blind_est=True,
                            kernel_size=[5],
                            sep_conv=False,
                            channel_att=True,
                            spatial_att=True,
                            upMode="bilinear",
                            core_bias=False)
    elif args.model_type == "attKPN_Wave":
        model = Att_KPN_Wavelet_DGF(color=color,
                                    burst_length=burst_length,
                                    blind_est=True,
                                    kernel_size=[5],
                                    sep_conv=False,
                                    channel_att=True,
                                    spatial_att=True,
                                    upMode="bilinear",
                                    core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN_DGF(color=color,
                                   burst_length=burst_length,
                                   blind_est=True,
                                   kernel_size=[5],
                                   sep_conv=False,
                                   channel_att=True,
                                   spatial_att=True,
                                   upMode="bilinear",
                                   core_bias=False)
    elif args.model_type == "KPN":
        model = KPN_DGF(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=False,
                        spatial_att=False,
                        upMode="bilinear",
                        core_bias=False)
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=device == 'cuda',
                           best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']

    # if not args.cuda:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    # else:
    #     model.load_state_dict(ckpt['state_dict'])

    model.to(device)
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.eval()
    # model= save_dict['state_dict']
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    noisy_path = sorted(glob.glob(args.noise_dir + "/*.png"))
    clean_path = [i.replace("noisy", "clean") for i in noisy_path]
    upscale_factor = int(math.sqrt(burst_length))
    for i in range(len(noisy_path)):
        image_noise, image_noise_hr = load_data(noisy_path[i], burst_length)
        image_noise_hr = image_noise_hr.to(device)
        # begin = time.time()
        image_noise_batch = image_noise.to(device)
        # print(image_noise_batch.size())
        burst_size = image_noise_batch.size()[1]
        burst_noise = image_noise_batch.to(device)
        # print(burst_noise.size())
        # print(image_noise_hr.size())
        if color:
            b, N, c, h, w = burst_noise.size()
            feedData = burst_noise.view(b, -1, h, w)
        else:
            feedData = burst_noise
        # print(feedData.size())
        pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...],
                             image_noise_hr)
        # del pred_i
        pred_i = pred_i.detach().cpu()
        print(pred_i.size())
        pred_full = pixel_shuffle(pred_i, upscale_factor)
        pred_full = pred_full
        print(pred_full.size())

        pred = pred.detach().cpu()
        # print("Time : ", time.time()-begin)
        gt = transforms.ToTensor()(Image.open(clean_path[i]).convert('RGB'))
        gt = gt.unsqueeze(0)
        # print(pred_i.size())
        # print(pred[0].size())
        psnr_t = calculate_psnr(pred, gt)
        ssim_t = calculate_ssim(pred, gt)
        print(i, "  pixel_shuffle UP   :  PSNR : ",
              str(calculate_psnr(pred_full, gt)), " :  SSIM : ",
              str(calculate_ssim(pred_full, gt)))
        print(i, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ", str(ssim_t))
        if args.save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
            image_name = noisy_path[i].split("/")[-1].split(".")[0]
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) +
                         " :  SSIM : " + str(ssim_t),
                         fontsize=25)
            plt.savefig(os.path.join(
                args.save_img, image_name + "_" + args.checkpoint + '.png'),
                        pad_inches=0)
        """
Пример #6
0
def test_multi(args):
    color = True
    burst_length = args.burst_length
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN_DGF(color=color,
                            burst_length=burst_length,
                            blind_est=True,
                            kernel_size=[5],
                            sep_conv=False,
                            channel_att=True,
                            spatial_att=True,
                            upMode="bilinear",
                            core_bias=False)
    elif args.model_type == "attKPN_Wave":
        model = Att_KPN_Wavelet_DGF(color=color,
                                    burst_length=burst_length,
                                    blind_est=True,
                                    kernel_size=[5],
                                    sep_conv=False,
                                    channel_att=True,
                                    spatial_att=True,
                                    upMode="bilinear",
                                    core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN_DGF(color=color,
                                   burst_length=burst_length,
                                   blind_est=True,
                                   kernel_size=[5],
                                   sep_conv=False,
                                   channel_att=True,
                                   spatial_att=True,
                                   upMode="bilinear",
                                   core_bias=False)
    elif args.model_type == "KPN":
        model = KPN_DGF(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=False,
                        spatial_att=False,
                        upMode="bilinear",
                        core_bias=False)
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=device == 'cuda',
                           best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']

    # if not args.cuda:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    # else:
    #     model.load_state_dict(ckpt['state_dict'])

    model.to(device)
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.eval()
    # model= save_dict['state_dict']
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(
        args.noise_dir)['BenchmarkNoisyBlocksSrgb']
    mat_re = np.zeros_like(all_noisy_imgs)
    # all_clean_imgs = scipy.io.loadmat(args.gt)['siddplus_valid_gt_srgb']
    i_imgs, i_blocks, _, _, _ = all_noisy_imgs.shape
    psnrs = []
    ssims = []
    for i_img in range(i_imgs):
        for i_block in range(i_blocks):
            image_noise = transforms.ToTensor()(Image.fromarray(
                all_noisy_imgs[i_img][i_block]))
            image_noise, image_noise_hr = load_data(image_noise, burst_length)
            image_noise_hr = image_noise_hr.to(device)
            # begin = time.time()
            image_noise_batch = image_noise.to(device)
            # print(image_noise_batch.size())
            burst_size = image_noise_batch.size()[1]
            burst_noise = image_noise_batch.to(device)
            # print(burst_noise.size())
            # print(image_noise_hr.size())
            if color:
                b, N, c, h, w = burst_noise.size()
                feedData = burst_noise.view(b, -1, h, w)
            else:
                feedData = burst_noise
            # print(feedData.size())
            pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...],
                                 image_noise_hr)
            # del pred_i
            pred = pred.detach().cpu()

            mat_re[i_img][i_block] = np.array(trans(pred[0]))

    return mat_re