Exemple #1
0
def run_experiment_serial(cfgs, version, mode, use_ddp=True):
    for cfg in cfgs:
        t = Timer()

        # log process
        record_experiment(cfg, version, mode, 'start', t)
        print(_build_v1_summary(cfg))

        cfg.mode = mode
        if cfg.mode == "test":
            cfg.load = True
            cfg.epoch_num = cfg.epochs  # load last epoch
        else:
            cfg.load = False
            cfg.epoch_num = -1
            # cfg.load = True
            # cfg.epoch_num = 500

        if use_ddp:
            # cfg.use_apex = False
            run_ddp(cfg=cfg)
        else:
            run_serial(cfg=cfg)

        record_experiment(cfg, version, mode, 'start', t)
Exemple #2
0
def run_experiment_parallel(cfgs, version, mode, use_ddp=False):

    ngpus = 3
    nproc_per_gpu = 1
    max_procs = nproc_per_gpu * ngpus
    gpuid = 0
    procs = []

    # for idx,cfg in enumerate(cfgs):
    for idx in [0, 3, 6, 9]:
        cfg = cfgs[idx]
        # if idx < start_idx: continue
        print(f"Running idx {idx}")
        # log process
        t = Timer()

        # wait if need to
        if len(procs) == max_procs:
            wait(procs)
            procs = []

        # run in proper mode
        cfg.use_ddp = False
        if cfg.mode == "test":
            cfg.load = True
            cfg.epoch_num = cfg.epochs  # load last epoch
        else:
            cfg.load = False
            cfg.epoch_num = -1

        # log process
        record_experiment(cfg, version, mode, 'start', t)

        # launch process
        p = run_process(version, idx, gpuid)

        # print what process is being launched
        print(_build_summary(cfg, version))

        # create separate tensorboard files
        time.sleep(5)

        # update current gpuid
        gpuid = (gpuid + 1) % ngpus

        # add process to list
        procs.append(p)

    # wait if need to
    if len(procs) > 0:
        wait(procs)
        procs = []
Exemple #3
0
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses,
               writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0
    dynamics_acc, dynamics_count = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-
    dynamics_acc_i = -1.
    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    train_iter = iter(train_loader)
    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        sample = next(train_iter)
        burst, raw_img, motion = sample['burst'], sample['clean'], sample[
            'directions']
        raw_img_iid = sample['iid']
        raw_img_iid = raw_img_iid.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)

        aligned, est_nnf = align_burst(cfg, burst, model)
        sim_images = subsample_aligned(cfg, aligned)
        burst_in, tgt_out = create_training_pairs(burst, sim_images)

        dn_losses = []
        for burst, target in zip(burst_in, tgt_out):

            # -- forward pass --
            est_denoised = model(burst)
            dn_loss = compute_denoising_loss(est_denoised, target)

            # -- compute grads --
            if cfg.use_seed: torch.set_deterministic(False)
            dn_loss.backward()
            if cfg.use_seed: torch.set_deterministic(True)

            # -- backprop --
            optim.step()
            scheduler.step()

            # -- store info --
            losses.append(dn_loss.item())

        # -- average over losses --
        dn_loss = torch.mean(dn_losses)

        # -- alignment loss --
        align_loss = compute_nnf_loss(gt_nnf, est_nnf)

        # -- total loss --
        final_loss = dn_loss + align_loss
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- recompute model output for original images --
            outputs = model(burst_og)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            raw_img = get_nmlz_tgt_img(cfg, raw_img)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, m_aligned_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(burst_og, dim=0)
            if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave)
            mse_loss = F.mse_loss(raw_img, mis_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25))
            # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25))

            # -- psnr for [bm3d] --
            mid_img_og = burst[N // 2]
            bm3d_nb_psnrs = []
            M = 4 if B > 4 else B
            for b in range(M):
                bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                # maybe an issue here
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for input averaged frames --
            # burst_ave = torch.mean(burst_og,dim=0)
            # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1)
            # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            # psnr_input_ave = np.mean(mse_to_psnr(mse_loss))
            # psnr_input_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for aligned + denoised --
            R = denoised.shape[1]
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1)
            # if noise_type == "qis": denoised = quantize_img(cfg,denoised)
            # save_image(denoised_ave,"denoised_ave.png")
            # save_image(denoised,"denoised.png")
            mse_loss = F.mse_loss(raw_img_repN, denoised,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img, denoised_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- ave dynamic acc --
            ave_dyn_acc = dynamics_acc / dynamics_count * 100.
            dynamics_acc, dynamics_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, ave_dyn_acc)  #rec_ot_ave)

            #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e"
                % write_info,
                flush=True)
            # -- write to summary writer --
            if writer:
                writer.add_scalar('train/running-loss', running_loss,
                                  cfg.global_step)
                writer.add_scalars('train/model-psnr', {
                    'ave': psnr,
                    'std': psnr_std
                }, cfg.global_step)
                writer.add_scalars('train/dn-frame-psnr', {
                    'ave': psnr_denoised_ave,
                    'std': psnr_denoised_std
                }, cfg.global_step)

            # -- reset loss --
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, motion)

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #4
0
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch,
               record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    write_examples_iter = 800
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    D = 5 * 10**3
    steps_per_epoch = len(train_loader)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        optimizer.zero_grad()
        model.zero_grad()
        model.denoiser_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, BS, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        aligned, aligned_ave, denoised, denoised_ave, filters = model(burst)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Entropy Loss for Filters
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N)
        filters_entropy = entropyLoss(filters_shaped)
        filters_entropy_coeff = 10.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0  #.933**cfg.global_step if cfg.global_step < 100 else 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_ave_d = denoised_ave.detach()
        losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- regularization scheduler --
        if cfg.global_step < 100: reg = 0.5
        elif cfg.global_step < 200: reg = 0.25
        elif cfg.global_step < 5000: reg = 0.15
        elif cfg.global_step < 10000: reg = 0.1
        else: reg = 0.05

        # -- computation --
        residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        residuals = rearrange(residuals, 'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level)
        # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3)
        rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2.
        rec_ot_pair_coeff = 100  # - .997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        align_loss = align_mse_coeff * align_mse
        rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse
        entropy_loss = filters_entropy_coeff * filters_entropy
        final_loss = align_loss + rec_loss + entropy_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot_pair.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.denoiser_info.optim.step()
        optimizer.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #5
0
def train_loop(cfg, model, noise_critic, optimizer, criterion, train_loader,
               epoch, record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #    Setup for epoch
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    D = 5 * 10**3
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    write_examples = True
    write_examples_iter = 800
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #      Init Record Keeping
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    losses_nc, losses_nc_count = 0, 0
    losses_mse, losses_mse_count = 0, 0
    running_loss, total_loss = 0, 0

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #      Init Loss Functions
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    lossRecMSE = LossRec(tensor_grad=cfg.supervised)
    lossBurstMSE = LossRecBurst(tensor_grad=cfg.supervised)

    # -=-=-=-=-=-=-=-=-=-=-
    #    Final Configs
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    train_iter = iter(train_loader)
    if use_timer: clock = Timer()

    # -=-=-=-=-=-=-=-=-=-=-
    #    GAN Scheduler
    # -=-=-=-=-=-=-=-=-=-=-

    # -- noise critic steps --
    if epoch == 0: disc_steps = 0
    elif epoch < 3: disc_steps = 1
    elif epoch < 10: disc_steps = 1
    else: disc_steps = 1

    # -- denoising steps --
    if epoch == 0: gen_steps = 1
    if epoch < 3: gen_steps = 15
    if epoch < 10: gen_steps = 10
    else: gen_steps = 10

    # -- steps each epoch --
    steps_per_iter = disc_steps * gen_steps
    steps_per_epoch = len(train_loader) // steps_per_iter
    if steps_per_epoch > 120: steps_per_epoch = 120

    # -=-=-=-=-=-=-=-=-=-=-
    #    Start Epoch
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        for gen_step in range(gen_steps):

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #      Setting up for Iteration
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- setup iteration timer --
            if use_timer: clock.tic()

            # -- zero gradients --
            optimizer.zero_grad()
            model.zero_grad()
            model.denoiser_info.model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            noise_critic.disc.zero_grad()
            noise_critic.optim.zero_grad()

            # -- grab data batch --
            burst, res_imgs, raw_img, directions = next(train_iter)

            # -- getting shapes of data --
            N, BS, C, H, W = burst.shape
            burst = burst.cuda(non_blocking=True)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #      Formatting Images for FP
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- creating some transforms --
            stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
            cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

            # -- extract target image --
            mid_img = burst[N // 2]
            raw_zm_img = szm(raw_img.cuda(non_blocking=True))
            if cfg.supervised: gt_img = raw_zm_img
            else: gt_img = mid_img

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #           Foward Pass
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            aligned, aligned_ave, denoised, denoised_ave, filters = model(
                burst)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #    MSE (KPN) Reconstruction Loss
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            loss_rec = lossRecMSE(denoised_ave, gt_img)
            loss_burst = lossBurstMSE(denoised, gt_img)
            loss_mse = loss_rec + 100 * loss_burst
            gbs, spe = cfg.global_step, steps_per_epoch
            if epoch < 3: weight_mse = 10
            else: weight_mse = 10 * 0.9999**(gbs - 3 * spe)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #      Noise Critic Loss
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            loss_nc = noise_critic.compute_residual_loss(denoised, gt_img)
            weight_nc = 1

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #          Final Loss
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            final_loss = weight_mse * loss_mse + weight_nc * loss_nc

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #    Update Info for Record Keeping
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- update alignment kl loss info --
            losses_nc += loss_nc.item()
            losses_nc_count += 1

            # -- update reconstruction kl loss info --
            losses_mse += loss_mse.item()
            losses_mse_count += 1

            # -- update info --
            running_loss += final_loss.item()
            total_loss += final_loss.item()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #          Backward Pass
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- compute the gradients! --
            final_loss.backward()

            # -- backprop now. --
            model.denoiser_info.optim.step()
            optimizer.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #     Iterate for Noise Critic
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        for disc_step in range(disc_steps):

            # -- zero gradients --
            optimizer.zero_grad()
            model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            noise_critic.disc.zero_grad()
            noise_critic.optim.zero_grad()

            # -- grab noisy data --
            _burst, _res_imgs, _raw_img, _directions = next(train_iter)
            _burst = _burst.to(cfg.device)

            # -- generate "fake" data from noisy data --
            _aligned, _aligned_ave, _denoised, _denoised_ave, _filters = model(
                _burst)
            _residuals = _denoised - _burst[N // 2].unsqueeze(1).repeat(
                1, N, 1, 1, 1)

            # -- update discriminator --
            loss_disc = noise_critic.update_disc(_residuals)

            # -- message to stdout --
            first_update = (disc_step == 0)
            last_update = (disc_step == disc_steps - 1)
            iter_update = first_update or last_update
            # if (batch_idx % cfg.log_interval//2) == 0 and batch_idx > 0 and iter_update:
            print(f"[Noise Critic]: {loss_disc}")

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #      Print Message to Stdout
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- init --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- write record --
            if use_record:
                record_losses = record_losses.append(
                    {
                        'burst': burst_loss,
                        'ave': ave_loss,
                        'ot': ot_loss,
                        'psnr': psnr,
                        'psnr_std': psnr_std
                    },
                    ignore_index=True)

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- average mse losses --
            ave_losses_mse = losses_mse / losses_mse_count
            losses_mse, losses_mse_count = 0, 0

            # -- average noise critic loss --
            ave_losses_nc = losses_nc / losses_nc_count
            losses_nc, losses_nc_count = 0, 0

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          ave_losses_mse, ave_losses_nc)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [mse]: %.2e [nc]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, stacked_burst, aligned, denoised, filters,
                               directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #6
0
def train_loop(cfg, model, train_loader, epoch, record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    align_ot_losses, align_ot_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    steps_per_epoch = len(train_loader)
    write_examples_iter = steps_per_epoch // 2

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        model.align_info.model.zero_grad()
        model.align_info.optim.zero_grad()
        model.denoiser_info.model.zero_grad()
        model.denoiser_info.optim.zero_grad()
        model.unet_info.model.zero_grad()
        model.unet_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Check Some Gradients
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        def mse_v_wassersteinG_check_some_gradients(cfg, burst, gt_img, model):
            grads = edict()
            gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            model.unet_info.model.zero_grad()
            burst.requires_grad_(True)

            outputs = model(burst)
            aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]
            residuals = denoised - gt_img_rs
            P = 1.  #residuals.numel()
            denoised.retain_grad()
            rec_mse = (denoised.reshape(B, -1) - gt_img.reshape(B, -1))**2
            rec_mse.retain_grad()
            ones = P * torch.ones_like(rec_mse)
            rec_mse.backward(ones, retain_graph=True)
            grads.rmse = rec_mse.grad.clone().reshape(B, -1)
            grad_rec_mse = grads.rmse
            grads.dmse = denoised.grad.clone().reshape(B, -1)
            grad_denoised_mse = grads.dmse
            ones = torch.ones_like(rec_mse)
            grads.d_to_b = torch.autograd.grad(rec_mse, denoised,
                                               ones)[0].reshape(B, -1)

            model.unet_info.model.zero_grad()
            outputs = model(burst)
            aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]
            # residuals = denoised - gt_img_rs
            # rec_ot = w_gaussian_bp(residuals,noise_level)
            denoised.retain_grad()
            rec_ot_v = (denoised - gt_img_rs)**2
            rec_ot_v.retain_grad()
            rec_ot = (rec_ot_v.mean() - noise_level / 255.)**2
            rec_ot.retain_grad()
            ones = P * torch.ones_like(rec_ot)
            rec_ot.backward(ones)
            grad_denoised_ot = denoised.grad.clone().reshape(B, -1)
            grads.dot = grad_denoised_ot
            grad_rec_ot = rec_ot_v.grad.clone().reshape(B, -1)
            grads.rot = grad_denoised_ot

            print("Gradient Name Info")
            for name, g in grads.items():
                g_norm = g.norm().item()
                g_mean = g.mean().item()
                g_std = g.std().item()
                print(name, g.shape, g_norm, g_mean, g_std)

            print_pairs = False
            if print_pairs:
                print("All Gradient Ratios")
                for name_t, g_t in grads.items():
                    for name_b, g_b in grads.items():
                        ratio = g_t / g_b
                        ratio_m = ratio.mean().item()
                        ratio_std = ratio.std().item()
                        print("[%s/%s] [%2.2e +/- %2.2e]" %
                              (name_t, name_b, ratio_m, ratio_std))

            use_true_mse = False
            if use_true_mse:
                print("Ratios with Estimated MSE Gradient")
                true_dmse = 2 * torch.mean(denoised_ave - gt_img)**2
                ratio_mse = grads.dmse / true_dmse
                ratio_mse_dtb = grads.dmse / grads.d_to_b
                print(ratio_mse)
                print(ratio_mse_dtb)

            dot_v_dmse = True
            if dot_v_dmse:
                print("Ratio of Denoised OT and Denoised MSE")
                ratio_mseot = (grads.dmse / grads.dot)
                print(ratio_mseot.mean(), ratio_mseot.std())
                ratio_mseot = ratio_mseot[0, 0].item()

                c1 = torch.mean((denoised - gt_img_rs)**2).item()
                c2 = noise_level / 255
                m = torch.mean(gt_img_rs).item()
                true_ratio = 2. * (c1 - c2) / (np.product(burst.shape))
                # diff = denoised.reshape(B,-1)-gt_img_rs.reshape(B,-1)
                # true_ratio = 2.*(c1 - c2) * ( diff / ( np.product(burst.shape) ) )
                # print(c1,c2,m,true_ratio,1./true_ratio)
                ratio_mseot = (grads.dmse / (grads.dot))
                print(ratio_mseot * true_ratio)

                # ratio_mseot = (grads.dmse / ( grads.dot / diff) )
                # print(ratio_mseot*true_ratio)
                # print(ratio_mseot.mean(),ratio_mseot.std())

            exit()
            model.unet_info.model.zero_grad()

        # mse_v_wassersteinG_check_some_gradients(cfg,burst,gt_img,model)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        outputs = model(burst)
        aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
        aligned_filters, denoised_filters = outputs[4:]

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms (aligned)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        aligned_filters_rs = rearrange(aligned_filters,
                                       'b n k2 c h w -> b n (k2 c h w)')
        norms = torch.norm(aligned_filters_rs, p=2., dim=2)
        norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N)
        norm_loss_align = torch.mean(
            torch.pow(torch.abs(norms - norms_mid), 1.))

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms (denoised)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_filters = rearrange(denoised_filters,
                                     'b n k2 c h w -> b n (k2 c h w)')
        norms = torch.norm(denoised_filters, p=2., dim=2)
        norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N)
        norm_loss_denoiser = torch.mean(
            torch.pow(torch.abs(norms - norms_mid), 1.))
        norm_loss_coeff = 0.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Decrease Entropy within a Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_entropy = 0
        filters_entropy_coeff = 0.  # 1000.
        all_filters = []
        L = len(align_hook.filters)
        iter_filters = align_hook.filters if L > 0 else [aligned_filters]
        for filters in iter_filters:
            filters_shaped = rearrange(filters,
                                       'b n k2 c h w -> (b n c h w) k2',
                                       n=N)
            filters_entropy += entropyLoss(filters_shaped)
            all_filters.append(filters)
        if L > 0: filters_entropy /= L
        all_filters = torch.stack(all_filters, dim=1)
        align_hook.clear()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Increase Entropy across each Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_dist_entropy = 0

        # -- across each frame --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each batch --
        filters_shaped = rearrange(all_filters,
                                   'b l n k2 c h w -> (n l) (b c h w) k2')
        filters_shaped = torch.mean(filters_shaped, dim=1)
        filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each kpn cascade --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        filters_dist_coeff = 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0.  #0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Alignment Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # pad = 2*cfg.N
        # fs = cfg.dynamic.frame_size
        residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # centered_residuals = tvF.center_crop(residuals,(fs-pad,fs-pad))
        # centered_residuals = tvF.center_crop(residuals,(fs//2,fs//2))
        # align_ot = kl_gaussian_bp(residuals,noise_level,flip=True)
        align_ot = kl_gaussian_bp_patches(residuals,
                                          noise_level,
                                          flip=True,
                                          patchsize=16)
        align_ot_coeff = 0  # 100.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = denoiseLossMSE(denoised, denoised_ave, gt_img,
                                cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- computation --
        gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # rec_ot = kl_gaussian_bp(residuals,noise_level)
        rec_ot = kl_gaussian_bp(residuals, noise_level, flip=True)
        # rec_ot /= 2.
        # alpha_grid = [0.,1.,5.,10.,25.]
        # for alpha in alpha_grid:
        #     # residuals = torch.normal(torch.zeros_like(residuals)+ gt_img_rs*alpha/255.,noise_level/255.)
        #     residuals = torch.normal(torch.zeros_like(residuals),noise_level/255.+ gt_img_rs*alpha/255.)

        #     rec_ot_v2_a = kl_gaussian_bp_patches(residuals,noise_level,patchsize=16)
        #     rec_ot_v1_b = kl_gaussian_bp(residuals,noise_level,flip=True)
        #     rec_ot_v2_b = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
        #     rec_ot_all = torch.tensor([rec_ot_v1_a,rec_ot_v2_a,rec_ot_v1_b,rec_ot_v2_b])

        #     rec_ot_v2 = (rec_ot_v2_a + rec_ot_v2_b).item()/2.
        #     print(alpha,torch.min(rec_ot_all),torch.max(rec_ot_all),rec_ot_v1,rec_ot_v2)
        # exit()
        # rec_ot = w_gaussian_bp(residuals,noise_level)
        # print(residuals.numel())
        rec_ot_coeff = 100.  #residuals.numel()*2.
        # 1000.# - .997**cfg.global_step

        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        # rec_ot_loss_v1 = kl_gaussian_bp(residuals,noise_level,flip=True)
        # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals)
        # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3)
        # rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        # rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse
        norm_loss = norm_loss_coeff * (norm_loss_denoiser + norm_loss_align)
        align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot
        entropy_loss = 0  #filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy
        # final_loss = align_loss + rec_loss + entropy_loss + norm_loss
        final_loss = rec_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- alignment Dist --
        align_ot_losses += align_ot.item()
        align_ot_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.align_info.optim.step()
        model.denoiser_info.optim.step()
        model.unet_info.optim.step()

        # for name,params in model.unet_info.model.named_parameters():
        #     if not ("weight" in name): continue
        #     print(params.grad.norm())
        #     # print(module.conv1.parameters())
        #     # print(module.conv1.data.grad)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            M = 10 if B > 10 else B
            for b in range(B):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- alignment Dist. --
            align_ot_ave = align_ot_losses / align_ot_count
            align_ot_losses, align_ot_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #7
0
def run_train(cfg, rank, models, data, loader):
    this_proc_prints = (rank == 0 and cfg.use_ddp) or (cfg.use_ddp is False)

    s = int(npr.rand() * 5 + 1)
    time.sleep(s)

    hyperparams = load_hyperparameters(cfg)

    criterion_inputs = [hyperparams]
    criterion_inputs += extract_loss_inputs(cfg, rank)
    criterion = DenoisingLossDDP(*criterion_inputs)
    criterion = criterion.to(cfg.device)

    optimizer = load_optimizer(cfg, models)
    scheduler = load_scheduler(cfg, optimizer, len(loader.tr))
    print("Loaded optimizer: ")
    print(optimizer)
    print("Loaded scheduler: ")
    print(scheduler)

    # apply apex
    if cfg.use_apex:
        models, optimizer = amp.initialize(models, optimizer, opt_level='O2')

    # for name,param in models.named_parameters():
    #     wnorm = param.norm()
    #     print("{}: {}".format(name,wnorm))

    # init writer
    if this_proc_prints:
        writer = SummaryWriter(filename_suffix=cfg.exp_name)
    else:
        writer = None

    # init training loop
    global_step, current_epoch = get_model_epoch_info(cfg)
    cfg.global_step = global_step
    cfg.current_epoch = current_epoch

    # training loop
    loop_scheduler = get_loop_scheduler(scheduler)
    test_losses = {}

    print(f"cfg.epochs: {cfg.epochs}")
    print(f"cfg.use_apex: {cfg.use_apex}")
    print(f"cfg.use_bn: {cfg.use_bn}")
    print("len of loader.val", len(loader.val))
    print(_build_v2_summary(cfg))
    t = Timer()
    for epoch in range(cfg.current_epoch, cfg.epochs):
        t.tic()
        loss_epoch = train_loop(cfg, loader.tr, models, criterion, optimizer,
                                epoch, writer, loop_scheduler)
        t.toc()
        lr = optimizer.param_groups[0]["lr"]
        # print(t)

        # if ms_scheduler:
        #     ms_scheduler.step()

        if scheduler and loop_scheduler is None:
            val_loss = test_loop(cfg, models, loader.val)
            scheduler.step(val_loss)
            if this_proc_prints:
                writer.add_scalar("Loss/val", val_loss, epoch)

        if epoch % cfg.checkpoint_interval == 0 and this_proc_prints:
            save_denoising_model(cfg, models, optimizer)

        if epoch % cfg.test_interval == 0:
            if this_proc_prints:
                te_loss = test_loop(cfg, models, loader.te)
                writer.add_scalar("Loss/test", te_loss, epoch)

        if this_proc_prints:
            writer.add_scalar("Loss/train", loss_epoch / len(loader.tr), epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            msg = f"Epoch [{epoch}/{cfg.epochs}]\t"
            msg += f"Loss: {loss_epoch / len(loader.tr)}\t"
            msg += "{:2.3e}".format(lr)
            print(msg)
        cfg.current_epoch += 1

    if this_proc_prints:
        te_loss = test_loop(cfg, models, loader.te)
        writer.add_scalar("Loss/test", te_loss, epoch)
    save_denoising_model(cfg, models, optimizer)
#  with the default; we can also change it later, e.g., for different batch sizes)
net.blobs['data'].reshape(
    50,  # batch size
    3,  # 3-channel (BGR) images
    224,
    224)  # image size is 227x227

image_path = 'data/image/67/1698/2005/'
image_name = '0a884f5a90267d.jpg'
image = caffe.io.load_image(IMAGE_ROOT + image_path + image_name)
transformed_image = transformer.preprocess('data', image)
#plt.imshow(image)

# copy the image data into the memory allocated for the net
net.blobs['data'].data[...] = transformed_image
timer = Timer()
timer.tic()
### perform classification
output = net.forward()
timer.toc()
print('Elapsed time {:.3f} s.').format(timer.total_time)
output_prob = output['prob'][
    0]  # the output probability vector for the first image in the batch

print 'predicted class is:', output_prob.argmax()

top_inds = output_prob.argsort(
)[::-1][:5]  # reverse sort and take five largest items

print '5 top probabilities and labels:'
print zip(output_prob[top_inds], top_inds)
Exemple #9
0
def train_loop(cfg, model, train_loader, epoch, record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    align_ot_losses, align_ot_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    steps_per_epoch = len(train_loader)
    write_examples_iter = steps_per_epoch // 3

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        model.align_info.model.zero_grad()
        model.align_info.optim.zero_grad()
        model.denoiser_info.model.zero_grad()
        model.denoiser_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        outputs = model(burst)
        aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
        aligned_filters, denoised_filters = outputs[4:]

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_filters = rearrange(denoised_filters.detach(),
                                     'b n k2 c h w -> n (b k2 c h w)')
        norms = denoised_filters.norm(dim=1)
        norm_loss_denoiser = torch.mean((norms - norms[N // 2])**2)
        norm_loss_coeff = 100.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Decrease Entropy within a Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_entropy = 0
        filters_entropy_coeff = 1000.
        all_filters = []
        L = len(align_hook.filters)
        iter_filters = align_hook.filters if L > 0 else [aligned_filters]
        for filters in iter_filters:
            filters_shaped = rearrange(filters,
                                       'b n k2 c h w -> (b n c h w) k2',
                                       n=N)
            filters_entropy += entropyLoss(filters_shaped)
            all_filters.append(filters)
        if L > 0: filters_entropy /= L
        all_filters = torch.stack(all_filters, dim=1)
        align_hook.clear()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Increase Entropy across each Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_dist_entropy = 0

        # -- across each frame --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each batch --
        filters_shaped = rearrange(all_filters,
                                   'b l n k2 c h w -> (n l) (b c h w) k2')
        filters_shaped = torch.mean(filters_shaped, dim=1)
        filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each kpn cascade --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        filters_dist_coeff = 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Alignment Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        fs = cfg.dynamic.frame_size
        residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        centered_residuals = tvF.center_crop(residuals, (fs // 2, fs // 2))
        align_ot = kl_gaussian_bp(centered_residuals, noise_level, flip=True)
        align_ot_coeff = 100.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = denoiseLossMSE(denoised, denoised_ave, gt_img,
                                cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- regularization scheduler --
        if cfg.global_step < 100: reg = 0.5
        elif cfg.global_step < 200: reg = 0.25
        elif cfg.global_step < 5000: reg = 0.15
        elif cfg.global_step < 10000: reg = 0.1
        else: reg = 0.05

        # -- computation --
        residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        rec_ot_loss_v1 = kl_gaussian_bp(residuals, noise_level, flip=True)
        # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals)
        # rec_ot_loss_v1 = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
        # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3)
        rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2)
        rec_ot_coeff = 100.  # - .997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse
        norm_loss = norm_loss_coeff * norm_loss_denoiser
        align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot
        entropy_loss = filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy
        final_loss = align_loss + rec_loss + entropy_loss + norm_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- alignment Dist --
        align_ot_losses += align_ot.item()
        align_ot_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.align_info.optim.step()
        model.denoiser_info.optim.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            M = 10 if B > 10 else B
            for b in range(B):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- alignment Dist. --
            align_ot_ave = align_ot_losses / align_ot_count
            align_ot_losses, align_ot_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #10
0
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses,
               writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0
    dynamics_acc, dynamics_count = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-
    dynamics_acc_i = -1.
    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    train_iter = iter(train_loader)
    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        if small_ds and batch_idx >= ds_size:
            if cfg.use_seed:
                init = torch.initial_seed()
                torch.manual_seed(cfg.seed + 1 + epoch + init)
            train_iter = iter(train_loader)  # reset if too big
        sample = next(train_iter)
        burst, raw_img, motion = sample['burst'], sample['clean'], sample[
            'flow']
        raw_img_iid = sample['iid']
        raw_img_iid = raw_img_iid.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)

        # -- handle possibly cached simulated bursts --
        if 'sim_burst' in sample:
            sim_burst = rearrange(sample['sim_burst'],
                                  'b n k c h w -> n b k c h w')
        else:
            sim_burst = None
        non_sim_method = cfg.n2n or cfg.supervised
        if sim_burst is None and not (non_sim_method or cfg.abps):
            if sim_burst is None:
                if cfg.use_kindex_lmdb:
                    kindex = kindex_ds[batch_idx].cuda(non_blocking=True)
                else:
                    kindex = None
                query = burst[[N // 2]]
                database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]])
                sim_burst = compute_similar_bursts(
                    cfg,
                    query,
                    database,
                    cfg.sim_K,
                    noise_level / 255.,
                    patchsize=cfg.sim_patchsize,
                    shuffle_k=cfg.sim_shuffleK,
                    kindex=kindex,
                    only_middle=cfg.sim_only_middle,
                    search_method=cfg.sim_method,
                    db_level="frame")

        if (sim_burst is None) and cfg.abps:
            # scores,aligned = abp_search(cfg,burst)
            # scores,aligned = lpas_search(cfg,burst,motion)
            if cfg.lpas_method == "spoof":
                mtype = "global"
                acc = cfg.optical_flow_acc
                scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype,
                                             acc)
            else:
                ref_frame = (cfg.nframes + 1) // 2
                nblocks = cfg.nblocks
                method = cfg.lpas_method
                scores, aligned, dacc = lpas_search(burst, ref_frame, nblocks,
                                                    motion, method)
                dynamics_acc_i = dacc
            # scores,aligned = lpas_spoof(motion,accuracy=cfg.optical_flow_acc)
            # shuffled = shuffle_aligned_pixels_noncenter(aligned,cfg.nframes)
            nsims = cfg.nframes
            sim_aligned = create_sim_from_aligned(burst, aligned, nsims)
            burst_s = rearrange(burst, 't b c h w -> t b 1 c h w')
            sim_burst = torch.cat([burst_s, sim_aligned], dim=2)
            # print("sim_burst.shape",sim_burst.shape)

        # raw_img = raw_img.cuda(non_blocking=True)-0.5
        # # print(np.sqrt(cfg.noise_params['g']['stddev']))
        # print(motion)
        # tiled = tile_across_blocks(burst[[cfg.nframes//2]],cfg.nblocks)
        # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2])
        # for t in range(cfg.nframes):
        #     save_image(tiled[0] - rep_burst[t],f"tiled_sub_burst_{t}.png")
        # save_image(aligned,"aligned.png")
        # print(aligned.shape)
        # # save_image(aligned[0] - aligned[cfg.nframes//2],"aligned_0.png")
        # # save_image(aligned[2] - aligned[cfg.nframes//2],"aligned_2.png")
        # M = (1+cfg.dynamic.ppf)*cfg.nframes
        # fs = cfg.dynamic.frame_size - M
        # fs = cfg.frame_size
        # cropped = crop_center_patch([burst,aligned,raw_img],cfg.nframes,cfg.frame_size)
        # burst,aligned,raw_img = cropped[0],cropped[1],cropped[2]
        # print(aligned.shape)
        # for t in range(cfg.nframes+1):
        #     diff_t = aligned[t] - raw_img
        #     spacing = cfg.nframes+1
        #     diff_t = crop_center_patch([diff_t],spacing,cfg.frame_size)[0]
        #     print_tensor_stats(f"diff_aligned_{t}",diff_t)
        #     save_image(diff_t,f"diff_aligned_{t}.png")
        #     if t < cfg.nframes:
        #         dt = aligned[t+1]-aligned[t]
        #         dt = crop_center_patch([dt],spacing,cfg.frame_size)[0]
        #         save_image(dt,f"dt_aligned_{t+1}m{t}.png")
        #     save_image(aligned[t],f"aligned_{t}.png")
        #     diff_t = tvF.crop(aligned[t] - raw_img,cfg.nframes,cfg.nframes,fs,fs)
        #     print_tensor_stats(f"diff_aligned_{t}",diff_t)

        # save_image(burst,"burst.png")
        # save_image(burst[0] - burst[cfg.nframes//2],"burst_0.png")
        # save_image(burst[2] - burst[cfg.nframes//2],"burst_2.png")
        # exit()

        # print(sample['burst'].shape,sample['res'].shape)
        # b_clean = sample['burst'] - sample['res']
        # scores,ave,t_aligned = test_abp_global_search(cfg,b_clean,noisy_img=burst)

        # burstBN = rearrange(burst,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(burstBN,"abps_burst.png",normalize=True)
        # alignedBN = rearrange(aligned,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(alignedBN,"abps_aligned.png",normalize=True)
        # rep_burst = burst[[N//2]].repeat(N,1,1,1,1)
        # deltaBN = rearrange(aligned - rep_burst,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(deltaBN,"abps_delta.png",normalize=True)
        # b_clean_rep = b_clean[[N//2]].repeat(N,1,1,1,1)
        # tdeltaBN = rearrange(t_aligned - b_clean_rep.cpu(),'n b c h w -> (b n) c h w')
        # tv_utils.save_image(tdeltaBN,"abps_tdelta.png",normalize=True)

        if non_sim_method:
            sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1)
        else:
            sim_burst = sim_burst.cuda(non_blocking=True)
        if use_timer: data_clock.toc()

        # -- to cuda --
        burst = burst.cuda(non_blocking=True)
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        # anscombe.test(cfg,burst_og)
        # save_image(burst,f"burst_{batch_idx}_{cfg.n2n}.png")

        # -- crop images --
        if True:  #cfg.abps or cfg.abps_inputs:
            images = [burst, sim_burst, raw_img, raw_img_iid]
            spacing = burst.shape[0]  # we use frames as spacing
            cropped = crop_center_patch(images, spacing, cfg.frame_size)
            burst, sim_burst = cropped[0], cropped[1]
            raw_img, raw_img_iid = cropped[2], cropped[3]
            if cfg.abps or cfg.abps_inputs:
                aligned = crop_center_patch([aligned], spacing,
                                            cfg.frame_size)[0]
            # print_tensor_stats("d-eq?",burst[-1] - aligned[-1])
            burst = burst[:cfg.nframes]  # last frame is target

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst_og = burst.clone()

        # -- shuffle over Simulated Samples --
        k_ins, k_outs = create_k_grid(sim_burst, shuffle=True)
        k_ins, k_outs = [k_ins[0]], [k_outs[0]]
        # k_ins,k_outs = create_k_grid_v3(sim_burst)

        for k_in, k_out in zip(k_ins, k_outs):
            if k_in == k_out: continue

            # -- zero gradients; ready 2 go --
            model.align_info.model.zero_grad()
            model.align_info.optim.zero_grad()
            model.denoiser_info.model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            model.unet_info.model.zero_grad()
            model.unet_info.optim.zero_grad()

            # -- compute input/output data --
            if cfg.sim_only_middle and (not cfg.abps):
                # sim_burst.shape == T,B,K,C,H,W
                midi = 0 if sim_burst.shape[0] == 1 else N // 2
                left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:]
                cat_burst = [
                    left_burst, sim_burst[[midi], :, k_in], right_burst
                ]
                burst = torch.cat(cat_burst, dim=0)
                mid_img = sim_burst[midi, :, k_out]
            elif cfg.abps and (not cfg.abps_inputs):
                # -- v1 --
                mid_img = aligned[-1]

                # -- v2 --
                # left_aligned,right_aligned = aligned[:N//2],aligned[N//2+1:]
                # nc_aligned = torch.cat([left_aligned,right_aligned],dim=0)
                # shuf = shuffle_aligned_pixels(nc_aligned,cfg.nframes)
                # mid_img = shuf[1]

                # ---- v3 ----
                # shuf = shuffle_aligned_pixels(aligned)
                # shuf = aligned[[N//2,0]]
                # midi = 0 if sim_burst.shape[0] == 1 else N//2
                # left_burst,right_burst = burst[:N//2],burst[N//2+1:]
                # burst = torch.cat([left_burst,shuf[[0]],right_burst],dim=0)
                # nc_burst = torch.cat([left_burst,right_burst],dim=0)
                # shuf = shuffle_aligned_pixels(aligned)

                # ---- v4 ----
                # nc_shuf = shuffle_aligned_pixels(nc_aligned)
                # mid_img = nc_shuf[0]
                # pick = npr.randint(0,2,size=(1,))[0]
                # mid_img = nc_aligned[pick]
                # mid_img = shuf[1]

                # save_image(shuf,"shuf.png")
                # print(shuf.shape)

                # diff = raw_img.cuda(non_blocking=True) - aligned[0]
                # mean = torch.mean(diff).item()
                # std = torch.std(diff).item()
                # print(mean,std)

                # -- v1 --
                # burst = burst
                # notMid = sample_not_mid(N)
                # mid_img = aligned[notMid]

            elif cfg.abps_inputs:
                burst = aligned.clone()
                burst_og = aligned.clone()
                mid_img = shuffle_aligned_pixels(burst, cfg.nframes)[0]

            else:
                burst = sim_burst[:, :, k_in]
                mid_img = sim_burst[N // 2, :, k_out]
            # mid_img =  sim_burst[N//2,:]
            # print(burst.shape,mid_img.shape)
            # print(F.mse_loss(burst,mid_img).item())
            if cfg.supervised:
                gt_img = get_nmlz_tgt_img(cfg, raw_img).cuda(non_blocking=True)
            elif cfg.n2n:
                gt_img = raw_img_iid  #noise_xform(raw_img).cuda(non_blocking=True)
            else:
                gt_img = mid_img

            # another = noise_xform(raw_img).cuda(non_blocking=True)
            # print_tensor_stats("a-iid?",raw_img_iid.cuda() - raw_img.cuda())
            # print_tensor_stats("b-iid?",mid_img.cuda() - raw_img.cuda())
            # print_tensor_stats("c-iid?",mid_img.cuda() - another)
            # print_tensor_stats("d-iid?",raw_img_iid.cuda() - another)
            # print_tensor_stats("e-iid?",mid_img.cuda() - raw_img_iid.cuda())

            # for bt in range(cfg.nframes):
            #     tiled = tile_across_blocks(burst[[bt]],cfg.nblocks)
            #     rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2])
            #     for t in range(cfg.nframes):
            #         save_image(tiled[0] - rep_burst[t],f"tiled_{bt}_sub_burst_{t}.png")
            #         print_tensor_stats(f"delta_{bt}_{t}",tiled[0,:,4] - burst[t])

            # raw_img = raw_img.cuda(non_blocking=True) - 0.5
            # print_tensor_stats("gt_img - raw",gt_img - raw_img)
            # # save_image(gt_img,"gt.png")
            # # save_image(raw,"raw.png")
            # save_image(gt_img - raw_img,"gt_sub_raw.png")
            # print_tensor_stats("burst[N//2] - raw",burst[N//2] - raw_img)
            # save_image(burst[N//2] - raw_img,"burst_sub_raw.png")
            # print_tensor_stats("burst[N//2] - gt_img",burst[N//2] - gt_img)
            # save_image(burst[N//2] - gt_img,"burst_sub_gt.png")
            # print_tensor_stats("aligned[N//2] - raw",aligned[N//2] - raw_img)
            # save_image(aligned[N//2] - raw_img,"aligned_sub_raw.png")
            # print_tensor_stats("aligned[N//2] - burst[N//2]",
            # aligned[N//2] - burst[N//2])
            # save_image(aligned[N//2] - burst[N//2],"aligned_sub_burst.png")
            # gt_img = torch.normal(raw_zm_img,noise_level/255.)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Dataset Augmentation
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # burst,gt_img = apply_transformations(burst,gt_img)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #      Formatting Images for FP
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
            cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #           Foward Pass
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            outputs = model(burst)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #    Decrease Entropy within a Kernel
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            filters_entropy = 0
            filters_entropy_coeff = 0.  # 1000.
            all_filters = []
            L = len(align_hook.filters)
            iter_filters = align_hook.filters if L > 0 else [aligned_filters]
            for filters in iter_filters:
                f_shape = 'b n k2 c h w -> (b n c h w) k2'
                filters_shaped = rearrange(filters, f_shape)
                filters_entropy += one  #entropyLoss(filters_shaped)
                all_filters.append(filters)
            if L > 0: filters_entropy /= L
            all_filters = torch.stack(all_filters, dim=1)
            align_hook.clear()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #   Reconstruction Losses (MSE)
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            losses = [F.mse_loss(denoised_ave, gt_img)]
            # losses = denoiseLossMSE(denoised,denoised_ave,gt_img,cfg.global_step)
            # losses = [ one, one ]
            # ave_loss,burst_loss = [loss.item() for loss in losses]
            rec_mse = np.sum(losses)
            # rec_mse = F.mse_loss(denoised_ave,gt_img)
            rec_mse_coeff = 1.

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #    Reconstruction Losses (Distribution)
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            gt_img_rep = gt_img.unsqueeze(1).repeat(1, denoised.shape[1], 1, 1,
                                                    1)
            residuals = denoised - gt_img_rep
            rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            # rec_ot = kl_gaussian_bp(residuals,noise_level,flip=True)
            # rec_ot = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
            if torch.any(torch.isnan(rec_ot)):
                rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            if torch.any(torch.isinf(rec_ot)):
                rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            rec_ot_coeff = 0.

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Final Losses
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            rec_loss = rec_mse_coeff * rec_mse + rec_ot_coeff * rec_ot
            final_loss = rec_loss

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Record Keeping
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- reconstruction MSE --
            rec_mse_losses += rec_mse.item()
            rec_mse_count += 1

            # -- reconstruction Dist. --
            rec_ot_losses += rec_ot.item()
            rec_ot_count += 1

            # -- dynamic acc -
            dynamics_acc += dynamics_acc_i
            dynamics_count += 1

            # -- total loss --
            running_loss += final_loss.item()
            total_loss += final_loss.item()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Gradients & Backpropogration
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- compute the gradients! --
            if cfg.use_seed: torch.set_deterministic(False)
            final_loss.backward()
            if cfg.use_seed: torch.set_deterministic(True)

            # -- backprop now. --
            model.align_info.optim.step()
            model.denoiser_info.optim.step()
            model.unet_info.optim.step()
            scheduler.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- recompute model output for original images --
            outputs = model(burst_og)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            raw_img = get_nmlz_tgt_img(cfg, raw_img)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, m_aligned_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(burst_og, dim=0)
            if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave)
            mse_loss = F.mse_loss(raw_img, mis_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25))
            # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25))

            # -- psnr for [bm3d] --
            mid_img_og = burst[N // 2]
            bm3d_nb_psnrs = []
            M = 4 if B > 4 else B
            for b in range(M):
                bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                # maybe an issue here
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for input averaged frames --
            # burst_ave = torch.mean(burst_og,dim=0)
            # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1)
            # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            # psnr_input_ave = np.mean(mse_to_psnr(mse_loss))
            # psnr_input_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for aligned + denoised --
            R = denoised.shape[1]
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1)
            # if noise_type == "qis": denoised = quantize_img(cfg,denoised)
            # save_image(denoised_ave,"denoised_ave.png")
            # save_image(denoised,"denoised.png")
            mse_loss = F.mse_loss(raw_img_repN, denoised,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img, denoised_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- ave dynamic acc --
            ave_dyn_acc = dynamics_acc / dynamics_count * 100.
            dynamics_acc, dynamics_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, ave_dyn_acc)  #rec_ot_ave)

            #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e"
                % write_info,
                flush=True)
            # -- write to summary writer --
            if writer:
                writer.add_scalar('train/running-loss', running_loss,
                                  cfg.global_step)
                writer.add_scalars('train/model-psnr', {
                    'ave': psnr,
                    'std': psnr_std
                }, cfg.global_step)
                writer.add_scalars('train/dn-frame-psnr', {
                    'ave': psnr_denoised_ave,
                    'std': psnr_denoised_std
                }, cfg.global_step)

            # -- reset loss --
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, motion)

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #11
0
def train_loop(cfg, model, optimizer, scheduler, train_loader, epoch,
               record_losses, writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.gpuid)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    random_crop = tvT.RandomCrop(cfg.byol_patchsize)
    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    train_iter = iter(train_loader)
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        if small_ds and batch_idx >= ds_size:
            train_iter = iter(train_loader)  # reset if too big
        sample = next(train_iter)
        burst, raw_img, directions = sample['burst'], sample['clean'], sample[
            'directions']
        burst = burst.cuda(non_blocking=True)

        # -- handle possibly cached simulated bursts --
        if 'sim_burst' in sample:
            sim_burst = rearrange(sample['sim_burst'],
                                  'b n k c h w -> n b k c h w')
        else:
            sim_burst = None
        if sim_burst is None and not (cfg.n2n or cfg.supervised):
            if sim_burst is None:
                if cfg.use_kindex_lmdb:
                    kindex = kindex_ds[batch_idx].cuda(non_blocking=True)
                else:
                    kindex = None
                query = burst[[N // 2]]
                database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]])
                sim_burst = compute_similar_bursts(
                    cfg,
                    query,
                    database,
                    cfg.sim_K,
                    noise_level / 255.,
                    patchsize=cfg.sim_patchsize,
                    shuffle_k=cfg.sim_shuffleK,
                    kindex=kindex,
                    only_middle=cfg.sim_only_middle,
                    search_method=cfg.sim_method,
                    db_level="frame")
        if cfg.n2n or cfg.supervised:
            sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1)
        else:
            sim_burst = sim_burst.cuda(non_blocking=True)

        if use_timer: data_clock.toc()

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        burst_og = burst.clone()
        mid_img_og = burst[N // 2]

        # -- shuffle over Simulated Samples --
        k_ins, k_outs = create_k_grid(sim_burst, shuffle=True)
        # k_ins,k_outs = [k_ins[0]],[k_outs[0]]

        for k_in, k_out in zip(k_ins, k_outs):
            if k_in == k_out: continue

            # -- zero gradients; ready 2 go --
            optimizer.zero_grad()
            model.zero_grad()

            # -- compute input/output data --
            if cfg.sim_only_middle:
                midi = 0 if sim_burst.shape[0] == 1 else N // 2
                left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:]
                burst = torch.cat(
                    [left_burst, sim_burst[[midi], :, k_in], right_burst],
                    dim=0)
                mid_img = sim_burst[midi, :, k_out]
            else:
                burst = sim_burst[:, :, k_in]
                mid_img = sim_burst[N // 2, :, k_out]
            # mid_img =  sim_burst[N//2,:]
            # print(burst.shape,mid_img.shape)
            # print(F.mse_loss(burst,mid_img).item())
            if cfg.supervised:
                gt_img = get_nmlz_img(cfg, raw_img).cuda(non_blocking=True)
            elif cfg.n2n:
                gt_img = noise_xform(raw_img).cuda(non_blocking=True)
            else:
                gt_img = mid_img
            # gt_img = torch.normal(raw_zm_img,noise_level/255.)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Dataset Augmentation
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # burst,gt_img = apply_transformations(burst,gt_img)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #   Experimentally Set Hyperparams
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- [before training] setting the ps and nh --
            # test_ps_nh_sizes(cfg,model,burst)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #      Formatting Images & FP
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            patches = sample_burst_patches(cfg, model, burst + 0.5)
            input_patches_0 = model.patch_helper.form_input_patches(patches)
            f_patches = torch.flip(patches, dims=(0, ))  # reverse
            input_patches_1 = model.patch_helper.form_input_patches(f_patches)
            final_loss = model(input_patches_0)
            final_loss += model(input_patches_1)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Record Keeping
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- total loss --
            running_loss += final_loss.item()
            total_loss += final_loss.item()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Gradients & Backpropogration
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- compute the gradients! --
            final_loss.backward()

            # -- backprop now. --
            optimizer.step()
            model.update_moving_average()
            # scheduler.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss)
            print("[%d/%d][%d/%d]: %2.3e" % write_info)

            nbatches = 2
            burst = burst[:, :nbatches]  # limit batch size to run test
            psnrs_sim = test_sim_search(cfg, burst + 0.5, model)
            psnrs_ftr = psnrs_sim[cfg.byol_backbone_name]
            psnrs_pix = psnrs_sim["pix"]
            print_psnr_results(psnrs_ftr, "[PSNR-ftr]")
            print_psnr_results(psnrs_pix, "[PSNR-pix]")
            print_edge_info(burst)

            # psnrs = test_sim_search(cfg,burst,model)
            # print_psnr_results(psnrs,"[PSNR-ftr]")
            # psnrs = test_sim_search_pix(cfg,burst,model)
            # print_psnr_results(psnrs,"[PSNR-pix]")

            # -- reset loss --
            running_loss = 0

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    total_loss /= len(train_loader)
    return total_loss, record_losses
Exemple #12
0
def run_train(cfg, rank, model, data, loader):
    this_proc_prints = (rank == 0 and cfg.use_ddp) or (cfg.use_ddp is False)

    s = int(npr.rand() * 5 + 1)
    time.sleep(s)

    hyperparams = load_hyperparameters(cfg)

    criterion_inputs = [hyperparams]
    criterion = ClBlockLoss(hyperparams, cfg.N, cfg.batch_size)
    criterion = criterion.to(cfg.device)

    optimizer = load_optimizer(cfg, model)
    scheduler = load_scheduler(cfg, optimizer, len(loader.tr))
    print("Loaded optimizer: ")
    print(optimizer)
    print("Loaded scheduler: ")
    print(scheduler)

    # apply apex
    if cfg.use_apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    # init writer
    if this_proc_prints:
        datetime_now = datetime.datetime.now().strftime("%b%d_%H-%M-%S")
        writer_dir = cfg.summary_log_dir / Path(datetime_now)
        writer = SummaryWriter(writer_dir)
    else:
        writer = None

    # init training loop
    global_step, current_epoch = get_model_epoch_info(cfg)
    cfg.global_step = global_step
    cfg.current_epoch = current_epoch

    # training loop
    loop_scheduler = get_loop_scheduler(scheduler)
    test_losses = {}

    # if this_proc_prints:
    #     spawn_split_eval(cfg,'val',writer,24)

    print(f"cfg.epochs: {cfg.epochs}")
    print(f"cfg.use_apex: {cfg.use_apex}")
    print(f"cfg.use_bn: {cfg.use_bn}")
    print("len of loader.val", len(loader.val))
    print(f"cfg.optim_type: {cfg.optim_type}")
    print(f"cfg.checkpoint_interval: {cfg.checkpoint_interval}")
    print(_build_v2_summary(cfg))
    t = Timer()
    for epoch in range(cfg.current_epoch, cfg.epochs):
        t.tic()
        loss_epoch = train_loop(cfg, loader.tr, model, criterion, optimizer,
                                epoch, writer, loop_scheduler)
        t.toc()
        lr = optimizer.param_groups[0]["lr"]
        # print(t)

        # if ms_scheduler:
        #     ms_scheduler.step()

        if epoch % cfg.checkpoint_interval == 0 and this_proc_prints and epoch > 0:
            save_simcl_model(cfg, model, optimizer)

        if scheduler and loop_scheduler is None and epoch % cfg.val_interval == 0 and epoch > 0:
            val_loss = test_loop(cfg, model, 'val')
            scheduler.step(val_loss)
            if this_proc_prints:
                writer.add_scalar("Loss/val", val_loss, epoch)
        elif epoch % cfg.val_interval == 0 and this_proc_prints and epoch > 0:
            if this_proc_prints:
                # spawn_split_eval(cfg,'val',writer,epoch)
                val_loss = test_loop(cfg, model, 'val')
                writer.add_scalar("Loss/val", val_loss, epoch)

        if epoch % cfg.test_interval == 0 and epoch > 0:
            if this_proc_prints:
                # spawn_split_eval(cfg,'test',writer,epoch)
                te_loss = test_loop(cfg, model, 'test')
                writer.add_scalar("Loss/test", te_loss, epoch)

        if this_proc_prints:
            writer.add_scalar("Loss/train", loss_epoch / len(loader.tr), epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            msg = f"Epoch [{epoch}/{cfg.epochs}]\t"
            msg += f"Loss: {loss_epoch / len(loader.tr)}\t"
            msg += "{:2.3e}".format(lr)
            print(msg)
        cfg.current_epoch += 1

    if this_proc_prints:
        te_loss = test_loop(cfg, model, 'test')
        writer.add_scalar("Loss/test", te_loss, epoch)
    save_simcl_model(cfg, model, optimizer)
Exemple #13
0
def thtrain_denoising(cfg,
                      train_loader,
                      model,
                      criterion,
                      optimizer,
                      epoch,
                      writer,
                      scheduler=None):

    model.train()
    # model.encoder.eval()
    idx = 0
    loss_epoch = 0
    data = train_loader.dataset.data
    print("N samples:", len(data))
    simcl_t, loss_t, optim_t = Timer(), Timer(), Timer()
    for batch_idx, (noisy_imgs, raw_img) in enumerate(train_loader):

        optimizer.zero_grad()

        # setup the forward pass
        idx += cfg.batch_size

        noisy_imgs = noisy_imgs.cuda(non_blocking=True)

        simcl_t.tic()
        dec_imgs, proj = model(noisy_imgs)
        simcl_t.toc()

        loss_t.tic()
        # print(noisy_imgs.mean().item(),noisy_imgs.max().item(),noisy_imgs.min().item())
        # print(dec_imgs.mean().item(),dec_imgs.max().item(),dec_imgs.min().item())
        loss = criterion(noisy_imgs, dec_imgs, proj)
        loss_t.toc()

        # print(dec_imgs.mean(),dec_imgs.min(),dec_imgs.max())
        # compute gradients
        if cfg.use_apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        # print(loss.item())
        # print(loss.grad)
        # psum = 0
        # for param in model.decoder.parameters():
        #     pnorm = param.grad.norm()
        #     print(pnorm)
        #     psum += pnorm
        # print("Overall: {:2.3e}".format(psum))
        # exit()

        # update weights
        optim_t.tic()
        optimizer.step()
        optim_t.toc()

        if scheduler:
            scheduler.step()

        # print updates
        if writer:
            writer.add_scalar("Loss/train_epoch", loss.item(), cfg.global_step)
        cfg.global_step += 1
        if batch_idx % cfg.log_interval == 0 and writer:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, cfg.world_size * batch_idx * cfg.batch_size,
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

        loss_epoch += loss.item()
    print(simcl_t)
    print(loss_t)
    print(optim_t)
    return loss_epoch
Exemple #14
0
def compare_sim_images_methods(cfg,burst,K,patchsize=3):
    b = 0
    n = 0
    ps = patchsize
    img = burst[n,b]
    single_burst = img.unsqueeze(0).unsqueeze(0)

    img_rs = rearrange(img,'c h w -> h w c').cpu()
    t_1 = Timer()
    t_1.tic()
    sim_burst_v1 = rearrange(torch.tensor(compute_sim_images(img_rs, patchsize, K)),'k h w c -> k c h w')
    t_1.toc()

    t_2 = Timer()
    t_2.tic()
    sim_burst_v2 = compute_similar_bursts(cfg,single_burst,K,patchsize=3,shuffle_k=False)[0,0,1:]
    t_2.toc()

    print(t_1,t_2)
    
    print("v1",sim_burst_v1.shape)
    print("v2",sim_burst_v2.shape)
    for k in range(K):
        print("mse-v1-{}".format(k),F.mse_loss(sim_burst_v1[k].cpu(),img.cpu()))
        print("mse-v2-{}".format(k),F.mse_loss(sim_burst_v2[k].cpu(),img.cpu()))
        print("mse-{}".format(k),F.mse_loss(sim_burst_v1[k].cpu(),sim_burst_v2[k].cpu()))
Exemple #15
0
def main():

    args = get_args()
    cfg = get_cfg(args)

    gpuid = 2
    cfg.device = f"cuda:{gpuid}"
    torch.cuda.set_device(gpuid)

    cfg.S = 50000
    cfg.N = 30
    cfg.noise_type = 'g'
    cfg.noise_params['g']['stddev'] = 25
    cfg.dataset.name = "cifar10"
    cfg.dynamic = edict()
    cfg.dynamic.bool = False
    cfg.dynamic.ppf = 2
    cfg.dynamic.frames = cfg.N
    cfg.dynamic.mode = "global"
    cfg.dynamic.global_mode = "shift"
    cfg.dynamic.frame_size = 128
    cfg.dynamic.total_pixels = 20
    cfg.use_ddp = False
    cfg.use_collate = True
    cfg.set_worker_seed = False
    cfg.batch_size = 8
    cfg.num_workers = 4

    data, loader = get_dataset(cfg, 'single_denoising')
    # noisy_trans = data.tr._get_noise_transform(cfg.noise_type,cfg.noise_params[cfg.noise_type])
    # motion = GlobalCameraMotionTransform(cfg.dynamic,noisy_trans,True)
    noisy_l, res_l, raw_l = [], [], []
    timer = Timer()
    timer.tic()
    for index in range(cfg.batch_size):
        noisy, raw = data.tr[index]
        res = data.tr.noise_set[index]
        print(noisy.shape, raw.shape, res.shape)
        noisy_l.append(noisy), res_l.append(res), raw_l.append(raw)
    timer.toc()

    print(timer)
    noisy = torch.stack(noisy_l, dim=1)
    res = torch.stack(res_l, dim=1)
    raw = torch.stack(raw_l, dim=1)
    # noisy,raw = next(iter(loader.tr))
    # noisy += 0.5
    # noisy.clamp_(0,1.)

    # rec += 0.5
    # rec.clamp_(0,1.)

    raw = raw.expand(noisy.shape)

    print(noisy.shape)
    images = torch.cat([noisy, res, raw], dim=1)
    print("pre", images.shape)
    images = images.transpose(0, 1)
    print("post", images.shape)

    fig, ax = plt.subplots(figsize=(10, 10))
    grid = vutils.make_grid(images, nrow=8)
    print(grid.shape)
    ax.imshow(np.transpose(grid, (1, 2, 0)))
    path = f"{settings.ROOT_PATH}/output/vis_cifar10.png"
    plt.savefig(path)
    # grids = [vutils.make_grid(images[i],nrow=8) for i in range(cfg.dynamic.frames)]
    # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids]
    # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
    # Writer = animation.writers['ffmpeg']
    # writer = Writer(fps=3, metadata=dict(artist='Me'), bitrate=1800)
    # path = f"{settings.ROOT_PATH}/test_voc.mp4"
    # ani.save(path, writer=writer)
    print(f"Wrote to {path}")
Exemple #16
0
def main():

    args = get_args()
    cfg = get_cfg(args)

    cfg.S = 50000
    cfg.N = 30
    cfg.noise_type = 'g'
    cfg.noise_params['g']['stddev'] = 50
    cfg.dataset.name = "voc"
    cfg.dynamic = edict()
    cfg.dynamic.bool = True
    cfg.dynamic.ppf = 2
    cfg.dynamic.frames = cfg.N
    cfg.dynamic.mode = "global"
    cfg.dynamic.global_mode = "shift"
    cfg.dynamic.frame_size = 128
    cfg.dynamic.total_pixels = 20
    cfg.use_ddp = False
    cfg.use_collate = True
    cfg.set_worker_seed = False
    cfg.batch_size = 8
    cfg.num_workers = 4

    data, loader = get_dataset(cfg, 'dynamic')
    noisy_trans = data.tr._get_noise_transform(
        cfg.noise_type, cfg.noise_params[cfg.noise_type])
    motion = GlobalCameraMotionTransform(cfg.dynamic, noisy_trans, True)
    noisy_l, rec_l, raw_l = [], [], []
    timer = Timer()
    timer.tic()
    for index in range(cfg.batch_size):
        img = Image.open(data.tr.images[index])
        noisy, rec, raw = motion(img)
        noisy_l.append(noisy), rec_l.append(rec), raw_l.append(raw)
    timer.toc()
    print(timer)
    noisy = torch.stack(noisy_l, dim=1)
    rec = torch.stack(rec_l, dim=1)
    raw = torch.stack(raw_l)
    # noisy,raw = next(iter(loader.tr))
    noisy += 0.5
    noisy.clamp_(0, 1.)

    rec += 0.5
    rec.clamp_(0, 1.)

    raw = raw.expand(noisy.shape)
    images = torch.cat([noisy, rec, raw], dim=1)

    fig, ax = plt.subplots(figsize=(10, 10))
    grids = [
        vutils.make_grid(images[i], nrow=8) for i in range(cfg.dynamic.frames)
    ]
    ims = [[ax.imshow(np.transpose(i, (1, 2, 0)), animated=True)]
           for i in grids]
    ani = animation.ArtistAnimation(fig,
                                    ims,
                                    interval=1000,
                                    repeat_delay=1000,
                                    blit=True)
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=3, metadata=dict(artist='Me'), bitrate=1800)
    path = f"{settings.ROOT_PATH}/test_voc.mp4"
    ani.save(path, writer=writer)
    print(f"Wrote to {path}")