예제 #1
0
PILimg2.save(join(savedir, "eigvect_sph_fin_trunc%.1f_%04d.jpg" % (1, RND)))
print("Spent time %.1f sec" % (time() - T00))

#%%
T00 = time()
truncation = 0.8
truncation_mean = 4096
RND = np.random.randint(10000)
mean_latent = g_ema.mean_latent(truncation_mean)
ref_z = torch.randn(1, latent, device=device).cuda()
mov_z = ref_z.detach().clone().requires_grad_(
    True)  # requires grad doesn't work for 1024 images.
ref_samp = G.visualize(ref_z, truncation=truncation, mean_latent=mean_latent)
mov_samp = G.visualize(mov_z, truncation=truncation, mean_latent=mean_latent)
dsim = ImDist(ref_samp, mov_samp)
H = get_full_hessian(dsim, mov_z)
print(time() - T00, " sec")

#%%
savedir = r"E:\OneDrive - Washington University in St. Louis\HessGANCmp\StyleGAN2"
truncation = 0.5
T00 = time()
for triali in range(3):
    for truncation in [1, 0.8, 0.6]:
        if truncation == 1 and triali == 0:
            continue
        T00 = time()
        truncation_mean = 4096
        RND = np.random.randint(1000)
        mean_latent = g_ema.mean_latent(truncation_mean)
        ref_z = torch.randn(1, latent, device=device).cuda()
            num_eigenthings=800,
            use_gpu=True,
            max_steps=200,
            tol=1e-6,
        )
        eigvects = eigvects.T
        # EPS=1E-2, max_steps=20 takes 84 sec on K20x cluster.
        # The hessian is not so close
    elif hessian_method == "BP":  # 240 sec on cluster
        ref_vect = feat.detach().clone().float().cuda()
        mov_vect = ref_vect.float().detach().clone().requires_grad_(True)
        imgs1 = G.visualize(ref_vect)
        imgs2 = G.visualize(mov_vect)
        dsim = model_squ(imgs1, imgs2)
        H = get_full_hessian(
            dsim, mov_vect
        )  # 122 sec for a 256d hessian, # 240 sec on cluster for 4096d hessian
        eigvals, eigvects = np.linalg.eigh(H)

    print(
        "Finish computing img %d %.2f sec passed, max %.2e min %.2e 5th %.1e 10th %.1e 50th %.1e 100th %.1e 200th "
        "%.1e 400th %.1e" %
        (imgi, time() - t0, max(np.abs(eigvals)), min(
            np.abs(eigvals)), eigvals[-5], eigvals[-10], eigvals[-50],
         eigvals[-100], eigvals[-200], eigvals[-400]))
    np.savez(join(
        out_dir,
        "%s_%03d_%s.npz" % (args.dataset, imgi, labeldict[hessian_method])),
             eigvals=eigvals,
             eigvects=eigvects,
             code=code)
예제 #3
0
#%% Compute Hessian decomposition and get the vectors
Hess_method = "BP"  # "BackwardIter" "ForwardIter"
Hess_all = False # Set to False to reduce computation time. 
t0 = time()
if Hess_method == "BP":
    print("Computing Hessian Decomposition Through auto-grad and full eigen decomposition.")
    classvec = torch.from_numpy(ref_class_vec).float().cuda()  # embed_mat[:, class_id:class_id+1].cuda().T
    noisevec = torch.from_numpy(ref_noise_vec).float().cuda()
    ref_vect = torch.cat((noisevec, classvec, ), dim=1).detach().clone()
    mov_vect = ref_vect.detach().clone().requires_grad_(True)
    #%
    imgs1 = G.visualize(ref_vect)
    if Hess_all:
        imgs2 = G.visualize(mov_vect)
        dsim = ImDist(imgs1, imgs2)
        H = get_full_hessian(dsim, mov_vect)  # 77sec to compute a Hessian. # 114sec on ML2a
        # ToPILImage()(imgs[0,:,:,:].cpu()).show()
        eigvals, eigvects = np.linalg.eigh(H)  # 75 ms
    #%
    noisevec.requires_grad_(True)
    classvec.requires_grad_(False)
    mov_vect = torch.cat((noisevec, classvec, ), dim=1)
    imgs2 = G.visualize(mov_vect)
    dsim = ImDist(imgs1, imgs2)
    H_nois = get_full_hessian(dsim, noisevec)  # 39.3 sec to compute a Hessian.# 59 sec on ML2a
    eigvals_nois, eigvects_nois = np.linalg.eigh(H_nois)  # 75 ms
    #%
    noisevec.requires_grad_(False)
    classvec.requires_grad_(True)
    mov_vect = torch.cat((noisevec, classvec, ), dim=1)
    imgs2 = G.visualize(mov_vect)
예제 #4
0
# H = get_full_hessian()
#%%
savedir = r"E:\iclr2021\Results"
savedir = r"E:\OneDrive - Washington University in St. Louis\HessGANCmp"
#%%
T00 = time()
for class_id in [17, 79, 95, 107, 224, 346, 493, 542, 579, 637, 667, 754, 761, 805, 814, 847, 856, 941, 954, 968]:
    classvec = embed_mat[:, class_id:class_id+1].cuda().T
    noisevec = torch.from_numpy(truncated_noise_sample(1, 128, 0.6)).cuda()
    ref_vect = torch.cat((noisevec, classvec, ), dim=1).detach().clone()
    mov_vect = ref_vect.detach().clone().requires_grad_(True)
    #%%
    imgs1 = G.visualize(ref_vect)
    imgs2 = G.visualize(mov_vect)
    dsim = ImDist(imgs1, imgs2)
    H = get_full_hessian(dsim, mov_vect)  # 77sec to compute a Hessian.
    # ToPILImage()(imgs[0,:,:,:].cpu()).show()
    eigvals, eigvects = np.linalg.eigh(H)  # 75 ms
    #%%
    noisevec.requires_grad_(True)
    classvec.requires_grad_(False)
    mov_vect = torch.cat((noisevec, classvec, ), dim=1)
    imgs2 = G.visualize(mov_vect)
    dsim = ImDist(imgs1, imgs2)
    H_nois = get_full_hessian(dsim, noisevec)  # 39.3 sec to compute a Hessian.
    eigvals_nois, eigvects_nois = np.linalg.eigh(H_nois)  # 75 ms
    #%
    noisevec.requires_grad_(False)
    classvec.requires_grad_(True)
    mov_vect = torch.cat((noisevec, classvec, ), dim=1)
    imgs2 = G.visualize(mov_vect)
예제 #5
0
        16,
        17,
        18,
        19,
        20,
        21,
        22,
        23,
]:  #
    L2dist_col = []
    torch.cuda.empty_cache()
    H1 = G.G[Li].register_forward_hook(Hess_hook)
    img = G.visualize(feat)
    H1.remove()
    T0 = time()
    H10 = get_full_hessian(L2dist_col[0], feat)
    eva10, evc10 = np.linalg.eigh(H10)
    print("Layer %d, cost %.2f sec" % (Li, time() - T0))
    #%
    np.savez(join(archdir, "eig_Layer%d.npz" % (Li)), evc=evc10, eva=eva10)
    plt.plot(np.log10(eva10)[::-1])
    plt.title("Layer %d %s\n%s" % (Li, layernames[Li], G.G[Li].__repr__()))
    plt.xlim([0, 4096])
    plt.savefig(join(archdir, "spectrum_Layer%d.png" % (Li)))
    plt.show()
#%%
eva_col = []
for Li in [
        0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
        21, 22, 23
]:
예제 #6
0
def hessian_compute(G,
                    feat,
                    ImDist,
                    hessian_method="BackwardIter",
                    cutoff=None,
                    preprocess=lambda img: img,
                    EPS=1E-2,
                    device="cuda"):
    """Higher level API for GAN hessian compute
    Parameters:
        G: GAN, usually wrapped up by a custom class. Equipped with a `visualize` function that takes a torch vector and
           output a torch image
        feat: a latent code as input to the GAN.
        ImDist: the image distance function. Support dsim = ImDist(img1, img2). takes in 2 torch images and output a
           scalar distance. Pass gradient.
       hessian_method: Currently, "BP" "ForwardIter" "BackwardIter" are supported
       preprocess: or post processing is the operation on the image generated by GAN. Default to be an identity map.
            `lambda img: F.interpolate(img, (256, 256), mode='bilinear', align_corners=True)` is a common choice.
        cutoff: For iterative methods, "ForwardIter" "BackwardIter" this specify how many eigenvectors it's going to
            compute.
    """
    if cutoff is None: cutoff = feat.numel() // 2 - 1
    if 'to' in dir(ImDist): ImDist.to(device)
    if hessian_method == "BackwardIter":
        metricHVP = GANHVPOperator(G, feat, ImDist, preprocess=preprocess)
        eigvals, eigvects = lanczos(
            metricHVP, num_eigenthings=cutoff,
            use_gpu=True)  # takes 113 sec on K20x cluster,
        eigvects = eigvects.T  # note the output shape from lanczos is different from that of linalg.eigh, row is eigvec
        H = eigvects @ np.diag(eigvals) @ eigvects.T
        # the spectrum has a close correspondance with the full Hessian. since they use the same graph.
    elif hessian_method == "ForwardIter":
        metricHVP = GANForwardMetricHVPOperator(G,
                                                feat,
                                                ImDist,
                                                preprocess=preprocess,
                                                EPS=EPS)  # 1E-3,)
        eigvals, eigvects = lanczos(
            metricHVP,
            num_eigenthings=cutoff,
            use_gpu=True,
            max_steps=200,
            tol=1e-6,
        )
        eigvects = eigvects.T
        H = eigvects @ np.diag(eigvals) @ eigvects.T
        # EPS=1E-2, max_steps=20 takes 84 sec on K20x cluster.
        # The hessian is not so close
    elif hessian_method == "BP":  # 240 sec on cluster
        ref_vect = feat.detach().clone().float().to(device)
        mov_vect = ref_vect.float().detach().clone().requires_grad_(True)
        imgs1 = G.visualize(ref_vect)
        imgs2 = G.visualize(mov_vect)
        dsim = ImDist(preprocess(imgs1), preprocess(imgs2))
        H = get_full_hessian(
            dsim, mov_vect
        )  # 122 sec for a 256d hessian, # 240 sec on cluster for 4096d hessian
        eigvals, eigvects = np.linalg.eigh(H)
    else:
        raise NotImplementedError
    return eigvals, eigvects, H
예제 #7
0
savedir = r"E:\OneDrive - Washington University in St. Louis\HessGANCmp\BigBiGAN"
#%%
T00 = time()
for triali in range(20):
    for trunc in [0.1, 1, 3, 6, 9, 10, 12, 15]:
        if trunc == 0.1:
            continue
        RND = np.random.randint(1000)
        noisevect = torch.randn(1, 120)
        noisevect = noisevect / noisevect.norm()
        ref_vect = trunc * noisevect.detach().clone().cuda()
        mov_vect = ref_vect.detach().clone().requires_grad_(True)
        imgs1 = G.visualize(ref_vect)
        imgs2 = G.visualize(mov_vect)
        dsim = ImDist(imgs1, imgs2)
        H = get_full_hessian(dsim, mov_vect)  # 77sec to compute a Hessian.
        # ToPILImage()(imgs[0,:,:,:].cpu()).show()
        eigvals, eigvects = np.linalg.eigh(H)
        plt.figure(figsize=[7, 5])
        plt.subplot(1, 2, 1)
        plt.plot(eigvals)
        plt.ylabel("eigenvalue")
        plt.subplot(1, 2, 2)
        plt.plot(np.log10(eigvals))
        plt.ylabel("eigenvalue (log)")
        plt.suptitle("Hessian Spectrum Full Space")
        plt.savefig(join(savedir, "Hessian_norm%d_%03d.jpg" % (trunc, RND)))
        np.savez(
            join(savedir, "Hess_norm%d_%03d.npz" % (trunc, RND)),
            H=H,
            eigvals=eigvals,