예제 #1
0
파일: main.py 프로젝트: gauenk/cl_gen
def run_me(tr_method,
           rank=0,
           Sgrid=[1],
           Ngrid=[3],
           nNgrid=1,
           Ggrid=[75.],
           nGgrid=1,
           ngpus=3,
           idx=0):

    args = get_args()
    args.name = "default"
    cfg = get_cfg(args)
    cfg.use_ddp = False
    cfg.use_apex = False
    cfg.global_steps = 0
    gpuid = rank % ngpus  # set gpuid
    gpuid = 1
    cfg.gpuid = gpuid
    cfg.device = f"cuda:{gpuid}"

    grid_idx = idx * (1 * ngpus) + rank
    B_grid_idx = (grid_idx % 2)
    N_grid_idx = (grid_idx // 2) % nNgrid
    G_grid_idx = grid_idx // (nNgrid * 2) % nGgrid
    S_grid_idx = grid_idx // (nGgrid * nNgrid * 2)

    # -- force blind --
    B_grid_idx = 0

    # -- config settings --
    cfg.use_collate = True
    # cfg.dataset.download = False
    # cfg.cls = cfg
    cfg.S = Sgrid[S_grid_idx]
    # cfg.dataset.name = "cifar10"
    cfg.dataset.name = "voc"
    # cfg.dataset.name = "rebel2021"
    cfg.blind = (B_grid_idx == 0)
    cfg.blind = False
    cfg.N = Ngrid[N_grid_idx]
    cfg.N = 5
    cfg.use_anscombe = True

    # -- noise 2 simulate parameters --
    cfg.sim_shuffleK = True
    cfg.sim_method = "l2"
    cfg.sim_K = 1
    cfg.sim_patchsize = 9
    # cfg.N = 30
    cfg.num_workers = 2
    cfg.dynamic.frames = cfg.N

    # -- gaussian noise --
    noise_type = "g"
    cfg.noise_type = noise_type
    cfg.noise_params['g']['stddev'] = Ggrid[G_grid_idx]
    cfg.noise_params.ntype = cfg.noise_type
    noise_params = cfg.noise_params['g']
    noise_level = Ggrid[G_grid_idx]
    noise_level_str = f"{int(noise_level)}"

    # -- heteroskedastic gaussian noise --
    # noise_type = "hg"
    # cfg.noise_type = noise_type
    # cfg.noise_params['hg']['read'] = Ggrid[G_grid_idx]
    # cfg.noise_params['hg']['shot'] = 25.
    # noise_params = cfg.noise_params['hg']
    # cfg.noise_params.ntype = cfg.noise_type
    # noise_level = Ggrid[G_grid_idx]
    # noise_level_str = f"{int(noise_params['read']),int(noise_params['shot'])}"

    # -- low-light noise --
    # noise_type = "qis"
    # cfg.noise_type = noise_type
    # cfg.noise_params['qis']['alpha'] = 4.0
    # cfg.noise_params['qis']['readout'] = 0.0
    # cfg.noise_params['qis']['nbits'] = 3
    # noise_params = cfg.noise_params['qis']
    # cfg.noise_params.ntype = cfg.noise_type
    # noise_level = noise_params['readout']
    # noise_level_str = f"{int(noise_params['alpha']),int(noise_params['readout']),int(noise_params['nbits'])}"

    # -- batch info --
    cfg.batch_size = 10
    cfg.init_lr = 1e-4
    cfg.unet_channels = 3
    cfg.input_N = cfg.N - 1
    cfg.epochs = 50
    cfg.color_cat = True
    cfg.log_interval = int(int(50000 / cfg.batch_size) / 100)
    cfg.dynamic.bool = True
    cfg.dynamic.ppf = 0
    cfg.dynamic.random_eraser = False
    cfg.dynamic.frame_size = 128
    cfg.dynamic.total_pixels = cfg.dynamic.ppf * cfg.N
    cfg.dynamic.nframe = cfg.N
    # cfg.dynamic.total_pixels = 6
    cfg.load = False
    print(
        "This is to benchmark a new distribution class. We want to find the case when mL2 fails."
    )

    cfg.input_noise = False
    cfg.input_noise_middle_only = False
    cfg.input_with_middle_frame = True

    cfg.middle_frame_random_erase = False
    cfg.input_noise_level = noise_level / 255.
    if (cfg.blind == 0):  # e.g. supervised is true
        cfg.input_with_middle_frame = True
    if cfg.input_with_middle_frame:
        cfg.input_N = cfg.N

    blind = "blind" if cfg.blind else "nonblind"
    print(grid_idx, blind, cfg.N, Ggrid[G_grid_idx], gpuid)

    # if blind == "nonblind": return
    dynamic_str = "dynamic_input_noise" if cfg.input_noise else "dynamic"
    if cfg.input_noise_middle_only: dynamic_str += "_mo"
    if cfg.input_with_middle_frame: dynamic_str += "_wmf"
    postfix = Path(
        f"./{dynamic_str}/{cfg.dynamic.frame_size}_{cfg.dynamic.ppf}_{cfg.dynamic.total_pixels}/{cfg.S}/{blind}/{cfg.N}/{noise_level}/"
    )
    print(postfix, cfg.dynamic.total_pixels)
    cfg.model_path = cfg.model_path / postfix
    cfg.optim_path = cfg.optim_path / postfix
    if not cfg.model_path.exists(): cfg.model_path.mkdir(parents=True)
    if not cfg.optim_path.exists(): cfg.optim_path.mkdir(parents=True)

    checkpoint = cfg.model_path / Path("checkpoint_{}.tar".format(cfg.epochs))
    # if checkpoint.exists(): return
    print(
        f"Sim Method: {cfg.sim_method} | Shuffle K {cfg.sim_shuffleK} | Sim K: {cfg.sim_K} | Patchsize: {cfg.sim_patchsize}"
    )
    print(f"Ascombe Transform: {cfg.use_anscombe}")
    print("N: {} | Noise Level: {} | Noise Type: {}".format(
        cfg.N, noise_level_str, noise_type))
    print("PID: {}".format(os.getpid()))
    torch.cuda.set_device(gpuid)

    # load model
    torch.manual_seed(cfg.seed)
    model = load_model(cfg)
    optimizer = load_optimizer(cfg, model)
    scheduler = load_scheduler(cfg, model, optimizer)
    nparams = count_parameters(model)
    print("Number of Trainable Parameters: {}".format(nparams))

    # load data
    # data,loader = load_dataset(cfg,'denoising')
    data, loader = load_dataset(cfg, 'dynamic')
    # data,loader = simulate_noisy_dataset(data,loaders,M,N)

    # load criterion
    criterion = nn.BCELoss()

    if cfg.load:
        fp = cfg.model_path / Path("checkpoint_30.tar")
        model = load_model_fp(cfg, model, fp, 0)

    cfg.current_epoch = 0
    te_ave_psnr = {}
    test_before = False
    if test_before:
        ave_psnr = test_loop(cfg, model, criterion, loader.te, -1)
        print("PSNR before training {:2.3e}".format(ave_psnr))
        return
    if checkpoint.exists() and cfg.load:
        model = load_model_fp(cfg, model, checkpoint, gpuid)
        print("Loaded model.")
        cfg.current_epoch = cfg.epochs

    # run_ot_v_displacement(cfg,criterion,loader.te)
    for epoch in range(cfg.current_epoch, cfg.epochs):

        if tr_method == "mse":
            losses = train_loop_mse(cfg, model, optimizer, criterion,
                                    loader.tr, epoch)
            ave_psnr = test_loop_mse(cfg, model, criterion, loader.te, epoch)
        elif tr_method == "n2n":
            losses = train_loop_n2n(cfg, model, optimizer, criterion,
                                    loader.tr, epoch)
            ave_psnr = test_loop_n2n(cfg, model, criterion, loader.te, epoch)
        else:
            raise ValueError(f"Uknown training method [{tr_method}]")

        # losses = train_loop(cfg,model,optimizer,criterion,loader.tr,epoch)
        # ave_psnr = test_loop(cfg,model,criterion,loader.te,epoch)
        te_ave_psnr[epoch] = ave_psnr
        cfg.current_epoch += 1

    epochs, psnr = zip(*te_ave_psnr.items())
    best_index = np.argmax(psnr)
    best_epoch, best_psnr = epochs[best_index], psnr[best_index]
    print(
        f"Best Epoch {best_epoch} | Best PSNR {best_psnr} | N: {cfg.N} | Blind: {blind}"
    )

    root = Path(f"{settings.ROOT_PATH}/output/n2n/{postfix}/")
    # if cfg.blind: root = root / Path(f"./blind/")
    # else: root = root / Path(f"./nonblind/")
    fn = Path(f"results.csv")

    if not root.exists(): root.mkdir(parents=True)
    path = root / fn
    with open(path, 'w') as f:
        f.write("{:d},{:d},{:2.10e},{:d}\n".format(cfg.N, best_epoch,
                                                   best_psnr, nparams))

    save_model(cfg, model, optimizer)
예제 #2
0
def run_me(rank=0,Sgrid=[50000],Ngrid=[3],nNgrid=1,Ggrid=[1.],nGgrid=1,ngpus=3,idx=0):
# def run_me(rank=1,Ngrid=1,Ggrid=1,nNgrid=1,ngpus=3,idx=1):
    
    """
    PSNR 20 = (can equal) = AWGN @ 25
    PSNR 25 = (can equal) = AWGN @ 14
    PSNR 28 = (can equal) = AWGN @ 5
    """

    args = get_args()
    args.name = "default"
    cfg = get_cfg(args)
    cfg.use_ddp = False
    cfg.use_apex = False
    gpuid = 1
    cfg.gpuid = gpuid
    # gpuid = rank % ngpus # set gpuid
    cfg.device = f"cuda:{gpuid}"
    
    # -- experiment info --
    cfg.exp_name = "sup_n9_kpn-standard-filterSize15_f128_kpnLoss"
    cfg.desc = "Desc: sup kpn-standard-filterSized15, f128, kpnLoss"

    grid_idx = idx*(1*ngpus)+rank
    B_grid_idx = (grid_idx % 2)
    N_grid_idx = ( grid_idx // 2 ) % nNgrid
    G_grid_idx = grid_idx // (nNgrid * 2) % nGgrid
    S_grid_idx = grid_idx // (nGgrid * nNgrid * 2) 

    cfg.use_collate = True
    # cfg.dataset.download = False
    # cfg.cls = cfg
    cfg.S = Sgrid[S_grid_idx]
    # cfg.dataset.name = "cifar10"
    cfg.dataset.name = "voc"
    # cfg.blind = (B_grid_idx == 0)
    cfg.supervised = True
    cfg.blind = not cfg.supervised
    cfg.N = Ngrid[N_grid_idx]
    cfg.N = 6
    cfg.kpn_filter_onehot = True
    cfg.kpn_frame_size = 15
    cfg.dynamic.frames = cfg.N
    cfg.noise_type = 'g'
    cfg.noise_params['g']['stddev'] = Ggrid[G_grid_idx]
    # cfg.noise_type = 'll'
    # cfg.noise_params['ll']['alpha'] = 255*0.015
    # cfg.noise_params['ll']['read_noise'] = 0.25
    # cfg.recon_l1 = True
    noise_level = Ggrid[G_grid_idx]
    cfg.batch_size = 4
    cfg.init_lr = 1e-4
    cfg.unet_channels = 3
    cfg.input_N = cfg.N-1
    cfg.epochs = 100
    cfg.log_interval = 50 # int(int(50000 / cfg.batch_size) / 100)
    cfg.dynamic.bool = True
    cfg.dynamic.ppf = 2
    cfg.dynamic.frame_size = 128
    cfg.dynamic.total_pixels = 2*cfg.N
    cfg.load = False

    # -- input noise for learning --
    cfg.input_noise = False
    cfg.input_noise_middle_only = False
    cfg.input_with_middle_frame = True
    cfg.input_noise_level = noise_level/255
    if cfg.input_with_middle_frame:
        cfg.input_N = cfg.N

    blind = "blind" if cfg.blind else "nonblind"
    print(grid_idx,blind,cfg.N,Ggrid[G_grid_idx],gpuid,cfg.input_noise,cfg.input_with_middle_frame)

    # if blind == "nonblind": return 
    dynamic_str = "dynamic_input_noise" if cfg.input_noise else "dynamic"
    if cfg.input_noise_middle_only: dynamic_str += "_mo"
    if cfg.input_with_middle_frame: dynamic_str += "_wmf"
    postfix = Path(f"./{dynamic_str}/{cfg.dynamic.frame_size}_{cfg.dynamic.ppf}_{cfg.dynamic.total_pixels}/{cfg.S}/{blind}/{cfg.N}/{noise_level}/")
    print(postfix)
    cfg.model_path = cfg.model_path / postfix
    cfg.optim_path = cfg.optim_path / postfix
    if not cfg.model_path.exists(): cfg.model_path.mkdir(parents=True)
    if not cfg.optim_path.exists(): cfg.optim_path.mkdir(parents=True)
    
    checkpoint = cfg.model_path / Path("checkpoint_{}.tar".format(cfg.epochs))
    # if checkpoint.exists(): return

    print("PID: {}".format(os.getpid()))
    print("N: {} | Noise Level: {}".format(cfg.N,cfg.noise_params['g']['stddev']))

    torch.cuda.set_device(gpuid)

    # -- load model --
    model,criterion = load_model_kpn(cfg)
    optimizer = load_optimizer(cfg,model)
    scheduler = load_scheduler(cfg,model,optimizer)
    nparams = count_parameters(model)
    print("Number of Trainable Parameters: {}".format(nparams))

    # load data
    # data,loader = load_dataset(cfg,'denoising')
    data,loader = load_dataset(cfg,'dynamic')
    # data,loader = simulate_noisy_dataset(data,loaders,M,N)


    if cfg.load:
        fp = cfg.model_path / Path("checkpoint_30.tar")
        model = load_model_fp(cfg,model,fp,0)

    cfg.current_epoch = 0
    te_ave_psnr = {}
    test_before = False
    if test_before:
        ave_psnr = test_loop(cfg,model,criterion,loader.te,-1)
        print("PSNR before training {:2.3e}".format(ave_psnr))
        return 
    if checkpoint.exists() and cfg.load:
        model = load_model_fp(cfg,model,checkpoint,gpuid)
        print("Loaded model.")
        cfg.current_epoch = cfg.epochs
        
    cfg.global_step = 0
    use_record = False
    record = init_record()
    # run_test_xbatch(cfg,criterion,loader.tr)
    # run_ot_v_displacement(cfg,criterion,loader.tr)
    # exit()

    for epoch in range(cfg.current_epoch,cfg.epochs):

        print(cfg.desc)
        sys.stdout.flush()

        losses,epoch_record = train_loop(cfg,model,optimizer,criterion,loader.tr,epoch)

        if use_record:
            record = record.append(epoch_record)
            write_record_file(cfg.current_epoch,postfix,record)

        ave_psnr = test_loop(cfg,model,criterion,loader.te,epoch)
        te_ave_psnr[epoch] = ave_psnr
        cfg.current_epoch += 1


    epochs,psnr = zip(*te_ave_psnr.items())
    best_index = np.argmax(psnr)
    best_epoch,best_psnr = epochs[best_index],psnr[best_index]
    
    root = Path(f"{settings.ROOT_PATH}/output/n2n-kpn/{postfix}/")
    # if cfg.blind: root = root / Path(f"./blind/")
    # else: root = root / Path(f"./nonblind/")
    fn = Path(f"results.csv")

    if not root.exists(): root.mkdir(parents=True)
    path = root / fn
    with open(path,'w') as f:
        f.write("{:d},{:d},{:2.10e},{:d}\n".format(cfg.N,best_epoch,best_psnr,nparams))
    
    save_model(cfg, model, optimizer)
예제 #3
0
파일: dncnn.py 프로젝트: gauenk/cl_gen
def run_me(rank=0,Sgrid=[50000],Ngrid=[2],nNgrid=2,Ggrid=[25],nGgrid=1,ngpus=3,idx=0):
    
    args = get_args()
    args.name = "default"
    cfg = get_cfg(args)
    cfg.use_ddp = False
    cfg.use_apex = False
    # gpuid = rank % ngpus # set gpuid
    gpuid = 0
    cfg.device = f"cuda:{gpuid}"

    grid_idx = idx*(1*ngpus)+rank
    B_grid_idx = (grid_idx % 2)
    N_grid_idx = ( grid_idx // 2 ) % nNgrid
    G_grid_idx = grid_idx // (nNgrid * 2) % nGgrid
    S_grid_idx = grid_idx // (nGgrid * nNgrid * 2) 

    cfg.use_collate = True
    # cfg.dataset.download = False
    # cfg.cls = cfg
    cfg.S = Sgrid[S_grid_idx]
    cfg.dataset.name = "cifar10"
    # cfg.dataset.name = "voc"
    cfg.blind = (B_grid_idx == 0)
    cfg.blind = False
    cfg.N = Ngrid[N_grid_idx]
    cfg.dynamic.frames = cfg.N
    cfg.noise_type = 'g'
    cfg.noise_params['g']['stddev'] = Ggrid[G_grid_idx]
    noise_level = Ggrid[G_grid_idx]
    cfg.batch_size = 16
    cfg.init_lr = 1e-3
    cfg.unet_channels = 3
    # if cfg.blind: cfg.input_N = cfg.N - 1
    # else: cfg.input_N = cfg.N
    cfg.input_N = cfg.N-1
    cfg.epochs = 30
    cfg.log_interval = int(int(50000 / cfg.batch_size) / 100)
    cfg.dataset.load_residual = True
    cfg.dynamic.bool = True
    cfg.dynamic.ppf = 0
    cfg.dynamic.frame_size = 128
    cfg.dynamic.total_pixels = 0
    cfg.load = False

    blind = "blind" if cfg.blind else "nonblind"
    print(grid_idx,blind,cfg.N,Ggrid[G_grid_idx],gpuid)

    # if blind == "nonblind": return 
    postfix = get_postfix_str(cfg,blind,noise_level)
    cfg.model_path = cfg.model_path / postfix
    cfg.optim_path = cfg.optim_path / postfix
    if not cfg.model_path.exists(): cfg.model_path.mkdir(parents=True)
    if not cfg.optim_path.exists(): cfg.optim_path.mkdir(parents=True)
    
    checkpoint = cfg.model_path / Path("checkpoint_{}.tar".format(cfg.epochs))
    # if checkpoint.exists(): return

    print("N: {} | Noise Level: {}".format(cfg.N,cfg.noise_params['g']['stddev']))

    torch.cuda.set_device(gpuid)

    # load model
    model = DnCNN_Net(3)
    optimizer = load_optimizer(cfg,model)
    scheduler = load_scheduler(cfg,model,optimizer)
    nparams = count_parameters(model)
    print("Number of Trainable Parameters: {}".format(nparams))

    # load data
    data,loader = load_dataset(cfg,'denoising')
    # data,loader = load_dataset(cfg,'dynamic')
    # data,loader = simulate_noisy_dataset(data,loaders,M,N)

    # load criterion
    criterion = nn.BCELoss()

    cfg.current_epoch = 0
    te_ave_psnr = {}
    test_before = False
    if test_before:
        ave_psnr = test_loop(cfg,model,criterion,loader.te,-1)
        print("PSNR before training {:2.3e}".format(ave_psnr))
    if checkpoint.exists() and cfg.load:
        model = load_model_fp(cfg,model,checkpoint,gpuid)
        print("Loaded model.")
        cfg.current_epoch = cfg.epochs
        
    for epoch in range(cfg.current_epoch,cfg.epochs):

        losses = train_loop(cfg,model,optimizer,criterion,loader.tr,epoch)
        ave_psnr = test_loop(cfg,model,criterion,loader.te,epoch)
        te_ave_psnr[epoch] = ave_psnr
        cfg.current_epoch += 1

    epochs,psnr = zip(*te_ave_psnr.items())
    best_index = np.argmax(psnr)
    best_epoch,best_psnr = epochs[best_index],psnr[best_index]
    
    root = Path(f"{settings.ROOT_PATH}/output/dncnn/{postfix}/")
    # if cfg.blind: root = root / Path(f"./blind/")
    # else: root = root / Path(f"./nonblind/")
    fn = Path(f"results.csv")

    if not root.exists(): root.mkdir(parents=True)
    path = root / fn
    with open(path,'w') as f:
        f.write("{:d},{:d},{:2.10e},{:d}\n".format(cfg.N,best_epoch,best_psnr,nparams))
    
    save_model(cfg, model, optimizer)
예제 #4
0
파일: main.py 프로젝트: gauenk/cl_gen
def run_me(rank=0,
           Sgrid=[1],
           Ngrid=[3],
           nNgrid=1,
           Ggrid=[25.],
           nGgrid=1,
           ngpus=3,
           idx=0):

    cfg = get_main_config()

    # -- noise info --
    noise_type = cfg.noise_params.ntype
    noise_params = cfg.noise_params['g']
    noise_level = noise_params['stddev']
    noise_level_str = f"{int(noise_params['stddev'])}"

    # noise_params = cfg.noise_params['qis']
    # noise_level = noise_params['readout']
    # noise_level_str = f"{int(noise_params['alpha']),int(noise_params['readout']),int(noise_params['nbits'])}"

    # -- experiment info --
    name = "n2sim_burstv2_testingAlignedAbps"
    ds_name = cfg.dataset.name.lower()
    sup_str = "sup" if cfg.supervised else "unsup"
    bs_str = "b{}".format(cfg.batch_size)
    align_str = "yesAlignNet" if cfg.burst_use_alignment else "noAlignNet"
    unet_str = "yesUnet" if cfg.burst_use_unet else "noUnet"
    if cfg.burst_use_unet_only: unet_str += "Only"
    kpn_cascade_str = "cascade{}".format(
        cfg.kpn_cascade_num) if cfg.kpn_cascade else "noCascade"
    kpnba_str = "kpnBurstAlpha{}".format(int(cfg.kpn_burst_alpha * 1000))
    frame_str = "n{}".format(cfg.N)
    framesize_str = "f{}".format(cfg.dynamic.frame_size)
    filtersize_str = "filterSized{}".format(cfg.kpn_frame_size)
    misc = "noKL"
    cfg.exp_name = f"{sup_str}_{name}_{ds_name}_{kpn_cascade_str}_{bs_str}_{frame_str}_{framesize_str}_{filtersize_str}_{align_str}_{unet_str}_{kpnba_str}_{misc}"
    print(f"Experiment name: {cfg.exp_name}")
    desc_fmt = (frame_str, kpn_cascade_str, framesize_str, filtersize_str,
                cfg.init_lr, align_str)
    cfg.desc = "Desc: unsup, frames {}, cascade {}, framesize {}, filter size {}, lr {}, {}, kl loss, anneal mse".format(
        *desc_fmt)
    print(f"Description: [{cfg.desc}]")
    noise_level = cfg.noise_params['g']['stddev']

    # -- attn params --
    cfg.patch_sizes = [128, 128]
    cfg.d_model_attn = 3

    cfg.input_noise = False
    cfg.input_noise_middle_only = False
    cfg.input_with_middle_frame = True

    cfg.middle_frame_random_erase = False
    cfg.input_noise_level = noise_level / 255.
    if (cfg.blind == 0):  # e.g. supervised is true
        cfg.input_with_middle_frame = True
    if cfg.input_with_middle_frame:
        cfg.input_N = cfg.N

    blind = "blind" if cfg.blind else "nonblind"
    gpuid = cfg.gpuid
    print(blind, cfg.N, noise_level, gpuid)

    # if blind == "nonblind": return
    dynamic_str = "dynamic_input_noise" if cfg.input_noise else "dynamic"
    if cfg.input_noise_middle_only: dynamic_str += "_mo"
    if cfg.input_with_middle_frame: dynamic_str += "_wmf"
    postfix = Path(
        f"./modelBurst/{cfg.exp_name}/{dynamic_str}/{cfg.dynamic.frame_size}_{cfg.dynamic.ppf}_{cfg.dynamic.total_pixels}/{cfg.S}/{blind}/{cfg.N}/{noise_level}/"
    )
    print(postfix, cfg.dynamic.total_pixels)
    cfg.model_path = cfg.model_path / postfix
    cfg.optim_path = cfg.optim_path / postfix
    if not cfg.model_path.exists(): cfg.model_path.mkdir(parents=True)
    if not cfg.optim_path.exists(): cfg.optim_path.mkdir(parents=True)

    checkpoint = cfg.model_path / Path("checkpoint_{}.tar".format(cfg.epochs))
    # if checkpoint.exists(): return

    print(
        f"Supervised: {cfg.supervised} | Noise2Noise: {cfg.n2n} | APBS: {cfg.abps} | ABPS-Inputs: {cfg.abps_inputs}"
    )
    print(
        f"Sim Method: {cfg.sim_method} | Shuffle K {cfg.sim_shuffleK} | Sim K: {cfg.sim_K} | Patchsize: {cfg.sim_patchsize}"
    )
    print("N: {} | Noise Level: {} | Noise Type: {}".format(
        cfg.N, noise_level_str, noise_type))

    torch.cuda.set_device(gpuid)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   init summary writer
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    log_base = Path(f"runs/{name}")
    if not log_base.exists(): log_base.mkdir(parents=True)
    log_dir = log_base / Path(f"{cfg.exp_name}")
    writer = SummaryWriter(log_dir=str(log_dir))

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load the Model, Data, Optim, Crit
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # load model
    # model = load_unet_model(cfg)
    # model,criterion = load_burst_n2n_model(cfg)
    model, noise_critic, criterion = load_burst_kpn_model(cfg)
    # model,criterion = load_model_kpn(cfg)
    # optimizer = load_optimizer(cfg,model)
    # scheduler = load_scheduler(cfg,model,optimizer)
    # scheduler = make_lr_scheduler(cfg,model.unet_info.optim)
    nparams = count_parameters(model.denoiser_info.model)
    print("Number of Trainable Parameters: {}".format(nparams))
    print("GPUID: {}".format(gpuid))
    print("PID: {}".format(os.getpid()))

    # load data
    # data,loader = load_dataset(cfg,'denoising')
    # data,loader = load_dataset(cfg,'default')
    data, loader = load_dataset(cfg, 'dynamic')
    # data,loader = load_dataset(cfg,'dynamic-lmdb-all')
    # data,loader = load_dataset(cfg,'dynamic-lmdb-burst')
    # data,loader = load_dataset(cfg,'default')
    # data,loader = simulate_noisy_dataset(data,loaders,M,N)

    # load criterion
    # criterion = nn.BCELoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Load the Model from Memory
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.load:
        name = "denoiser"
        fp = "/home/gauenk/Documents/experiments/cl_gen/output/n2sim/cifar10/default/model/modelBurst/unsup_n2sim_burstv2_voc_noCascade_b20_n8_f128_filterSized9_noAlignNet_noUnet_unet_mse_noKL/dynamic_wmf/128_1_8/1/blind/8/25.0/denoiser/checkpoint_83.tar"
        # fp = cfg.model_path / Path("{}/checkpoint_{}.tar".format(name,cfg.load_epoch))
        # fp = Path("/home/gauenk/Documents/experiments/cl_gen/output/n2n_wl/cifar10/default/model/modelBurst/unsup_burst_noCascade_b4_n10_f128_filterSized12_kpn_klLoss_annealMSE_klPRes/dynamic_wmf/128_1_10/1/blind/10/25.0/denoiser/checkpoint_{}.tar".format(cfg.load_epoch))
        model.denoiser_info.model = load_model_fp(cfg,
                                                  model.denoiser_info.model,
                                                  fp, cfg.gpuid)
        fp = "/home/gauenk/Documents/experiments/cl_gen/output/n2sim/cifar10/default/optim/modelBurst/unsup_n2sim_burstv2_voc_noCascade_b20_n8_f128_filterSized9_noAlignNet_noUnet_unet_mse_noKL/dynamic_wmf/128_1_8/1/blind/8/25.0/denoiser/checkpoint_83.tar"
        # model.denoiser_info.optim = load_optim_fp(cfg,model.denoiser_info.optim,fp,cfg.gpuid)
        # name = "critic"
        # fp = cfg.model_path / Path("{}/checkpoint_{}.tar".format(name,cfg.load_epoch))
        # noise_critic.disc = load_model_fp(cfg,noise_critic.disc,fp,cfg.gpuid)
        if cfg.restart_after_load:
            cfg.current_epoch = 0
            cfg.global_step = 0
        else:
            cfg.current_epoch = cfg.load_epoch + 1
            cfg.global_step = cfg.load_epoch * len(data.tr)
            ce, gs = cfg.current_epoch, cfg.global_step
            print(
                f"Starting Training from epoch [{ce}] and global step [{gs}]")
    else:
        cfg.current_epoch = 0
    scheduler = make_lr_scheduler(cfg, model.denoiser_info.optim,
                                  cfg.global_step)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #       Pre train-loop setup
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    te_ave_psnr = {}
    test_before = False
    if test_before:
        ave_psnr, record_test = test_loop(cfg, model, loader.te, -1, writer)
        print("PSNR before training {:2.3e}".format(ave_psnr))
        return
    if checkpoint.exists() and cfg.load:
        model = load_model_fp(cfg, model, checkpoint, cfg.gpuid)
        print("Loaded model.")
        cfg.current_epoch = cfg.epochs + 1
        cfg.global_step = len(train_data) * cfg.epochs

    record_losses = pd.DataFrame({
        'kpn': [],
        'ot': [],
        'psnr': [],
        'psnr_std': []
    })
    use_record = False
    loss_type = "sup_r_ot"

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #       Training Loop
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    for epoch in range(cfg.current_epoch, cfg.epochs):
        lr = model.denoiser_info.optim.param_groups[0]["lr"]
        print(cfg.desc)
        print("Learning Rate: %2.2e" % (lr))
        sys.stdout.flush()

        losses, record_losses = train_loop(cfg, model, scheduler, loader.tr,
                                           epoch, record_losses, writer)
        if use_record:
            write_record_losses_file(cfg.current_epoch, postfix, loss_type,
                                     record_losses)

        cfg.current_epoch += 1
        if epoch % cfg.save_interval == 0 and epoch > 0:
            save_burst_model(cfg, "align", model.align_info.model,
                             model.align_info.optim)
            save_burst_model(cfg, "denoiser", model.denoiser_info.model,
                             model.denoiser_info.optim)
            save_burst_model(cfg, "critic", noise_critic.disc,
                             noise_critic.optim)

        ave_psnr, record_test = test_loop(cfg, model, loader.te, epoch)
        if use_record:
            write_record_test_file(cfg.current_epoch, postfix, loss_type,
                                   record_test, writer)
        te_ave_psnr[epoch] = ave_psnr

    epochs, psnr = zip(*te_ave_psnr.items())
    best_index = np.argmax(psnr)
    best_epoch, best_psnr = epochs[best_index], psnr[best_index]
    print(
        f"Best Epoch {best_epoch} | Best PSNR {best_psnr} | N: {cfg.N} | Blind: {blind}"
    )

    root = Path(f"{settings.ROOT_PATH}/output/n2sim/{postfix}/")
    # if cfg.blind: root = root / Path(f"./blind/")
    # else: root = root / Path(f"./nonblind/")
    fn = Path(f"results.csv")

    if not root.exists(): root.mkdir(parents=True)
    path = root / fn
    with open(path, 'w') as f:
        f.write("{:d},{:d},{:2.10e},{:d}\n".format(cfg.N, best_epoch,
                                                   best_psnr, nparams))

    save_model(cfg, model, optimizer)