Exemple #1
0
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch,
               record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    write_examples_iter = 800
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    D = 5 * 10**3
    steps_per_epoch = len(train_loader)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        optimizer.zero_grad()
        model.zero_grad()
        model.denoiser_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, BS, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        aligned, aligned_ave, denoised, denoised_ave, filters = model(burst)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Entropy Loss for Filters
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N)
        filters_entropy = entropyLoss(filters_shaped)
        filters_entropy_coeff = 10.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0  #.933**cfg.global_step if cfg.global_step < 100 else 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_ave_d = denoised_ave.detach()
        losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- regularization scheduler --
        if cfg.global_step < 100: reg = 0.5
        elif cfg.global_step < 200: reg = 0.25
        elif cfg.global_step < 5000: reg = 0.15
        elif cfg.global_step < 10000: reg = 0.1
        else: reg = 0.05

        # -- computation --
        residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        residuals = rearrange(residuals, 'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level)
        # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3)
        rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2.
        rec_ot_pair_coeff = 100  # - .997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        align_loss = align_mse_coeff * align_mse
        rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse
        entropy_loss = filters_entropy_coeff * filters_entropy
        final_loss = align_loss + rec_loss + entropy_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot_pair.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.denoiser_info.optim.step()
        optimizer.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #2
0
def train_loop(cfg,model,optimizer,criterion,train_loader,epoch):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()

    for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader):

        optimizer.zero_grad()
        model.zero_grad()

        # -- reshaping of data --
        # raw_img = raw_img.cuda(non_blocking=True)
        input_order = np.arange(cfg.N)
        # print("pre",input_order,cfg.blind,cfg.N)
        middle_img_idx = -1
        if not cfg.input_with_middle_frame:
            middle = len(input_order) // 2
            # print(middle)
            middle_img_idx = input_order[middle]
            input_order = np.r_[input_order[:middle],input_order[middle+1:]]
        else:
            middle = len(input_order) // 2
            middle_img_idx = input_order[middle]
            input_order = np.arange(cfg.N)
        # print("post",input_order,middle_img_idx,cfg.blind,cfg.N)

        # -- add input noise --
        burst_imgs = burst_imgs.cuda(non_blocking=True)
        burst_imgs_noisy = burst_imgs.clone()
        if cfg.input_noise:
            noise = np.random.rand() * cfg.input_noise_level
            burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise)

        # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)])
        if cfg.color_cat:
            stacked_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1)
        else:
            stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1)
        # print("stacked_burst",stacked_burst.shape)

        # if cfg.input_noise:
        #     stacked_burst = torch.normal(stacked_burst,noise)

        # -- extract target image --
        if cfg.blind:
            t_img = burst_imgs[middle_img_idx]
        else:
            t_img = szm(raw_img.cuda(non_blocking=True))

        # -- denoising --
        rec_img = model(stacked_burst)

        # -- compute loss --
        loss = F.mse_loss(t_img,rec_img)

        # -- dncnn denoising --
        # rec_res = model(stacked_burst)

        # -- compute loss --
        # t_res = t_img - burst_imgs[middle_img_idx]
        # loss = F.mse_loss(t_res,rec_res)

        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- BP and optimize --
        loss.backward()
        optimizer.step()


        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            BS = raw_img.shape[0]            
            raw_img = raw_img.cuda(non_blocking=True)
            mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            running_loss /= cfg.log_interval
            print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e"%(epoch, cfg.epochs, batch_idx,
                                                         len(train_loader),
                                                         running_loss,psnr))

    total_loss /= len(train_loader)
    return total_loss
Exemple #3
0
def train_loop(cfg, model, train_loader, epoch, record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    align_ot_losses, align_ot_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    steps_per_epoch = len(train_loader)
    write_examples_iter = steps_per_epoch // 3

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        model.align_info.model.zero_grad()
        model.align_info.optim.zero_grad()
        model.denoiser_info.model.zero_grad()
        model.denoiser_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        outputs = model(burst)
        aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
        aligned_filters, denoised_filters = outputs[4:]

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_filters = rearrange(denoised_filters.detach(),
                                     'b n k2 c h w -> n (b k2 c h w)')
        norms = denoised_filters.norm(dim=1)
        norm_loss_denoiser = torch.mean((norms - norms[N // 2])**2)
        norm_loss_coeff = 100.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Decrease Entropy within a Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_entropy = 0
        filters_entropy_coeff = 1000.
        all_filters = []
        L = len(align_hook.filters)
        iter_filters = align_hook.filters if L > 0 else [aligned_filters]
        for filters in iter_filters:
            filters_shaped = rearrange(filters,
                                       'b n k2 c h w -> (b n c h w) k2',
                                       n=N)
            filters_entropy += entropyLoss(filters_shaped)
            all_filters.append(filters)
        if L > 0: filters_entropy /= L
        all_filters = torch.stack(all_filters, dim=1)
        align_hook.clear()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Increase Entropy across each Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_dist_entropy = 0

        # -- across each frame --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each batch --
        filters_shaped = rearrange(all_filters,
                                   'b l n k2 c h w -> (n l) (b c h w) k2')
        filters_shaped = torch.mean(filters_shaped, dim=1)
        filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each kpn cascade --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        filters_dist_coeff = 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Alignment Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        fs = cfg.dynamic.frame_size
        residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        centered_residuals = tvF.center_crop(residuals, (fs // 2, fs // 2))
        align_ot = kl_gaussian_bp(centered_residuals, noise_level, flip=True)
        align_ot_coeff = 100.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = denoiseLossMSE(denoised, denoised_ave, gt_img,
                                cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- regularization scheduler --
        if cfg.global_step < 100: reg = 0.5
        elif cfg.global_step < 200: reg = 0.25
        elif cfg.global_step < 5000: reg = 0.15
        elif cfg.global_step < 10000: reg = 0.1
        else: reg = 0.05

        # -- computation --
        residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        rec_ot_loss_v1 = kl_gaussian_bp(residuals, noise_level, flip=True)
        # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals)
        # rec_ot_loss_v1 = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
        # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3)
        rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2)
        rec_ot_coeff = 100.  # - .997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse
        norm_loss = norm_loss_coeff * norm_loss_denoiser
        align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot
        entropy_loss = filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy
        final_loss = align_loss + rec_loss + entropy_loss + norm_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- alignment Dist --
        align_ot_losses += align_ot.item()
        align_ot_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.align_info.optim.step()
        model.denoiser_info.optim.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            M = 10 if B > 10 else B
            for b in range(B):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- alignment Dist. --
            align_ot_ave = align_ot_losses / align_ot_count
            align_ot_losses, align_ot_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #4
0
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    spoof_loss = 0

    optimizer.zero_grad()
    model.zero_grad()

    # if cfg.N != 5: return
    # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader):
    for batch_idx, stuff in enumerate(train_loader):
        if len(stuff) == 3: burst_imgs, res_imgs, raw_img = stuff
        elif len(stuff) == 2: burst_imgs, raw_img = stuff
        else: exit("WHAT?")

        optimizer.zero_grad()
        model.zero_grad()

        # fig,ax = plt.subplots(figsize=(10,10))
        # imgs = burst_imgs + 0.5
        # imgs.clamp_(0.,1.)
        # raw_img = raw_img.expand(burst_imgs.shape)
        # print(imgs.shape,raw_img.shape)
        # all_img = torch.cat([imgs,raw_img],dim=1)
        # print(all_img.shape)
        # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)]
        # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids]
        # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
        # Writer = animation.writers['ffmpeg']
        # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)
        # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer)
        # print("I DID IT!")
        # return

        # -- reshaping of data --
        # raw_img = raw_img.cuda(non_blocking=True)
        input_order = np.arange(cfg.N)
        # print("pre",input_order,cfg.blind,cfg.N)
        middle = len(input_order) // 2
        middle_img_idx = input_order[middle]
        if not cfg.input_with_middle_frame:
            middle = len(input_order) // 2
            # print(middle)
            middle_img_idx = input_order[middle]
            input_order = np.r_[input_order[:middle], input_order[middle + 1:]]
        else:
            input_order = np.arange(cfg.N)

        # -- add input noise --
        burst_imgs = burst_imgs.cuda(non_blocking=True)
        burst_imgs_noisy = burst_imgs.clone()
        if cfg.input_noise:
            noise = np.random.rand() * cfg.input_noise_level
            burst_imgs_noisy[middle_img_idx] = torch.normal(
                burst_imgs_noisy[middle_img_idx], noise)

        # -- stack images --
        stacked_burst = torch.stack(
            [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],
            dim=1)

        # print("post",input_order,cfg.blind,cfg.N,middle_img_idx)
        # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)])
        # print("stacked_burst",stacked_burst.shape)

        # -- extract target image --
        if cfg.blind:
            t_img = burst_imgs[middle_img_idx]
        else:
            t_img = szm(raw_img.cuda(non_blocking=True))

        # -- denoising --
        loss, rec_img = model(stacked_burst, t_img)

        # -- spoof batch size --
        if cfg.spoof_batch:
            spoof_loss += loss

        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- BP and optimize --
        if cfg.spoof_batch and (batch_idx % cfg.spoof_batch_size) == 0:
            spoof_loss.backward()
            optimizer.step()

            optimizer.zero_grad()
            model.zero_grad()
            spoof_loss = 0

        elif not cfg.spoof_batch:
            loss.backward()
            optimizer.step()

            optimizer.zero_grad()
            model.zero_grad()

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:
            running_loss /= cfg.log_interval
            print("[%d/%d][%d/%d]: %2.3e " % (epoch, cfg.epochs, batch_idx,
                                              len(train_loader), running_loss))
            running_loss = 0
    total_loss /= len(train_loader)
    return total_loss
Exemple #5
0
def test_loop(cfg, model, criterion, test_loader, epoch):
    model.eval()
    model = model.to(cfg.device)
    total_psnr = 0
    total_loss = 0
    szm = ScaleZeroMean()
    with torch.no_grad():
        # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader):
        for batch_idx, stuff in enumerate(test_loader):
            if len(stuff) == 3: burst_imgs, res_imgs, raw_img = stuff
            elif len(stuff) == 2: burst_imgs, raw_img = stuff
            else: exit("WHAT?")

            BS = raw_img.shape[0]

            # -- selecting input frames --
            input_order = np.arange(cfg.N)
            # print("pre",input_order)
            middle_img_idx = -1
            if not cfg.input_with_middle_frame:
                middle = cfg.N // 2
                # print(middle)
                middle_img_idx = input_order[middle]
                input_order = np.r_[input_order[:middle],
                                    input_order[middle + 1:]]
            else:
                input_order = np.arange(cfg.N)

            if cfg.blind:
                t_img = burst_imgs[middle_img_idx]
            else:
                t_img = szm(raw_img.cuda(non_blocking=True))
            t_img = t_img.to(cfg.device)

            # reshaping of data
            raw_img = raw_img.cuda(non_blocking=True)
            burst_imgs = burst_imgs.cuda(non_blocking=True)
            stacked_burst = torch.stack(
                [burst_imgs[input_order[x]] for x in range(cfg.input_N)],
                dim=1)

            # denoising
            loss, rec_img = model(stacked_burst, t_img)
            # rec_img = burst_imgs[middle_img_idx] - rec_res

            # compare with stacked targets
            rec_img = rescale_noisy_image(rec_img)
            loss = F.mse_loss(raw_img, rec_img,
                              reduction='none').reshape(BS, -1)
            loss = torch.mean(loss, 1).detach().cpu().numpy()
            psnr = mse_to_psnr(loss)

            total_psnr += np.mean(psnr)
            total_loss += np.mean(loss)

            if (batch_idx % cfg.test_log_interval) == 0:
                root = Path(
                    f"{settings.ROOT_PATH}/output/attn/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}"
                )
                if not root.exists(): root.mkdir(parents=True)
                fn = root / Path(f"b{batch_idx}.png")
                nrow = int(np.sqrt(cfg.batch_size))
                rec_img = rec_img.detach().cpu()
                grid_imgs = vutils.make_grid(rec_img,
                                             padding=2,
                                             normalize=True,
                                             nrow=nrow)
                plt.imshow(grid_imgs.permute(1, 2, 0))
                plt.savefig(fn)
                plt.close('all')

    ave_psnr = total_psnr / len(test_loader)
    ave_loss = total_loss / len(test_loader)
    print(
        "[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e" %
        (cfg.blind, cfg.N, ave_psnr, ave_loss))
    return ave_psnr
Exemple #6
0
def train_loop(cfg,model_target,model_online,optim_target,optim_online,criterion,train_loader,epoch,record_losses):


    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #     setup for train epoch
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-


    # -- setup for training --
    model_online.train()
    model_online = model_online.to(cfg.device)

    model_target.train()
    model_target = model_target.to(cfg.device)

    moving_average_decay = 0.99
    ema_updater = EMA(moving_average_decay)
    
    # -- init vars --
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize,1,0,blocksize)
    D = 5 * 10**3
    use_record = False
    if record_losses is None: record_losses = pd.DataFrame({'burst':[],'ave':[],'ot':[],'psnr':[],'psnr_std':[]})
    nc_losses,nc_count = 0,0
    al_ot_losses,al_ot_count = 0,0
    rec_ot_losses,rec_ot_count = 0,0
    write_examples = True
    write_examples_iter = 800
    noise_level = cfg.noise_params['g']['stddev']
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #       run training epoch
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    for batch_idx, (burst, res_imgs, raw_img, directions) in enumerate(train_loader):
        if batch_idx > D: break

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #        forward pass
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- zero gradient --
        optim_online.zero_grad()
        optim_target.zero_grad()
        model_online.zero_grad()
        model_online.denoiser_info.optim.zero_grad()
        model_target.zero_grad()
        model_target.denoiser_info.optim.zero_grad()

        # -- reshaping of data --
        N,BS,C,H,W = burst.shape
        burst = burst.cuda(non_blocking=True)
        stacked_burst = rearrange(burst,'n b c h w -> b n c h w')

        # -- create target image --
        mid_img = burst[N//2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: t_img = szm(raw_img.cuda(non_blocking=True))
        else: t_img = burst[N//2]
        
        # -- direct denoising --
        aligned_o,aligned_ave_o,denoised_o,rec_img_o,filters_o = model_online(burst)
        aligned_t,aligned_ave_t,denoised_t,rec_img_t,filters_t = model_target(burst)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #        alignment losses
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-


        # -- compute aligned losses to optimize --
        rec_img_d_o = rec_img_o.detach()
        losses = criterion(aligned_o,aligned_ave_o,rec_img_d_o,t_img,raw_zm_img,cfg.global_step)
        nc_loss,ave_loss,burst_loss,ot_loss = [loss.item() for loss in losses]
        kpn_loss = losses[1] + losses[2] # np.sum(losses)
        kpn_coeff = .9997**cfg.global_step

        # -- OT loss --
        al_ot_loss = torch.FloatTensor([0.]).to(cfg.device)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #    reconstruction losses
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            
        # -- decaying rec loss --
        rec_mse_coeff = 0.997**cfg.global_step
        rec_mse_loss = F.mse_loss(rec_img_o,mid_img)

        # -- BYOL loss --
        byol_loss = F.mse_loss(rec_img_o,rec_img_t)

        # -- OT loss -- 
        rec_ot_coeff = 100
        residuals = denoised_o - mid_img.unsqueeze(1).repeat(1,N,1,1,1)
        residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # rec_ot_loss = ot_pairwise_bp(residuals,K=3)
        rec_ot_loss = torch.FloatTensor([0.]).to(cfg.device)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #    final losses & recording
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- final losses --
        align_loss = kpn_coeff * kpn_loss
        denoise_loss = rec_ot_coeff * rec_ot_loss + byol_loss  + rec_mse_coeff * rec_mse_loss

        # -- update alignment kl loss info --
        al_ot_losses += al_ot_loss.item()
        al_ot_count += 1

        # -- update reconstruction kl loss info --
        rec_ot_losses += rec_ot_loss.item()
        rec_ot_count += 1

        # -- update info --
        if not np.isclose(nc_loss,0):
            nc_losses += nc_loss
            nc_count += 1
        running_loss += align_loss.item() + denoise_loss.item()
        total_loss += align_loss.item() + denoise_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #    backprop and optimize
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        loss = align_loss + denoise_loss
        loss.backward()

        # -- backprop for [online] --
        optim_online.step()
        model_online.denoiser_info.optim.step()

        # -- exponential moving average for [target] --
        update_moving_average(ema_updater,model_target,model_online)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #      message to stdout
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:
            # -- compute mse for fun --
            BS = raw_img.shape[0]            
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img,aligned_ave_o+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst,dim=1)
            mse_loss = F.mse_loss(raw_img,mis_ave+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0,2)+0.5, sigma_psd=noise_level/255, stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0,2)
                b_loss = F.mse_loss(raw_img[b].cpu(),bm3d_rec,reduction='none').reshape(BS,-1)
                b_loss = torch.mean(b_loss,1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1,N,1,1,1)
            mse_loss = F.mse_loss(raw_img_repN,denoised_o+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,rec_img_o+0.5,reduction='none').reshape(BS,-1)
            mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- write record --
            if use_record:
                record_losses = record_losses.append({'burst':burst_loss,'ave':ave_loss,'ot':ot_loss,'psnr':psnr,'psnr_std':psnr_std},ignore_index=True)
                
            # -- update losses --
            running_loss /= cfg.log_interval
            ave_nc_loss = nc_losses / nc_count if nc_count > 0 else 0

            # -- alignment kl loss --
            ave_al_ot_loss = al_ot_losses / al_ot_count if al_ot_count > 0 else 0
            al_ot_losses,al_ot_count = 0,0

            # -- reconstruction kl loss --
            ave_rec_ot_loss = rec_ot_losses / rec_ot_count if rec_ot_count > 0 else 0
            rec_ot_losses,rec_ot_count = 0,0


            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx,len(train_loader),running_loss,psnr,psnr_std,
                          psnr_denoised_ave,psnr_denoised_std,psnr_aligned_ave,psnr_aligned_std,
                          psnr_misaligned_ave,psnr_misaligned_std,bm3d_nb_ave,bm3d_nb_std,
                          ave_nc_loss,ave_rec_ot_loss,ave_al_ot_loss)
            print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [denoised]: %2.2f +/- %2.2f [aligned]: %2.2f +/- %2.2f [misaligned]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [loss-nc]: %.2e [loss-rot]: %.2e [loss-aot]: %.2e" % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg,stacked_burst,aligned_o,denoised_o,filters_o,directions)

        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss,record_losses