Esempio n. 1
0
def test(args):
    model = MIRNet_DGF()
    # summary(model,[[3,128,128],[0]])
    # exit()
    if args.data_type == 'rgb':
        load_data = load_data_split
    elif args.data_type == 'filter':
        load_data = load_data_filter
    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(args.noise_dir)['siddplus_valid_noisy_srgb']
    all_clean_imgs = scipy.io.loadmat(args.gt)['siddplus_valid_gt_srgb']
    # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png"))
    # clean_path = [ i.replace("noisy","clean") for i in noisy_path]
    i_imgs, _,_,_ = all_noisy_imgs.shape
    psnrs = []
    ssims = []
    # print(noisy_path)
    for i_img in range(i_imgs):
        noise = transforms.ToTensor()(Image.fromarray(all_noisy_imgs[i_img]))
        image_noise, image_noise_hr = load_data(noise, args.burst_length)
        image_noise_hr = image_noise_hr.to(device)
        burst_noise = image_noise.to(device)
        begin = time.time()
        _, pred = model(burst_noise,image_noise_hr)
        pred = pred.detach().cpu()
        gt = transforms.ToTensor()((Image.fromarray(all_clean_imgs[i_img])))
        gt = gt.unsqueeze(0)
        psnr_t = calculate_psnr(pred, gt)
        ssim_t = calculate_ssim(pred, gt)
        psnrs.append(psnr_t)
        ssims.append(ssim_t)
        print(i_img, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ", str(ssim_t))
        if args.save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
            image_name = str(i_img) + "_"
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) + " :  SSIM : " + str(ssim_t), fontsize=25)
            plt.savefig(os.path.join(args.save_img, image_name + "_" + args.checkpoint + '.png'), pad_inches=0)
    print("   AVG   :  PSNR : "+ str(np.mean(psnrs))+" :  SSIM : "+ str(np.mean(ssims)))
Esempio n. 2
0
def test(args):
    model = MIRNet()
    save_img = args.save_img
    checkpoint_dir = "checkpoints/mir"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:
    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    noisy_path = sorted(glob.glob(args.noise_dir + "/2_*.png"))
    clean_path = [i.replace("noisy", "clean") for i in noisy_path]
    print(noisy_path)
    for i in range(len(noisy_path)):
        noise = transforms.ToTensor()(Image.open(
            noisy_path[i]).convert('RGB'))[:, 0:args.image_size,
                                           0:args.image_size].unsqueeze(0)
        noise = noise.to(device)
        begin = time.time()
        print(noise.size())
        pred = model(noise)
        pred = pred.detach().cpu()
        gt = transforms.ToTensor()(Image.open(
            clean_path[i]).convert('RGB'))[:, 0:args.image_size,
                                           0:args.image_size]
        gt = gt.unsqueeze(0)
        psnr_t = calculate_psnr(pred, gt)
        ssim_t = calculate_ssim(pred, gt)
        print(i, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ", str(ssim_t))
        if save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
            image_name = noisy_path[i].split("/")[-1].split(".")[0]
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) +
                         " :  SSIM : " + str(ssim_t),
                         fontsize=25)
            plt.savefig(os.path.join(
                args.save_img, image_name + "_" + args.checkpoint + '.png'),
                        pad_inches=0)
Esempio n. 3
0
def train(args):
    torch.set_num_threads(args.num_workers)
    torch.manual_seed(0)
    if args.data_type == 'rgb':
        data_set = SingleLoader(noise_dir=args.noise_dir,
                                gt_dir=args.gt_dir,
                                image_size=args.image_size)
    elif args.data_type == 'raw':
        data_set = SingleLoader_raw(noise_dir=args.noise_dir,
                                    gt_dir=args.gt_dir,
                                    image_size=args.image_size)
    else:
        print("Data type not valid")
        exit()
    data_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_func = losses.CharbonnierLoss().to(device)
    # loss_func = losses.AlginLoss().to(device)
    adaptive = robust_loss.adaptive.AdaptiveLossFunction(
        num_dims=3 * args.image_size**2, float_dtype=np.float32, device=device)
    checkpoint_dir = args.checkpoint
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if args.model_type == "MIR":
        model = MIRNet(in_channels=args.n_colors,
                       out_channels=args.out_channels).to(device)
    elif args.model_type == "KPN":
        model = MIRNet_kpn(in_channels=args.n_colors,
                           out_channels=args.out_channels).to(device)
    else:
        print(" Model type not valid")
        return
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    optimizer.zero_grad()
    average_loss = MovingAverage(args.save_every)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               [2, 4, 6, 8, 10, 12, 14, 16],
                                               0.8)
    if args.restart:
        start_epoch = 0
        global_step = 0
        best_loss = np.inf
        print('=> no checkpoint file to be loaded.')
    else:
        try:
            checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda',
                                         'latest')
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_iter']
            best_loss = checkpoint['best_loss']
            state_dict = checkpoint['state_dict']
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = "model."+ k  # remove `module.`
            #     new_state_dict[name] = v
            model.load_state_dict(state_dict)
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint (epoch {}, global_step {})'.format(
                start_epoch, global_step))
        except:
            start_epoch = 0
            global_step = 0
            best_loss = np.inf
            print('=> no checkpoint file to be loaded.')
    eps = 1e-4
    for epoch in range(start_epoch, args.epoch):
        for step, (noise, gt) in enumerate(data_loader):
            noise = noise.to(device)
            gt = gt.to(device)
            pred = model(noise)
            # print(pred.size())
            loss = loss_func(pred, gt)
            # bs = gt.size()[0]
            # diff = noise - gt
            # loss = torch.sqrt((diff * diff) + (eps * eps))
            # loss = loss.view(bs,-1)
            # loss = adaptive.lossfun(loss)
            # loss = torch.mean(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            average_loss.update(loss)
            if global_step % args.save_every == 0:
                print(len(average_loss._cache))
                if average_loss.get_value() < best_loss:
                    is_best = True
                    best_loss = average_loss.get_value()
                else:
                    is_best = False

                save_dict = {
                    'epoch': epoch,
                    'global_iter': global_step,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }
                save_checkpoint(save_dict, is_best, checkpoint_dir,
                                global_step)
            if global_step % args.loss_every == 0:
                print(global_step, "PSNR  : ", calculate_psnr(pred, gt))
                print(average_loss.get_value())
            global_step += 1
        print('Epoch {} is finished.'.format(epoch))
        scheduler.step()
Esempio n. 4
0
def train(args):
    # torch.set_num_threads(4)
    # torch.manual_seed(args.seed)
    # checkpoint = utility.checkpoint(args)
    data_set = SingleLoader(noise_dir=args.noise_dir,
                            gt_dir=args.gt_dir,
                            image_size=args.image_size)
    data_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    loss_basic = BasicLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint_dir = args.checkpoint
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    model = MWRN_lv1().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               [5, 10, 15, 20, 25, 30], 0.5)
    optimizer.zero_grad()
    average_loss = MovingAverage(args.save_every)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.checkpoint2 != "":
        if device == 'cuda':
            checkpoint2 = torch.load(args.checkpoint2)
        else:
            checkpoint2 = torch.load(args.checkpoint2,
                                     map_location=torch.device('cpu'))
        state_dict2 = checkpoint2['state_dict']
        model.lv2.load_state_dict(state_dict2)
        print("load lv2 done ....")
    try:
        checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda',
                                     'latest')
        start_epoch = checkpoint['epoch']
        global_step = checkpoint['global_iter']
        best_loss = checkpoint['best_loss']
        state_dict = checkpoint['state_dict']
        # new_state_dict = OrderedDict()
        # for k, v in state_dict.items():
        #     name = "model."+ k  # remove `module.`
        #     new_state_dict[name] = v
        model.load_state_dict(state_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('=> loaded checkpoint (epoch {}, global_step {})'.format(
            start_epoch, global_step))
    except:
        start_epoch = 0
        global_step = 0
        best_loss = np.inf
        print('=> no checkpoint file to be loaded.')
    DWT = common.DWT()
    param = [x for name, x in model.named_parameters()]
    clip_grad_D = 1e4
    grad_norm_D = 0
    for epoch in range(start_epoch, args.epoch):
        for step, (noise, gt) in enumerate(data_loader):
            noise = noise.to(device)
            gt = gt.to(device)
            x1 = DWT(gt).to(device)
            x2 = DWT(x1).to(device)
            x3 = DWT(x2).to(device)
            pred, img_lv2, img_lv3 = model(noise)
            # print(pred.size())
            loss_pred = loss_basic(pred, gt)
            scale_loss_lv2 = loss_basic(x2, img_lv2)
            scale_loss_lv3 = loss_basic(x3, img_lv3)
            loss = loss_pred + scale_loss_lv2 + scale_loss_lv3
            optimizer.zero_grad()
            loss.backward()
            total_norm_D = nn.utils.clip_grad_norm_(param, clip_grad_D)
            grad_norm_D = (grad_norm_D * (step / (step + 1)) + total_norm_D /
                           (step + 1))
            optimizer.step()
            average_loss.update(loss)
            if global_step % args.save_every == 0:
                print("Save : epoch ", epoch,
                      " step : ", global_step, " with avg loss : ",
                      average_loss.get_value(), ",   best loss : ", best_loss)
                if average_loss.get_value() < best_loss:
                    is_best = True
                    best_loss = average_loss.get_value()
                else:
                    is_best = False
                save_dict = {
                    'epoch': epoch,
                    'global_iter': global_step,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }
                save_checkpoint(save_dict, is_best, checkpoint_dir,
                                global_step)
            if global_step % args.loss_every == 0:
                print(global_step, "PSNR  : ", calculate_psnr(pred, gt))
                print(average_loss.get_value())
            global_step += 1
        clip_grad_D = min(clip_grad_D, grad_norm_D)
        scheduler.step()
        print("Epoch : ", epoch, "end at step: ", global_step)
Esempio n. 5
0
def test(args):
    if args.model_type == "DGF":
        model = MIRNet_DGF(n_colors=args.n_colors,
                           out_channels=args.out_channels)
    elif args.model_type == "noise":
        model = MIRNet_noise(n_colors=args.n_colors,
                             out_channels=args.out_channels)
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = args.checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # try:

    checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_iter']
    state_dict = checkpoint['state_dict']
    # new_state_dict = OrderedDict()
    # for k, v in state_dict.items():
    #     name = "model." + k  # remove `module.`
    #     new_state_dict[name] = v
    model.load_state_dict(state_dict)
    print('=> loaded checkpoint (epoch {}, global_step {})'.format(
        start_epoch, global_step))
    # except:
    #     print('=> no checkpoint file to be loaded.')    # model.load_state_dict(state_dict)
    #     exit(1)
    model.eval()
    model = model.to(device)
    trans = transforms.ToPILImage()
    torch.manual_seed(0)

    all_noisy_imgs = scipy.io.loadmat(
        args.noise_dir)['ValidationNoisyBlocksRaw']
    all_clean_imgs = scipy.io.loadmat(args.gt_dir)['ValidationGtBlocksRaw']
    # noisy_path = sorted(glob.glob(args.noise_dir+ "/*.png"))
    # clean_path = [ i.replace("noisy","clean") for i in noisy_path]
    i_imgs, i_blocks, _, _ = all_noisy_imgs.shape
    psnrs = []
    ssims = []
    # print(noisy_path)
    for i_img in range(i_imgs):
        for i_block in range(i_blocks):
            noise = transforms.ToTensor()(pack_raw(
                all_noisy_imgs[i_img][i_block]))
            image_noise, image_noise_hr = load_data(noise, args.burst_length)
            image_noise_hr = image_noise_hr.to(device)
            burst_noise = image_noise.to(device)
            begin = time.time()
            _, pred = model(burst_noise, image_noise_hr)
            pred = pred.detach().cpu()
            gt = transforms.ToTensor()(
                (pack_raw(all_clean_imgs[i_img][i_block])))
            gt = gt.unsqueeze(0)
            psnr_t = calculate_psnr(pred, gt)
            ssim_t = calculate_ssim(pred, gt)
            psnrs.append(psnr_t)
            ssims.append(ssim_t)
            print(i_img, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ",
                  str(ssim_t))
            if args.save_img != '':
                if not os.path.exists(args.save_img):
                    os.makedirs(args.save_img)
                plt.figure(figsize=(15, 15))
                plt.imshow(np.array(trans(pred[0])))
                plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
                image_name = str(i_img)
                plt.axis("off")
                plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) +
                             " :  SSIM : " + str(ssim_t),
                             fontsize=25)
                plt.savefig(os.path.join(
                    args.save_img,
                    image_name + "_" + args.checkpoint + '.png'),
                            pad_inches=0)
    print("   AVG   :  PSNR : " + str(np.mean(psnrs)) + " :  SSIM : " +
          str(np.mean(ssims)))