コード例 #1
0
def plots_exp_v2():
    epochs = [0]
    skel = "/home/gauenk/Documents/experiments/cl_gen/output/n2n-kpn/dynamic_wmf/64_2_6/50000/nonblind/3/25.0/record_losses_26.csv"
    # skel = "/home/gauenk/Documents/experiments/cl_gen/output/n2n-kpn/dynamic_wmf/64_2_6/50000/nonblind/3/25.0/record_supOT_withweight_3.csv"
    # skel = "/home/gauenk/Documents/experiments/cl_gen/output/n2n-kpn/dynamic_wmf/64_2_6/50000/nonblind/3/25.0/record_withweight_15.csv"
    fps = [skel.format(epoch) for epoch in epochs]

    pds = [pd.read_csv(fp, index_col=0) for fp in fps]

    df = pds[0]
    df = df.reset_index(drop=True)
    # df = df.loc[:,~df.columns.str.contains('^psnr')]
    # df = df.loc[:,["ot_loss_rec_frame","ot_loss_rec_frame_w","ot_loss_raw_frame","ot_loss_raw_frame_w","kpn_loss","psnr_ave","psnr_std"]]
    df = df.loc[:, ["ot_loss_rec_frame", "ot_loss_raw_frame", "kpn_loss"]]
    print(df)
    cols = list(df.columns)
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_ylabel("loss")
    df.plot(ax=ax, logy=True)
    add_legend(ax, "Losses", cols)

    ax2 = ax.twinx()
    ax2.set_ylabel("psnr")
    psnr_ave = pds[0].loc[:, "psnr_ave"]
    ax2.plot(np.arange(len(psnr_ave)), psnr_ave, lw=2, color='black')
    cols += ['psnr']

    # pds[0].reset_index().plot(x="index",y="ot",ax=ax)
    # pds[0].reset_index().plot(x="index",y="kpn",ax=ax)
    # pds[0].reset_index().plot(x="index",y="psnr",ax=ax)

    # pds[0].reset_index().plot(x="index",y="kpn",'-+')
    plt.savefig("./sup_kpn_tracking.png", bbox_inches='tight', dpi=300)
コード例 #2
0
ファイル: misc.py プロジェクト: gauenk/cl_gen
def plot_frame_grid(results):
    """
    plot frames v.s. psnr with many "sigma" lines
    """
    fig,ax = plt.subplots(figsize=(10,8))
    results = results.sort_values("noise_level")
    blind_results = results[results['blind'] == 'blind']
    nonblind_results = results[results['blind'] == 'nonblind']
    b_legend_str,b_min_val,b_max_val = plot_groupby_noise_ax(blind_results,ax,'-x')
    plt.gca().set_prop_cycle(None)
    nb_legend_str,nb_min_val,nb_max_val = plot_groupby_noise_ax(nonblind_results,ax,'--+')
    min_val = np.min([b_min_val,nb_min_val])
    max_val = np.max([b_max_val,nb_max_val])
    legend_str = [x+'-n' for x in b_legend_str] + [x + '-c' for x in nb_legend_str]

    
    frame_levels = np.sort(results['frames'].unique())
    frame_names = ["{:d}".format(int(y)) for y in frame_levels]
    ax.set_xticks(np.log(frame_levels))
    ax.set_xticklabels(frame_names,fontsize=13)
    
    ytick_locs = np.linspace(min_val,max_val,4)
    ytick_names = ["{:2d}".format(int(y)) for y in ytick_locs]
    ax.set_yticks(ytick_locs)
    ax.set_yticklabels(ytick_names,fontsize=13)

    ax.set_xlabel("log(N)",fontsize=13)
    ax.set_ylabel("PSNR",fontsize=13)
    ax.set_title("PSNR Across Frame Count",fontsize=15)
    add_legend(ax,"Noise Level",legend_str,shrink = True,fontsize=12,
               framealpha=1.0,ncol=2,shrink_perc=.80)
    plt.savefig(f"{settings.ROOT_PATH}/output/n2n/frame_grid.png")
コード例 #3
0
ファイル: vis.py プロジェクト: gauenk/cl_gen
def plot_ave_nframes_scoreacc_noisetype(record):
    
    nframes = 3
    nblocks = 5
    ppf = 1.0
    score_function = ['ave','pairwise','refcmp']
    noise_type_codes,noise_type_unique = pd.factorize(record['noise_type'])
    record['noise_type_codes'] = noise_type_codes

    # ['score_function','patchsize','noise_type','nframes','nblocks','ppf']

    record = record[record['nblocks'] == nblocks]
    record = record[np.isclose(record['ppf'].to_numpy(),ppf)]
    record = record[record['score_function'] == 'ave']
    record = record[record['patchsize'] == 13]
    

    fig,ax = plt.subplots(2,figsize=(8,8))
    labels = []
    for nframes,df in record.groupby('nframes'):

        # -- compute acc --

        tcodes,tnoises,noise_acc,clean_acc = [],[],[],[]
        for tnoise,noise_df in df.groupby("noise_type"):
            noise_index = noise_df['noisy_best_block']
            clean_index = noise_df['clean_best_block']
            align_index = noise_df['align_best_block']
            tcode = noise_df['noise_type_codes'].to_numpy()[0]
            noise_acc.append(np.mean(noise_index == align_index))
            clean_acc.append(np.mean(clean_index == align_index))
            tnoises.append(tnoise)
            tcodes.append(tcode)

        # -- plot --
        ax[0].plot(tcodes,noise_acc,label=f'{int(nframes)}')
        ax[1].plot(tcodes,clean_acc,label=f'{int(nframes)}')

        ax[0].set_xticks(tcodes)
        ax[0].set_xticklabels(tnoises)

        ax[1].set_xticks(tcodes)
        ax[1].set_xticklabels(tnoises)
        labels.append(str(int(nframes)))

    add_legend(ax[0],"Frames",labels)
    add_legend(ax[1],"Frames",labels)
    plot_fn = ROOT_PATH_VIS / "ave_nframes_scoreacc_noisetype.png"
    print(f"Saving figure to path [{plot_fn}]")
    plt.savefig(plot_fn)
    plt.close("all")
コード例 #4
0
def main():

    exps = get_exps_a()
    nexps = len(exps)
    fig,ax = plt.subplots(1,1,figsize=(8,8))
    lnames = []
    for name,info in exps.items():
        print(name)
        results = get_exp_results(info)
        epochs,losses = zip(*results.items())
        ax.plot(epochs,losses,'o-',label=name,alpha=0.5)
        # epochs,means = zip(*results['means'].items())
        # epochs,stderrs = zip(*results['stderrs'].items())
        # ax.errorbar(epochs,means,yerr=stderrs,'o-',label=name,alpha=0.5)
        lnames.append(name)
    add_legend(ax,"model",lnames)
    plt.savefig("exp_plot_test_results_a.png")

    plt.cla()
    plt.clf()
コード例 #5
0
def plot_noise_grid(results):
    """
    plot sigma v.s. psnr with many "frames" lines
    """
    fig, ax = plt.subplots(figsize=(8, 8))

    min_val, max_val = 1000, -1
    legend_str = []
    for frames, frames_df in results.groupby('frames'):
        frames_df = frames_df.sort_values('noise_level')
        noise_levels = frames_df['noise_level'].to_numpy()
        psnr = frames_df['psnr'].to_numpy()
        ax.plot(noise_levels, psnr)
        legend_str.append(str(int(frames)))
        print(psnr.min(), psnr.max())
        if min_val > psnr.min():
            min_val = psnr.min()
        if max_val < psnr.max():
            max_val = psnr.max()

    noise_levels = np.sort(results['noise_level'].unique())
    noise_names = ["{:d}".format(int(y)) for y in noise_levels]
    # noise_names[0] = 'msg'
    ax.set_xticks(noise_levels)
    ax.set_xticklabels(noise_names, fontsize=13)

    ytick_locs = np.linspace(min_val, max_val, 4)
    ytick_names = ["{:2d}".format(int(y)) for y in ytick_locs]
    ax.set_yticks(ytick_locs)
    ax.set_yticklabels(ytick_names, fontsize=13)

    ax.set_xlabel("sigma", fontsize=13)
    ax.set_ylabel("PSNR", fontsize=13)
    ax.set_title("PSNR Across Noise Level", fontsize=15)
    add_legend(ax,
               "Frame Count",
               legend_str,
               shrink=True,
               fontsize=13,
               framealpha=1.0)
    plt.savefig(f"{settings.ROOT_PATH}/output/n2n/noise_grid.png")
コード例 #6
0
def plot_histogram_residual_burst(burst,filename,global_step):

    # -- init --
    burst = rearrange(burst,'n c h w -> n (c h w)')
    N,D = burst.shape
    labels = [str(i) for i in range(N)]
    fig,ax = plt.subplots()
    
    # -- frame index --
    amin,amax = burst.min(),burst.max()
    for frame_index in range(N):
        frame = burst[frame_index]
        ax.hist(frame,label=f"{frame_index}",bins=30,range=(amin,amax),alpha=0.5,lw=1,edgecolor='k')
    ax = add_legend(ax,"Frame Index",labels)
    ax.set_title(f"Residuals at Iteration [{global_step}]")

    # -- save filename --
    plt.savefig(filename,dpi=300,bbox_inches='tight')
    plt.close("all")
コード例 #7
0
ファイル: plot_test_results.py プロジェクト: gauenk/cl_gen
def plot_x_sigma(df):
    gb_groups = df.groupby('agg_enc_fxn')
    fig, ax = plt.subplots(2, 2, figsize=(10, 10))
    ax_titles = []
    plt_fmt = {2: '-+r', 4: '-+g', 8: '-+b', 12: '-+k'}
    for i, (agg_enc_fxn, group_i) in enumerate(gb_groups):
        # different plot
        for j, (hyper_h, group_j) in enumerate(group_i.groupby('hyper_h')):
            # different plot
            labels = []
            for k, (N, group_k) in enumerate(group_j.groupby('N')):
                # same plot
                group_k = group_k.sort_values('sigma')
                data = group_k[['sigma', 'te_acc']].to_numpy()
                ax[i, j].plot(data[:, 0], data[:, 1], plt_fmt[N])
                labels.append(str(N))
            ax[i, j].set_title(
                f"agg_enc_fxn: {agg_enc_fxn} | CL Loss Coefficient: {hyper_h}")
            ax[i, j].set_ylabel("PSNR")
            ax[i, j].set_xlabel("sigma")
            ax[i, j] = add_legend(ax[i, j], "N", labels, fontsize=12)
    plt.savefig("./tmp.png")
コード例 #8
0
ファイル: noise_level.py プロジェクト: gauenk/cl_gen
def main():
    print("HI")

    frames = [1, 2, 3, 5, 8, 10, 15, 20]
    # frames = [1,2,3,10,15,30,50]
    noise_levels = [1, 3, 5, 10, 15, 25, 50, 75, 100]
    size = 256 * 256
    repeats = 5
    ave_psnrs, std_psnrs = [], []
    df = pd.DataFrame({"ave": [], "std": [], "noise_level": [], "frames": []})
    for n_frames in frames:
        for noise_level in noise_levels:
            psnrs = []
            for r in range(repeats):
                psnrs.append(
                    compute_sample_psnr(noise_level / 255., n_frames, size))
            ave_psnr = np.mean(psnrs)
            std_psnr = np.std(psnrs)
            df = df.append(
                {
                    "ave": ave_psnr,
                    "std": std_psnr,
                    "noise_level": noise_level,
                    "frames": n_frames
                }, True)

    print(df)

    # -- plot data: frame label --
    fig, ax = plt.subplots(figsize=(8, 8))
    frames = []
    for n_frames, df_nl in df.groupby("frames"):
        frames.append(str(n_frames))
        nl = df_nl['noise_level']
        ave = df_nl['ave']
        std = df_nl['std']
        ax.errorbar(nl, ave, yerr=std, label=f"{n_frames}")
    # plt.errorbar(noise_levels,ave_psnrs,yerr=std_psnrs,label="N=1")
    add_legend(ax, "Frames", frames, fontsize=12)
    ax.set_ylabel("PSNR", fontsize=12)
    ax.set_xlabel("Noise Level", fontsize=12)
    ax.set_title("Averaging Gaussian Samples")
    plt.savefig("./output/awgn_noise_levels.png", dpi=300)
    plt.clf()
    plt.cla()
    plt.close("all")

    # -- plot data: noise level label --
    fig, ax = plt.subplots(figsize=(8, 8))
    nls = []
    for nl, df_frames in df.groupby("noise_level"):
        nls.append(str(nl))
        frames = df_frames['frames']
        ave = df_frames['ave']
        std = df_frames['std']
        ax.errorbar(frames, ave, yerr=std)
    add_legend(ax, "NoiseLevel", nls, fontsize=12)
    ax.set_ylabel("PSNR", fontsize=12)
    ax.set_xlabel("Frames", fontsize=12)
    ax.set_title("Averaging Gaussian Samples")
    plt.savefig("./output/awgn_frames.png", dpi=300)
コード例 #9
0
def explore_cog_record(cfg, record, bss_dir=None):
    REF_H = get_ref_block_index(cfg.nblocks)
    # cfg.nblocks,cfg.nframes = 5,7

    # -- load block search space --
    tcount = 3
    size = 30
    bss = get_small_test_block_arangements(bss_dir,
                                           cfg.nblocks,
                                           cfg.nframes,
                                           tcount,
                                           size,
                                           difficult=True)
    block_search_space = bss
    # if block_search_space_fn:
    #     block_search_space = np.load(block_search_space_fn,allow_pickle=True)
    # else:
    #     block_search_space = get_block_arangements(cfg.nblocks,cfg.nframes)

    # -- print un-modified results --
    for ridx, field in enumerate(record.columns):
        if field == "Unnamed: 0": continue
        # -- printing best index and score --
        index = np.argmin(record[field])
        search = block_search_space[index]
        if torch.is_tensor(search): search = search.numpy()
        max_index = np.argmax(record[field])
        max_search = block_search_space[max_index]
        max_value = record[field][max_index]

        min_index = np.argmin(record[field])
        min_search = block_search_space[min_index]
        min_value = record[field][min_index]

        print(field)
        print("\t", max_search, max_index, max_value)
        print("\t", min_search, min_index, min_value)

    # -- prepare data for plotting --
    P = len(record)
    pgrid = np.arange(P)
    psnr_clean = jitter(zo_nmlz(record['psnr_clean'].to_numpy()))
    psnr_noisy = jitter(zo_nmlz(record['psnr_noisy'].to_numpy()))
    psnr_rec = jitter(zo_nmlz(record['psnr_rec'].to_numpy()))
    cog = jitter(zo_nmlz(record['cog'].to_numpy()))
    ave = jitter(zo_nmlz(record['ave'].to_numpy()))

    # -- plot --
    naxs, K = 2, 3
    rm_fields = ["psnr_clean", "cog", "ave", "psnr_noisy", "psnr_rec"]
    ncols = len(record.columns) - len(rm_fields)
    fig, ax = plt.subplots(naxs, 1, figsize=(12, 10))
    plt.subplots_adjust(hspace=0.3)
    ax[0].plot(pgrid, psnr_clean, 'x', label='clean')
    ax[0].plot(pgrid, psnr_noisy, 'x', label='noisy')
    ax[0].plot(pgrid, psnr_rec, '+', label='rec')
    ax[0].plot(pgrid, cog, '+', label='cog')
    ax[0].plot(pgrid, ave, '+', label='ave')
    add_legend(ax[0], "cmp", ['clean', 'noisy', 'rec', 'cog', 'ave'])
    idx, ax_idx, ax_mod = 0, 1, 100  #ncols//(naxs-1)+1
    labels = [[] for _ in range(naxs - 1)]
    print(ax_mod)
    for ridx, field in enumerate(record.columns):

        # -- plotting search type --
        if field in rm_fields: continue
        search = re.match("[0-9]f", field[-2:])
        if search is not None: continue

        scores = record[field].to_numpy()
        if not isinstance(scores[0], float): continue
        scores = jitter(zo_nmlz(scores))
        # ax[ax_idx].plot(pgrid,scores,'x-',alpha=0.8)
        extrema, topK, bottomK = extrema_scores(scores, K)
        print(idx)
        if idx < 8: mrk = 'x'
        else: mrk = '+'
        if "fu_" in field or "fnet_" in field: mrk += '-'
        vorder = np.argsort(extrema.values)
        print(np.abs(extrema.values[vorder[0]] - extrema.values[vorder[1]]))
        iorder = np.argsort(extrema.grid)
        grid = extrema.grid[iorder]
        values = extrema.values[iorder]

        ax[ax_idx].plot(grid, values, mrk, alpha=0.8, label=field)
        if "fu_" in field:
            lfield = field.replace("fu_", "fnet_")
        else:
            lfield = field
        labels[ax_idx - 1].append(lfield)
        idx += 1
        if idx % ax_mod == 0:
            idx = 0
            ax_idx += 1
            print(idx, ax_idx, ax_mod)
    for i in range(naxs - 1):
        add_legend(ax[i + 1], "cmp", labels[i])

    plot_dir = Path("./output/lpas/cog/")
    if not plot_dir.exists(): plot_dir.mkdir()
    plt.savefig(plot_dir / Path("cog_psnrs_extrema.png"))
コード例 #10
0
ファイル: plot_ablation.py プロジェクト: gauenk/cl_gen
def main():
    exps = get_exps()

    gpuid = 0
    #epoch_grid = [-1,0,5,10,15,20,25,30,35,40,45,50,60,70,80,90,100,125,150,175,200,225]
    # epoch_grid = [10,15,20,25,30,35,40,45,50,60,70,80,90,100]
    epoch_grid = [5, 25, 50, 75, 100, 150, 200, 250, 300]
    dataset = "MNIST"
    use_psnr = False

    # get results
    results = {}
    for name, fields in exps.items():
        if len(name) == 0: continue

        results[name] = edict()
        results[name].means = {}
        results[name].stderrs = {}
        results[name].te_losses = {}

        for epoch_num in epoch_grid:
            losses = load_cached_results(name, epoch_num, use_psnr=use_psnr)
            if losses is not None:
                sepoch = str(epoch_num)
                results[name].means[sepoch] = losses.mean
                results[name].stderrs[sepoch] = losses.stderr
                results[name].te_losses[sepoch] = losses.te_losses
                continue

            cfg = get_exp_cfg(name, gpuid, epoch_num, dataset, fields)
            valid = check_valid(cfg, epoch_num)
            if not valid: continue

            losses = test_disent(cfg, use_psnr=use_psnr)
            sepoch = str(epoch_num)
            results[name].means[sepoch] = losses.mean
            results[name].stderrs[sepoch] = losses.stderr
            results[name].te_losses[sepoch] = losses.te_losses

            save_cache_results(losses, name, epoch_num, use_psnr)

        # report best loss for each experiment
        print(name, results[name].means.items())
        egrid, means = zip(*results[name].means.items())
        if use_psnr:
            idx = np.argmax(means)
            rstr = "Exp {:s} best psnr of {:2.3e} at epoch {:s}"
        else:
            idx = np.argmin(means)
            rstr = "Exp {:s} best test loss of {:2.3e} at epoch {:s}"
        fmt = (name, means[idx], egrid[idx])
        print(rstr.format(*fmt))

    # plot results
    fig, ax = plt.subplots(figsize=(8, 8))
    names = []
    for name, result in results.items():
        egrid, means = zip(*result.means.items())
        egrid, stderrs = zip(*result.stderrs.items())
        grid = [int(e) for e in egrid]
        ax.errorbar(grid, means, yerr=stderrs, fmt='+-', label=name, alpha=1.0)
        names.append(name[0:5])
    # plot_noise_level(ax,epoch_grid)
    # names.append("noise")
    ax.set_title("Ablation Experiments: Testing Losses")
    add_legend(ax, "Experiment", names, framealpha=0.0)
    ax.set_yscale("log", nonpositive="clip")
    path = f"{settings.ROOT_PATH}/reports/plot_ablation.png"
    print(f"Saving plot to {path}")
    plt.savefig(path, transparent=True, dpi=300)
コード例 #11
0
ファイル: explore_fast_unet.py プロジェクト: gauenk/cl_gen
def explore_fast_unet_record(cfg, record, block_search_space_fn=None):
    REF_H = get_ref_block_index(3)
    cfg.nblocks, cfg.nframes = 5, 7

    # -- load block search space --
    if block_search_space_fn:
        block_search_space = np.load(block_search_space_fn, allow_pickle=True)
    else:
        block_search_space = get_block_arangements(cfg.nblocks, cfg.nframes)

    # -- prepare data for plotting --
    P = len(record)
    mse = record['mse'].to_numpy()
    psnrs_clean = record['psnr_clean'].to_numpy()
    psnrs_bc_v1 = record['psnr_bc_v1'].to_numpy()
    psnrs_rec = record['psnr_rec'].to_numpy()
    psnrs_burst = record['psnr_burst'].to_numpy()
    psnrs_noisy = record['psnr_noisy'].to_numpy()
    fnet_ave = record['fnet_ave'].to_numpy()
    ave = record['ave'].to_numpy()
    fnet_lvi = record['fnet_lgsubset_v_indices'].to_numpy()
    lvi = record['lgsubset_v_indices'].to_numpy()
    # pii = None
    pii = record['psnr_intra_input'].to_numpy()
    got = record['fnet_gaussian_ot'].to_numpy()
    params_nm = record['params_norm_mean'].to_numpy()
    # trace_norm = None
    trace_norm = record['params_norm_mean'].to_numpy()

    # -- rescale for plotting --
    mse -= mse.min()
    mse /= mse.max()
    mse *= psnrs_rec.max()

    pii -= pii.min()
    pii /= pii.max()
    pii *= ave.max()

    params_nm -= params_nm.min()
    params_nm /= params_nm.max()
    params_nm *= ave.max()

    # -- rescale for plotting --
    # got /= got.max()
    # got *= ave.max()

    # -- plot --
    fig, ax = plt.subplots(3, 1, figsize=(8, 8))
    ax[0].plot(np.arange(P), psnrs_clean, 'x', label='clean')
    ax[1].plot(np.arange(P), psnrs_bc_v1, '+', label='bc')
    ax[1].plot(np.arange(P), psnrs_rec, '+', label='rec')
    # ax[1].plot(np.arange(P),mse,'+',label='mse')
    ax[1].plot(np.arange(P), psnrs_burst, '+', label='burst')
    ax[1].plot(np.arange(P), psnrs_noisy, '+', label='noisy')
    labels_2 = []
    ax[2].plot(np.arange(P), params_nm, '*', label='params_nm')
    labels_2.append('params_nm')
    if not (trace_norm is None):
        ax[2].plot(np.arange(P), trace_norm, '*', label='trace_norm')
        labels_2.append('trace_norm')
    ax[2].plot(np.arange(P), ave, '+', label='ave')
    labels_2.append('ave')
    ax[2].plot(np.arange(P), fnet_ave, '+', label='fnet_ave')
    labels_2.append('fnet_ave')
    ax[2].plot(np.arange(P), lvi, '^', label='lvi')
    labels_2.append('lvi')
    ax[2].plot(np.arange(P), fnet_lvi, '^', label='fnet_lvi')
    labels_2.append('fnet_lvi')
    ax[2].plot(np.arange(P), got, 'x', label='got')
    labels_2.append('got')
    if not (pii is None):
        ax[2].plot(np.arange(P), pii, 'x', label='pii')
        labels_2.append('pii')
    add_legend(ax[1], "cmp", ['bc_v1', 'rec', 'burst', 'noisy'])
    add_legend(ax[2], "srch_fxn", labels_2)
    plt.savefig("./output/lpas/fast_unet/psnrs.png")

    for field in record.columns:
        index = np.argmin(record[field])
        search = block_search_space[index]
        print(field, search, index)
        index = np.argmax(record[field])
        search = block_search_space[index]
        print(field, search, index)
コード例 #12
0
ファイル: ot_loss.py プロジェクト: gauenk/cl_gen
def run_ot_v_displacement(cfg, criterion, train_loader):

    reg = 1.0
    for batch_idx, (burst, residuals, raw_img,
                    directions) in enumerate(train_loader):

        # -- init --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)
        raw_img = raw_img.cuda(non_blocking=True) - 0.5

        # -- compute residuals --
        misaligned_residual = rearrange(
            burst - raw_img.unsqueeze(0).repeat(N, 1, 1, 1, 1),
            'n b c h w -> b n (h w) c')

        # -- save img --
        res = burst - raw_img.unsqueeze(0).repeat(N, 1, 1, 1, 1)
        res = rearrange(torch.cat([res, raw_img.unsqueeze(0)], dim=0),
                        'n b c h w -> b n c h w')
        burst_r = rearrange(torch.cat([burst, raw_img.unsqueeze(0)], dim=0),
                            'n b c h w -> b n c h w')
        img = rearrange(burst_r, 'b n c h w -> (b n) c h w')
        # img = rearrange(torch.cat([burst_r,res],dim=1),'b n c h w -> (b n) c h w')
        fn = f"./test_ot_v_displacement_ex_{batch_idx}.png"
        tv_utils.save_image(img, fn, nrow=B, normalize=True)
        print(f"Wrote images to [{fn}]")

        # print("25/255 = {:.2e}".format(25/255))
        # for b in range(B):
        #     for n in range(N):
        #         print("({:d},{:d}): {:.2e} +/- {:.2e}".format(b,n,misaligned_residual[b,n].mean(),misaligned_residual[b,n].std()))

        # -- compute ot for pairwise of misaligned images on sides --
        reg = 0.25
        losses = pd.DataFrame({
            'd': [],
            'ot_ave': [],
            'ot_std': [],
            'kl_ave': [],
            'kl_std': []
        })
        ot_losses = np.zeros((B, N // 2 + N % 2))
        kl_losses = np.zeros((B, N // 2 + N % 2))
        for n in range(N // 2 + N % 2):
            b_ot_losses, b_kl_losses = [], []
            for b in range(B):
                #ot_loss = compute_pair_ot(misaligned_residual[b,n],misaligned_residual[b,N//2],reg).item()
                a = misaligned_residual[b, n]
                noise = torch.zeros_like(a)
                # noise = torch.normal(torch.zeros_like(a),std=0.001/255)
                ot_loss = compute_pair_ot(misaligned_residual[b, n], noise,
                                          reg).item()
                kl_loss = compute_binned_kl(misaligned_residual[b, n],
                                            misaligned_residual[b, N //
                                                                2]).item()
                ot_losses[b, n] = ot_loss
                kl_losses[b, n] = kl_loss
                b_ot_losses.append(ot_loss), b_kl_losses.append(kl_loss)
            append_loss = {
                'd': int(cfg.dynamic.ppf * (N // 2 - n)),
                'ot_ave': np.mean(b_ot_losses),
                'ot_std': np.std(b_ot_losses),
                'kl_ave': np.mean(b_kl_losses),
                'kl_std': np.std(b_kl_losses),
            }
            losses = losses.append(append_loss, ignore_index=True)

        print(losses)
        # -- plot losses vs. displacement --
        fig, ax = plt.subplots()
        ax.errorbar(losses['d'],
                    losses['ot_ave'],
                    yerr=losses['ot_std'],
                    label='ot')
        ax.errorbar(losses['d'],
                    losses['kl_ave'],
                    yerr=losses['kl_std'],
                    label='kl')
        ax.set_title("Verifying Impact of Displacement")
        ax.set_xlabel("Number of Pixels")
        ax.set_ylabel("Loss")
        ax = add_legend(ax, 'Type', ['ot', 'kl'])
        fn = f"test_ot_v_displacement_plot_{batch_idx}.png"
        plt.savefig(fn, dpi=300)
        plt.close("all")
        print(f"Wrote ot loss image to [{fn}]")

        # -- plot top K losses vs displacement --
        K = 20

        b_ot_losses = np.sum(ot_losses, 1)
        ot_args = np.argsort(-b_ot_losses)[:K]
        ot_args_bk = np.argsort(b_ot_losses)[:K]
        ot_frame_ave = np.mean(ot_losses[ot_args, :], 0)
        ot_frame_std = np.std(ot_losses[ot_args, :], 0)

        b_kl_losses = np.sum(kl_losses, 1)
        kl_args = np.argsort(-b_kl_losses)[:K]
        kl_args_bk = np.argsort(b_kl_losses)[:K]
        kl_frame_ave = np.mean(kl_losses[kl_args, :], 0)
        kl_frame_std = np.std(kl_losses[kl_args, :], 0)

        print(ot_frame_ave.shape)
        print(ot_frame_std.shape)
        print(losses['d'])

        fig, ax = plt.subplots()
        ax.errorbar(losses['d'], ot_frame_ave, yerr=ot_frame_std, label='ot')
        ax.errorbar(losses['d'], kl_frame_ave, yerr=kl_frame_std, label='kl')
        ax.set_title("Verifying Impact of Displacement (Top-K)")
        ax.set_xlabel("Number of Pixels")
        ax.set_ylabel("Loss")
        ax = add_legend(ax, 'Type', ['ot', 'kl'])
        fn = f"test_ot_v_displacement_plot_topk_{batch_idx}.png"
        plt.savefig(fn, dpi=300)
        plt.close("all")
        print(f"Wrote plot file {fn}")

        # -- show the top K images for ot and kl --

        fn = f"./ot_args_test_displacement_topk_{batch_idx}.png"
        img = rearrange(burst_r[ot_args], 'b n c h w -> (b n) c h w')
        tv_utils.save_image(img, fn, nrow=N + 1, normalize=True)
        print(f"Wrote example ot images file {fn}")

        fn = f"./ot_args_test_displacement_bottomk_{batch_idx}.png"
        img = rearrange(burst_r[ot_args_bk], 'b n c h w -> (b n) c h w')
        tv_utils.save_image(img, fn, nrow=N + 1, normalize=True)
        print(f"Wrote example ot images file {fn}")

        fn = f"./kl_args_test_displacement_topk_{batch_idx}.png"
        img = rearrange(burst_r[kl_args], 'b n c h w -> (b n) c h w')
        tv_utils.save_image(img, fn, nrow=N + 1, normalize=True)
        print(f"Wrote example kl images file {fn}")

        fn = f"./kl_args_test_displacement_bottomk_{batch_idx}.png"
        img = rearrange(burst_r[kl_args_bk], 'b n c h w -> (b n) c h w')
        tv_utils.save_image(img, fn, nrow=N + 1, normalize=True)
        print(f"Wrote example kl images file {fn}")

        # -- compute ot for pairwise of misaligned images to middle --
        g1 = torch.normal(torch.zeros((64 * 64 * 3)), 25 / 255).view(-1, 3)
        g2 = torch.normal(torch.zeros((64 * 64 * 3)), 25 / 255).view(-1, 3)
        g1, g2 = g1.cuda(non_blocking=True), g2.cuda(non_blocking=True)
        ot_gg = compute_pair_ot(g1, g2, reg)

        print("[g_g: %.2e]" % (ot_gg))
        """
        compare 
        - iid gaussian and actual residuals [baseline]
        - wrong residual (misaligned - raw_img) v.s. another wrong residual
        - wrong residual (misaligned - raw_img) v.s. middle residual (burst[N//2] - raw_img)

        this looks like its say the noise on the input middle frame is not iid gaussian noise
        """

        # -- compute ot loss to optimize --
        # residuals = aligned - rec_img.unsqueeze(1).repeat(1,N,1,1,1)
        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # ot_loss = ot_frame_pairwise_bp(residuals,reg=1.0,K=5)
        # ot_coeff = 1 - .997**cfg.global_step

        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record