Пример #1
0
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0

    for batch_idx, (burst_imgs, raw_img) in enumerate(train_loader):
        # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader):

        optimizer.zero_grad()
        model.zero_grad()

        # -- reshaping of data --
        raw_img = raw_img.cuda(non_blocking=True)
        burst_imgs = burst_imgs.cuda(non_blocking=True)
        # res_imgs = res_imgs.cuda(non_blocking=True)
        img0 = burst_imgs[0]
        # img0,res0 = burst_imgs[0],res_imgs[0]
        # img1,res1 = burst_imgs[1],res_imgs[1]

        # -- predict residual --
        pred_res = model(img0)
        rec_img = img0 - pred_res

        # -- compare with stacked burst --
        loss = F.mse_loss(raw_img, rec_img + 0.5)

        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- BP and optimize --
        loss.backward()
        optimizer.step()

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:
            # -- compute mse for fun --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            mse_loss = F.mse_loss(raw_img, rec_img + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            running_loss /= cfg.log_interval
            print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e" %
                  (epoch, cfg.epochs, batch_idx, len(train_loader),
                   running_loss, psnr))
            running_loss = 0
    total_loss /= len(train_loader)
    return total_loss
Пример #2
0
def test_loop(cfg, model, criterion, test_loader, epoch):
    model.eval()
    model = model.to(cfg.device)
    total_psnr = 0
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader):
            # for batch_idx, (burst_imgs, res_img, raw_img) in enumerate(test_loader):

            BS = raw_img.shape[0]

            # reshaping of data
            raw_img = raw_img.cuda(non_blocking=True)
            burst_imgs = burst_imgs.cuda(non_blocking=True)
            img0 = burst_imgs[0]

            # denoising
            pred_res = model(img0)
            rec_img = img0 - pred_res

            # compare with stacked targets
            rec_img = rescale_noisy_image(rec_img)
            loss = F.mse_loss(raw_img, rec_img,
                              reduction='none').reshape(BS, -1)
            loss = torch.mean(loss, 1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)

            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            if (batch_idx % cfg.test_log_interval) == 0:
                root = Path(
                    f"{settings.ROOT_PATH}/output/n2n/rec_imgs/e{epoch}")
                if not root.exists(): root.mkdir(parents=True)
                fn = root / Path(f"b{batch_idx}.png")
                nrow = int(np.sqrt(cfg.batch_size))
                rec_img = rec_img.detach().cpu()
                grid_imgs = vutils.make_grid(rec_img,
                                             padding=2,
                                             normalize=True,
                                             nrow=nrow)
                plt.imshow(grid_imgs.permute(1, 2, 0))
                plt.savefig(fn)
                plt.close('all')

    ave_psnr = total_psnr / len(test_loader)
    ave_loss = total_loss / len(test_loader)
    print("Testing results: Ave psnr %2.3e Ave loss %2.3e" %
          (ave_psnr, ave_loss))
    return ave_psnr
Пример #3
0
def train_loop_offset(cfg,model,optimizer,criterion,train_loader,epoch):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    # random_eraser = th_trans.RandomErasing(scale=(0.40,0.80))
    random_eraser = th_trans.RandomErasing(scale=(0.02,0.33))

    # if cfg.N != 5: return
    # for batch_idx, (burst_imgs, raw_img) in enumerate(train_loader):
    for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader):


        optimizer.zero_grad()
        model.zero_grad()

        # fig,ax = plt.subplots(figsize=(10,10))
        # imgs = burst_imgs + 0.5
        # imgs.clamp_(0.,1.)
        # raw_img = raw_img.expand(burst_imgs.shape)
        # print(imgs.shape,raw_img.shape)
        # all_img = torch.cat([imgs,raw_img],dim=1)
        # print(all_img.shape)
        # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)]
        # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids]
        # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
        # Writer = animation.writers['ffmpeg']
        # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)
        # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer)
        # print("I DID IT!")
        # return

        # -- reshaping of data --
        # raw_img = raw_img.cuda(non_blocking=True)
        input_order = np.arange(cfg.N)
        # print("pre",input_order,cfg.blind,cfg.N)
        middle_img_idx = -1
        if not cfg.input_with_middle_frame:
            middle = len(input_order) // 2
            # print(middle)
            middle_img_idx = input_order[middle]
            input_order = np.r_[input_order[:middle],input_order[middle+1:]]
        else:
            middle = len(input_order) // 2
            middle_img_idx = input_order[middle]
            input_order = np.arange(cfg.N)
        # print("post",input_order,middle_img_idx,cfg.blind,cfg.N)

        # -- add input noise --
        burst_imgs = burst_imgs.cuda(non_blocking=True)
        burst_imgs_noisy = burst_imgs.clone()
        if cfg.input_noise:
            # noise = np.random.rand() * cfg.input_noise_level
            noise = cfg.input_noise_level
            if cfg.input_noise_middle_only:
                burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise)
            else:
                burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise)

        # if cfg.middle_frame_random_erase:
        #     for i in range(burst_imgs_noisy[middle_img_idx].shape[0]):
        #         tmp = random_eraser(burst_imgs_noisy[middle_img_idx][i])
        #         burst_imgs_noisy[middle_img_idx][i] = tmp
        # burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise)
        # print(torch.sum(burst_imgs_noisy[middle_img_idx] - burst_imgs[middle_img_idx]))

        # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)])
        if cfg.color_cat:
            stacked_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1)
        else:
            stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1)

        # if cfg.input_noise:
        #     stacked_burst = torch.normal(stacked_burst,noise)

        # -- extract target image --
        if cfg.blind:
            t_img = burst_imgs[middle_img_idx]
        else:
            t_img = szm(raw_img.cuda(non_blocking=True))

        # -- denoising --
        rec_img = model(stacked_burst)

        # -- compute loss --
        loss = F.mse_loss(t_img,rec_img)
        

        # -- dncnn denoising --
        # rec_res = model(stacked_burst)

        # -- compute loss --
        # t_res = t_img - burst_imgs[middle_img_idx]
        # loss = F.mse_loss(t_res,rec_res)

        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- BP and optimize --
        loss.backward()
        optimizer.step()

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            BS = raw_img.shape[0]            
            raw_img = raw_img.cuda(non_blocking=True)
            mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            running_loss /= cfg.log_interval
            print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e"%(epoch, cfg.epochs, batch_idx,
                                                         len(train_loader),
                                                         running_loss,psnr))
            running_loss = 0
    total_loss /= len(train_loader)
    return total_loss
Пример #4
0
def test_loop_offset(cfg,model,criterion,test_loader,epoch):
    model.eval()
    model = model.to(cfg.device)
    total_psnr = 0
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader):
        # for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader):
    
            BS = raw_img.shape[0]
            
            # -- selecting input frames --
            input_order = np.arange(cfg.N)
            # print("pre",input_order)
            # if cfg.blind or True:
            middle_img_idx = -1
            if not cfg.input_with_middle_frame:
                middle = cfg.N // 2
                # print(middle)
                middle_img_idx = input_order[middle]
                input_order = np.r_[input_order[:middle],input_order[middle+1:]]
            else:
                # input_order = np.arange(cfg.N)
                middle = len(input_order) // 2
                middle_img_idx = input_order[middle]
                input_order = np.arange(cfg.N)
            
            # -- reshaping of data --
            raw_img = raw_img.cuda(non_blocking=True)
            burst_imgs = burst_imgs.cuda(non_blocking=True)

            if cfg.color_cat:
                stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1)
            else:
                stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1)
    
            # -- direct denoising --
            rec_img = model(stacked_burst)
            
            # -- dncnn denoising --
            # rec_res = model(stacked_burst)
            # rec_img = burst_imgs[middle_img_idx] + rec_res
            
            # -- compare with stacked targets --
            rec_img = rescale_noisy_image(rec_img)        
            loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1)
            loss = torch.mean(loss,1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)

            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            if (batch_idx % cfg.test_log_interval) == 0:
                root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}")
                if not root.exists(): root.mkdir(parents=True)
                fn = root / Path(f"b{batch_idx}.png")
                nrow = int(np.sqrt(cfg.batch_size))
                rec_img = rec_img.detach().cpu()
                grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow)
                plt.imshow(grid_imgs.permute(1,2,0))
                plt.savefig(fn)
                plt.close('all')

    ave_psnr = total_psnr / len(test_loader)
    ave_loss = total_loss / len(test_loader)
    print("[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e"%(cfg.blind,cfg.N,ave_psnr,ave_loss))
    return ave_psnr
Пример #5
0
def test_loop(cfg, model, test_loader, epoch):
    model.eval()
    model.align_info.model.eval()
    model.denoiser_info.model.eval()
    model.unet_info.model.eval()
    model = model.to(cfg.device)
    noise_type = cfg.noise_params.ntype
    total_psnr = 0
    total_loss = 0
    use_record = False
    record_test = pd.DataFrame({'psnr': []})

    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    test_iter = iter(test_loader)
    num_batches, D = 25, len(test_iter)
    num_batches = D
    num_batches = num_batches if D > num_batches else D
    psnrs = np.zeros((num_batches, cfg.batch_size))

    with torch.no_grad():
        for batch_idx in range(num_batches):

            sample = next(test_iter)
            burst, raw_img, motion = sample['burst'], sample['clean'], sample[
                'directions']
            B = raw_img.shape[0]

            # -- selecting input frames --
            input_order = np.arange(cfg.N)
            # print("pre",input_order)
            middle_img_idx = -1
            if not cfg.input_with_middle_frame:
                middle = cfg.N // 2
                # print(middle)
                middle_img_idx = input_order[middle]
                # input_order = np.r_[input_order[:middle],input_order[middle+1:]]
            else:
                middle = len(input_order) // 2
                input_order = np.arange(cfg.N)
                middle_img_idx = input_order[middle]
                # input_order = np.arange(cfg.N)

            # -- reshaping of data --
            raw_img = raw_img.cuda(non_blocking=True)
            burst = burst.cuda(non_blocking=True)
            stacked_burst = torch.stack(
                [burst[input_order[x]] for x in range(cfg.input_N)], dim=1)
            cat_burst = torch.cat(
                [burst[input_order[x]] for x in range(cfg.input_N)], dim=1)

            # -- align images if necessary --
            if cfg.abps_inputs:
                # scores,aligned = abp_search(cfg,burst)
                if cfg.lpas_method == "spoof":
                    mtype = "global"
                    acc = cfg.optical_flow_acc
                    scores, aligned = lpas_spoof(burst, motion, cfg.nblocks,
                                                 mtype, acc)
                else:
                    ref_frame = (cfg.nframes + 1) // 2
                    nblocks = cfg.nblocks
                    method = cfg.lpas_method
                    results = lpas_search(burst, ref_frame, nblocks, motion,
                                          method)
                    scores, aligned, dacc = results
                burst = aligned.clone()

            if True:
                images = [burst, raw_img]
                cropped = crop_center_patch(images, cfg.nframes,
                                            cfg.frame_size)
                burst, raw_img = cropped[0], cropped[1]
                if cfg.abps_inputs:
                    aligned = crop_center_patch([aligned], spacing,
                                                cfg.frame_size)[0]
                burst = burst[:cfg.nframes]

            # -- denoising --
            m_aligned, m_aligned_ave, denoised, denoised_ave, a_filters, d_filters = model(
                burst)
            denoised_ave = denoised_ave.detach()

            # if not cfg.input_with_middle_frame:
            #     denoised_ave = model(cat_burst,stacked_burst)[1]
            # else:
            #     denoised_ave = model(cat_burst,stacked_burst)[0][middle_img_idx]

            # denoised_ave = burst[middle_img_idx] - rec_res

            # -- compare with stacked targets --
            raw_img = get_nmlz_tgt_img(cfg, raw_img)
            # denoised_ave = rescale_noisy_image(denoised_ave)

            # -- compute psnr --
            loss = F.mse_loss(raw_img, denoised_ave,
                              reduction='none').reshape(B, -1)
            # loss = F.mse_loss(raw_img,burst[cfg.input_N//2]+0.5,reduction='none').reshape(B,-1)
            loss = torch.mean(loss, 1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)
            psnrs[batch_idx, :] = psnr

            if use_record:
                record_test = record_test.append({'psnr': psnr},
                                                 ignore_index=True)
            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            # if (batch_idx % cfg.test_log_interval) == 0:
            #     root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/denoised_aves/N{cfg.N}/e{epoch}")
            #     if not root.exists(): root.mkdir(parents=True)
            #     fn = root / Path(f"b{batch_idx}.png")
            #     nrow = int(np.sqrt(cfg.batch_size))
            #     denoised_ave = denoised_ave.detach().cpu()
            #     grid_imgs = tv_utils.make_grid(denoised_ave, padding=2, normalize=True, nrow=nrow)
            #     plt.imshow(grid_imgs.permute(1,2,0))
            #     plt.savefig(fn)
            #     plt.close('all')
            if batch_idx % 100 == 0:
                print("[%d/%d] Test PSNR: %2.2f" %
                      (batch_idx, num_batches, total_psnr / (batch_idx + 1)),
                      flush=True)

    psnr_ave = np.mean(psnrs)
    psnr_std = np.std(psnrs)
    ave_loss = total_loss / num_batches
    print("[N: %d] Testing: [psnr: %2.2f +/- %2.2f] [ave loss %2.3e]" %
          (cfg.N, psnr_ave, psnr_std, ave_loss),
          flush=True)
    return psnr_ave, record_test
Пример #6
0
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses,
               writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0
    dynamics_acc, dynamics_count = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-
    dynamics_acc_i = -1.
    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    train_iter = iter(train_loader)
    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        sample = next(train_iter)
        burst, raw_img, motion = sample['burst'], sample['clean'], sample[
            'directions']
        raw_img_iid = sample['iid']
        raw_img_iid = raw_img_iid.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)

        aligned, est_nnf = align_burst(cfg, burst, model)
        sim_images = subsample_aligned(cfg, aligned)
        burst_in, tgt_out = create_training_pairs(burst, sim_images)

        dn_losses = []
        for burst, target in zip(burst_in, tgt_out):

            # -- forward pass --
            est_denoised = model(burst)
            dn_loss = compute_denoising_loss(est_denoised, target)

            # -- compute grads --
            if cfg.use_seed: torch.set_deterministic(False)
            dn_loss.backward()
            if cfg.use_seed: torch.set_deterministic(True)

            # -- backprop --
            optim.step()
            scheduler.step()

            # -- store info --
            losses.append(dn_loss.item())

        # -- average over losses --
        dn_loss = torch.mean(dn_losses)

        # -- alignment loss --
        align_loss = compute_nnf_loss(gt_nnf, est_nnf)

        # -- total loss --
        final_loss = dn_loss + align_loss
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- recompute model output for original images --
            outputs = model(burst_og)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            raw_img = get_nmlz_tgt_img(cfg, raw_img)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, m_aligned_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(burst_og, dim=0)
            if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave)
            mse_loss = F.mse_loss(raw_img, mis_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25))
            # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25))

            # -- psnr for [bm3d] --
            mid_img_og = burst[N // 2]
            bm3d_nb_psnrs = []
            M = 4 if B > 4 else B
            for b in range(M):
                bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                # maybe an issue here
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for input averaged frames --
            # burst_ave = torch.mean(burst_og,dim=0)
            # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1)
            # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            # psnr_input_ave = np.mean(mse_to_psnr(mse_loss))
            # psnr_input_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for aligned + denoised --
            R = denoised.shape[1]
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1)
            # if noise_type == "qis": denoised = quantize_img(cfg,denoised)
            # save_image(denoised_ave,"denoised_ave.png")
            # save_image(denoised,"denoised.png")
            mse_loss = F.mse_loss(raw_img_repN, denoised,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img, denoised_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- ave dynamic acc --
            ave_dyn_acc = dynamics_acc / dynamics_count * 100.
            dynamics_acc, dynamics_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, ave_dyn_acc)  #rec_ot_ave)

            #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e"
                % write_info,
                flush=True)
            # -- write to summary writer --
            if writer:
                writer.add_scalar('train/running-loss', running_loss,
                                  cfg.global_step)
                writer.add_scalars('train/model-psnr', {
                    'ave': psnr,
                    'std': psnr_std
                }, cfg.global_step)
                writer.add_scalars('train/dn-frame-psnr', {
                    'ave': psnr_denoised_ave,
                    'std': psnr_denoised_std
                }, cfg.global_step)

            # -- reset loss --
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, motion)

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Пример #7
0
def train_loop_mse(cfg, model, optimizer, criterion, train_loader, epoch):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    train_iter = iter(train_loader)
    K = cfg.sim_K
    noise_type = cfg.noise_params.ntype
    noise_level = cfg.noise_params['g']['stddev']
    # raw_offset,raw_scale = 0,0
    # if noise_type in ["g","hg"]:
    #     raw_offset = 0.5
    #     if noise_type == "g":
    #         noise_level = cfg.noise_params[noise_type]['stddev']
    #     elif noise_type == "hg":
    #         noise_level = cfg.noise_params[noise_type]['read']
    # elif noise_type == "qis":
    #     noise_params = cfg.noise_params[noise_type]
    #     noise_level = noise_params['readout']
    #     raw_scale = ( 2**noise_params['nbits']-1 ) / noise_params['alpha']

    cfg.noise_params['qis']['alpha'] = 255.0
    cfg.noise_params['qis']['readout'] = 0.0
    cfg.noise_params['qis']['nbits'] = 8
    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    for batch_idx, (burst, res_img, raw_img, d) in enumerate(train_loader):

        optimizer.zero_grad()
        model.zero_grad()

        # -- reshaping of data --
        BS = raw_img.shape[0]
        raw_img = raw_img.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)
        T, B = burst.shape[:2]

        # -- anscombe --
        # if cfg.use_anscombe:
        #     burst = anscombe_nmlz.forward(cfg,burst+0.5)

        burst = rearrange(burst, 't b c h w -> (t b) c h w')
        denoised = model(burst)
        loss = compute_bootstrap_loss(denoised, B, T, R=100)
        loss = torch.mean(loss)
        loss_other = (1 / (cfg.global_step + 1.))**1.2 * F.mse_loss(
            burst, denoised)
        loss += loss_other

        # img0 = burst[0]
        # img1 = burst[1]
        # kindex_ds = kIndexPermLMDB(cfg.batch_size,cfg.N)
        # kindex = kindex_ds[batch_idx].cuda(non_blocking=True)
        # kindex = None
        # sim_burst = compute_similar_bursts(cfg,burst0,burst1,K,noise_level/255.,
        #                                    patchsize=cfg.sim_patchsize,
        #                                    shuffle_k=cfg.sim_shuffleK,
        #                                    kindex=kindex,only_middle=True,
        #                                    search_method=cfg.sim_method,
        #                                    db_level="frame")

        #
        # -- select outputs --
        #

        # -- supervised --
        # img0 = burst[0]
        # img1 = get_nmlz_img(cfg,raw_img)
        # if cfg.use_anscombe: img1 = anscombe_nmlz.forward(cfg,img1+0.5)-0.5

        # -- noise2noise: mismatch noise --
        # img0 = burst[0]
        # img1 = torch.normal(raw_img-0.5,75./255.)

        # -- noise2noise --
        img0 = burst[0]
        img1 = burst[1]

        # img1 = noise_xform(raw_img)
        # img1 = img1.cuda(non_blocking=True)
        # raw_img = raw_img.cuda(non_blocking=True)
        # if cfg.use_anscombe: img1 = anscombe_nmlz.forward(cfg,img1+0.5)-0.5

        # raw_img = raw_img.cuda(non_blocking=True)
        # tv_utils.save_image(img0,'noisy0.png')
        # tv_utils.save_image(img1,'noisy1.png')
        # img1 = burst[1]

        # -- noise2noise + one-denoising-level --
        # img0 = burst[0]
        # img1 = burst[1]
        # if cfg.global_steps < 1000: img1 = burst[1]
        # else: img1 = model(burst[1]).detach()

        # -- noise2sim --
        # img0 = burst[0]
        # img1 = sim_burst[0][:,0]

        # img0 = sim_burst[0][:,0]
        # img1 = sim_burst[0][:,1]

        # -- plot example input/output --
        # plt_burst = rearrange(burst,'n b c h w -> (n b) c h w')
        # tv_utils.save_image(plt_burst,'burst.png',nrow=BS,normalize=True)

        # -- denoising --
        # rec_img = model(img0)

        # -- compare with stacked burst --
        # loss = F.mse_loss(raw_img,rec_img)
        # loss = F.mse_loss(img1,rec_img)
        # print_tensor_stats("img1",img1)
        # print_tensor_stats("rec",rec_img)

        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- BP and optimize --
        loss.backward()
        optimizer.step()

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            burst = rearrange(burst, '(t b) c h w -> t b c h w', t=T)
            rec_img = model(burst[0])
            # -- anscombe --
            print_tensor_stats("burst", burst)
            # if cfg.use_anscombe:
            #     # rec_img = torch.clamp(rec_img+0.5,0)-0.5
            #     print_tensor_stats("rec",rec_img)
            #     rec_img = anscombe_nmlz.backward(cfg,rec_img)-0.5
            #     print_tensor_stats("nmlz-rec",rec_img)

            # -- qis noise --
            # if noise_type == "qis":
            # rec_img += 0.5
            # rec_img *= 4
            # rec_img = torch.round(rec_img)
            # rec_img = torch.clamp(rec_img,0,4)
            # rec_img /= 4
            # rec_img -= 0.5
            # rec_img = quantize_img(cfg,rec_img+0.5)-0.5
            # rec_img = get_nmlz_img(cfg,rec_img+0.5)

            # -- raw image normalized for noise --
            # raw_img = torch.round(7*raw_img)/7. - 0.5
            # raw_img = get_nmlz_img(cfg,raw_img)
            # raw_img = get_nmlz_img(cfg,raw_img)

            # -- psnr finally --
            loss = F.mse_loss(raw_img, rec_img + 0.5,
                              reduction='none').reshape(BS, -1)
            loss = torch.mean(loss, 1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)
            psnr_ave = np.mean(psnr)
            psnr_std = np.std(psnr)

            # print( f"Ratio of noisy to clean: {img0.mean().item() / nmlz_raw.mean().item()}" )
            # print_tensor_stats("img1",img1)
            print_tensor_stats("rec_img", rec_img + 0.5)
            print_tensor_stats("raw_img", raw_img)
            # print_tensor_stats("nmlz_raw",nmlz_raw)
            # tv_utils.save_image(img0,'learn_noisy0.png',nrow=BS,normalize=True)
            # tv_utils.save_image(rec_img,'learn_rec_img.png',nrow=BS,normalize=True)
            # tv_utils.save_image(raw_img,'learn_raw_img.png',nrow=BS,normalize=True)
            # tv_utils.save_image(nmlz_raw,'learn_nmlz_raw.png',nrow=BS,normalize=True)

            running_loss /= cfg.log_interval
            print("[%d/%d][%d/%d]: %2.3e [PSNR] %2.2f +/- %2.2f " %
                  (epoch, cfg.epochs, batch_idx, len(train_loader),
                   running_loss, psnr_ave, psnr_std))
            running_loss = 0
        cfg.global_steps += 1
    total_loss /= len(train_loader)
    return total_loss
Пример #8
0
def test_loop_mse(cfg, model, criterion, test_loader, epoch):
    model.eval()
    model = model.to(cfg.device)
    total_psnr = 0
    total_loss = 0

    noise_type = cfg.noise_params.ntype
    # raw_offset,raw_scale = 0,0
    # if noise_type in ["g","hg"]:
    #     noise_level = cfg.noise_params[noise_type]['stddev']
    #     raw_offset = 0.5
    # elif noise_type == "qis":
    #     params = cfg.noise_params[noise_type]
    #     noise_level = params['readout']
    #     raw_scale = ( 2**params['nbits']-1 ) / params['alpha']

    with torch.no_grad():
        for batch_idx, (burst, res_img, raw_img, d) in enumerate(test_loader):

            BS = raw_img.shape[0]

            # reshaping of data
            raw_img = raw_img.cuda(non_blocking=True)
            burst = burst.cuda(non_blocking=True)
            img0 = burst[0]

            # -- anscombe --
            # if cfg.use_anscombe:
            #     img0 = anscombe_nmlz.forward(cfg,img0+0.5) - 0.5

            # denoising
            rec_img = model(img0)

            # -- anscombe --
            # if cfg.use_anscombe:
            #     rec_img = anscombe_nmlz.backward(cfg,rec_img + 0.5) - 0.5

            # compare with stacked targets
            # rec_img = rescale_noisy_image(rec_img)
            # if noise_type == "qis": rec_img = quantize_img(cfg,rec_img+0.5)-0.5
            # nmlz_raw = get_nmlz_img(cfg,raw_img)
            loss = F.mse_loss(raw_img, rec_img + 0.5,
                              reduction='none').reshape(BS, -1)
            loss = torch.mean(loss, 1).detach().cpu().numpy()

            # -- check for perfect matches --
            # PSNR_MAX = 50
            # if np.any(np.isinf(loss)):
            #     loss = []
            #     for b in range(BS):
            #         if np.isinf(loss[b]): loss.append(PSNR_MAX)
            #         else: loss.append(loss[b])
            psnr = mse_to_psnr(loss)

            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            if (batch_idx % cfg.test_log_interval) == 0:
                root = Path(
                    f"{settings.ROOT_PATH}/output/mse/rec_imgs/e{epoch}")
                if not root.exists(): root.mkdir(parents=True)
                fn = root / Path(f"b{batch_idx}.png")
                nrow = int(np.sqrt(cfg.batch_size))
                rec_img = rec_img.detach().cpu()
                grid_imgs = tv_utils.make_grid(rec_img,
                                               padding=2,
                                               normalize=True,
                                               nrow=nrow)
                plt.imshow(grid_imgs.permute(1, 2, 0))
                plt.savefig(fn)
                plt.close('all')

    ave_psnr = total_psnr / len(test_loader)
    ave_loss = total_loss / len(test_loader)
    print("Testing results: Ave psnr %2.3e Ave loss %2.3e" %
          (ave_psnr, ave_loss))
    return ave_psnr
Пример #9
0
def train_loop_offset(cfg,model,optimizer,criterion,train_loader,epoch):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    sf_losses,sf_count = 0,0
    kl_losses,kl_count = 0,0
    temporal_losses,temporal_count = 0,0
    write_examples = True
    write_examples_iter = 800
    szm = ScaleZeroMean()
    record = init_record()
    use_record = False

    # if cfg.N != 5: return
    for batch_idx, (burst_imgs, res_imgs, raw_img, directions) in enumerate(train_loader):

        optimizer.zero_grad()
        model.zero_grad()

        # fig,ax = plt.subplots(figsize=(10,10))
        # imgs = burst_imgs + 0.5
        # imgs.clamp_(0.,1.)
        # raw_img = raw_img.expand(burst_imgs.shape)
        # print(imgs.shape,raw_img.shape)
        # all_img = torch.cat([imgs,raw_img],dim=1)
        # print(all_img.shape)
        # grids = [tv_utils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)]
        # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids]
        # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
        # Writer = animation.writers['ffmpeg']
        # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)
        # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer)
        # print("I DID IT!")
        # return

        # -- reshaping of data --
        # raw_img = raw_img.cuda(non_blocking=True)
        input_order = np.arange(cfg.N)
        # print("pre",input_order,cfg.blind,cfg.N)
        middle_img_idx = -1
        if not cfg.input_with_middle_frame:
            middle = len(input_order) // 2
            # print(middle)
            middle_img_idx = input_order[middle]
            # input_order = np.r_[input_order[:middle],input_order[middle+1:]]
        else:
            middle = len(input_order) // 2
            input_order = np.arange(cfg.N)
            middle_img_idx = input_order[middle]
            # input_order = np.arange(cfg.N)
        # print("post",input_order,cfg.blind,cfg.N,middle_img_idx)

        burst_imgs = burst_imgs.cuda(non_blocking=True)
        # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)])
        # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1)
        # print("stacked_burst",stacked_burst.shape)
        # print("burst_imgs.shape",burst_imgs.shape)
        # print("stacked_burst.shape",stacked_burst.shape)

        # -- add input noise --
        burst_imgs_noisy = burst_imgs.clone()
        if cfg.input_noise:
            noise = np.random.rand() * cfg.input_noise_level
            if cfg.input_noise_middle_only:
                burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise)
            else:
                burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise)

        # -- create inputs for kpn --
        stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1)
        cat_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1)
        # print(stacked_burst.shape)
        # print(cat_burst.shape)

        # -- extract target image --
        mid_img =  burst_imgs[middle_img_idx]
        raw_img_zm = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: t_img = szm(raw_img.cuda(non_blocking=True))
        else: t_img = burst_imgs[middle_img_idx]
        
        # -- direct denoising --
        mis_ave = torch.mean(stacked_burst,dim=1)
        # aligned,rec_img,temporal_loss,filters = model(cat_burst,stacked_burst)
        aligned,rec_img,filters = model(cat_burst,stacked_burst)
        temporal_loss = torch.FloatTensor([-1.]).to(cfg.device)

        # print("(a) [m: %2.2e] [std: %2.2e] vs [tgt: %2.2e]" % (torch.mean(mid_img - raw_img_zm).item(),F.mse_loss(mid_img,raw_img_zm).item(),(25./255)**2) )
        # r_raw_img_zm = raw_img_zm.unsqueeze(1).repeat(1,N,1,1,1)
        # print("(b) [m: %2.2e] [std: %2.2e] vs [tgt: %2.2e]" % (torch.mean(aligned - r_raw_img_zm).item(),F.mse_loss(aligned,r_raw_img_zm).item(),(25./255)**2) )

        # -- compare with stacked burst --
        # print(cfg.blind,t_img.min(),t_img.max(),t_img.mean())
        # rec_img = rec_img.expand(t_img.shape)
        # loss = F.mse_loss(t_img,rec_img)

        # -- sparse filter loss (sf_loss) --
        # sf_loss = sparse_filter_loss(filters)
        sf_loss = torch.FloatTensor([-1.]).to(cfg.device)

        # -- compute loss to optimize --
        losses = criterion(aligned, rec_img, t_img, cfg.global_step)
        loss = np.sum(losses) #+ sf_loss + temporal_loss
        # loss = losses[1]
        kpn_loss = loss
        kpn_coeff = 1. # .9997**cfg.global_step
        # temporal_loss = temporal_loss.item()
        # mse_loss = F.mse_loss(rec_img,mid_img)

        # -- compute ot loss to optimize --
        # residuals = aligned - rec_img.unsqueeze(1).repeat(1,N,1,1,1)
        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # ot_loss = ot_pairwise_bp(residuals,reg=1.0,K=5)
        # ot_coeff = 1 - .997**cfg.global_step

        # -- compute kl loss to optimize -- 
        if cfg.supervised: kl_ref = szm(raw_img.cuda(non_blocking=True))
        else: kl_ref = rec_img
        residuals = aligned - kl_ref.unsqueeze(1).repeat(1,N,1,1,1)
        residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        kl_loss = kl_pairwise_bp(residuals,K=100,supervised=cfg.supervised)
        kl_coeff = 100# - .997**cfg.global_step
        # kl_loss = torch.FloatTensor([-1.]).to(cfg.device)

        # -- final loss --
        # loss = ot_coeff * ot_loss + kpn_loss
        # loss = kl_coeff * kl_loss + kpn_coeff * kpn_loss
        loss = kpn_coeff * kpn_loss
            
        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- update sparse filter loss info --
        sf_losses += sf_loss.item()
        sf_count += 1

        # -- update temporal loss info --
        temporal_losses += temporal_loss.item()
        temporal_count += 1

        # -- update temporal loss info --
        kl_losses += kl_loss.item()
        kl_count += 1


        # -- BP and optimize --
        loss.backward()
        optimizer.step()

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for [rec img] --
            BS = raw_img.shape[0]            
            raw_img = raw_img.cuda(non_blocking=True)
            mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))
            running_loss /= cfg.log_interval

            # -- psnr for misaligned ave --
            mse_loss = F.mse_loss(raw_img,mis_ave+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            mis_psnr_ave = np.mean(mse_to_psnr(mse_loss))
            mis_psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0,2)+0.5, sigma_psd=25/255, stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0,2)
                b_loss = F.mse_loss(raw_img[b].cpu(),bm3d_rec,reduction='none').reshape(BS,-1)
                b_loss = torch.mean(b_loss,1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- temporal loss --
            ave_temporal_loss = temporal_losses / temporal_count if temporal_count > 0 else 0
            temporal_losses,temporal_count = 0,0

            # -- sparse filter loss --
            ave_sf_loss = sf_losses / sf_count if sf_count > 0 else 0
            sf_losses,sf_count = 0,0

            # -- kl loss --
            ave_kl_loss = kl_losses / kl_count if kl_count > 0 else 0
            kl_losses,kl_count = 0,0

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx,len(train_loader),running_loss,psnr_ave,psnr_std,bm3d_nb_ave,bm3d_nb_std,
                          mis_psnr_ave,mis_psnr_std,ave_temporal_loss,ave_sf_loss,ave_kl_loss)
            print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [misaligned]: %2.2f +/- %2.2f [loss-t]: %.2e [loss-sf]: %.2e [loss-kl]: %.2e" % write_info)
            # print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f"%(epoch, cfg.epochs, batch_idx,
            #                                                        len(train_loader),
            #                                                        running_loss,psnr_ave,psnr_std))
            running_loss = 0

            # -- record information --
            if use_record:
                rec = rec_img
                raw = raw_img_zm
                frame_results = compute_ot_frame(aligned,rec,raw,reg=0.5)
                burst_results = compute_ot_burst(aligned,rec,raw,reg=0.5)
                psnr_record = {'psnr_ave':psnr_ave,'psnr_std':psnr_std}
                kpn_record = {'kpn_loss':kpn_loss}
                new_record = merge_records(frame_results,burst_results,psnr_record,kpn_record)
                record = record.append(new_record,ignore_index=True)

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0:
            write_input_output(cfg,model,stacked_burst,aligned,filters,directions)

        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss,record
Пример #10
0
def test_loop_offset(cfg,model,criterion,test_loader,epoch):
    model.eval()
    model = model.to(cfg.device)
    total_psnr = 0
    total_loss = 0
    psnrs = np.zeros( (len(test_loader),cfg.batch_size) )
    szm = ScaleZeroMean()

    with torch.no_grad():
        for batch_idx, (burst_imgs, res_imgs, raw_img, directions) in enumerate(test_loader):
            BS = raw_img.shape[0]
            
            # -- selecting input frames --
            input_order = np.arange(cfg.N)
            # print("pre",input_order)
            middle_img_idx = -1
            if not cfg.input_with_middle_frame:
                middle = cfg.N // 2
                # print(middle)
                middle_img_idx = input_order[middle]
                # input_order = np.r_[input_order[:middle],input_order[middle+1:]]
            else:
                middle = len(input_order) // 2
                input_order = np.arange(cfg.N)
                middle_img_idx = input_order[middle]
                # input_order = np.arange(cfg.N)
            
            # -- reshaping of data --
            raw_img = raw_img.cuda(non_blocking=True)
            burst_imgs = burst_imgs.cuda(non_blocking=True)
            stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1)
            cat_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1)
    
            # -- extract images for psnr --
            mid_img =  burst_imgs[middle_img_idx]
            raw_img_zm = szm(raw_img.cuda(non_blocking=True))

            # -- denoising --
            rec_img = model(cat_burst,stacked_burst)[1].detach()

            # if not cfg.input_with_middle_frame:
            #     rec_img = model(cat_burst,stacked_burst)[1]
            # else:
            #     rec_img = model(cat_burst,stacked_burst)[0][middle_img_idx]

            # rec_img = burst_imgs[middle_img_idx] - rec_res
            
            # -- compare with stacked targets --
            rec_img = rescale_noisy_image(rec_img)        

            # -- compute psnr --
            loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1)
            # loss = F.mse_loss(raw_img,burst_imgs[cfg.input_N//2]+0.5,reduction='none').reshape(BS,-1)
            loss = torch.mean(loss,1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)
            psnrs[batch_idx,:] = psnr
                        
            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            # if (batch_idx % cfg.test_log_interval) == 0:
            #     root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}")
            #     if not root.exists(): root.mkdir(parents=True)
            #     fn = root / Path(f"b{batch_idx}.png")
            #     nrow = int(np.sqrt(cfg.batch_size))
            #     rec_img = rec_img.detach().cpu()
            #     grid_imgs = tv_utils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow)
            #     plt.imshow(grid_imgs.permute(1,2,0))
            #     plt.savefig(fn)
            #     plt.close('all')

            if (batch_idx % cfg.test_log_interval) == 0:
                print("[%d/%d] Running Test PSNR: %2.2f" % (batch_idx, len(test_loader), total_psnr / (batch_idx+1)))

    psnr_ave = np.mean(psnrs)
    psnr_std = np.std(psnrs)
    ave_loss = total_loss / len(test_loader)
    print("[N: %d] Testing: [psnr: %2.2f +/- %2.2f] [ave loss %2.3e]"%(cfg.N,psnr_ave,psnr_std,ave_loss))
    return psnr_ave
Пример #11
0
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses,
               writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0
    dynamics_acc, dynamics_count = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-
    dynamics_acc_i = -1.
    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    train_iter = iter(train_loader)
    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        if small_ds and batch_idx >= ds_size:
            if cfg.use_seed:
                init = torch.initial_seed()
                torch.manual_seed(cfg.seed + 1 + epoch + init)
            train_iter = iter(train_loader)  # reset if too big
        sample = next(train_iter)
        burst, raw_img, motion = sample['burst'], sample['clean'], sample[
            'flow']
        raw_img_iid = sample['iid']
        raw_img_iid = raw_img_iid.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)

        # -- handle possibly cached simulated bursts --
        if 'sim_burst' in sample:
            sim_burst = rearrange(sample['sim_burst'],
                                  'b n k c h w -> n b k c h w')
        else:
            sim_burst = None
        non_sim_method = cfg.n2n or cfg.supervised
        if sim_burst is None and not (non_sim_method or cfg.abps):
            if sim_burst is None:
                if cfg.use_kindex_lmdb:
                    kindex = kindex_ds[batch_idx].cuda(non_blocking=True)
                else:
                    kindex = None
                query = burst[[N // 2]]
                database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]])
                sim_burst = compute_similar_bursts(
                    cfg,
                    query,
                    database,
                    cfg.sim_K,
                    noise_level / 255.,
                    patchsize=cfg.sim_patchsize,
                    shuffle_k=cfg.sim_shuffleK,
                    kindex=kindex,
                    only_middle=cfg.sim_only_middle,
                    search_method=cfg.sim_method,
                    db_level="frame")

        if (sim_burst is None) and cfg.abps:
            # scores,aligned = abp_search(cfg,burst)
            # scores,aligned = lpas_search(cfg,burst,motion)
            if cfg.lpas_method == "spoof":
                mtype = "global"
                acc = cfg.optical_flow_acc
                scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype,
                                             acc)
            else:
                ref_frame = (cfg.nframes + 1) // 2
                nblocks = cfg.nblocks
                method = cfg.lpas_method
                scores, aligned, dacc = lpas_search(burst, ref_frame, nblocks,
                                                    motion, method)
                dynamics_acc_i = dacc
            # scores,aligned = lpas_spoof(motion,accuracy=cfg.optical_flow_acc)
            # shuffled = shuffle_aligned_pixels_noncenter(aligned,cfg.nframes)
            nsims = cfg.nframes
            sim_aligned = create_sim_from_aligned(burst, aligned, nsims)
            burst_s = rearrange(burst, 't b c h w -> t b 1 c h w')
            sim_burst = torch.cat([burst_s, sim_aligned], dim=2)
            # print("sim_burst.shape",sim_burst.shape)

        # raw_img = raw_img.cuda(non_blocking=True)-0.5
        # # print(np.sqrt(cfg.noise_params['g']['stddev']))
        # print(motion)
        # tiled = tile_across_blocks(burst[[cfg.nframes//2]],cfg.nblocks)
        # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2])
        # for t in range(cfg.nframes):
        #     save_image(tiled[0] - rep_burst[t],f"tiled_sub_burst_{t}.png")
        # save_image(aligned,"aligned.png")
        # print(aligned.shape)
        # # save_image(aligned[0] - aligned[cfg.nframes//2],"aligned_0.png")
        # # save_image(aligned[2] - aligned[cfg.nframes//2],"aligned_2.png")
        # M = (1+cfg.dynamic.ppf)*cfg.nframes
        # fs = cfg.dynamic.frame_size - M
        # fs = cfg.frame_size
        # cropped = crop_center_patch([burst,aligned,raw_img],cfg.nframes,cfg.frame_size)
        # burst,aligned,raw_img = cropped[0],cropped[1],cropped[2]
        # print(aligned.shape)
        # for t in range(cfg.nframes+1):
        #     diff_t = aligned[t] - raw_img
        #     spacing = cfg.nframes+1
        #     diff_t = crop_center_patch([diff_t],spacing,cfg.frame_size)[0]
        #     print_tensor_stats(f"diff_aligned_{t}",diff_t)
        #     save_image(diff_t,f"diff_aligned_{t}.png")
        #     if t < cfg.nframes:
        #         dt = aligned[t+1]-aligned[t]
        #         dt = crop_center_patch([dt],spacing,cfg.frame_size)[0]
        #         save_image(dt,f"dt_aligned_{t+1}m{t}.png")
        #     save_image(aligned[t],f"aligned_{t}.png")
        #     diff_t = tvF.crop(aligned[t] - raw_img,cfg.nframes,cfg.nframes,fs,fs)
        #     print_tensor_stats(f"diff_aligned_{t}",diff_t)

        # save_image(burst,"burst.png")
        # save_image(burst[0] - burst[cfg.nframes//2],"burst_0.png")
        # save_image(burst[2] - burst[cfg.nframes//2],"burst_2.png")
        # exit()

        # print(sample['burst'].shape,sample['res'].shape)
        # b_clean = sample['burst'] - sample['res']
        # scores,ave,t_aligned = test_abp_global_search(cfg,b_clean,noisy_img=burst)

        # burstBN = rearrange(burst,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(burstBN,"abps_burst.png",normalize=True)
        # alignedBN = rearrange(aligned,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(alignedBN,"abps_aligned.png",normalize=True)
        # rep_burst = burst[[N//2]].repeat(N,1,1,1,1)
        # deltaBN = rearrange(aligned - rep_burst,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(deltaBN,"abps_delta.png",normalize=True)
        # b_clean_rep = b_clean[[N//2]].repeat(N,1,1,1,1)
        # tdeltaBN = rearrange(t_aligned - b_clean_rep.cpu(),'n b c h w -> (b n) c h w')
        # tv_utils.save_image(tdeltaBN,"abps_tdelta.png",normalize=True)

        if non_sim_method:
            sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1)
        else:
            sim_burst = sim_burst.cuda(non_blocking=True)
        if use_timer: data_clock.toc()

        # -- to cuda --
        burst = burst.cuda(non_blocking=True)
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        # anscombe.test(cfg,burst_og)
        # save_image(burst,f"burst_{batch_idx}_{cfg.n2n}.png")

        # -- crop images --
        if True:  #cfg.abps or cfg.abps_inputs:
            images = [burst, sim_burst, raw_img, raw_img_iid]
            spacing = burst.shape[0]  # we use frames as spacing
            cropped = crop_center_patch(images, spacing, cfg.frame_size)
            burst, sim_burst = cropped[0], cropped[1]
            raw_img, raw_img_iid = cropped[2], cropped[3]
            if cfg.abps or cfg.abps_inputs:
                aligned = crop_center_patch([aligned], spacing,
                                            cfg.frame_size)[0]
            # print_tensor_stats("d-eq?",burst[-1] - aligned[-1])
            burst = burst[:cfg.nframes]  # last frame is target

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst_og = burst.clone()

        # -- shuffle over Simulated Samples --
        k_ins, k_outs = create_k_grid(sim_burst, shuffle=True)
        k_ins, k_outs = [k_ins[0]], [k_outs[0]]
        # k_ins,k_outs = create_k_grid_v3(sim_burst)

        for k_in, k_out in zip(k_ins, k_outs):
            if k_in == k_out: continue

            # -- zero gradients; ready 2 go --
            model.align_info.model.zero_grad()
            model.align_info.optim.zero_grad()
            model.denoiser_info.model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            model.unet_info.model.zero_grad()
            model.unet_info.optim.zero_grad()

            # -- compute input/output data --
            if cfg.sim_only_middle and (not cfg.abps):
                # sim_burst.shape == T,B,K,C,H,W
                midi = 0 if sim_burst.shape[0] == 1 else N // 2
                left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:]
                cat_burst = [
                    left_burst, sim_burst[[midi], :, k_in], right_burst
                ]
                burst = torch.cat(cat_burst, dim=0)
                mid_img = sim_burst[midi, :, k_out]
            elif cfg.abps and (not cfg.abps_inputs):
                # -- v1 --
                mid_img = aligned[-1]

                # -- v2 --
                # left_aligned,right_aligned = aligned[:N//2],aligned[N//2+1:]
                # nc_aligned = torch.cat([left_aligned,right_aligned],dim=0)
                # shuf = shuffle_aligned_pixels(nc_aligned,cfg.nframes)
                # mid_img = shuf[1]

                # ---- v3 ----
                # shuf = shuffle_aligned_pixels(aligned)
                # shuf = aligned[[N//2,0]]
                # midi = 0 if sim_burst.shape[0] == 1 else N//2
                # left_burst,right_burst = burst[:N//2],burst[N//2+1:]
                # burst = torch.cat([left_burst,shuf[[0]],right_burst],dim=0)
                # nc_burst = torch.cat([left_burst,right_burst],dim=0)
                # shuf = shuffle_aligned_pixels(aligned)

                # ---- v4 ----
                # nc_shuf = shuffle_aligned_pixels(nc_aligned)
                # mid_img = nc_shuf[0]
                # pick = npr.randint(0,2,size=(1,))[0]
                # mid_img = nc_aligned[pick]
                # mid_img = shuf[1]

                # save_image(shuf,"shuf.png")
                # print(shuf.shape)

                # diff = raw_img.cuda(non_blocking=True) - aligned[0]
                # mean = torch.mean(diff).item()
                # std = torch.std(diff).item()
                # print(mean,std)

                # -- v1 --
                # burst = burst
                # notMid = sample_not_mid(N)
                # mid_img = aligned[notMid]

            elif cfg.abps_inputs:
                burst = aligned.clone()
                burst_og = aligned.clone()
                mid_img = shuffle_aligned_pixels(burst, cfg.nframes)[0]

            else:
                burst = sim_burst[:, :, k_in]
                mid_img = sim_burst[N // 2, :, k_out]
            # mid_img =  sim_burst[N//2,:]
            # print(burst.shape,mid_img.shape)
            # print(F.mse_loss(burst,mid_img).item())
            if cfg.supervised:
                gt_img = get_nmlz_tgt_img(cfg, raw_img).cuda(non_blocking=True)
            elif cfg.n2n:
                gt_img = raw_img_iid  #noise_xform(raw_img).cuda(non_blocking=True)
            else:
                gt_img = mid_img

            # another = noise_xform(raw_img).cuda(non_blocking=True)
            # print_tensor_stats("a-iid?",raw_img_iid.cuda() - raw_img.cuda())
            # print_tensor_stats("b-iid?",mid_img.cuda() - raw_img.cuda())
            # print_tensor_stats("c-iid?",mid_img.cuda() - another)
            # print_tensor_stats("d-iid?",raw_img_iid.cuda() - another)
            # print_tensor_stats("e-iid?",mid_img.cuda() - raw_img_iid.cuda())

            # for bt in range(cfg.nframes):
            #     tiled = tile_across_blocks(burst[[bt]],cfg.nblocks)
            #     rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2])
            #     for t in range(cfg.nframes):
            #         save_image(tiled[0] - rep_burst[t],f"tiled_{bt}_sub_burst_{t}.png")
            #         print_tensor_stats(f"delta_{bt}_{t}",tiled[0,:,4] - burst[t])

            # raw_img = raw_img.cuda(non_blocking=True) - 0.5
            # print_tensor_stats("gt_img - raw",gt_img - raw_img)
            # # save_image(gt_img,"gt.png")
            # # save_image(raw,"raw.png")
            # save_image(gt_img - raw_img,"gt_sub_raw.png")
            # print_tensor_stats("burst[N//2] - raw",burst[N//2] - raw_img)
            # save_image(burst[N//2] - raw_img,"burst_sub_raw.png")
            # print_tensor_stats("burst[N//2] - gt_img",burst[N//2] - gt_img)
            # save_image(burst[N//2] - gt_img,"burst_sub_gt.png")
            # print_tensor_stats("aligned[N//2] - raw",aligned[N//2] - raw_img)
            # save_image(aligned[N//2] - raw_img,"aligned_sub_raw.png")
            # print_tensor_stats("aligned[N//2] - burst[N//2]",
            # aligned[N//2] - burst[N//2])
            # save_image(aligned[N//2] - burst[N//2],"aligned_sub_burst.png")
            # gt_img = torch.normal(raw_zm_img,noise_level/255.)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Dataset Augmentation
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # burst,gt_img = apply_transformations(burst,gt_img)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #      Formatting Images for FP
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
            cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #           Foward Pass
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            outputs = model(burst)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #    Decrease Entropy within a Kernel
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            filters_entropy = 0
            filters_entropy_coeff = 0.  # 1000.
            all_filters = []
            L = len(align_hook.filters)
            iter_filters = align_hook.filters if L > 0 else [aligned_filters]
            for filters in iter_filters:
                f_shape = 'b n k2 c h w -> (b n c h w) k2'
                filters_shaped = rearrange(filters, f_shape)
                filters_entropy += one  #entropyLoss(filters_shaped)
                all_filters.append(filters)
            if L > 0: filters_entropy /= L
            all_filters = torch.stack(all_filters, dim=1)
            align_hook.clear()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #   Reconstruction Losses (MSE)
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            losses = [F.mse_loss(denoised_ave, gt_img)]
            # losses = denoiseLossMSE(denoised,denoised_ave,gt_img,cfg.global_step)
            # losses = [ one, one ]
            # ave_loss,burst_loss = [loss.item() for loss in losses]
            rec_mse = np.sum(losses)
            # rec_mse = F.mse_loss(denoised_ave,gt_img)
            rec_mse_coeff = 1.

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #    Reconstruction Losses (Distribution)
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            gt_img_rep = gt_img.unsqueeze(1).repeat(1, denoised.shape[1], 1, 1,
                                                    1)
            residuals = denoised - gt_img_rep
            rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            # rec_ot = kl_gaussian_bp(residuals,noise_level,flip=True)
            # rec_ot = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
            if torch.any(torch.isnan(rec_ot)):
                rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            if torch.any(torch.isinf(rec_ot)):
                rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            rec_ot_coeff = 0.

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Final Losses
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            rec_loss = rec_mse_coeff * rec_mse + rec_ot_coeff * rec_ot
            final_loss = rec_loss

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Record Keeping
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- reconstruction MSE --
            rec_mse_losses += rec_mse.item()
            rec_mse_count += 1

            # -- reconstruction Dist. --
            rec_ot_losses += rec_ot.item()
            rec_ot_count += 1

            # -- dynamic acc -
            dynamics_acc += dynamics_acc_i
            dynamics_count += 1

            # -- total loss --
            running_loss += final_loss.item()
            total_loss += final_loss.item()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Gradients & Backpropogration
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- compute the gradients! --
            if cfg.use_seed: torch.set_deterministic(False)
            final_loss.backward()
            if cfg.use_seed: torch.set_deterministic(True)

            # -- backprop now. --
            model.align_info.optim.step()
            model.denoiser_info.optim.step()
            model.unet_info.optim.step()
            scheduler.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- recompute model output for original images --
            outputs = model(burst_og)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            raw_img = get_nmlz_tgt_img(cfg, raw_img)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, m_aligned_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(burst_og, dim=0)
            if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave)
            mse_loss = F.mse_loss(raw_img, mis_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25))
            # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25))

            # -- psnr for [bm3d] --
            mid_img_og = burst[N // 2]
            bm3d_nb_psnrs = []
            M = 4 if B > 4 else B
            for b in range(M):
                bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                # maybe an issue here
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for input averaged frames --
            # burst_ave = torch.mean(burst_og,dim=0)
            # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1)
            # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            # psnr_input_ave = np.mean(mse_to_psnr(mse_loss))
            # psnr_input_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for aligned + denoised --
            R = denoised.shape[1]
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1)
            # if noise_type == "qis": denoised = quantize_img(cfg,denoised)
            # save_image(denoised_ave,"denoised_ave.png")
            # save_image(denoised,"denoised.png")
            mse_loss = F.mse_loss(raw_img_repN, denoised,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img, denoised_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- ave dynamic acc --
            ave_dyn_acc = dynamics_acc / dynamics_count * 100.
            dynamics_acc, dynamics_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, ave_dyn_acc)  #rec_ot_ave)

            #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e"
                % write_info,
                flush=True)
            # -- write to summary writer --
            if writer:
                writer.add_scalar('train/running-loss', running_loss,
                                  cfg.global_step)
                writer.add_scalars('train/model-psnr', {
                    'ave': psnr,
                    'std': psnr_std
                }, cfg.global_step)
                writer.add_scalars('train/dn-frame-psnr', {
                    'ave': psnr_denoised_ave,
                    'std': psnr_denoised_std
                }, cfg.global_step)

            # -- reset loss --
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, motion)

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses