Example #1
0
def run_fnet(cfg, burst, score_fxn):
    model = UNet_small(3)
    model = model.to(burst.device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.99))
    train(cfg, burst, model, optim, 1000)
    score, scores_t = test(cfg, burst, model, score_fxn)
    return score, scores_t
Example #2
0
File: cog.py Project: gauenk/cl_gen
def run_cog(cfg, clean, noisy, directions, results):

    image_volume = noisy
    backbone = UNet_small(3).to(cfg.device)
    nn_params = {'lr': 1e-3, 'init_params': None}
    train_steps = 1000
    score = score_cog(cfg, image_volume, backbone, nn_params, train_steps)
    results['cog'] = score

    return
Example #3
0
def load_model_v2(cfg):
    simclr = None
    denoise_model = UNet_small(3 * cfg.input_N)
    input_patch, output_patch = cfg.patch_sizes
    d_model = cfg.d_model_attn
    xformer_args = [
        simclr, d_model, cfg.dynamic.frame_size, input_patch, output_patch,
        cfg.input_N, cfg.dataset.bw, denoise_model
    ]
    model = TransformerNetwork32_noxform(*xformer_args)
    model = model.to(cfg.device)
    return model
Example #4
0
def load_model_v4(cfg):
    simclr = None  #load_model_simclr(cfg)
    denoise_model = UNet_small(cfg.d_model_attn * cfg.input_N)
    # denoise_model = load_denoise_model_fp(cfg,denoise_model)

    input_patch, output_patch = cfg.patch_sizes
    # d_model = 2048 // (patch_height * patch_width)
    # d_model = 398336 // (patch_height * patch_width)
    d_model = cfg.d_model_attn
    xformer_args = [
        simclr, d_model, cfg.dynamic.frame_size, input_patch, output_patch,
        cfg.input_N, cfg.dataset.bw, denoise_model
    ]
    model = TransformerNetwork32_v4(*xformer_args)
    model = model.to(cfg.device)
    return model
Example #5
0
    def __init__(self, num_frames, num_in_ftrs, num_out_ftrs, num_out_channels,
                 patchsize, nh_size, img_size, attn_params = None):
        super().__init__()

        # -- init --
        self.num_frames = num_frames
        self.num_in_ftrs = num_in_ftrs
        self.num_out_ftrs = num_out_ftrs
        self.nh_size = nh_size # number of patches around center pixel
        self._patchsize = patchsize
        self.img_size = img_size
        self.layer_norm = nn.LayerNorm(num_out_channels*patchsize**2)

        # -- create model --
        self.unet = UNet_small(num_in_ftrs+3,num_out_channels)
        # self.unet_out_size = num_out_channels * patchsize**2
        patchsize_e = patchsize + (patchsize % 2) # add one dim for odd input size
        self.unet_out_size = num_out_channels * patchsize_e**2
        self.mlp = MLP( self.unet_out_size, num_out_ftrs, hidden_size = 256)
Example #6
0
    def __init__(self, ftr_model, d_model, im_size, input_patch, output_patch, n_frames, bw=False, denoise_model = None, batch_size = 8):
        super(TransformerNetwork32_dip_v2, self).__init__()
        """
        Transformer for denoising

        im_size : int : the length and width of the square input image

        """
        self.d_model = d_model
        self.ftr_model = {'spoof':ftr_model}
        print(d_model*n_frames,d_model)
        self.denoise_model_preproc_a = nn.Conv2d(d_model*n_frames,d_model*n_frames,3,1,1)
        # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3*n_frames,3,1,1)
        # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3,3,1,1)
        self.denoise_model_preproc_b = nn.Conv2d(d_model,3,3,1,1)
        self.denoise_model = denoise_model
        self.std = 5./255
        nhead = 4 # 4
        num_enc_layers = 2 # 6
        num_dec_layers = 2 # 6
        dim_ff = 256 # 512
        dropout = 0.0
        xform_args = [d_model, nhead, num_enc_layers, num_dec_layers, dim_ff, dropout]
        self.use_pos_enc = True
        self.use_pos_enc_values = False
        self.use_all_loss = False
        self.all_loss_v = "v1"

        self.color_code = nn.Linear(d_model,3)

        d_model = 3
        # self.perform = Performer(d_model,8,1) # 8,4
        # self.xform = nn.Transformer(*xform_args)
        self.clusters = np.load(f"{settings.ROOT_PATH}/data/kmeans_centers.npy")

        # -- constrastiv learning loss --
        num_transforms = n_frames
        hyperparams = edict()
        hyperparams.temperature = 0.1
        self.simclr_loss = ClBlockLoss(hyperparams, num_transforms, batch_size)

        # kwargs = {'num_tokens':512,'max_seq_len':input_patch**2,
        #           'dim':512, 'depth':6,'heads':4,'causal':False,'cross_attend':False}
        # self.xform_enc = PerformerLM(**kwargs)
        # self.xform_dec = PerformerLM(**kwargs)

        nhead = 1
        vdim = 1 if bw else 3
        self.pos_enc = PositionalEncoder(3,32*32*n_frames)
        self.attn_1 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False)
        self.attn_2 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False)
        self.attn_3 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False)
        self.attn = [self.attn_1,self.attn_2,self.attn_3]
        # self.stn = STN_Net()

        # -- keep output to be shifted inputs --
        # print(dir(self.attn))
        # print(self.attn.out_proj)
        # print(self.attn.v_proj_weight)
        for i in range(3):
            self.attn[i].v_proj_weight.data = torch.eye(3)
            # # print(self.attn.v_proj_weight)
            self.attn[i].v_proj_weight = self.attn[i].v_proj_weight.requires_grad_(False)
            # # print('o',self.attn.out_proj.weight)
            self.attn[i].out_proj.weight.data = torch.eye(3)
            # # print('o',self.attn.out_proj.weight.data)
            # # print('b',self.attn.out_proj.bias.data)
            self.attn[i].out_proj.bias.data = torch.zeros_like(self.attn[i].out_proj.bias.data)
            # # print('b',self.attn.out_proj.bias.data)
            self.attn[i].out_proj = self.attn[i].out_proj.requires_grad_(False)
            

        # self.attn = nn.MultiheadAttention(d_model,nhead,dropout)#,vdim=vdim)
        self.input_patch,self.output_patch = input_patch,output_patch
        self.im_size = im_size

        ftr_size = (input_patch**2)*d_model
        img_size = (input_patch**2)*3
        in_channels = d_model * n_frames
        # in_channels = 2304
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,1)])
        stride = input_patch // output_patch
        padding = 1 if stride > 1 else 1
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- left settings --; reduce imsize by 2 pix ( 10 in ,8 out )
        # stride = 1
        # padding = 0
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- conv settings; reduce imsize by Half -- ( 16 in, 8 out ) ( A in, A/2 out )
        # stride = 2
        # padding = 1
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)])

        # -- conv settings; maintain imsize -- ( A in, A out )
        stride = 1
        padding = 1
        out_chann = 1 if bw else 3
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)])

        # -- using MEAN as denoiser --
        in_channels = d_model
        # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)])

        self.end_conv = nn.Sequential(*[nn.Conv2d(3,3,1),nn.LeakyReLU(),nn.Conv2d(3,3,1)])
        self.unet = UNet_small(3)

        padding = (input_patch - output_patch) // 2
        # stride = im_size-input_patch
        stride = output_patch
        # print(input_patch,padding,stride)
        self.unfold_input = nn.Unfold(input_patch,1,padding,stride)
Example #7
0
def run_experiment(cfg, data, record_fn, bss_dir):

    #
    # Experiment Setup
    #

    # -- init variables for access --
    T = cfg.nframes
    H = cfg.nblocks
    framesize = 156
    patchsize = 32
    P = 9
    REF_H = get_ref_block_index(cfg.nblocks)
    nn_params = edict({'lr': 1e-3, 'init_params': None})
    gridT = torch.arange(T)

    # -- setup noise --
    cfg.noise_type = 'g'
    cfg.ntype = cfg.noise_type
    cfg.noise_params.ntype = cfg.noise_type
    noise_level = 25.
    cfg.noise_params['g']['stddev'] = noise_level
    noise_level_str = f"{int(noise_level)}"
    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -- simulate no motion --
    nomotion = np.zeros((T, 2)).astype(np.long)
    image_index = 10
    full_image = data.tr[image_index][2]
    full_burst, aligned, motion = simulate_dynamics(full_image, T, nomotion, 0,
                                                    framesize)
    save_image(full_burst, "full_burst.png")
    clean_full_ref = full_burst[T // 2]

    # -- apply noisy --
    full_noisy = noise_xform(full_burst)
    save_image(full_noisy, "full_noisy.png")
    noisy_full_ref = full_noisy[T // 2]

    #
    # Start Method
    #

    # -- find good patches from full noisy image --
    # init_tl_list = sample_good_init_tl(clean_full_ref,P,patchsize)
    init_tl_list = sample_good_init_tl(noisy_full_ref, P, patchsize)

    # -- grab blocks from selected patches for burst --
    clean, noisy = [], []
    clean = crop_burst_to_blocks(full_burst, cfg.nblocks, init_tl_list[0],
                                 patchsize)
    noisy = crop_burst_to_blocks(full_noisy, cfg.nblocks, init_tl_list[0],
                                 patchsize)

    # -- image for "test" function --
    REF_PATCH = 0  # backward compat. for functions without patch-dim support
    # image = clean[T//2,REF_H,REF_PATCH] # legacy
    image = noisy[T // 2, REF_H, REF_PATCH] + 0.5  # legacy

    # -- normalize --
    # clean -= clean.min()
    # clean /= clean.max()
    print_tensor_stats("clean", clean)
    print_tensor_stats("noisy", clean)
    print_tensor_stats("full_burst", full_burst)
    print_tensor_stats("full_noisy", full_noisy)

    save_image(clean, "fast_unet_clean.png", normalize=True)
    save_image(noisy, "fast_unet_noisy.png", normalize=True)
    save_image(image, "fast_unet_image.png", normalize=True)

    # -- select search space --
    # block_search_space = get_block_arangements_subset(cfg.nblocks,cfg.nframes,
    #                                                   tcount=4,difficult=True)
    # block_search_space = get_block_arangements(cfg.nblocks,cfg.nframes)
    use_rand = True
    if use_rand:
        tcount = 3
        size = 30
        bss = get_small_test_block_arangements(bss_dir,
                                               cfg.nblocks,
                                               cfg.nframes,
                                               tcount,
                                               size,
                                               difficult=True)
        print("LEN BSS", len(bss))
        block_search_space = bss

    # -- setup loop --
    clean = clean.to(cfg.device)
    noisy = noisy.to(cfg.device)
    image = image.to(cfg.device)
    record, idx = [], 0

    # -- search over search space --
    for prop in tqdm(block_search_space):

        # -- fetch --
        clean_prop = clean[gridT, prop].to(cfg.device)
        noisy_prop = noisy[gridT, prop].to(cfg.device)
        save_image(clean_prop[:, 0], "clean_prop.png")
        save_image(noisy_prop[:, 0], "noisy_prop.png")

        # -- compute COG --
        backbone = UNet_small
        if (nn_params['init_params'] is None):
            train_steps = cfg.cog_train_steps
        else:
            train_steps = cfg.cog_second_train_steps
        score = 0.
        # score = score_cog(cfg,image_volume,backbone,nn_params,train_steps)

        # -- [legacy] fill results with -1's --
        # results = fill_results(cfg,image,clean_prop,noisy_prop,None,idx)

        # -- compute single UNet --
        model = UNet_small(3).to(cfg.device)
        init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean_prop[:, REF_PATCH], noisy_prop[:, REF_PATCH],
              model, optim)
        print("noisy_prop.shape", noisy_prop.shape)
        results = test(cfg, image, clean_prop[:, REF_PATCH],
                       noisy_prop[:, REF_PATCH], model, idx)
        results['cog'] = score
        print(score, results['ave'], results['psnr_clean'], prop)

        # -- update --
        record.append(results)
        idx += 1

    record = pd.DataFrame(record)
    print(f"Writing record to {record_fn}")
    record.to_csv(record_fn)
    return record
Example #8
0
def single_image_unet(cfg, queue, full_image, device):
    full_image = full_image.to(device)
    image = tvF.crop(full_image, 128, 128, 32, 32)
    T = 5

    # -- poisson noise --
    noise_type = "pn"
    cfg.noise_type = noise_type
    cfg.noise_params['pn']['alpha'] = 40.0
    cfg.noise_params['pn']['readout'] = 0.0
    cfg.noise_params.ntype = cfg.noise_type
    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    clean = torch.stack([image for i in range(T)], dim=0)
    noisy = noise_xform(clean)
    save_image(clean, "clean.png", normalize=True)
    m_clean, m_noisy = clean.clone(), noisy.clone()
    for i in range(T // 2):
        image_mis = tvF.crop(full_image, 128 + 1, 128, 32, 32)
        m_clean[i] = image_mis
        m_noisy[i] = noise_xform(m_clean[i])

    noise_level = 50. / 255.

    # -- model all --
    print("-- All Aligned --")
    model = UNet_small(3)  # UNet_n2n(1)
    cfg.init_lr = 1e-4
    optim = torch.optim.Adam(model.parameters(),
                             lr=cfg.init_lr,
                             betas=(0.9, 0.99))
    train(cfg, image, clean, noisy, model, optim)
    results = test(cfg, image, clean, noisy, model, 0)
    rec = model(noisy) + 0.5
    save_image(rec, "rec_all.png", normalize=True)
    print("Single Image Unet:")
    print(images_to_psnrs(clean, rec))
    print(images_to_psnrs(rec - 0.5, noisy))
    print(images_to_psnrs(rec[[0]], rec[[1]]))
    print(images_to_psnrs(rec[[0]], rec[[2]]))
    print(images_to_psnrs(rec[[1]], rec[[2]]))

    og_clean, og_noisy = clean.clone(), noisy.clone()
    clean[0] = image_mis
    noisy[0] = noise_xform(clean[[0]])[0]

    # -- model all --
    print("-- All Misligned --")
    model = UNet_small(3)  # UNet_n2n(1)
    cfg.init_lr = 1e-4
    optim = torch.optim.Adam(model.parameters(),
                             lr=cfg.init_lr,
                             betas=(0.9, 0.99))
    train(cfg, image, clean, noisy, model, optim)
    results = test(cfg, image, clean, noisy, model, 0)
    rec = model(noisy) + 0.5
    save_image(rec, "rec_all.png", normalize=True)
    print("Single Image Unet:")
    print(images_to_psnrs(clean, rec))
    print(images_to_psnrs(rec - 0.5, noisy))
    print(images_to_psnrs(rec[[0]], rec[[1]]))
    print(images_to_psnrs(rec[[0]], rec[[2]]))
    print(images_to_psnrs(rec[[1]], rec[[2]]))

    for j in range(3):
        # -- data --
        noisy1 = torch.stack([noisy[0], noisy[1]], dim=0)
        clean1 = torch.stack([clean[0], clean[1]], dim=0)

        noisy2 = torch.stack([noisy[1], noisy[2]], dim=0)
        clean2 = torch.stack([clean[1], clean[2]], dim=0)

        noisy3 = torch.stack([og_noisy[0], noisy[1]], dim=0)
        clean3 = torch.stack([og_clean[0], clean[1]], dim=0)

        # -- model 1 --
        model = UNet_small(3)  # UNet_n2n(1)
        cfg.init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=cfg.init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean1, noisy1, model, optim)
        results = test(cfg, image, clean1, noisy1, model, 0)
        rec = model(noisy1) + 0.5
        xrec = model(noisy2) + 0.5
        save_image(rec, "rec1.png", normalize=True)
        print("[misaligned] Single Image Unet:", images_to_psnrs(clean1, rec),
              images_to_psnrs(clean2, xrec),
              images_to_psnrs(rec - 0.5, noisy1),
              images_to_psnrs(xrec - 0.5, noisy2),
              images_to_psnrs(rec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[0]], xrec[[1]]),
              images_to_psnrs(xrec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[1]], rec[[0]]))

        # -- model 2 --
        model = UNet_small(3)  # UNet_n2n(1)
        cfg.init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=cfg.init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean2, noisy2, model, optim)
        results = test(cfg, image, clean2, noisy2, model, 0)
        rec = model(noisy2) + 0.5
        xrec = model(noisy1) + 0.5
        save_image(rec, "rec2.png", normalize=True)
        print("[aligned] Single Image Unet:", images_to_psnrs(clean2, rec),
              images_to_psnrs(clean1, xrec),
              images_to_psnrs(rec - 0.5, noisy2),
              images_to_psnrs(xrec - 0.5, noisy1),
              images_to_psnrs(rec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[0]], xrec[[1]]),
              images_to_psnrs(xrec[[0]], rec[[1]]),
              images_to_psnrs(xrec[[1]], rec[[0]]))

        # -- model 3 --
        model = UNet_small(3)  # UNet_n2n(1)
        cfg.init_lr = 1e-4
        optim = torch.optim.Adam(model.parameters(),
                                 lr=cfg.init_lr,
                                 betas=(0.9, 0.99))
        train(cfg, image, clean3, noisy3, model, optim)
        results = test(cfg, image, clean3, noisy3, model, 0)
        rec = model(noisy3) + 0.5
        rec_2 = model(noisy2) + 0.5
        rec_1 = model(noisy1) + 0.5
        save_image(rec, "rec1.png", normalize=True)
        print("[aligned (v3)] Single Image Unet:")
        print("clean-rec", images_to_psnrs(clean3, rec))
        print("clean1-rec1", images_to_psnrs(clean1, rec_1))
        print("clean2-rec2", images_to_psnrs(clean2, rec_2))
        print("rec-noisy3", images_to_psnrs(rec - 0.5, noisy3))
        print("rec1-noisy1", images_to_psnrs(rec_1 - 0.5, noisy1))
        print("rec2-noisy2", images_to_psnrs(rec_2 - 0.5, noisy2))
        print("[v3]: rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))
        print("[v1]: rec0-rec1", images_to_psnrs(rec_1[[0]], rec_1[[1]]))
        print("[v2]: rec0-rec1", images_to_psnrs(rec_2[[0]], rec_2[[1]]))
        print("[v1-v2](a):", images_to_psnrs(rec_1[[1]], rec_2[[1]]))
        print("[v1-v2](b):", images_to_psnrs(rec_1[[1]], rec_2[[0]]))
        print("[v1-v2](c):", images_to_psnrs(rec_1[[0]], rec_2[[1]]))
        print("[v1-v2](d):", images_to_psnrs(rec_1[[0]], rec_2[[0]]))
        print("-" * 20)
        print("[v2-v3](a):", images_to_psnrs(rec_2[[1]], rec[[1]]))
        print("[v2-v3](b):", images_to_psnrs(rec_2[[1]], rec[[0]]))
        print("[v2-v3](c):", images_to_psnrs(rec_2[[0]], rec[[1]]))
        print("[v2-v3](d):", images_to_psnrs(rec_2[[0]], rec[[0]]))
        print("-" * 20)
        print("[v1-v3](a):", images_to_psnrs(rec_1[[1]], rec[[1]]))
        print("[v1-v3](b):", images_to_psnrs(rec_1[[1]], rec[[0]]))
        print("[v1-v3](c):", images_to_psnrs(rec_1[[0]], rec[[1]]))
        print("[v1-v3](d):", images_to_psnrs(rec_1[[0]], rec[[0]]))
        print("-" * 20)
        print("rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))