def __init__(self, config): super().__init__() self.config = config if config['image_size'][0] != config['image_size'][1]: raise Exception('Non-square images are not supported yet.') device = config['device'] self.downsampler_1024_256 = BicubicDownSample(4) self.downsampler_1024_image = BicubicDownSample( 1024 // config['image_size'][0]) self.downsampler_image_256 = BicubicDownSample( config['image_size'][0] // 256) # Load models and pre-trained weights gen = Generator(1024, 512, 8) gen.load_state_dict(torch.load(config["ckpt"])["g_ema"], strict=False) gen.eval() self.gen = gen.to(device) self.gen.start_layer = config['start_layer'] self.gen.end_layer = config['end_layer'] self.mpl = MappingProxy(torch.load('gaussian_fit.pt')) self.percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda")) self.init_state()
def __init__(self, args): self.args = args self.init_vars() self.percept = lpips.PerceptualLoss( model='net-lin', net='vgg', use_gpu=self.args.device.startswith('cuda') )
def __init__(self, cuda, des="Learned Perceptual Image Patch Similarity", version="0.1"): self.des = des self.version = version self.model = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=cuda)
def __init__(self, config): super().__init__() self.config = config if config['image_size'][0] != config['image_size'][1]: raise Exception('Non-square images are not supported yet.') self.reconstruction = config["reconstruction"] self.steps = config["steps"] self.lr = config['lr'] # self.mse_record = [] transform = get_transformation(config['image_size'][0]) images = [] for imgfile in config['input_files']: images.append(transform(Image.open(imgfile).convert("RGB"))) self.images = torch.stack(images, 0).to(config['device']) self.downsampler_256_image = BicubicDownSample(256 // config['image_size'][0]) biggan = BigGAN(config) biggan.load_pretrained() self.generator = biggan.generator self.generator.eval() self.generator.to(config['device']) (self.z, y) = prepare_z_y(self.images.shape[0], self.generator.dim_z, config["n_classes"], device=config["device"], fp16=config["G_fp16"], z_var=config["z_var"], target=config["target"], range=config["range"]) self.y = self.generator.shared(y) self.z.requires_grad = True self.perceptual_loss = lpips.PerceptualLoss( model="net-lin", net="vgg", use_gpu=config['device'].startswith("cuda"))
import sys from os.path import join from time import time import matplotlib.pylab as plt from matplotlib import cm import torch import numpy as np sys.path.append("E:\Github_Projects\Visual_Neuro_InSilico_Exp") sys.path.append("D:\Github\Visual_Neuro_InSilico_Exp") import os # os.system(r"'C:\Program Files (x86)\Microsoft Visual Studio\2019\Professional\VC\Auxiliary\Build\vcvars64.bat'") import lpips try: ImDist = lpips.LPIPS(net="squeeze").cuda() except: ImDist = lpips.PerceptualLoss(net="squeeze").cuda() from GAN_hessian_compute import hessian_compute, get_full_hessian from hessian_analysis_tools import compute_vector_hess_corr, compute_hess_corr, plot_layer_consistency_mat from GAN_utils import loadStyleGAN2, StyleGAN2_wrapper #%% modelname = "ffhq-256-config-e-003810" # 109 sec SGAN = loadStyleGAN2(modelname+".pt", size=256, channel_multiplier=1) # 491 sec per BP #%% L2dist_col = [] def Hess_hook(module, fea_in, fea_out): print("hooker on %s"%module.__class__) ref_feat = fea_out.detach().clone() ref_feat.requires_grad_(False) L2dist = torch.pow(fea_out - ref_feat, 2).sum() L2dist_col.append(L2dist) return None
return image def saveImage(self, image, name): ####save image utils.save_image( image, name + ".png", nrow=int(1**0.5), normalize=True, range=(-1, 1), ) ####save image ends if __name__ == "__main__": loss_fn_vgg = lpips.PerceptualLoss(net='vgg') with torch.no_grad(): ckpt = "/common/users/sm2322/MS-Thesis/GAN-Thesis-Work-Remote/styleGAN2-AE-Ligong-Remote/trainedPts/gan/168000.pt" ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) device = "cuda" generator = Generator(128, 512, 8, 2).to(device) generator.load_state_dict(ckpt["g_ema"]) generator.eval() # print(list(generator.children())) z0, z1 = torch.randn([1 * 2, 512], device=device).chunk(2) # sample_z = torch.randn(1, 512, device=device) genBlocks = GeneratorBlocks(generator) w0, stackW0 = genBlocks.G_Mapper(z0) w1, stackW1 = genBlocks.G_Mapper(z1)
args.start_iter = int(ckpt_name.split('.')[0].split('-')[-1]) except ValueError: pass generator.load_state_dict(ckpt['g']) discriminator.load_my_state_dict(ckpt['d']) g_ema.load_state_dict(ckpt['g_ema']) del ckpt torch.cuda.empty_cache() percept = None if args.with_gan_feat_loss: import lpips percept = lpips.PerceptualLoss(model='net-lin', net='vgg') if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, # find_unused_parameters=True, ) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False,
def train(): from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm import lpips from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3 real_features = None inception = load_patched_inception_v3().cuda() inception.eval() percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) saved_image_folder = saved_model_folder = None log_file_path = None if saved_image_folder is None: saved_image_folder, saved_model_folder = make_folders( SAVE_FOLDER, 'GAN_' + TRIAL_NAME) log_file_path = saved_image_folder + '/../gan_log.txt' log_file = open(log_file_path, 'w') log_file.close() dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE_GAN, rand_crop=True) print('the dataset contains %d images.' % len(dataset)) dataloader = iter( DataLoader(dataset, BATCH_SIZE_GAN, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) from datasets import ImageFolder from datasets import trans_maker_augment as trans_maker dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512)) dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512)) net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS) if PRETRAINED_AE_PATH is None: PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE else: from config import PRETRAINED_AE_ITER PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER net_ae.load_state_dicts(PRETRAINED_AE_PATH) net_ae.cuda() net_ae.eval() RefineGenerator = None if DATA_NAME == 'celeba': from models import RefineGenerator_face as RefineGenerator elif DATA_NAME == 'art' or DATA_NAME == 'shoe': from models import RefineGenerator_art as RefineGenerator net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda() net_id = Discriminator(nc=3).cuda( ) # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024 if MULTI_GPU: net_ae = nn.DataParallel(net_ae) net_ig = nn.DataParallel(net_ig) net_id = nn.DataParallel(net_id) net_ig_ema = copy_G_params(net_ig) opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999)) if GAN_CKECKPOINT is not None: ckpt = torch.load(GAN_CKECKPOINT) net_ig.load_state_dict(ckpt['ig']) net_id.load_state_dict(ckpt['id']) net_ig_ema = ckpt['ig_ema'] opt_ig.load_state_dict(ckpt['opt_ig']) opt_id.load_state_dict(ckpt['opt_id']) ## create a log file losses_g_img = AverageMeter() losses_d_img = AverageMeter() losses_mse = AverageMeter() losses_rec_s = AverageMeter() losses_rec_ae = AverageMeter() fixed_skt = fixed_rgb = fixed_perm = None fid = [[0, 0]] for epoch in range(EPOCH_GAN): for iteration in tqdm(range(10000)): rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader) rgb_img = rgb_img.cuda() rd = random.randint(0, 3) if rd == 0: skt_img = skt_img_1.cuda() elif rd == 1: skt_img = skt_img_2.cuda() else: skt_img = skt_img_3.cuda() if iteration == 0: fixed_skt = skt_img_3[:8].clone().cuda() fixed_rgb = rgb_img[:8].clone() fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda') ### 1. train D gimg_ae, style_feats = net_ae(skt_img, rgb_img) g_image = net_ig(gimg_ae, style_feats) pred_r = net_id(rgb_img) pred_f = net_id(g_image.detach()) loss_d = d_hinge_loss(pred_r, pred_f) net_id.zero_grad() loss_d.backward() opt_id.step() loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss( gimg_ae, rgb_img) losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN) ### 2. train G pred_g = net_id(g_image) loss_g = g_hinge_loss(pred_g) if DATA_NAME == 'shoe': loss_mse = 10 * (F.l1_loss(g_image, rgb_img) + F.mse_loss(g_image, rgb_img)) else: loss_mse = 10 * percept( F.adaptive_avg_pool2d(g_image, output_size=256), F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN) loss_all = loss_g + loss_mse if DATA_NAME == 'shoe': ### the grey image reconstruction perm = true_randperm(BATCH_SIZE_GAN) img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm]) gimg_grey = net_ig(img_ae_perm, style_feats_perm) gimg_grey = gimg_grey.mean(dim=1, keepdim=True) real_grey = rgb_img.mean(dim=1, keepdim=True) loss_rec_grey = F.mse_loss(gimg_grey, real_grey) loss_all += 10 * loss_rec_grey net_ig.zero_grad() loss_all.backward() opt_ig.step() for p, avg_p in zip(net_ig.parameters(), net_ig_ema): avg_p.mul_(0.999).add_(p.data, alpha=0.001) ### 3. logging losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN) losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN) if iteration % SAVE_IMAGE_INTERVAL == 0: #show the current images with torch.no_grad(): backup_para_g = copy_G_params(net_ig) load_params(net_ig, net_ig_ema) gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb) gmatch = net_ig(gimg_ae, style_feats) gimg_ae_perm, style_feats = net_ae(fixed_skt, fixed_rgb[fixed_perm]) gmismatch = net_ig(gimg_ae_perm, style_feats) gimg = torch.cat([ F.interpolate(fixed_rgb, IM_SIZE_GAN), F.interpolate(fixed_skt.repeat(1, 3, 1, 1), IM_SIZE_GAN), gmatch, F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch, F.interpolate(gimg_ae_perm, IM_SIZE_GAN) ]) vutils.save_image( gimg, f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg', normalize=True, range=(-1, 1)) del gimg make_matrix( dataset_rgb, dataset_skt, net_ae, net_ig, 5, f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg' ) load_params(net_ig, backup_para_g) if iteration % LOG_INTERVAL == 0: log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f} D: {losses_d_img.avg:.4f} MSE: {losses_mse.avg:.4f} Rec: {losses_rec_s.avg:.5f} FID: {fid:.4f}'.format( epoch, iteration, losses_g_img=losses_g_img, losses_d_img=losses_d_img, losses_mse=losses_mse, losses_rec_s=losses_rec_s, fid=fid[-1][0]) print(log_msg) print('%.5f' % (losses_rec_ae.avg)) if log_file_path is not None: log_file = open(log_file_path, 'a') log_file.write(log_msg + '\n') log_file.close() losses_g_img.reset() losses_d_img.reset() losses_mse.reset() losses_rec_s.reset() losses_rec_ae.reset() if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000: print('Saving history model') torch.save( { 'ig': net_ig.state_dict(), 'id': net_id.state_dict(), 'ae': net_ae.state_dict(), 'ig_ema': net_ig_ema, 'opt_ig': opt_ig.state_dict(), 'opt_id': opt_id.state_dict(), }, '%s/%d.pth' % (saved_model_folder, epoch)) if iteration % FID_INTERVAL == 0 and iteration > 1: print("calculating FID ...") fid_batch_images = FID_BATCH_NBR if real_features is None: if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)): real_features = pickle.load( open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) else: real_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=fid_batch_images), inception) real_mean = np.mean(real_features, 0) real_cov = np.cov(real_features, rowvar=False) pickle.dump( { 'feats': real_features, 'mean': real_mean, 'cov': real_cov }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb')) real_features = pickle.load( open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) sample_features = extract_feature_from_generator_fn( image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images), inception, total=fid_batch_images) cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) sample_features_perm = extract_feature_from_generator_fn( image_generator_perm(dataset, net_ae, net_ig, n_batches=fid_batch_images), inception, total=fid_batch_images) cur_fid_perm = calc_fid(sample_features_perm, real_mean=real_features['mean'], real_cov=real_features['cov']) fid.append([cur_fid, cur_fid_perm]) print('fid:', fid) if log_file_path is not None: log_file = open(log_file_path, 'a') log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1]) log_file.write(log_msg + '\n') log_file.close()
def lerp(a, b, t): return a + (b - a) * t latent_dim = 512 ckpt = torch.load(args.ckpt) g = Generator(args.size, latent_dim, 8).to(device) g.load_state_dict(ckpt['g_ema']) g.eval() percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=device.startswith('cuda')) distances = [] n_batch = args.n_sample // args.batch resid = args.n_sample - (n_batch * args.batch) batch_sizes = [args.batch] * n_batch + [resid] with torch.no_grad(): for batch in tqdm(batch_sizes): noise = g.make_noise() inputs = torch.randn([batch * 2, latent_dim], device=device) lerp_t = torch.rand(batch, device=device) if args.space == 'w': latent = g.get_latent(inputs) latent_t0, latent_t1 = latent[::2], latent[1::2] latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) image, _ = g([latent_e], input_is_latent=True, noise=noise) if args.crop: c = image.shape[2] // 8 image = image[:, :, c * 3:c * 7, c * 2:c * 6] factor = image.shape[2] // 256 if factor > 1: image = F.interpolate(image, size=(256, 256), mode='bilinear', align_corners=False) dist = percept(image[::2], image[1::2]).view( image.shape[0] // 2) / (args.eps**2) distances.append(dist.to('cpu').numpy()) distances = np.concatenate(distances, 0) lo = np.percentile(distances, 1, interpolation='lower') hi = np.percentile(distances, 99, interpolation='higher') filtered_dist = np.extract( np.logical_and(lo <= distances, distances <= hi), distances) print('ppl:', filtered_dist.mean())
elif args.gan == "wgan-gp": discriminator = ResidualDiscriminator(1, dim=128).to(device) discriminator.eval() discriminator.load_state_dict(torch.load(args.ckpt)['d']) g_ema = ResidualGenerator(512, dim=128).to(device) g_ema.load_state_dict(torch.load(args.ckpt)['g_ema']) g_ema.eval() # encoder = Encoder(1, dim=128).to(device) # encoder.load_state_dict(torch.load(args.ckptE)) # encoder.eval() percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) def get_npy_idx(patient_n): # print(test_metadata[test_metadata["patient_n"].isin(patient_n)]) return np.arange(95)[test_metadata["patient_n"].isin(patient_n)] def get_lr_features( losses, query_images, anomaly_score=False, ): latents = losses["latents"]
import random import argparse from tqdm import tqdm import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' from models import weights_init, Discriminator, Generator from operation import copy_G_params, load_params, get_dir from operation import ImageFolder, InfiniteSamplerWrapper from diffaug import DiffAugment policy = 'color,translation' import lpips percept = lpips.PerceptualLoss( model='net-lin', net='vgg', use_gpu=True if torch.cuda.is_available() else False) #cccc #torch.backends.cudnn.benchmark = True def crop_image_by_part(image, part): hw = image.shape[2] // 2 if part == 0: return image[:, :, :hw, :hw] if part == 1: return image[:, :, :hw, hw:] if part == 2: return image[:, :, hw:, :hw] if part == 3: return image[:, :, hw:, hw:]
def train(): from config import IM_SIZE_AE, BATCH_SIZE_AE, NFC, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, TRIAL_NAME, LOG_INTERVAL from config import DATA_NAME from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3 dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE_AE, rand_crop=True) print(len(dataset)) dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \ sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) dataset_ss = SelfSupervisedDataset(data_root_colorful, data_root_sketch_3, im_size=IM_SIZE_AE, nbr_cls=NBR_CLS, rand_crop=True) print(len(dataset_ss), len(dataset_ss.frame)) dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \ sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True)) style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda() content_encoder = ContentEncoder(nfc=NFC).cuda() decoder = Decoder(nfc=NFC).cuda() opt_c = optim.Adam(content_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_s = optim.Adam(style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999)) style_encoder.reset_cls() style_encoder.final_cls.cuda() from config import PRETRAINED_AE_PATH, PRETRAINED_AE_ITER if PRETRAINED_AE_PATH is not None: PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER ckpt = torch.load(PRETRAINED_AE_PATH) print(PRETRAINED_AE_PATH) style_encoder.load_state_dict(ckpt['s']) content_encoder.load_state_dict(ckpt['c']) decoder.load_state_dict(ckpt['d']) opt_c.load_state_dict(ckpt['opt_c']) opt_s.load_state_dict(ckpt['opt_s']) opt_d.load_state_dict(ckpt['opt_d']) print('loaded pre-trained AE') style_encoder.reset_cls() style_encoder.final_cls.cuda() opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=2e-4, betas=(0.5, 0.999)) saved_image_folder, saved_model_folder = make_folders( SAVE_FOLDER, 'AE_' + TRIAL_NAME) log_file_path = saved_image_folder + '/../ae_log.txt' log_file = open(log_file_path, 'w') log_file.close() ## for logging losses_sf_consist = AverageMeter() losses_cf_consist = AverageMeter() losses_cls = AverageMeter() losses_rec_rd = AverageMeter() losses_rec_org = AverageMeter() losses_rec_grey = AverageMeter() import lpips percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) for iteration in tqdm(range(ITERATION_AE)): if iteration % ( (NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1: dataset_ss._next_set() dataloader_ss = iter( DataLoader(dataset_ss, BATCH_SIZE_AE, sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True)) style_encoder.reset_cls() opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_s.param_groups[0]['lr'] = 1e-4 opt_d.param_groups[0]['lr'] = 1e-4 ### 1. train the encoder with self-supervision methods rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next( dataloader_ss) rgb_img_rd = rgb_img_rd.cuda() rgb_img_org = rgb_img_org.cuda() img_idx = img_idx.cuda() skt_org = F.interpolate(skt_org, size=512).cuda() skt_bold = F.interpolate(skt_bold, size=512).cuda() skt_erased = F.interpolate(skt_erased, size=512).cuda() skt_erased_bold = F.interpolate(skt_erased_bold, size=512).cuda() style_encoder.zero_grad() decoder.zero_grad() content_encoder.zero_grad() style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd) style_vector_org, pred_cls_org = style_encoder(rgb_img_org) content_feats = content_encoder(skt_org) content_feats_bold = content_encoder(skt_bold) content_feats_erased = content_encoder(skt_erased) content_feats_eb = content_encoder(skt_erased_bold) rd = random.randint(0, 3) gimg_rd = None if rd == 0: gimg_rd = decoder(content_feats, style_vector_rd) elif rd == 1: gimg_rd = decoder(content_feats_bold, style_vector_rd) elif rd == 2: gimg_rd = decoder(content_feats_erased, style_vector_rd) elif rd == 3: gimg_rd = decoder(content_feats_eb, style_vector_rd) loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\ loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\ loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats) loss_sf_consist = 0 for loss_idx in range(3): loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \ F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean() loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy( pred_cls_org, img_idx) loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org) if DATA_NAME != 'shoe': loss_rec_rd += percept( F.adaptive_avg_pool2d(gimg_rd, output_size=256), F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum() else: loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org) loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist #+ loss_kl_c + loss_kl_s loss_total.backward() opt_s.step() opt_s_cls.step() opt_c.step() opt_d.step() ### 2. train as AutoEncoder rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader) rgb_img = rgb_img.cuda() rd = random.randint(0, 3) if rd == 0: skt_img = skt_img_1 elif rd == 1: skt_img = skt_img_2 else: skt_img = skt_img_3 skt_img = F.interpolate(skt_img, size=512).cuda() style_encoder.zero_grad() decoder.zero_grad() content_encoder.zero_grad() style_vector, _ = style_encoder(rgb_img) content_feats = content_encoder(skt_img) gimg = decoder(content_feats, style_vector) loss_rec_org = F.mse_loss(gimg, rgb_img) if DATA_NAME != 'shoe': loss_rec_org += percept( F.adaptive_avg_pool2d(gimg, output_size=256), F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() #else: # loss_rec_org += F.l1_loss(gimg, rgb_img) loss_rec = loss_rec_org if DATA_NAME == 'shoe': ### the grey image reconstruction perm = true_randperm(BATCH_SIZE_AE) gimg_perm = decoder(content_feats, [s[perm] for s in style_vector]) gimg_grey = gimg_perm.mean(dim=1, keepdim=True) real_grey = rgb_img.mean(dim=1, keepdim=True) loss_rec_grey = F.mse_loss(gimg_grey, real_grey) loss_rec += loss_rec_grey loss_rec.backward() opt_s.step() opt_d.step() opt_c.step() ### Logging losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE) losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE) losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE) losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE) losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE) if DATA_NAME == 'shoe': losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE) if iteration % LOG_INTERVAL == 0: log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f rec_org: %.4f cls: %.4f style_consist: %.4f content_consist: %.4f rec_grey: %.4f'%(losses_rec_rd.avg, \ losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg) print(log_msg) if log_file_path is not None: log_file = open(log_file_path, 'a') log_file.write(log_msg + '\n') log_file.close() losses_sf_consist.reset() losses_cls.reset() losses_rec_rd.reset() losses_rec_org.reset() losses_cf_consist.reset() losses_rec_grey.reset() if iteration % SAVE_IMAGE_INTERVAL == 0: vutils.save_image(torch.cat([ rgb_img_rd, F.interpolate(skt_org.repeat(1, 3, 1, 1), size=512), gimg_rd ]), '%s/rd_%d.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) if DATA_NAME != 'shoe': with torch.no_grad(): perm = true_randperm(BATCH_SIZE_AE) gimg_perm = decoder([c for c in content_feats], [s[perm] for s in style_vector]) vutils.save_image(torch.cat([ rgb_img, F.interpolate(skt_img.repeat(1, 3, 1, 1), size=512), gimg, gimg_perm ]), '%s/org_%d.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) if iteration % SAVE_MODEL_INTERVAL == 0: print('Saving history model') torch.save( { 's': style_encoder.state_dict(), 'd': decoder.state_dict(), 'c': content_encoder.state_dict(), 'opt_c': opt_c.state_dict(), 'opt_s_cls': opt_s_cls.state_dict(), 'opt_s': opt_s.state_dict(), 'opt_d': opt_d.state_dict(), }, '%s/%d.pth' % (saved_model_folder, iteration)) torch.save( { 's': style_encoder.state_dict(), 'd': decoder.state_dict(), 'c': content_encoder.state_dict(), 'opt_c': opt_c.state_dict(), 'opt_s_cls': opt_s_cls.state_dict(), 'opt_s': opt_s.state_dict(), 'opt_d': opt_d.state_dict(), }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))
def projector_factor_fn(ckpt, fact, files): device = "cuda" parser = argparse.ArgumentParser() # parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--size", type=int, default=256) parser.add_argument("--lr_rampup", type=float, default=0.05) parser.add_argument("--lr_rampdown", type=float, default=0.25) parser.add_argument("--lr", type=float, default=0.1) parser.add_argument("--noise", type=float, default=0.05) parser.add_argument("--noise_ramp", type=float, default=0.75) parser.add_argument("--step", type=int, default=1000) parser.add_argument("--noise_regularize", type=float, default=1e5) parser.add_argument("--mse", type=float, default=0) parser.add_argument("--w_plus", action="store_true") parser.add_argument("--device", type=str, default="cuda") # parser.add_argument("--fact", type=str, required=True) # parser.add_argument("files", metavar="FILES", nargs="+") args = parser.parse_args() eigvec = torch.load(fact)["eigvec"].to(args.device) # args.fact eigvec.requires_grad = False n_mean_degree = 10000 n_mean_weight = 10000 resize = min(args.size, 256) transform = transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) imgs = [] for imgfile in files: # args.files img = transform(Image.open(imgfile).convert("RGB")) imgs.append(img) imgs = torch.stack(imgs, 0).to(device) g_ema = Generator(args.size, 512, 8) g_ema.load_state_dict(torch.load(ckpt)["g_ema"], strict=False) # args.ckpt g_ema.eval() g_ema = g_ema.to(device) trunc = g_ema.mean_latent(4096) with torch.no_grad(): noise_sample = torch.randn(1, 512, device=device) latent = g_ema.get_latent(noise_sample) degree_sample = torch.randn(n_mean_degree, 1, device=device) * 1000 weight_sample = torch.randn( n_mean_weight, 512, device=device).unsqueeze(1) # row of factor degree_mean = degree_sample.mean(0) degree_std = ((degree_sample - degree_mean).pow(2).sum() / n_mean_degree)**0.5 weight_mean = weight_sample.mean(0) weight_std = ((weight_sample - weight_mean).pow(2).sum() / n_mean_weight)**0.5 direction = torch.mm(eigvec, weight_mean.T) percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda")) noises_single = g_ema.make_noise() noises = [] for noise in noises_single: noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) direction_in = direction.T weight_in = weight_mean weight_mean.requires_grad = True for noise in noises: noise.requires_grad = True optimizer = optim.Adam([weight_mean] + noises, lr=args.lr) pbar = tqdm(range(args.step)) latent_path = [] weight_path = [] imgs.detach() for i in pbar: t = i / args.step lr = get_lr(t, args.lr) optimizer.param_groups[0]["lr"] = lr noise_strength = weight_std * args.noise * max( 0, 1 - t / args.noise_ramp)**2 weight_n = latent_noise(weight_in, noise_strength.item()) direction_n = torch.mm(weight_in, eigvec) img_gen, _ = g_ema([direction_n], input_is_latent=True, noise=noises) batch, channel, height, width = img_gen.shape if height > 256: factor = height // 256 img_gen = img_gen.reshape(batch, channel, height // factor, factor, width // factor, factor) img_gen = img_gen.mean([3, 5]) p_loss = percept(img_gen, imgs).sum() n_loss = noise_regularize(noises) mse_loss = F.mse_loss(img_gen, imgs) loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() noise_normalize_(noises) if (i + 1) % 100 == 0: direction_in = torch.mm(weight_in, eigvec) latent_path.append(direction_in.detach().clone()) weight_path.append(weight_in.detach().clone()) pbar.set_description(( f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}")) img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) filename = os.path.splitext(os.path.basename( files[0]))[0] + ".pt" # args.files[0] img_ar = make_image(img_gen) result_file = {} # factor_base_path_1 = './models/' for i, input_name in enumerate(files): # args.files noise_single = [] for noise in noises: noise_single.append(noise[i:i + 1]) result_file[input_name] = { "img": img_gen[i], "latent": latent_path[i], "weight": weight_path[i], "noise": noise_single, } img_name = os.path.splitext( os.path.basename(input_name))[0] + "-project.png" pil_img = Image.fromarray(img_ar[i]) save_latent_image_path = factor_base_path + img_name pil_img.save(save_latent_image_path) # add factor_base_path save_latent_code_path = factor_base_path + filename torch.save(result_file, save_latent_code_path) # factor_base_path return save_latent_image_path, save_latent_code_path
#================================ loss_l1_fn = nn.L1Loss() loss_vgg_fn = VGGLoss(device, n_channels=3) if (args.gan_loss_type == "lsgan"): loss_adv_fn = LSGANLoss(device) elif (args.gan_loss_type == "hinge"): loss_adv_fn = HingeGANLoss() else: NotImplementedError() if (args.rec_loss_type == "vgg"): loss_rec_fn = VGGLoss(device, n_channels=3) elif (args.rec_loss_type == "lpips"): if (args.device == "gpu"): loss_rec_fn = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) else: loss_rec_fn = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=False) else: NotImplementedError() #================================ # モデルの学習 #================================ print("Starting Training Loop...") n_print = 1 step = 0 for epoch in tqdm(range(args.n_epoches), desc="epoches"):
def project(path_ckpt, path_files, step=1000): device = "cuda" parser = argparse.ArgumentParser() parser.add_argument('-f', type=str, help='jup kernel') # parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--size", type=int, default=512) parser.add_argument("--lr_rampup", type=float, default=0.05) parser.add_argument("--lr_rampdown", type=float, default=0.25) parser.add_argument("--lr", type=float, default=0.1) parser.add_argument("--noise", type=float, default=0.05) parser.add_argument("--noise_ramp", type=float, default=0.75) parser.add_argument("--step", type=int, default=1000) parser.add_argument("--noise_regularize", type=float, default=1e5) parser.add_argument("--mse", type=float, default=0) # parser.add_argument("--w_plus", action="store_true") # parser.add_argument("files", metavar="FILES", nargs="+") args = parser.parse_args() args.ckpt = path_ckpt args.files = path_files args.w_plus = False args.step = step n_mean_latent = 10000 resize = min(args.size, 256) transform = transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) imgs = [] for imgfile in args.files: img = transform(Image.open(imgfile).convert("RGB")) imgs.append(img) imgs = torch.stack(imgs, 0).to(device) g_ema = Generator(args.size, 512, 8) g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) g_ema.eval() g_ema = g_ema.to(device) with torch.no_grad(): noise_sample = torch.randn(n_mean_latent, 512, device=device) latent_out = g_ema.style(noise_sample) latent_mean = latent_out.mean(0) latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 percept = lpips.PerceptualLoss( model="net-lin", net="vgg", use_gpu=device.startswith("cuda") ) noises_single = g_ema.make_noise() noises = [] for noise in noises_single: noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) if args.w_plus: latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) latent_in.requires_grad = True for noise in noises: noise.requires_grad = True optimizer = optim.Adam([latent_in] + noises, lr=args.lr) pbar = tqdm(range(args.step)) latent_path = [] for i in pbar: t = i / args.step lr = get_lr(t, args.lr) optimizer.param_groups[0]["lr"] = lr noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 latent_n = latent_noise(latent_in, noise_strength.item()) img_gen, _ = g_ema([latent_n], input_is='latent', noise=noises) batch, channel, height, width = img_gen.shape if height > resize: factor = height // resize img_gen = img_gen.reshape( batch, channel, height // factor, factor, width // factor, factor ) img_gen = img_gen.mean([3, 5]) p_loss = percept(img_gen, imgs).sum() n_loss = noise_regularize(noises) mse_loss = F.mse_loss(img_gen, imgs) loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() noise_normalize_(noises) if (i + 1) % 100 == 0: latent_path.append(latent_in.detach().clone()) pbar.set_description(( f"perceptual: {p_loss.item():.8f}; noise regularize: {n_loss.item():.8f}; mse: {mse_loss.item():.8f}; lr: {lr:.4f}" )) img_gen, _ = g_ema([latent_path[-1]], input_is='latent', noise=noises) filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" img_ar = make_image(img_gen) result_file = {} for i, input_name in enumerate(args.files): noise_single = [] for noise in noises: noise_single.append(noise[i: i + 1]) result_file[input_name] = { "img": img_gen[i], "latent": latent_in[i], "noise": noise_single, } img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" pil_img = Image.fromarray(img_ar[i]) pil_img.save(img_name) torch.save(result_file, filename) print(filename) return img_gen, latent_path, latent_in
g_ema.eval() g_ema = g_ema.to(device) g_ema = MyDataParallel(g_ema, device_ids=range(args.n_gpu)) if args.feature_extractor == "d": discriminator = Discriminator(args.size, channel_multiplier=2) discriminator.load_state_dict(torch.load(args.ckpt)['d']) discriminator.eval() discriminator = discriminator.to(device) discriminator = MyDataParallel(discriminator, device_ids=range(args.n_gpu)) percept = d_feat_loss elif args.feature_extractor == "vgg": percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=device.startswith('cuda'), gpu_ids=range(args.n_gpu)) # checkpoint = torch.load("/scratch/gobi2/lechang/trained_vae.pth") # vae = VariationalAutoEncoderLite(5) # vae.load_state_dict(checkpoint['ae']) # vae = vae.to(device) # vae.eval() print("loaded models") if args.latent_space == "w": with torch.no_grad(): noise_sample = torch.randn(n_mean_latent, 512, device=device) latent_out = g_ema.style(noise_sample) latent_mean = latent_out.mean(0)
parser.add_argument('--eps', type=float, default=1e-4) parser.add_argument('--crop', action='store_true') parser.add_argument('ckpt', metavar='CHECKPOINT') args = parser.parse_args() latent_dim = 512 ckpt = torch.load(args.ckpt) g = Generator(args.size, latent_dim, 8).to(device) g.load_state_dict(ckpt['g_ema']) g.eval() percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=device.startswith('cuda')) distances = [] n_batch = args.n_sample // args.batch resid = args.n_sample - (n_batch * args.batch) batch_sizes = [args.batch] * n_batch + [resid] with torch.no_grad(): for batch in tqdm(batch_sizes): noise = g.make_noise() inputs = torch.randn([batch * 2, latent_dim], device=device) lerp_t = torch.rand(batch, device=device)
def __init__(self, x, gma, device, lr=0.1, steps='1000', task='invert', search_space='W+', search_noise=True, project=True, start_layer=0, end_layer=5, discriminator=None, cls_alpha=0, mask=1, mse_weight=1, lpips_alpha=0, r_alpha=0.1): """ :param x: :param gma: :param device: :param lr: :param steps: :param task: :param search_space: W, W+, Z, Z+ :param search_noise: :param project: :param start_layer: :param end_layer: :param discriminator: :param cls_alpha: :param mask: :param mse_weight: :param lpips_alpha: :param r_alpha: """ super().__init__() # self.config = config # if config['image_size'][0] != config['image_size'][1]: # raise Exception('Non-square images are not supported yet.') self.task = task self.search_space = search_space self.search_noise = search_noise self.project = project self.steps = steps self.start_layer = start_layer self.end_layer = end_layer self.dead_zone_linear = 0 self.dead_zone_linear_alpha = 0.1 self.device = device self.geocross_alpha = 0.1 self.cls_alpha = cls_alpha self.lpips_alpha = lpips_alpha self.r_alpha = r_alpha self.mask = mask self.mse_weight = mse_weight self.layer_in = None self.best = None self.skip = None self.lr = lr self.lr_record = [] self.current_step = 0 self.original_imgs = x.to(device) self.discriminator = discriminator if self.discriminator is not None: self.discriminator = self.discriminator.to(device) if self.task == 'separate': bs = self.original_imgs.shape[0] * 2 else: bs = self.original_imgs.shape[0] # self.downsampler_1024_256 = BicubicDownSample(4) # self.downsampler_1024_image = BicubicDownSample(1024 // config['image_size'][0]) # self.downsampler_image_256 = BicubicDownSample(config['image_size'][0] // 256) # Load models and pre-trained weights self.gen = gma.to(device) self.gen.start_layer = start_layer self.gen.end_layer = end_layer for p in self.gen.parameters(): p.requires_grad = False self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) self.plrelu = torch.nn.LeakyReLU(negative_slope=5) # if self.verbose: print("\tRunning Mapping Network") with torch.no_grad(): # torch.manual_seed(0) # latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") # latent_out = torch.nn.LeakyReLU(5)(self.gen.style(latent)) latent_p = self.plrelu( self.gen.style( torch.randn((500000, 512), dtype=torch.float32, device="cuda"))).double() self.mu = latent_p.mean(dim=0, keepdim=True) self.Sigma = (latent_p - self.mu).T @ (latent_p - self.mu) / latent_p.shape[0] d, V = torch.symeig(self.Sigma, eigenvectors=True) # small eigenvalues do not get overamplified. D = torch.diag(1. / torch.sqrt(d + 1e-18)) # whitening matrix # W = np.dot(np.dot(V, D), V.T) # ZCA whitening self.W = (V @ D).float() # PCA whitening self.W_inv = torch.inverse(self.W) latent_p = latent_p.float() self.mu = self.mu.float().unsqueeze(0).to(device) self.Sigma = self.Sigma.float().to(device) self.gaussian_fit = { "mean": latent_p.mean(0).to(device), "std": latent_p.std(0).to(device) } del latent_p torch.cuda.empty_cache() # self.mpl = MappingProxy(torch.load('gaussian_fit.pt')) self.percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda")) # # load a classifier # self.cls = imagenet_models.resnet50() # state_dict = torch.load('imagenet_l2_3_0.pt')['model'] # new_dict = OrderedDict() # # for key in state_dict.keys(): # if 'module.model' in key: # new_dict[key[13:]] = state_dict[key] # # self.cls.load_state_dict(new_dict) # self.cls.to(config['device']) # initialization # self.scalar = torch.ones(bs, requires_grad=True).to(device) self.scalar = torch.ones((bs, 1, 1, 1), dtype=torch.float, requires_grad=True, device='cuda') if start_layer == 0: noises_single = self.gen.make_noise(bs) self.noises = [] for noise in noises_single: self.noises.append(noise.normal_()) if self.search_space == 'W': # # self.latent_z = torch.randn( # # (bs, self.gen.n_latent, self.gen.style_dim), # # dtype=torch.float, # # requires_grad=True, device='cuda') # with torch.no_grad(): # self.latent_z = self.gen.style(F.normalize(torch.randn(bs, self.gen.n_latent, self.gen.style_dim), p=2, dim=2).to(device)) # random w # # self.latent_z = self.gen.style(F.normalize(torch.randn(bs, 1, self.gen.style_dim), p=2, dim=2).repeat(1, self.gen.n_latent, 1).to(device)) # random w # # self.latent_z = self.gen.mean_latent(16384).unsqueeze(1).repeat(bs, self.gen.n_latent, 1).to(device) # mean w # self.latent_z.requires_grad = True # Generate latent tensor self.latent = torch.randn((bs, 1, self.gen.style_dim), dtype=torch.float, requires_grad=True, device='cuda') elif self.search_space == 'W+': self.latent = torch.randn( (bs, self.gen.n_latent, self.gen.style_dim), dtype=torch.float, requires_grad=True, device='cuda') # with torch.no_grad(): # self.latent = self.gen.style(torch.randn(bs, self.gen.n_latent, self.gen.style_dim).to(device)) # random w # self.latent.requires_grad = True elif self.search_space == 'Z': self.latent = torch.randn((bs, 1, self.gen.style_dim), dtype=torch.float, requires_grad=True, device='cuda') elif self.search_space == 'Z+': # self.latent_z = torch.randn( # (bs, self.gen.style_dim), # dtype=torch.float, # requires_grad=True, device='cuda') self.latent_z = torch.randn( (bs, self.gen.n_latent, self.gen.style_dim), dtype=torch.float, requires_grad=True, device='cuda') self.latent_w = self.gen.style(self.latent_z) else: raise ValueError("searching_space incorrect") self.gen_outs = [None] else: # restore noises self.noises = torch.load(config['saved_noises'][0]) self.latent_z = torch.load(config['saved_noises'][1]).to( config['device']) self.gen_outs = torch.load(config['saved_noises'][2]) self.latent_z.requires_grad = True
imgs = torch.stack(imgs, 0).to(device) g_ema = Generator(args.size, 512, 8) g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) g_ema.eval() g_ema = g_ema.to(device) with torch.no_grad(): noise_sample = torch.randn(n_mean_latent, 512, device=device) latent_out = g_ema.mapping_network(noise_sample) latent_mean = latent_out.mean(0) latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 percept = lpips.PerceptualLoss( model="net-lin", net="vgg", use_gpu=device.startswith("cuda") ) noises_single = g_ema.make_noise() noises = [] for noise in noises_single: noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) if args.w_plus: latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) latent_in.requires_grad = True for noise in noises:
def train(self, ref_images, epochs=100, start_lr=0.1): '''Train ...''' if ref_images.dim() < 4: ref_images = ref_images.unsqueeze(0) ref_images = ref_images.to(os.environ["DEVICE"]) ref_batch, channel, ref_height, ref_width = ref_images.shape assert ref_height == ref_width noise_var_list = [] for noise in self.generator.make_noise(): normal_noise = noise.repeat(ref_batch, 1, 1, 1).normal_() normal_noise.requires_grad = True noise_var_list.append(normal_noise) n_mean_latent = 10000 latent_out = torch.randn(n_mean_latent, 512, device=self.device) latent_mean = latent_out.mean(0) latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent)**0.5 del noise_sample, latent_out latent_var = latent_mean.detach().clone().unsqueeze(0).repeat( ref_batch, 1) latent_var.requires_grad = True optimizer = optim.Adam([latent_var] + noise_var_list, lr=start_lr) percept = lpips.PerceptualLoss( model="net-lin", net="vgg", use_gpu=os.environ["DEVICE"].startswith("cuda")) progress_bar = tqdm(range(epochs)) noise_level = 0.05 noise_ramp = 0.07 self.generator.train() for i in progress_bar: t = i / epochs lr = get_lr(t, start_lr) optimizer.param_groups[0]["lr"] = lr # Inject Noise to latent_n noise_strength = latent_std * noise_level * max( 0, 1 - t / noise_ramp)**2 latent_n = latent_var + torch.randn_like( latent_var) * noise_strength.item() gen_images = self.generator(latent_n, noise=noise_var_list) batch, channel, height, width = gen_images.shape if height > ref_height: factor = height // ref_height gen_images = gen_images.reshape(batch, channel, height // factor, factor, width // factor, factor) gen_images = gen_images.mean([3, 5]) p_loss = percept(gen_images, ref_images).sum() n_loss = noise_loss(noise_var_list) mse_loss = F.mse_loss(gen_images, ref_images) loss = p_loss + 1e5 * n_loss + 0.2 * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() noise_normalize_(noise_var_list) progress_bar.set_description(( f"Loss = perceptual: {p_loss.item():.4f}; noise: {n_loss.item():.4f};" f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}")) last_latent = latent_n.detach().clone() self.generator.eval() del noise_var_list torch.cuda.empty_cache() # maybe the best latent ? return last_latent