Esempio n. 1
0
def ot_all_pairwise_items(residuals_raw, reg=1.0, S=6):
    """
    :param residuals: shape [B N D C]
    """

    # -- init --
    residuals = residuals_raw.detach().requires_grad_(False)
    B, N, D, C = residuals.shape

    # -- restrict number of terms  --
    Ngrid_i, Ngrid_j = np.tril_indices(N)
    order = npr.permutation(len(Ngrid_i))
    Ngrid_i, Ngrid_j = Ngrid_i[order[:S]], Ngrid_j[order[:S]]

    # -- compute losses --
    results = pd.DataFrame({'indices': [], 'loss': []})
    for bi in range(B):
        for bj in range(B):
            if bi > bj: continue
            for ni, nj in zip(Ngrid_i, Ngrid_j):
                ri, rj = residuals[bi, ni], residuals[bj, nj]
                M = torch.sum(torch.pow(ri.unsqueeze(1) - rj, 2), dim=-1)
                loss = sink_stabilized(M, reg)
                result = {'indices': (bi, bj, ni, nj), 'loss': loss}
                results = results.append(result, ignore_index=True)
    return results
Esempio n. 2
0
 def ot_frame_pairwise_bp(self,diffs):
     ot_loss = 0
     for b in range(BS):
         pairs,S = self._select_pairs(S,N)
         for (i,j) in pairs:
             di,dj = diffs[b,i],diffs[b,j]
             M = torch.sum(torch.pow(di.unsqueeze(1) - dj,2),dim=-1)
             ot_loss += sink_stabilized(M, self.reg)
     ot_loss /= S*BS
     return ot_loss
Esempio n. 3
0
def ot_burst_pairwise(residuals,reg=0.5):
    """
    :param residuals: shape [B D C]
    """
    # -- init --
    B,D,C = residuals.shape

    # -- create pairs --
    pairs,S = select_pairs(B)

    # -- compute losses --
    ot_loss,ot_loss_w = 0,0
    for (i,j) in pairs:
        ri,rj = residuals[i],residuals[j]
        M = torch.sum(torch.pow(ri.unsqueeze(1) - rj,2),dim=-1)
        loss = sink_stabilized(M,reg)
        ot_loss += loss
        weight = (torch.sum(ri**2) + torch.sum(rj**2)).item()
        ot_loss_w += weight * loss
    return ot_loss,ot_loss_w
Esempio n. 4
0
def ot_gaussian_bp(residuals, std=25, reg=1.0, K=4):
    """
    :param residuals: shape [B D C]
    """

    # -- init --
    B, D, C = residuals.shape

    # -- grab some batches --
    Bgrid = torch.randperm(B)[:K]

    # -- compute losses --
    ot_loss = 0
    for b in Bgrid:
        rb = residuals[b]
        noise = torch.normal(torch.zeros_like(rb), std=std / 255.)
        M = torch.sum(torch.pow(rb.unsqueeze(1) - noise, 2), dim=-1)
        loss = sink_stabilized(M, reg)
        ot_loss += loss
    return ot_loss / len(Bgrid)
Esempio n. 5
0
def ot_all_gaussian_items(residuals_raw, std=25, reg=1.0):
    """
    :param residuals: shape [B N D C]
    """

    # -- init --
    residuals = residuals_raw.detach().requires_grad_(False)
    B, N, D, C = residuals.shape

    # -- compute losses --
    results = pd.DataFrame({'indices': [], 'loss': []})
    for b in range(B):
        for n in range(N):
            r = residuals[b, n]
            noise = torch.normal(torch.zeros_like(r), std=std / 255.)
            M = torch.sum(torch.pow(r.unsqueeze(1) - noise, 2), dim=-1)
            loss = sink_stabilized(M, reg)
            result = {'indices': (b, n), 'loss': loss}
            results = results.append(result, ignore_index=True)
    return results
Esempio n. 6
0
 def ot_frame_pairwise_xbatch_bp(self,residuals,reg=0.5,K=3):
     """
     :param residuals: shape [B N D C]
     """
     
     # -- init --
     B,N,D,C = residuals.shape
 
     # -- create triplets
     S = B*K
     indices,S = create_ot_indices(B,N,S)
 
     # -- compute losses --
     ot_loss = 0
     for (bi,bj,i,j) in indices:
         ri,rj = residuals[bi,i],residuals[bj,j]
         M = torch.sum(torch.pow(ri.unsqueeze(1) - rj,2),dim=-1)
         loss = sink_stabilized(M,reg)
         weight = ( torch.mean(ri) + torch.mean(rj) ) / 2
         ot_loss += loss * weight.item()
     return ot_loss / len(indices)
Esempio n. 7
0
def ot_all_pairwise_items(residuals_raw,reg=1.0):
    """
    :param residuals: shape [B N D C]
    """
    
    # -- init --
    residuals = residuals_raw.detach().requires_grad_(False)
    B,N,D,C = residuals.shape

    results = pd.DataFrame({'indices':[],'loss':[]})
    for bi in range(B):
        for bj in range(B):
            if bi > bj: continue
            for ni in range(N):
                for nj in range(N):
                    if ni > nj: continue
                    ri,rj = residuals[bi,ni],residuals[bj,nj]
                    M = torch.sum(torch.pow(ri.unsqueeze(1) - rj,2),dim=-1)                    
                    loss = sink_stabilized(M,reg)
                    result = {'indices':(bi,bj,ni,nj),'loss':loss}
                    results = results.append(result,ignore_index=True)
    return results
Esempio n. 8
0
def ot_pairwise_bp(residuals,reg=1.0,K=3):
    """
    :param residuals: shape [B N D C]
    """
    
    # -- init --
    B,N,D,C = residuals.shape

    # -- compute all ot --
    results = ot_all_pairwise_items(residuals,reg=1.0)

    # -- create triplets -- 
    indices = get_ot_topK(results,K)

    # -- compute losses --
    ot_loss = 0
    for (bi,bj,i,j) in indices:
        ri,rj = residuals[bi,i],residuals[bj,j]
        M = torch.sum(torch.pow(ri.unsqueeze(1) - rj,2),dim=-1)
        loss = sink_stabilized(M,reg)
        ot_loss += loss
    return ot_loss / len(indices)
Esempio n. 9
0
def ot_pairwise2gaussian_bp(residuals, std=25, reg=1.0, K=3):
    """
    :param residuals: shape [B N D C]
    """

    # -- init --
    B, N, D, C = residuals.shape

    # -- compute all ot --
    results = ot_all_gaussian_items(residuals, std=std, reg=reg)

    # -- create triplets --
    indices = get_ot_topK(results, K)

    # -- compute losses --
    ot_loss = 0
    for (b, n) in indices:
        r = residuals[b, n]
        noise = torch.normal(torch.zeros_like(r), std=std / 255)
        M = torch.sum(torch.pow(r.unsqueeze(1) - noise, 2), dim=-1)
        loss = sink_stabilized(M, reg, print_me=False, print_period=5)
        ot_loss += loss
    return ot_loss / len(indices)
Esempio n. 10
0
def ot_frame_pairwise(residuals,reg=0.5):
    """
    :paraam residuals: shape [B N D C]
    """
    
    # -- init --
    B,N,D,C = residuals.shape

    # -- create pairs --
    pairs,S = select_pairs(N)

    # -- compute losses --
    ot_loss,ot_loss_mid,ot_loss_w = 0,0,0
    for b in range(B):
        for (i,j) in pairs:
            ri,rj = residuals[b,i],residuals[b,j]
            M = torch.sum(torch.pow(ri.unsqueeze(1) - rj,2),dim=-1)
            loss = sink_stabilized(M,reg)
            ot_loss += loss
            weight = (torch.sum(ri**2) + torch.sum(rj**2)).item()
            ot_loss_w += weight * loss
            if i == N//2 or j == N//2:
                ot_loss_mid += loss
    return ot_loss,ot_loss_mid,ot_loss_w
Esempio n. 11
0
def train_loop_offset(cfg, model, optimizer, criterion, train_loader, epoch,
                      record_losses):
    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    D = 5 * 10**3
    if record_losses is None:
        record_losses = pd.DataFrame({
            'kpn': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # if cfg.N != 5: return
    switch = True
    for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader):
        if batch_idx > D: break

        optimizer.zero_grad()
        model.zero_grad()

        # fig,ax = plt.subplots(figsize=(10,10))
        # imgs = burst_imgs + 0.5
        # imgs.clamp_(0.,1.)
        # raw_img = raw_img.expand(burst_imgs.shape)
        # print(imgs.shape,raw_img.shape)
        # all_img = torch.cat([imgs,raw_img],dim=1)
        # print(all_img.shape)
        # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)]
        # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids]
        # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
        # Writer = animation.writers['ffmpeg']
        # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)
        # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer)
        # print("I DID IT!")
        # return

        # -- reshaping of data --
        # raw_img = raw_img.cuda(non_blocking=True)
        input_order = np.arange(cfg.N)
        # print("pre",input_order,cfg.blind,cfg.N)
        middle_img_idx = -1
        if not cfg.input_with_middle_frame:
            middle = len(input_order) // 2
            # print(middle)
            middle_img_idx = input_order[middle]
            # input_order = np.r_[input_order[:middle],input_order[middle+1:]]
        else:
            middle = len(input_order) // 2
            input_order = np.arange(cfg.N)
            middle_img_idx = input_order[middle]
            # input_order = np.arange(cfg.N)
        # print("post",input_order,cfg.blind,cfg.N,middle_img_idx)

        N, BS, C, H, W = burst_imgs.shape
        burst_imgs = burst_imgs.cuda(non_blocking=True)
        middle_img = burst_imgs[middle_img_idx]
        # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)])
        # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1)
        # print("stacked_burst",stacked_burst.shape)
        # print("burst_imgs.shape",burst_imgs.shape)
        # print("stacked_burst.shape",stacked_burst.shape)

        # -- add input noise --
        burst_imgs_noisy = burst_imgs.clone()
        if cfg.input_noise:
            noise = np.random.rand() * cfg.input_noise_level
            if cfg.input_noise_middle_only:
                burst_imgs_noisy[middle_img_idx] = torch.normal(
                    burst_imgs_noisy[middle_img_idx], noise)
            else:
                burst_imgs_noisy = torch.normal(burst_imgs_noisy, noise)

        # -- create inputs for kpn --
        stacked_burst = torch.stack(
            [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],
            dim=1)
        cat_burst = torch.cat(
            [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],
            dim=1)
        # print(stacked_burst.shape)
        # print(cat_burst.shape)

        # -- extract target image --
        if cfg.blind:
            t_img = burst_imgs[middle_img_idx]
        else:
            t_img = szm(raw_img.cuda(non_blocking=True))

        # -- direct denoising --
        rec_img_i, rec_img = model(cat_burst, stacked_burst)

        # rec_img = burst_imgs[middle_img_idx] - rec_res

        # -- compare with stacked burst --
        # print(cfg.blind,t_img.min(),t_img.max(),t_img.mean())
        # rec_img = rec_img.expand(t_img.shape)
        # loss = F.mse_loss(t_img,rec_img)

        # -- compute mse to optimize --
        mse_loss = F.mse_loss(rec_img, t_img)

        # -- compute kpn loss to optimize --
        kpn_losses = criterion(rec_img_i, rec_img, t_img, cfg.global_step)
        kpn_loss = np.sum(kpn_losses)

        # -- compute blockwise differences --
        rec_img_i_bn = rearrange(rec_img_i, 'b n c h w -> (b n) c h w')
        r_middle_img = t_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        r_middle_img = rearrange(r_middle_img, 'b n c h w -> (b n) c h w')
        diffs = r_middle_img - rec_img_i_bn
        # diffs = rearrange(unfold(diffs),'(b n) (c i) r -> b n r (c i)',b=BS,c=3)

        # -- compute OT loss --
        # mse_loss = torch.mean(torch.pow(diffs,2))
        diffs = rearrange(diffs, '(b n) c h w -> b n (h w) c', n=N)
        ot_loss = 0
        #skip_middle = i != N//2 and j != N//2
        pairs = list(set([(i, j) for i in range(N) for j in range(N)
                          if i < j]))
        P = len(pairs)
        S = 3  #P
        r_idx = npr.choice(range(P), S)
        for idx in r_idx:
            i, j = pairs[idx]
            if i >= j: continue
            # assert BS == 1, "batch size must be one right now."
            for b in range(BS):
                di, dj = diffs[b, i], diffs[b, j]
                M = torch.sum(torch.pow(di.unsqueeze(1) - dj, 2), dim=-1)
                ot_loss += sink_stabilized(M, 0.5)
        ot_loss /= S * BS

        # M = torch.mean(torch.pow(diffs.unsqueeze(1) - diffs,2),dim=2)
        # ot_loss = sink(M, 0.5)

        # -- compute stats for each block --
        # mean_est = torch.mean(diffs, dim=(1,2,3), keepdim=True)
        # std_est = torch.pow( diffs - mean_est, 2)
        # # mse_loss = F.mse_loss(r_middle_img,rec_img_i_bn,reduction='none')
        # std_est = torch.flatten(torch.mean( std_est, dim=(1,2,3) ))
        # # dist_loss = torch.norm(std_est.unsqueeze(1) - std_est)

        # # -- flatten and compare each block stats --
        # dist_loss = 0
        # mean_est = torch.flatten(mean_est)
        # std_est = torch.flatten(std_est)
        # M = mean_est.shape[0]
        # for i in range(M):
        #     for j in range(M):
        #         if i >= j: continue
        #         si,sj = std_est[i],std_est[j]
        #         dist_loss += torch.abs(mean_est[i] - mean_est[j])
        #         dist_loss += torch.abs(si + sj - 2 * (si * sj)**0.5)

        # -- combine loss --
        # print(kpn_loss.item(),10**3 * ot_loss.item(),ot_loss.item() / (1 + mse_loss.item()))
        # loss = kpn_loss + 10**4 * ot_loss / (1 + mse_loss.item())
        alpha, beta = criterion.loss_anneal.alpha, criterion.loss_anneal.beta
        ot_coeff = 10
        # loss = kpn_loss
        loss = kpn_loss + ot_coeff * ot_loss  # / (1 + mse_loss.item())
        # print(kpn_loss.item(), 10**4 * ot_loss.item() / (1 + mse_loss.item()))

        # loss = mse_loss + ot_loss / (1 +  mse_loss.item())
        # if batch_idx % 100 == 0 or switch: switch = not switch
        # if switch:
        #     loss += kpn_loss# + ot_loss / (1 + kpn_loss.item())
        #     # loss = kpn_loss + ot_loss / (1 + kpn_loss.item())
        # print(ot_loss.item(),mse_loss.item(),kpn_loss.item(),loss.item())

        # -- update info --
        running_loss += loss.item()
        total_loss += loss.item()

        # -- BP and optimize --
        loss.backward()
        optimizer.step()

        if True:
            # -- compute mse for fun --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            mse_loss = F.mse_loss(raw_img, rec_img + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))
            record_losses = record_losses.append(
                {
                    'kpn': kpn_loss.item(),
                    'ot': ot_loss.item(),
                    'psnr': psnr,
                    'psnr_std': psnr_std
                },
                ignore_index=True)
            running_loss /= cfg.log_interval
            if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:
                print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f" %
                      (epoch, cfg.epochs, batch_idx, len(train_loader),
                       running_loss, psnr, psnr_std))
            running_loss = 0
        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record_losses
Esempio n. 12
0
def compute_pair_ot(ri, rj, reg):
    M = torch.sum(torch.pow(ri.unsqueeze(1) - rj, 2), dim=-1)
    loss = sink_stabilized(M, reg)
    return loss