Example #1
0
import matplotlib.pylab as plt
import torch, numpy as np
from os.path import join
import os
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import matplotlib.pylab as plt
from torch_utils import show_imgrid, save_imgrid

saveroot = r"E:\OneDrive - Washington University in St. Louis\GAN_baddots"
#%% Progressive Growing GAN
modelsnm = "PGGAN"
savedir = join(saveroot, modelsnm)
os.makedirs(savedir, exist_ok=True)
PGAN = loadPGGAN()
G = PGGAN_wrapper(PGAN)
H, eva, evc = load_Haverage("PGGAN", descend=True)
randvec = G.sample_vector(10, 'cpu')
distnorm = randvec.norm(dim=1).mean().item()
# distnorm = np.sqrt(randvec.shape[1])
#%
baddot_imgs_pos = G.visualize_batch_np(distnorm * evc[:, :20].T)
show_imgrid(
    baddot_imgs_pos,
    nrow=5,
    padding=2,
)
save_imgrid(
    baddot_imgs_pos,
    join(savedir, "GAN_baddots_pos.png"),
    nrow=5,
Example #2
0
            eiglist=list(range(0,20)), maxdist=1E-3, rown=5, sphere=False, scaling=eva_avg[-1:-21:-1]**(-1/3))
#%%
refvec = np.random.randn(1, 512)
_=vis_eigen_explore(refvec, evc_avg, eva_avg, SG, figdir=figsumdir, namestr="StyleGAN_sphexpon", transpose=False,
            eiglist=list(range(0, 20)), maxdist=0.4, rown=5, sphere=True, scaling=eva_avg[-1:-21:-1]**(-1/3))
#%%
refvec = np.random.randn(1, 512)
_=vis_eigen_explore(refvec, evc_avg, eva_avg, SG, figdir=figsumdir, namestr="StyleGAN_expon", transpose=False,
            eiglist=list(range(0,20)), maxdist=8, rown=5, sphere=False, scaling=eva_avg[-1:-21:-1]**(-1/3))

#%% ProgGAN
"""Progressive Growing GAN! PG"""
Hessdir = join(rootdir, 'PGGAN')
modelnm = "PGGAN"
PGAN = loadPGGAN()
PG = PGGAN_wrapper(PGAN)
with np.load(join(Hessdir, "H_avg_%s.npz"%modelnm)) as data:
    eva_avg, evc_avg, = data["eva_avg"], data["evc_avg"], #feat_col = data["feat_col"]
#%%
refvec = torch.randn(1, 512).cuda()
img = PG.visualize(refvec)
ToPILImage()(make_grid(img).cpu())
#%%
# vis_eigen_frame(evc_avg, eva_avg, PG, ref_code=refvec, figdir=figsumdir, namestr="PGGAN", transpose=False,
#                     eiglist=None, eig_rng=(0, 20), maxdist=1, rown=5, sphere=False, )
#%% Linear exploration from origin
_=vis_eigen_explore(np.zeros((1, 512)), evc_avg, eva_avg, PG, figdir=figsumdir, namestr="PGGAN_orig_expon", transpose=False,
            eiglist=list(range(0,20)), maxdist=1E-4, rown=5, sphere=False, scaling=eva_avg[-1:-21:-1]**(-1/2.5))
#%% Linear exploration from origin
_=vis_eigen_explore(np.zeros((1, 512)), evc_avg, eva_avg, PG, figdir=figsumdir, namestr="PGGAN_orig_expon", transpose=False,
            eiglist=list(range(0,20)), maxdist=.5E-4, rown=5, sphere=False, scaling=eva_avg[-1:-21:-1]**(-1/3))
Example #3
0
                                "reg_log_slop", "reg_log_intcp", "H_cc",
                                "logH_cc"
                            ])
MSE_stat_tab.to_csv(join(savedir, "H_cmp_MSE_stat.csv"))
SSIM_stat_tab = pd.DataFrame(SSIM_stat_col,
                             columns=[
                                 "id", "cc", "logcc", "reg_slop", "reg_intcp",
                                 "reg_log_slop", "reg_log_intcp", "H_cc",
                                 "logH_cc"
                             ])
SSIM_stat_tab.to_csv(join(savedir, "H_cmp_SSIM_stat.csv"))
del G, BGAN
torch.cuda.empty_cache()
#%%
PGGAN = loadPGGAN()
G = PGGAN_wrapper(PGGAN)
savedir = join(saveroot, "ImDist_cmp\\PGGAN")
os.makedirs(savedir, exist_ok=True)
SSIM_stat_col = []
MSE_stat_col = []
for idx in range(100):
    refvec = G.sample_vector(1, device="cuda")  # 0.1 * torch.randn(1, 256)
    eigvals_PS, eigvects_PS, H_PS = hessian_compute(G,
                                                    refvec,
                                                    ImDist,
                                                    hessian_method="BP")
    eigvals_SSIM, eigvects_SSIM, H_SSIM = hessian_compute(G,
                                                          refvec,
                                                          D,
                                                          hessian_method="BP")
    eigvals_MSE, eigvects_MSE, H_MSE = hessian_compute(G,
Example #4
0
#%%
PGGAN = loadPGGAN()
SD = PGGAN.state_dict()
#%%
shuffled_SD = {}
for name, Weight in SD.items():
    idx = torch.randperm(Weight.numel())
    W_shuf = Weight.view(-1)[idx].view(Weight.shape)
    shuffled_SD[name] = W_shuf
#%%
torch.save(shuffled_SD, join(datadir, "PGGAN_shuffle.pt"))
    # print(name, Weight.shape, Weight.mean().item(), Weight.std().item())
#%%
PGGAN_sf = loadPGGAN()
PGGAN_sf.load_state_dict(torch.load(join(datadir, "PGGAN_shuffle.pt")))
G_sf = PGGAN_wrapper(PGGAN_sf)
#%%
img = G_sf.visualize(torch.randn(1,512).cuda()).cpu()
ToPILImage()(img[0,:].cpu()).show()
#%%
def Hess_hook(module, fea_in, fea_out):
    print("hooker on %s"%module.__class__)
    ref_feat = fea_out.detach().clone()
    ref_feat.requires_grad_(False)
    L2dist = torch.pow(fea_out - ref_feat, 2).sum()
    L2dist_col.append(L2dist)
    return None

savedir = r"E:\OneDrive - Washington University in St. Louis\HessNetArchit\PGGAN\ctrl_Hessians"
os.makedirs(savedir, exist_ok=True)
for triali in tqdm(range(0, 100)):
Example #5
0
import torch, numpy as np
from os.path import join
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import matplotlib.pylab as plt
from sklearn.decomposition import TruncatedSVD
from scipy.sparse.linalg import svds, eigs
from sklearn.linear_model import LinearRegression
from scipy.spatial.distance import cdist, pdist
from load_hessian_data import load_Haverage
"""The major challenge is memory"""
figroot = r"E:\OneDrive - Washington University in St. Louis\GAN_PCA"
#%% PGGAN
modelsnm = "PGGAN"
PGAN = loadPGGAN()
G = PGGAN_wrapper(PGAN)
savedir = join(figroot, "PGGAN")
#%%
vecn = 4000
codes = np.random.randn(vecn, 512)
imgs = G.visualize_batch_np(codes)
#% Compute sparse SVD
U, S, VH = svds(imgs.view(vecn, -1).numpy(), k=30)
#%%
for PCi in range(30):
    img_mean = VH[-1, :].reshape((3, 256, 256)).transpose((1,2,0)) * np.sign(VH[-1, :].mean()) # sign may be inverted
    img_dev = VH[-PCi, :].reshape((3, 256, 256)).transpose((1,2,0))
    plt.figure(figsize=[7,3])
    plt.subplot(1, 3, 1)
    plt.imshow((img_mean-0.4*img_dev)*255)
    plt.axis(False)