コード例 #1
0
ファイル: burst_kitti.py プロジェクト: gauenk/cl_gen
 def __init__(self,
              root,
              split,
              edition,
              nframes,
              noise_info,
              nnf_K,
              nnf_ps,
              nnf_exists=True):
     self.root = root
     paths = get_kitti_path(root)
     self.edition = edition
     self.paths = paths
     self.split = split
     self.istest = split == "test"
     self.nframes = nframes
     self.noise_info = noise_info
     self.nnf_K = nnf_K
     self.nnf_ps = nnf_ps
     self.nnf_exists = nnf_exists
     self.read_resize = (370, 1224)
     parts = self._get_split_parts_name(split)
     self.dataset = self._read_dataset_paths(paths, edition, parts, nframes,
                                             self.read_resize, nnf_K,
                                             nnf_ps, nnf_exists)
     self.noise_xform = get_noise_transform(noise_info, use_to_tensor=False)
     print(split, len(self.dataset['burst_id']))
コード例 #2
0
ファイル: eval_score.py プロジェクト: gauenk/cl_gen
def init_exp(cfg,exp):

    # -- set patchsize -- 
    cfg.patchsize = int(exp.patchsize)
    
    # -- set patchsize -- 
    cfg.nframes = int(exp.nframes)
    cfg.N = cfg.nframes

    # -- set number of blocks (old: neighborhood size) -- 
    cfg.nblocks = int(exp.nblocks)
    cfg.nh_size = cfg.nblocks # old name

    # -- get noise function --
    nconfig = get_noise_config(cfg,exp.noise_type)
    noise_xform = get_noise_transform(nconfig,use_to_tensor=False)
    
    # -- get dynamics function --
    cfg.dynamic.ppf = exp.ppf
    cfg.dynamic.bool = True
    cfg.dynamic.random_eraser = False
    cfg.dynamic.frame_size = cfg.frame_size
    cfg.dynamic.total_pixels = cfg.dynamic.ppf*(cfg.nframes-1)
    cfg.dynamic.frames = exp.nframes

    def nonoise(image): return image
    dynamic_info = cfg.dynamic
    dynamic_raw_xform = get_dynamic_transform(dynamic_info,nonoise)
    dynamic_xform = dynamic_wrapper(dynamic_raw_xform)

    # -- get score function --
    score_function = get_score_function(exp.score_function)

    return noise_xform,dynamic_xform,score_function
コード例 #3
0
    def __init__(self, root, split, isize, nsamples, noise_info, dynamic_info):

        # -- set init params --
        self.root = root
        self.split = split
        self.noise_info = noise_info
        self.dynamic_info = dynamic_info
        self.nsamples = nsamples
        self.isize = isize

        # -- create transforms --
        self.noise_trans = get_noise_transform(noise_info, noise_only=True)
        self.dynamic_trans = get_dynamic_transform(dynamic_info, None,
                                                   load_res)

        # -- load paths --
        self.paths = []

        # -- limit num of samples --
        self.indices = enumerate_indices(len(self.paths), nsamples)
        self.nsamples = len(self.indices)

        # -- single random dynamics --
        self.dyn_once = return_optional(dynamic_info, "sim_once", False)
        self.fixRandDynamics = RandomOnce(self.dyn_once, nsamples)

        # -- single random noise --
        self.noise_once = return_optional(noise_info, "sim_once", False)
        self.fixRandNoise_1 = RandomOnce(self.noise_once, nsamples)
        self.fixRandNoise_2 = RandomOnce(self.noise_once, nsamples)
コード例 #4
0
def transforms_from_cfg(cfg):

    # -- noise transform --
    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -- simple functions for compat. --
    def dynamic_wrapper(dynamic_raw_xform):
        def wrapped(image):
            pil_image = tvT.ToPILImage()(image).convert("RGB")
            results = dynamic_raw_xform(pil_image)
            burst = results[0] + 0.5
            flow = results[3]
            return burst, flow

        return wrapped

    def nonoise(image):
        return image

    # -- dynamics --
    dynamic_info = cfg.dynamic
    dynamic_raw_xform = get_dynamic_transform(dynamic_info, nonoise)
    dynamic_xform = dynamic_wrapper(dynamic_raw_xform)

    return noise_xform, dynamic_xform
コード例 #5
0
    def __init__(self, iroot, froot, sroot, split, isize, ps, nsamples,
                 nframes, noise_info):

        # -- set init params --
        self.iroot = iroot
        self.froot = froot
        self.sroot = sroot
        self.split = split
        self.noise_info = noise_info
        self.ps = ps
        self.nsamples = nsamples
        self.isize = isize

        # -- create transforms --
        self.noise_trans = get_noise_transform(noise_info, noise_only=True)

        # -- load paths --
        self.paths, self.nframes, all_eq = read_files(iroot, froot, sroot,
                                                      split, isize, ps,
                                                      nframes)
        if not (all_eq):
            print("\n\n\n\nWarning: Not all bursts are same length!!!\n\n\n\n")
        self.groups = sorted(list(self.paths['images'].keys()))

        # -- limit num of samples --
        self.indices = enumerate_indices(len(self.paths['images']), nsamples)
        self.nsamples = len(self.indices)

        # -- single random noise --
        self.noise_once = return_optional(noise_info, "sim_once", False)
        self.fixRandNoise_1 = RandomOnce(self.noise_once, nsamples)
        self.fixRandNoise_2 = RandomOnce(self.noise_once, nsamples)
コード例 #6
0
ファイル: interface.py プロジェクト: gauenk/cl_gen
def get_align_noise(align_name):
    if align_name == "same":  # use the same noisy samples given

        def align_noise_fxn(noisy, clean):
            return noisy
    else:  # use a different noisy sample with a different noise level
        apply_noise = get_noise_transform(align_name, noise_only=True)

        def align_noise_fxn(noisy, clean):
            return apply_noise(clean)

    return align_noise_fxn
コード例 #7
0
ファイル: explore_fast_unet.py プロジェクト: gauenk/cl_gen
def single_image_unet(cfg, queue, full_image, device):
    full_image = full_image.to(device)
    image = tvF.crop(full_image, 128, 128, 32, 32)
    T = 5

    # -- poisson noise --
    noise_type = "pn"
    cfg.noise_type = noise_type
    cfg.noise_params['pn']['alpha'] = 40.0
    cfg.noise_params['pn']['readout'] = 0.0
    cfg.noise_params.ntype = cfg.noise_type
    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    clean = torch.stack([image for i in range(T)], dim=0)
    noisy = noise_xform(clean)
    save_image(clean, "clean.png", normalize=True)
    m_clean, m_noisy = clean.clone(), noisy.clone()
    for i in range(T // 2):
        image_mis = tvF.crop(full_image, 128 + 1, 128, 32, 32)
        m_clean[i] = image_mis
        m_noisy[i] = noise_xform(m_clean[i])

    noise_level = 50. / 255.

    # -- model all --
    print("-- All Aligned --")
    model = UNet_small(3)  # UNet_n2n(1)
    cfg.init_lr = 1e-4
    optim = torch.optim.Adam(model.parameters(),
                             lr=cfg.init_lr,
                             betas=(0.9, 0.99))
    train(cfg, image, clean, noisy, model, optim)
    results = test(cfg, image, clean, noisy, model, 0)
    rec = model(noisy) + 0.5
    save_image(rec, "rec_all.png", normalize=True)
    print("Single Image Unet:")
    print(images_to_psnrs(clean, rec))
    print(images_to_psnrs(rec - 0.5, noisy))
    print(images_to_psnrs(rec[[0]], rec[[1]]))
    print(images_to_psnrs(rec[[0]], rec[[2]]))
    print(images_to_psnrs(rec[[1]], rec[[2]]))

    og_clean, og_noisy = clean.clone(), noisy.clone()
    clean[0] = image_mis
    noisy[0] = noise_xform(clean[[0]])[0]

    # -- model all --
    print("-- All Misligned --")
    model = UNet_small(3)  # UNet_n2n(1)
    cfg.init_lr = 1e-4
    optim = torch.optim.Adam(model.parameters(),
                             lr=cfg.init_lr,
                             betas=(0.9, 0.99))
    train(cfg, image, clean, noisy, model, optim)
    results = test(cfg, image, clean, noisy, model, 0)
    rec = model(noisy) + 0.5
    save_image(rec, "rec_all.png", normalize=True)
    print("Single Image Unet:")
    print(images_to_psnrs(clean, rec))
    print(images_to_psnrs(rec - 0.5, noisy))
    print(images_to_psnrs(rec[[0]], rec[[1]]))
    print(images_to_psnrs(rec[[0]], rec[[2]]))
    print(images_to_psnrs(rec[[1]], rec[[2]]))

    for j in range(3):
        # -- data --
        noisy1 = torch.stack([noisy[0], noisy[1]], dim=0)
        clean1 = torch.stack([clean[0], clean[1]], dim=0)

        noisy2 = torch.stack([noisy[1], noisy[2]], dim=0)
        clean2 = torch.stack([clean[1], clean[2]], dim=0)

        noisy3 = torch.stack([og_noisy[0], noisy[1]], dim=0)
        clean3 = torch.stack([og_clean[0], clean[1]], dim=0)

        # -- model 1 --
        model = UNet_small(3)  # UNet_n2n(1)
        cfg.init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=cfg.init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean1, noisy1, model, optim)
        results = test(cfg, image, clean1, noisy1, model, 0)
        rec = model(noisy1) + 0.5
        xrec = model(noisy2) + 0.5
        save_image(rec, "rec1.png", normalize=True)
        print("[misaligned] Single Image Unet:", images_to_psnrs(clean1, rec),
              images_to_psnrs(clean2, xrec),
              images_to_psnrs(rec - 0.5, noisy1),
              images_to_psnrs(xrec - 0.5, noisy2),
              images_to_psnrs(rec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[0]], xrec[[1]]),
              images_to_psnrs(xrec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[1]], rec[[0]]))

        # -- model 2 --
        model = UNet_small(3)  # UNet_n2n(1)
        cfg.init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=cfg.init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean2, noisy2, model, optim)
        results = test(cfg, image, clean2, noisy2, model, 0)
        rec = model(noisy2) + 0.5
        xrec = model(noisy1) + 0.5
        save_image(rec, "rec2.png", normalize=True)
        print("[aligned] Single Image Unet:", images_to_psnrs(clean2, rec),
              images_to_psnrs(clean1, xrec),
              images_to_psnrs(rec - 0.5, noisy2),
              images_to_psnrs(xrec - 0.5, noisy1),
              images_to_psnrs(rec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[0]], xrec[[1]]),
              images_to_psnrs(xrec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[1]], rec[[0]]))

        # -- model 3 --
        model = UNet_small(3)  # UNet_n2n(1)
        cfg.init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=cfg.init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean3, noisy3, model, optim)
        results = test(cfg, image, clean3, noisy3, model, 0)
        rec = model(noisy3) + 0.5
        rec_2 = model(noisy2) + 0.5
        rec_1 = model(noisy1) + 0.5
        save_image(rec, "rec1.png", normalize=True)
        print("[aligned (v3)] Single Image Unet:")
        print("clean-rec", images_to_psnrs(clean3, rec))
        print("clean1-rec1", images_to_psnrs(clean1, rec_1))
        print("clean2-rec2", images_to_psnrs(clean2, rec_2))
        print("rec-noisy3", images_to_psnrs(rec - 0.5, noisy3))
        print("rec1-noisy1", images_to_psnrs(rec_1 - 0.5, noisy1))
        print("rec2-noisy2", images_to_psnrs(rec_2 - 0.5, noisy2))
        print("[v3]: rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))
        print("[v1]: rec0-rec1", images_to_psnrs(rec_1[[0]], rec_1[[1]]))
        print("[v2]: rec0-rec1", images_to_psnrs(rec_2[[0]], rec_2[[1]]))
        print("[v1-v2](a):", images_to_psnrs(rec_1[[1]], rec_2[[1]]))
        print("[v1-v2](b):", images_to_psnrs(rec_1[[1]], rec_2[[0]]))
        print("[v1-v2](c):", images_to_psnrs(rec_1[[0]], rec_2[[1]]))
        print("[v1-v2](d):", images_to_psnrs(rec_1[[0]], rec_2[[0]]))
        print("-" * 20)
        print("[v2-v3](a):", images_to_psnrs(rec_2[[1]], rec[[1]]))
        print("[v2-v3](b):", images_to_psnrs(rec_2[[1]], rec[[0]]))
        print("[v2-v3](c):", images_to_psnrs(rec_2[[0]], rec[[1]]))
        print("[v2-v3](d):", images_to_psnrs(rec_2[[0]], rec[[0]]))
        print("-" * 20)
        print("[v1-v3](a):", images_to_psnrs(rec_1[[1]], rec[[1]]))
        print("[v1-v3](b):", images_to_psnrs(rec_1[[1]], rec[[0]]))
        print("[v1-v3](c):", images_to_psnrs(rec_1[[0]], rec[[1]]))
        print("[v1-v3](d):", images_to_psnrs(rec_1[[0]], rec[[0]]))
        print("-" * 20)
        print("rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))
コード例 #8
0
ファイル: explore_fast_unet.py プロジェクト: gauenk/cl_gen
def run_experiment(cfg, data, record_fn, bss_fn):

    # -- setup noise --
    cfg.noise_type = 'g'
    cfg.ntype = cfg.noise_type
    cfg.noise_params.ntype = cfg.noise_type
    noise_level = 50.
    cfg.noise_params['g']['stddev'] = noise_level
    noise_level_str = f"{int(noise_level)}"
    # nconfig = get_noise_config(cfg,exp.noise_type)
    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -- set configs --
    T = cfg.nframes
    H = cfg.nblocks

    # -- create our neighborhood --
    full_image = data.tr[0][2]

    clean = []
    # tl_list = [[0,0],[1,0],[0,1]]
    tl_list = np.zeros((T, 2)).astype(np.int)  #[[0,0],[0,0],[0,0]]
    for t in range(T):
        clean_t = []
        t, l = tl_list[t]
        for i in range(-H // 2 + 1, H // 2 + 1):
            for j in range(-H // 2 + 1, H // 2 + 1):
                clean_t.append(
                    tvF.crop(full_image, t + 128 + i, l + 128 + j, 32, 32))
        clean_t = torch.stack(clean_t, dim=0)
        clean.append(clean_t)
    clean = torch.stack(clean, dim=0)
    REF_H = get_ref_block_index(cfg.nblocks)
    image = clean[T // 2, REF_H]

    # -- normalize --
    clean -= clean.min()
    clean /= clean.max()
    save_image(clean, "fast_unet_clean.png", normalize=True)
    save_image(image, "fast_unet_image.png", normalize=True)

    # -- apply noise --
    # torch.manual_seed(123)
    noisy = noise_xform(clean)

    # aveT = torch.mean(burst,dim=0)
    # print("Ave MSE: %2.3e" % images_to_psnrs(aveT.unsqueeze(0),image.unsqueeze(0)))
    record = []

    gridT = torch.arange(T)
    # -- exhaustive --
    block_search_space = get_block_arangements_subset(cfg.nblocks,
                                                      cfg.nframes,
                                                      tcount=3)
    # block_search_space = get_block_arangements(cfg.nblocks,cfg.nframes)
    # -- random subset --
    use_rand = True
    if use_rand:
        if bss_fn.exists() and False:
            print(f"Reading bss {bss_fn}")
            block_search_space = np.load(bss_fn, allow_pickle=True)
        else:
            bss = block_search_space
            print(f"Original block search space: [{len(bss)}]")
            if len(block_search_space) >= 100:
                rand_blocks = random.sample(list(block_search_space), 100)
                block_search_space = [
                    np.array([REF_H] * T),
                ]  # include gt
                block_search_space.extend(rand_blocks)
            bss = block_search_space
            print(f"Writing block search space: [{bss_fn}]")
            np.save(bss_fn, np.array(block_search_space))
    print(f"Search Space Size: {len(block_search_space)}")

    idx = 0
    clean = clean.to(cfg.device)
    noisy = noisy.to(cfg.device)
    for prop in tqdm(block_search_space):
        # -- fetch --
        clean_prop = clean[gridT, prop]
        noisy_prop = noisy[gridT, prop]
        if idx == 49 or idx == 40 or idx == 40:
            save_image(noisy_prop, f"noisy_prop_{idx}.png", normalize=True)
            save_image(clean_prop, f"clean_prop_{idx}.png", normalize=True)

        # -- compute again --
        train_steps = 500
        cog = COG(UNet_small,
                  T,
                  noisy.device,
                  nn_params=None,
                  train_steps=train_steps)
        cog.train_models(noisy_prop)
        recs = cog.test_models(noisy_prop)
        score = cog.operator_consistency(recs, noisy_prop)
        results = fill_results(cfg, image, clean_prop, noisy_prop, None, idx)
        results['cog'] = score
        print(score, prop)

        # -- compute --
        # model = [UNet_small(3),UNet_small(3)]
        # model = UNet_small(3) # UNet_n2n(1)
        # cfg.init_lr = 1e-4
        # optim = torch.optim.Adam(model.parameters(),lr=cfg.init_lr,betas=(0.9,0.99))
        # train(cfg,image,clean_prop,noisy_prop,model,optim)
        # results = test(cfg,image,clean_prop,noisy_prop,model,idx)

        # -- update --
        record.append(results)
        idx += 1

    record = pd.DataFrame(record)
    print(f"Writing record to {record_fn}")
    record.to_csv(record_fn)
    return record
コード例 #9
0
ファイル: learn.py プロジェクト: gauenk/cl_gen
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses,
               writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0
    dynamics_acc, dynamics_count = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-
    dynamics_acc_i = -1.
    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    train_iter = iter(train_loader)
    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        if small_ds and batch_idx >= ds_size:
            if cfg.use_seed:
                init = torch.initial_seed()
                torch.manual_seed(cfg.seed + 1 + epoch + init)
            train_iter = iter(train_loader)  # reset if too big
        sample = next(train_iter)
        burst, raw_img, motion = sample['burst'], sample['clean'], sample[
            'flow']
        raw_img_iid = sample['iid']
        raw_img_iid = raw_img_iid.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)

        # -- handle possibly cached simulated bursts --
        if 'sim_burst' in sample:
            sim_burst = rearrange(sample['sim_burst'],
                                  'b n k c h w -> n b k c h w')
        else:
            sim_burst = None
        non_sim_method = cfg.n2n or cfg.supervised
        if sim_burst is None and not (non_sim_method or cfg.abps):
            if sim_burst is None:
                if cfg.use_kindex_lmdb:
                    kindex = kindex_ds[batch_idx].cuda(non_blocking=True)
                else:
                    kindex = None
                query = burst[[N // 2]]
                database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]])
                sim_burst = compute_similar_bursts(
                    cfg,
                    query,
                    database,
                    cfg.sim_K,
                    noise_level / 255.,
                    patchsize=cfg.sim_patchsize,
                    shuffle_k=cfg.sim_shuffleK,
                    kindex=kindex,
                    only_middle=cfg.sim_only_middle,
                    search_method=cfg.sim_method,
                    db_level="frame")

        if (sim_burst is None) and cfg.abps:
            # scores,aligned = abp_search(cfg,burst)
            # scores,aligned = lpas_search(cfg,burst,motion)
            if cfg.lpas_method == "spoof":
                mtype = "global"
                acc = cfg.optical_flow_acc
                scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype,
                                             acc)
            else:
                ref_frame = (cfg.nframes + 1) // 2
                nblocks = cfg.nblocks
                method = cfg.lpas_method
                scores, aligned, dacc = lpas_search(burst, ref_frame, nblocks,
                                                    motion, method)
                dynamics_acc_i = dacc
            # scores,aligned = lpas_spoof(motion,accuracy=cfg.optical_flow_acc)
            # shuffled = shuffle_aligned_pixels_noncenter(aligned,cfg.nframes)
            nsims = cfg.nframes
            sim_aligned = create_sim_from_aligned(burst, aligned, nsims)
            burst_s = rearrange(burst, 't b c h w -> t b 1 c h w')
            sim_burst = torch.cat([burst_s, sim_aligned], dim=2)
            # print("sim_burst.shape",sim_burst.shape)

        # raw_img = raw_img.cuda(non_blocking=True)-0.5
        # # print(np.sqrt(cfg.noise_params['g']['stddev']))
        # print(motion)
        # tiled = tile_across_blocks(burst[[cfg.nframes//2]],cfg.nblocks)
        # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2])
        # for t in range(cfg.nframes):
        #     save_image(tiled[0] - rep_burst[t],f"tiled_sub_burst_{t}.png")
        # save_image(aligned,"aligned.png")
        # print(aligned.shape)
        # # save_image(aligned[0] - aligned[cfg.nframes//2],"aligned_0.png")
        # # save_image(aligned[2] - aligned[cfg.nframes//2],"aligned_2.png")
        # M = (1+cfg.dynamic.ppf)*cfg.nframes
        # fs = cfg.dynamic.frame_size - M
        # fs = cfg.frame_size
        # cropped = crop_center_patch([burst,aligned,raw_img],cfg.nframes,cfg.frame_size)
        # burst,aligned,raw_img = cropped[0],cropped[1],cropped[2]
        # print(aligned.shape)
        # for t in range(cfg.nframes+1):
        #     diff_t = aligned[t] - raw_img
        #     spacing = cfg.nframes+1
        #     diff_t = crop_center_patch([diff_t],spacing,cfg.frame_size)[0]
        #     print_tensor_stats(f"diff_aligned_{t}",diff_t)
        #     save_image(diff_t,f"diff_aligned_{t}.png")
        #     if t < cfg.nframes:
        #         dt = aligned[t+1]-aligned[t]
        #         dt = crop_center_patch([dt],spacing,cfg.frame_size)[0]
        #         save_image(dt,f"dt_aligned_{t+1}m{t}.png")
        #     save_image(aligned[t],f"aligned_{t}.png")
        #     diff_t = tvF.crop(aligned[t] - raw_img,cfg.nframes,cfg.nframes,fs,fs)
        #     print_tensor_stats(f"diff_aligned_{t}",diff_t)

        # save_image(burst,"burst.png")
        # save_image(burst[0] - burst[cfg.nframes//2],"burst_0.png")
        # save_image(burst[2] - burst[cfg.nframes//2],"burst_2.png")
        # exit()

        # print(sample['burst'].shape,sample['res'].shape)
        # b_clean = sample['burst'] - sample['res']
        # scores,ave,t_aligned = test_abp_global_search(cfg,b_clean,noisy_img=burst)

        # burstBN = rearrange(burst,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(burstBN,"abps_burst.png",normalize=True)
        # alignedBN = rearrange(aligned,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(alignedBN,"abps_aligned.png",normalize=True)
        # rep_burst = burst[[N//2]].repeat(N,1,1,1,1)
        # deltaBN = rearrange(aligned - rep_burst,'n b c h w -> (b n) c h w')
        # tv_utils.save_image(deltaBN,"abps_delta.png",normalize=True)
        # b_clean_rep = b_clean[[N//2]].repeat(N,1,1,1,1)
        # tdeltaBN = rearrange(t_aligned - b_clean_rep.cpu(),'n b c h w -> (b n) c h w')
        # tv_utils.save_image(tdeltaBN,"abps_tdelta.png",normalize=True)

        if non_sim_method:
            sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1)
        else:
            sim_burst = sim_burst.cuda(non_blocking=True)
        if use_timer: data_clock.toc()

        # -- to cuda --
        burst = burst.cuda(non_blocking=True)
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        # anscombe.test(cfg,burst_og)
        # save_image(burst,f"burst_{batch_idx}_{cfg.n2n}.png")

        # -- crop images --
        if True:  #cfg.abps or cfg.abps_inputs:
            images = [burst, sim_burst, raw_img, raw_img_iid]
            spacing = burst.shape[0]  # we use frames as spacing
            cropped = crop_center_patch(images, spacing, cfg.frame_size)
            burst, sim_burst = cropped[0], cropped[1]
            raw_img, raw_img_iid = cropped[2], cropped[3]
            if cfg.abps or cfg.abps_inputs:
                aligned = crop_center_patch([aligned], spacing,
                                            cfg.frame_size)[0]
            # print_tensor_stats("d-eq?",burst[-1] - aligned[-1])
            burst = burst[:cfg.nframes]  # last frame is target

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst_og = burst.clone()

        # -- shuffle over Simulated Samples --
        k_ins, k_outs = create_k_grid(sim_burst, shuffle=True)
        k_ins, k_outs = [k_ins[0]], [k_outs[0]]
        # k_ins,k_outs = create_k_grid_v3(sim_burst)

        for k_in, k_out in zip(k_ins, k_outs):
            if k_in == k_out: continue

            # -- zero gradients; ready 2 go --
            model.align_info.model.zero_grad()
            model.align_info.optim.zero_grad()
            model.denoiser_info.model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            model.unet_info.model.zero_grad()
            model.unet_info.optim.zero_grad()

            # -- compute input/output data --
            if cfg.sim_only_middle and (not cfg.abps):
                # sim_burst.shape == T,B,K,C,H,W
                midi = 0 if sim_burst.shape[0] == 1 else N // 2
                left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:]
                cat_burst = [
                    left_burst, sim_burst[[midi], :, k_in], right_burst
                ]
                burst = torch.cat(cat_burst, dim=0)
                mid_img = sim_burst[midi, :, k_out]
            elif cfg.abps and (not cfg.abps_inputs):
                # -- v1 --
                mid_img = aligned[-1]

                # -- v2 --
                # left_aligned,right_aligned = aligned[:N//2],aligned[N//2+1:]
                # nc_aligned = torch.cat([left_aligned,right_aligned],dim=0)
                # shuf = shuffle_aligned_pixels(nc_aligned,cfg.nframes)
                # mid_img = shuf[1]

                # ---- v3 ----
                # shuf = shuffle_aligned_pixels(aligned)
                # shuf = aligned[[N//2,0]]
                # midi = 0 if sim_burst.shape[0] == 1 else N//2
                # left_burst,right_burst = burst[:N//2],burst[N//2+1:]
                # burst = torch.cat([left_burst,shuf[[0]],right_burst],dim=0)
                # nc_burst = torch.cat([left_burst,right_burst],dim=0)
                # shuf = shuffle_aligned_pixels(aligned)

                # ---- v4 ----
                # nc_shuf = shuffle_aligned_pixels(nc_aligned)
                # mid_img = nc_shuf[0]
                # pick = npr.randint(0,2,size=(1,))[0]
                # mid_img = nc_aligned[pick]
                # mid_img = shuf[1]

                # save_image(shuf,"shuf.png")
                # print(shuf.shape)

                # diff = raw_img.cuda(non_blocking=True) - aligned[0]
                # mean = torch.mean(diff).item()
                # std = torch.std(diff).item()
                # print(mean,std)

                # -- v1 --
                # burst = burst
                # notMid = sample_not_mid(N)
                # mid_img = aligned[notMid]

            elif cfg.abps_inputs:
                burst = aligned.clone()
                burst_og = aligned.clone()
                mid_img = shuffle_aligned_pixels(burst, cfg.nframes)[0]

            else:
                burst = sim_burst[:, :, k_in]
                mid_img = sim_burst[N // 2, :, k_out]
            # mid_img =  sim_burst[N//2,:]
            # print(burst.shape,mid_img.shape)
            # print(F.mse_loss(burst,mid_img).item())
            if cfg.supervised:
                gt_img = get_nmlz_tgt_img(cfg, raw_img).cuda(non_blocking=True)
            elif cfg.n2n:
                gt_img = raw_img_iid  #noise_xform(raw_img).cuda(non_blocking=True)
            else:
                gt_img = mid_img

            # another = noise_xform(raw_img).cuda(non_blocking=True)
            # print_tensor_stats("a-iid?",raw_img_iid.cuda() - raw_img.cuda())
            # print_tensor_stats("b-iid?",mid_img.cuda() - raw_img.cuda())
            # print_tensor_stats("c-iid?",mid_img.cuda() - another)
            # print_tensor_stats("d-iid?",raw_img_iid.cuda() - another)
            # print_tensor_stats("e-iid?",mid_img.cuda() - raw_img_iid.cuda())

            # for bt in range(cfg.nframes):
            #     tiled = tile_across_blocks(burst[[bt]],cfg.nblocks)
            #     rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2])
            #     for t in range(cfg.nframes):
            #         save_image(tiled[0] - rep_burst[t],f"tiled_{bt}_sub_burst_{t}.png")
            #         print_tensor_stats(f"delta_{bt}_{t}",tiled[0,:,4] - burst[t])

            # raw_img = raw_img.cuda(non_blocking=True) - 0.5
            # print_tensor_stats("gt_img - raw",gt_img - raw_img)
            # # save_image(gt_img,"gt.png")
            # # save_image(raw,"raw.png")
            # save_image(gt_img - raw_img,"gt_sub_raw.png")
            # print_tensor_stats("burst[N//2] - raw",burst[N//2] - raw_img)
            # save_image(burst[N//2] - raw_img,"burst_sub_raw.png")
            # print_tensor_stats("burst[N//2] - gt_img",burst[N//2] - gt_img)
            # save_image(burst[N//2] - gt_img,"burst_sub_gt.png")
            # print_tensor_stats("aligned[N//2] - raw",aligned[N//2] - raw_img)
            # save_image(aligned[N//2] - raw_img,"aligned_sub_raw.png")
            # print_tensor_stats("aligned[N//2] - burst[N//2]",
            # aligned[N//2] - burst[N//2])
            # save_image(aligned[N//2] - burst[N//2],"aligned_sub_burst.png")
            # gt_img = torch.normal(raw_zm_img,noise_level/255.)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Dataset Augmentation
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # burst,gt_img = apply_transformations(burst,gt_img)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #      Formatting Images for FP
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
            cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #           Foward Pass
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            outputs = model(burst)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #    Decrease Entropy within a Kernel
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            filters_entropy = 0
            filters_entropy_coeff = 0.  # 1000.
            all_filters = []
            L = len(align_hook.filters)
            iter_filters = align_hook.filters if L > 0 else [aligned_filters]
            for filters in iter_filters:
                f_shape = 'b n k2 c h w -> (b n c h w) k2'
                filters_shaped = rearrange(filters, f_shape)
                filters_entropy += one  #entropyLoss(filters_shaped)
                all_filters.append(filters)
            if L > 0: filters_entropy /= L
            all_filters = torch.stack(all_filters, dim=1)
            align_hook.clear()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #   Reconstruction Losses (MSE)
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            losses = [F.mse_loss(denoised_ave, gt_img)]
            # losses = denoiseLossMSE(denoised,denoised_ave,gt_img,cfg.global_step)
            # losses = [ one, one ]
            # ave_loss,burst_loss = [loss.item() for loss in losses]
            rec_mse = np.sum(losses)
            # rec_mse = F.mse_loss(denoised_ave,gt_img)
            rec_mse_coeff = 1.

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #    Reconstruction Losses (Distribution)
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            gt_img_rep = gt_img.unsqueeze(1).repeat(1, denoised.shape[1], 1, 1,
                                                    1)
            residuals = denoised - gt_img_rep
            rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            # rec_ot = kl_gaussian_bp(residuals,noise_level,flip=True)
            # rec_ot = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
            if torch.any(torch.isnan(rec_ot)):
                rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            if torch.any(torch.isinf(rec_ot)):
                rec_ot = torch.FloatTensor([0.]).to(cfg.device)
            rec_ot_coeff = 0.

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Final Losses
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            rec_loss = rec_mse_coeff * rec_mse + rec_ot_coeff * rec_ot
            final_loss = rec_loss

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #              Record Keeping
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- reconstruction MSE --
            rec_mse_losses += rec_mse.item()
            rec_mse_count += 1

            # -- reconstruction Dist. --
            rec_ot_losses += rec_ot.item()
            rec_ot_count += 1

            # -- dynamic acc -
            dynamics_acc += dynamics_acc_i
            dynamics_count += 1

            # -- total loss --
            running_loss += final_loss.item()
            total_loss += final_loss.item()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #
            #        Gradients & Backpropogration
            #
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- compute the gradients! --
            if cfg.use_seed: torch.set_deterministic(False)
            final_loss.backward()
            if cfg.use_seed: torch.set_deterministic(True)

            # -- backprop now. --
            model.align_info.optim.step()
            model.denoiser_info.optim.step()
            model.unet_info.optim.step()
            scheduler.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- recompute model output for original images --
            outputs = model(burst_og)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            raw_img = get_nmlz_tgt_img(cfg, raw_img)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, m_aligned_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(burst_og, dim=0)
            if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave)
            mse_loss = F.mse_loss(raw_img, mis_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25))
            # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25))

            # -- psnr for [bm3d] --
            mid_img_og = burst[N // 2]
            bm3d_nb_psnrs = []
            M = 4 if B > 4 else B
            for b in range(M):
                bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                # maybe an issue here
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for input averaged frames --
            # burst_ave = torch.mean(burst_og,dim=0)
            # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1)
            # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            # psnr_input_ave = np.mean(mse_to_psnr(mse_loss))
            # psnr_input_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for aligned + denoised --
            R = denoised.shape[1]
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1)
            # if noise_type == "qis": denoised = quantize_img(cfg,denoised)
            # save_image(denoised_ave,"denoised_ave.png")
            # save_image(denoised,"denoised.png")
            mse_loss = F.mse_loss(raw_img_repN, denoised,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img, denoised_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- ave dynamic acc --
            ave_dyn_acc = dynamics_acc / dynamics_count * 100.
            dynamics_acc, dynamics_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, ave_dyn_acc)  #rec_ot_ave)

            #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e"
                % write_info,
                flush=True)
            # -- write to summary writer --
            if writer:
                writer.add_scalar('train/running-loss', running_loss,
                                  cfg.global_step)
                writer.add_scalars('train/model-psnr', {
                    'ave': psnr,
                    'std': psnr_std
                }, cfg.global_step)
                writer.add_scalars('train/dn-frame-psnr', {
                    'ave': psnr_denoised_ave,
                    'std': psnr_denoised_std
                }, cfg.global_step)

            # -- reset loss --
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, motion)

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
コード例 #10
0
def main():

    #
    # -- init experiment --
    #

    cfg = edict()
    cfg.gpuid = 1
    cfg.noise_params = edict()
    cfg.noise_params.g = edict()
    # data = load_dataset(cfg)
    torch.manual_seed(143)  #131 = 80% vs 20%

    #
    # -- pick our noise --
    #

    # -- gaussian noise --
    # cfg.noise_type = 'g'
    # cfg.noise_params['g']['mean'] = 0.
    # cfg.noise_params['g']['stddev'] = 125.
    # cfg.noise_params.ntype = cfg.noise_type

    # -- poisson noise --
    cfg.noise_type = "pn"
    cfg.noise_params['pn'] = edict()
    cfg.noise_params['pn']['alpha'] = 1.0
    cfg.noise_params['pn']['std'] = 0.0
    cfg.noise_params.ntype = cfg.noise_type

    # -- low-light noise --
    # cfg.noise_type = "qis"
    # cfg.noise_params['qis'] = edict()
    # cfg.noise_params['qis']['alpha'] = 4.0
    # cfg.noise_params['qis']['readout'] = 0.0
    # cfg.noise_params['qis']['nbits'] = 3
    # cfg.noise_params['qis']['use_adc'] = True
    # cfg.noise_params.ntype = cfg.noise_type

    #
    # -- setup the dynamics --
    #

    cfg.nframes = 5
    cfg.frame_size = 350
    cfg.nblocks = 5
    T = cfg.nframes

    cfg.dynamic = edict()
    cfg.dynamic.frames = cfg.nframes
    cfg.dynamic.bool = True
    cfg.dynamic.ppf = 1
    cfg.dynamic.mode = "global"
    cfg.dynamic.random_eraser = False
    cfg.dynamic.frame_size = cfg.frame_size
    cfg.dynamic.total_pixels = cfg.dynamic.ppf * (cfg.nframes - 1)

    # -- setup noise and dynamics --
    noise_xform = get_noise_transform(cfg.noise_params, noise_only=True)

    def null(image):
        return image

    dynamics_xform = get_dynamic_transform(cfg.dynamic, null)

    # -- sample data --
    image_path = "./data/512-512-grayscale-image-Cameraman.png"
    image = Image.open(image_path).convert("RGB")
    image = image.crop((0, 0, cfg.frame_size, cfg.frame_size))
    clean, res, raw, flow = dynamics_xform(image)
    clean = clean[:, None]
    burst = noise_xform(clean + 0.5)
    flow = flow[None, :]
    reference = repeat(clean[[T // 2]], '1 b c h w -> t b c h w', t=T)
    print("Flow")
    print(flow)

    # -- our method --
    ref_frame = T // 2
    nblocks = cfg.nblocks
    method = "simple"
    noise_info = cfg.noise_params
    scores, aligned_simp, dacc_simp = lpas_search(burst, ref_frame, nblocks,
                                                  flow, method, clean,
                                                  noise_info)

    # -- split search --
    ref_frame = T // 2
    nblocks = cfg.nblocks
    method = "split"
    noise_info = cfg.noise_params
    scores, aligned_split, dacc_split = lpas_search(burst, ref_frame, nblocks,
                                                    flow, method, clean,
                                                    noise_info)

    # -- quantitative comparison --
    crop_size = 256
    image1, image2 = cc(aligned_simp, crop_size), cc(reference, crop_size)
    psnrs = images_to_psnrs(image1, image2)
    print("Aligned Simple Method: ", psnrs, dacc_simp.item())
    image1, image2 = cc(aligned_split, crop_size), cc(reference, crop_size)
    psnrs = images_to_psnrs(image1, image2)
    print("Aligned Split Method: ", psnrs, dacc_split.item())

    # -- compute noise 2 sim --
    # T,K = cfg.nframes,cfg.nframes
    # patchsize = 31
    # query = burst[[T//2]]
    # database = torch.cat([burst[:T//2],burst[T//2+1:]])
    # clean_db = clean
    # sim_outputs = compute_similar_bursts_analysis(cfg,query,database,clean_db,K,-1.,
    #                                               patchsize=patchsize,shuffle_k=False,
    #                                               kindex=None,only_middle=False,
    #                                               search_method="l2",db_level="burst")
    # sims,csims,wsims,b_dist,b_indx = sim_outputs

    # -- display images --
    print(aligned_simp.shape)
    print(aligned_split.shape)
    print_tensor_stats("aligned", aligned_simp)

    # print(csims.shape)
    save_image(burst, "lpas_demo_burst.png", [-0.5, 0.5])
    save_image(clean, "lpas_demo_clean.png")

    save_image(aligned_simp, "lpas_demo_aligned_simp.png")
    save_image(aligned_split, "lpas_demo_aligned_split.png")
    save_image(cc(aligned_simp, crop_size), "lpas_demo_aligned_simp_ccrop.png")
    save_image(cc(aligned_split, crop_size),
               "lpas_demo_aligned_split_ccrop.png")

    delta_full_simp = aligned_simp - aligned_simp[T // 2]
    delta_full_split = aligned_split - aligned_split[T // 2]
    save_image(delta_full_simp, "lpas_demo_aligned_full_delta_simp.png",
               [-0.5, 0.5])
    save_image(delta_full_split, "lpas_demo_aligned_full_delta_split.png",
               [-0.5, 0.5])

    delta_cc_simp = cc(delta_full_simp, crop_size)
    delta_cc_split = cc(delta_full_split, crop_size)
    save_image(delta_full_simp, "lpas_demo_aligned_cc_delta_simp.png")
    save_image(delta_full_split, "lpas_demo_aligned_cc_delta_split.png")

    top = 75
    size = 64
    simp = tvF.crop(aligned_simp, top, 200, size, size)
    split = tvF.crop(aligned_split, top, 200, size, size)
    print_tensor_stats("delta", simp)
    save_image(simp, "lpas_demo_aligned_simp_inspect.png")
    save_image(split, "lpas_demo_aligned_split_inspect.png")

    delta_simp = simp - simp[T // 2]
    delta_split = split - split[T // 2]
    print_tensor_stats("delta", delta_simp)
    save_image(delta_simp, "lpas_demo_aligned_simp_inspect_delta.png",
               [-1, 1.])
    save_image(delta_split, "lpas_demo_aligned_split_inspect_delta.png",
               [-1, 1.])
コード例 #11
0
def test_sim_search_attn_v2(cfg, clean, model):

    # -- init --
    N, B, C, H, W = clean.shape
    ps = cfg.byol_patchsize

    # -- unfold clean image --
    patches = model.patch_helper.prepare_burst_patches(clean)
    patches = patches.cuda(non_blocking=True)
    # R,N,B,L,C,H,W = patches.shape

    # -- start loop --
    psnrs = {}
    noisy_grid = create_noise_level_grid(cfg)
    for noise_params in noisy_grid:

        # -- setup noise xform --
        cfg.noise_type = noise_params.ntype
        cfg.noise_params.ntype = cfg.noise_type
        cfg.noise_params[cfg.noise_type] = noise_params
        noise_func = get_noise_transform(cfg.noise_params, use_to_tensor=False)

        # -- apply noise --
        noisy_patches = noise_func(
            patches)  # shape = (r n b nh_size^2 c ps_B ps_B)

        # -- create noisy img --
        f_mid = cfg.byol_nh_size**2 // 2
        p_mid = cfg.byol_patchsize // 2
        noisy_img = noisy_patches[:, :, :, f_mid, :, p_mid, p_mid]
        noisy_img = rearrange(noisy_img,
                              '(h w) n b c -> n b c h w',
                              h=cfg.frame_size)

        ftr_img = get_feature_image(cfg, noisy_patches, model, "attn")

        print("[ftr_img.shape]", ftr_img.shape)
        # print("[emd] PSNR: ",np.mean(images_to_psnrs(embeddings_0,embeddings_1)))
        # print("[ftr] PSNR: ",np.mean(images_to_psnrs(ftr_img_0,ftr_img_1)))

        # -- construct similar image --
        query = edict()
        query.pix = noisy_img[[0]]
        query.ftr = ftr_img[[0]]
        query.shape = query.pix.shape

        database = edict()
        database.pix = noisy_img[[1]]
        database.ftr = ftr_img[[1]]
        database.shape = database.pix.shape

        clean_db = edict()
        clean_db.pix = clean[[1]]
        clean_db.ftr = clean_db.pix
        clean_db.shape = clean_db.pix.shape

        sim_outputs = compute_similar_bursts_analysis(
            cfg,
            query,
            database,
            clean_db,
            1,
            patchsize=cfg.sim_patchsize,
            shuffle_k=False,
            kindex=None,
            only_middle=cfg.sim_only_middle,
            db_level='frame',
            search_method=cfg.sim_method,
            noise_level=None)

        # -- compute psnr --
        ref = clean[0]
        clean_sims = sim_outputs[1][0, :, 0]
        psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu())
        psnrs[noise_params.name] = edict()
        psnrs[noise_params.name].psnrs = psnrs_np
        psnrs[noise_params.name].ave = np.mean(psnrs_np)
        psnrs[noise_params.name].std = np.std(psnrs_np)
        psnrs[noise_params.name].min = np.min(psnrs_np)
        psnrs[noise_params.name].max = np.max(psnrs_np)
        # print(noise_params.name,psnrs[noise_params.name])

    return psnrs
コード例 #12
0
def test_sim_search_pix_v2(cfg, clean, model):

    # -- init --
    N, B, C, H, W = clean.shape
    cleanBN = rearrange(clean, 'n b c h w -> (b n) c h w')
    clean_pil = [
        tvT.ToPILImage()(cleanBN[i] + 0.5).convert("RGB") for i in range(B * N)
    ]
    ps = cfg.byol_patchsize
    unfold = nn.Unfold(ps, 1, 0, 1)

    # -- start loop --
    psnrs = {}
    noisy_grid = create_noise_level_grid(cfg)
    for noise_params in noisy_grid:

        # -- get noisy images --
        cfg.noise_type = noise_params.ntype
        cfg.noise_params.ntype = cfg.noise_type
        cfg.noise_params[cfg.noise_type] = noise_params
        noise_func = get_noise_transform(cfg.noise_params)
        noisyBN = torch.stack([noise_func(clean_pil[i]) for i in range(B * N)],
                              dim=0)
        noisy = rearrange(noisyBN, '(b n) c h w -> n b c h w', b=B)

        # -- construct similar image --
        query = edict()
        query.pix = noisy[[0]]
        query.ftr = noisy[[0]]
        query.shape = query.pix.shape

        database = edict()
        database.pix = noisy[[1]]
        database.ftr = noisy[[1]]
        database.shape = database.pix.shape

        clean_db = edict()
        clean_db.pix = clean[[1]]
        clean_db.ftr = clean_db.pix
        clean_db.shape = clean_db.pix.shape

        sim_outputs = compute_similar_bursts_analysis(
            cfg,
            query,
            database,
            clean_db,
            1,
            patchsize=cfg.sim_patchsize,
            shuffle_k=False,
            kindex=None,
            only_middle=cfg.sim_only_middle,
            db_level='frame',
            search_method=cfg.sim_method,
            noise_level=None)

        # -- compute psnr --
        ref = clean[0]
        clean_sims = sim_outputs[1][0, :, 0]
        psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu())
        psnrs[noise_params.name] = edict()
        psnrs[noise_params.name].psnrs = psnrs_np
        psnrs[noise_params.name].ave = np.mean(psnrs_np)
        psnrs[noise_params.name].std = np.std(psnrs_np)
        psnrs[noise_params.name].min = np.min(psnrs_np)
        psnrs[noise_params.name].max = np.max(psnrs_np)
        # print(noise_params.name,psnrs[noise_params.name])

    return psnrs
コード例 #13
0
def test_sim_search_ftr(cfg, clean, model, ftr_types):

    # -- init --
    N, B, C, H, W = clean.shape
    ps = cfg.byol_patchsize
    if clean.min() < 0: clean += 0.5  # non-negative pixels

    # -- unfold clean image --
    patches = model.patch_helper.prepare_burst_patches(clean)
    patches = patches.cuda(non_blocking=True)
    ps = cfg.byol_patchsize

    # shape = (r n b nh_size^2 c ps_B ps_B)

    # -- start loop --
    psnrs = edict({})
    for ftr_type in ftr_types:
        psnrs[ftr_type] = edict({})
    noisy_grid = create_noise_level_grid(cfg)
    with torch.no_grad():
        for noise_params in noisy_grid:

            # -- setup noise xform --
            cfg.noise_type = noise_params.ntype
            cfg.noise_params.ntype = cfg.noise_type
            cfg.noise_params[cfg.noise_type] = noise_params
            noise_func = get_noise_transform(cfg.noise_params, noise_only=True)

            # -- apply noise --
            noisy_patches = noise_func(
                patches)  # shape = (r n b nh_size^2 c ps_B ps_B)

            # -- create noisy img --
            noisy_img = get_pixel_features(cfg, noisy_patches)

            # -- get features --
            for ftype in ftr_types:
                ftr_img = get_feature_image(cfg, noisy_patches, model, ftype)

                # -- some debugging code --
                vis = False
                if vis:
                    vis_noisy_features(cfg, noisy_img, ftr_img, clean, ftype)

                testing_indexing = False
                if testing_indexing:
                    test_patch_helper_indexing(cfg, noisy_img, ftr_img, clean,
                                               ftype)

                # -- construct similar image --
                if ftype != "pix":
                    sim_patchsize = cfg.sim_patchsize
                    cfg.sim_patchsize = 1
                    psnrs_np = compute_similar_psnr(cfg, noisy_img, ftr_img,
                                                    clean)
                    cfg.sim_patchsize = sim_patchsize
                else:
                    psnrs_np = compute_similar_psnr(cfg, noisy_img, ftr_img,
                                                    clean)

                # -- compute psnr --
                psnrs[ftype][noise_params.name] = edict()
                psnrs[ftype][noise_params.name].psnrs = psnrs_np
                compute_psnrs_summary(psnrs[ftype][noise_params.name])
                # psnrs[ftype][noise_params.name].ave = np.mean(psnrs_np)
                # psnrs[ftype][noise_params.name].std = np.std(psnrs_np)
                # psnrs[ftype][noise_params.name].min = np.min(psnrs_np)
                # psnrs[ftype][noise_params.name].max = np.max(psnrs_np)
                del ftr_img

    return psnrs