def main(conf): device = "cuda:0" if torch.cuda.is_available() else 'cpu' beta_schedule = "linear" beta_start = 1e-4 beta_end = 2e-2 n_timestep = 1000 conf.distributed = dist.get_world_size() > 1 transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) train_set = MultiResolutionDataset( conf.dataset.path, transform, conf.dataset.resolution ) train_sampler = dist.data_sampler( train_set, shuffle=True, distributed=conf.distributed ) train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler) model = UNet( conf.model.in_channel, conf.model.channel, channel_multiplier=conf.model.channel_multiplier, n_res_blocks=conf.model.n_res_blocks, attn_strides=conf.model.attn_strides, dropout=conf.model.dropout, fold=conf.model.fold, ) model = model.to(device) ema = UNet( conf.model.in_channel, conf.model.channel, channel_multiplier=conf.model.channel_multiplier, n_res_blocks=conf.model.n_res_blocks, attn_strides=conf.model.attn_strides, dropout=conf.model.dropout, fold=conf.model.fold, ) ema = ema.to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep) diffusion = GaussianDiffusion(betas).to(device) train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
def main(conf): wandb = None if dist.is_primary() and conf.evaluate.wandb: wandb = load_wandb() wandb.init(project="denoising diffusion") device = "cuda" beta_schedule = "linear" conf.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) train_set = MultiResolutionDataset(conf.dataset.path, transform, conf.dataset.resolution) train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=conf.distributed) train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler) model = conf.model.make() model = model.to(device) ema = conf.model.make() ema = ema.to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) if conf.ckpt is not None: ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage) if conf.distributed: model.module.load_state_dict(ckpt["model"]) else: model.load_state_dict(ckpt["model"]) ema.load_state_dict(ckpt["ema"]) betas = conf.diffusion.beta_schedule.make() diffusion = GaussianDiffusion(betas).to(device) train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device, wandb)
def train_dataloader(self): transform = transforms.Compose( [ transforms.RandomVerticalFlip(p=0.5 if self.vflip else 0), transforms.RandomHorizontalFlip(p=0.5 if self.hflip else 0), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) dataset = MultiResolutionDataset(self.path, transform, self.size) loader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0) return loader
def set_dataset(self): args = self.args transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) self.dataset = MultiResolutionDataset(args.path, transform, args.size) self.loader = data.DataLoader( self.dataset, batch_size=args.batch, sampler=data_sampler(self.dataset, shuffle=True, distributed=args.distributed), drop_last=True, )
discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose([ transforms.Resize(args.size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) # dataset_src = MultiResolutionDataset(args.path_src, transform, args.size) dataset_src = MultiResolutionDataset(args.path_src, transform, 256) loader_src = data.DataLoader( dataset_src, batch_size=args.batch, sampler=data_sampler(dataset_src, shuffle=True, distributed=args.distributed), drop_last=True, ) # dataset_norm = MultiResolutionDataset(args.path_norm, transform, args.size) dataset_norm = MultiResolutionDataset(args.path_norm, transform, 256) loader_norm = data.DataLoader( dataset_norm, batch_size=args.batch, sampler=data_sampler(dataset_norm, shuffle=True,
def setup_and_run(device, args): os.makedirs(f'sample_{args.name}', exist_ok=True) os.makedirs(f'checkpoint_{args.name}', exist_ok=True) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 args.distributed = n_gpu > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() args.latent = 512 args.n_mlp = 8 args.start_iter = 0 generator = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) discriminator = Discriminator( args.size, channel_multiplier=args.channel_multiplier).to(device) g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) g_ema.eval() accumulate(g_ema, generator, 0) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0**g_reg_ratio, 0.99**g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0**d_reg_ratio, 0.99**d_reg_ratio), ) if args.ckpt is not None: print('load model:', args.ckpt) ckpt = torch.load(args.ckpt) try: ckpt_name = os.path.basename(args.ckpt) args.start_iter = int(os.path.splitext(ckpt_name)[0]) except ValueError: pass generator.load_state_dict(ckpt['g'], strict=False) discriminator.load_state_dict(ckpt['d'], strict=False) g_ema.load_state_dict(ckpt['g_ema'], strict=False) g_optim.load_state_dict(ckpt['g_optim']) d_optim.load_state_dict(ckpt['d_optim']) if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project='stylegan 2') train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
d_optim.load_state_dict(e_ckpt["d_optim"]) try: ckpt_name = os.path.basename(args.e_ckpt) args.start_iter = int( os.path.splitext(ckpt_name.split('_')[-1])[0]) except ValueError: pass transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.data, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True), drop_last=True, ) test_dataset = MultiResolutionDataset(args.test_data, transform, args.size) test_loader = data.DataLoader( test_dataset, batch_size=args.val_batch, sampler=data_sampler(test_dataset, shuffle=True), drop_last=True, )
generator.module.load_state_dict(ckpt['generator']) discriminator.module.load_state_dict(ckpt['discriminator']) g_running.load_state_dict(ckpt['g_running']) g_optimizer.load_state_dict(ckpt['g_optimizer']) d_optimizer.load_state_dict(ckpt['d_optimizer']) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((8, 8)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset1 = CondDataset(args.path, transform, transform) dataset2 = MultiResolutionDataset(args.path, transform) if args.sched: args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} args.batch = { 4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32 } else: args.lr = {} args.batch = {}
) if args.mirror_augment: transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) else: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) dataset = MultiResolutionDataset(args.path, transform, args.size, args.use_label, metadata, categories) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project='stylegan 2') train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
def eval(args, latent_sampler, g_ema, inception, device, config): if g_ema is not None: g_ema.eval() """ Cast FID calculation spec """ if hasattr(config.train_params, "extra_pre_resize"): real_data_res = config.train_params.extra_pre_resize else: # StyleGAN2 baseline assert config.train_params.styleGAN2_baseline real_data_res = config.train_params.full_size assert real_data_res in {128, 256}, "In this paper, we only benchmark in size {128, 256}. Got {}.".format(real_data_res) eval_gen_res = real_data_res * args.scale # InfinityGAN is trained with larger image, so the same resolution equivalents to smaller FoV. # Here, we ensures the FoV is the same as the StyleGAN2 baseline fov_scale = config.train_params.full_size / real_data_res raw_gen_res = int(np.ceil(eval_gen_res * fov_scale)) if args.seq_inference: assert (not hasattr(config.train_params, "styleGAN2_baseline")) or (not config.train_params.styleGAN2_baseline) assert args.scale > 1, "Set sequential inference with scale==1 is meaningless" use_seq_inf = True else: use_seq_inf = False """ Create dataloader and generator """ if args.img_folder is not None: postprocessing_params = [ ["assert", eval_gen_res], ["resize", real_data_res], ] else: postprocessing_params = [ ["scale", 1 / fov_scale], ["crop", eval_gen_res], ["resize", real_data_res], ] fake_generator = \ QuantEvalSampleGenerator( g_ema, latent_sampler, img_folder=args.img_folder, # if applicable output_size=raw_gen_res, use_seq_inf=use_seq_inf, postprocessing_params=postprocessing_params, fid_type=args.type, device=device, config=config, use_pil_resize=args.use_pil_resize) stats_key = "benchmark-{}-{}-RealRes{}".format( args.type, config.data_params.dataset, real_data_res) # FID statistics can be different for different PyTorch version, not sure about cuda stats_key += f"_PT{torch.__version__}_cu{torch.version.cuda}" fid_cache_path = os.path.join(".fid-cache/", stats_key+".pkl") if os.path.exists(fid_cache_path): if args.clear_fid_cache: os.remove(fid_cache_path) use_cache = False else: use_cache = True else: use_cache = False if not use_cache: dataset = MultiResolutionDataset( split="train", config=config, is_training=False, # return "full" of real full images and crop on-the-fly disable_extra_cropping=True, simple_return_full=True, override_full_size=real_data_res) real_dataloader = QuantEvalDataLoader(dataset, real_data_res, device, config) else: real_dataloader = None """ Eval """ st = time.time() if args.metric == "is": assert args.scale == 1, "We didn't implement scaleinv IS." n_batch = int(np.ceil(config.test_params.n_fid_sample / config.train_params.batch_size)) all_imgs = [] for img_batch in tqdm(fake_generator(n_batch), total=n_batch): img_batch = ((img_batch + 1) / 2).cpu() # [-1, 1] => [0, 1] all_imgs.append(img_batch) all_imgs = torch.cat(all_imgs, 0) is_mean, is_std = inception_score(all_imgs, device="cuda", batch_size=config.train_params.batch_size, resize=False, splits=10) print(" [*] IS time spend {}".format(args.type, time.time()-st)) print(" [*] IS at eval_gen_res {} is {}+-{} (ckpt patch FID = {})".format( eval_gen_res, is_mean, is_std, config.var.best_fid)) elif args.metric == "fid": if args.type == "spatial": fid = eval_fid( real_dataloader, fake_generator, inception, stats_key, None, device, config, spatial_partition_cat=True, assert_eval_shape=real_data_res) elif args.type in {"scaleinv", "alis"}: fid = eval_fid( real_dataloader, fake_generator, inception, stats_key, None, device, config, spatial_partition_cat=False, assert_eval_shape=real_data_res) else: raise NotImplementedError("Unknown FID variant {}".format(args.type)) print(" [*] {} FID time spend {}".format(args.type, time.time()-st)) print(" [*] FID (type {}) at eval_gen_res {} is {} (ckpt patch FID = {})".format( args.type, eval_gen_res, fid, config.var.best_fid)) """ Setup Logging """ if args.metric == "is": log_root = os.path.join("logs-quant", "IS") filename = f"EvalGenRes{eval_gen_res}-Exp-{config.var.exp_name}.txt" score = "{:.6f}+-{:.6f}\n".format(is_mean, is_std) else: log_root = os.path.join("logs-quant", "FID-"+args.type) filename = f"Scale{args.scale}-EvalGenRes{eval_gen_res}-Exp-{config.var.exp_name}.txt" score = "{:.6f}\n".format(fid) if not os.path.exists(log_root): os.makedirs(log_root) with open(os.path.join(log_root, filename), "a") as lf: lf.write(score)
def train(learning_rate, lambda_mse): print( f"learning_rate={learning_rate:.4f}", f"lambda_mse={lambda_mse:.4f}", ) transform = transforms.Compose( [ transforms.Resize(128), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) batch_size = 72 data_path = "/home/hans/trainsets/cyphis" name = os.path.splitext(os.path.basename(data_path))[0] dataset = MultiResolutionDataset(data_path, transform, 256) dataloader = data.DataLoader( dataset, batch_size=batch_size, sampler=data.RandomSampler(dataset), num_workers=12, drop_last=True, ) loader = sample_data(dataloader) sample_imgs = next(loader)[:24] wandb.log({"Real Images": [wandb.Image(utils.make_grid(sample_imgs, nrow=6, normalize=True, range=(-1, 1)))]}) vae, vae_optim = None, None vae = ConvSegNet().to(device) vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate) vgg = VGGLoss() sample_z = th.randn(size=(24, 512, 16, 16)) sample_z /= sample_z.abs().max() scores = [] num_iters = 100_000 pbar = tqdm(range(num_iters), smoothing=0.1) for i in pbar: vae.train() real = next(loader).to(device) z = vae.encode(real) fake = vae.decode(z) vgg_loss = vgg(fake, real) mse_loss = th.sqrt((fake - real).pow(2).mean()) # diff = fake - real # recons_loss = recons_alpha * diff + th.log(1.0 + th.exp(-2 * recons_alpha * diff)) - th.log(th.tensor(2.0)) # recons_loss = (1.0 / recons_alpha) * recons_loss.mean() # recons_loss = recons_loss if not th.isinf(recons_loss).any() else 0 # x, y = z.chunk(2) # align_loss = align(x, y, alpha=align_alpha) # unif_loss = -(uniform(x, t=unif_t) + uniform(y, t=unif_t)) / 2.0 loss = ( vgg_loss + lambda_mse * mse_loss # + lambda_recons * recons_loss # + lambda_align * align_loss # + lambda_unif * unif_loss ) # print(vgg_loss.detach().cpu().item()) # print(lambda_mse * mse_loss.detach().cpu().item()) # # print(lambda_recons * recons_loss.detach().cpu().item()) # print(lambda_align * align_loss.detach().cpu().item()) # print(lambda_unif * unif_loss.detach().cpu().item()) loss_dict = { "Total": loss, "MSE": mse_loss, "VGG": vgg_loss, # "Reconstruction": recons_loss, # "Alignment": align_loss, # "Uniformity": unif_loss, } vae.zero_grad() loss.backward() vae_optim.step() wandb.log(loss_dict) # pbar.set_description(" ".join()) with th.no_grad(): if i % int(num_iters / 100) == 0 or i + 1 == num_iters: vae.eval() sample = vae(sample_imgs.to(device)) grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1)) del sample wandb.log({"Reconstructed Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) sample = vae.decode(sample_z.to(device)) grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1)) del sample wandb.log({"Generated Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) gc.collect() th.cuda.empty_cache() th.save( {"vae": vae.state_dict(), "vae_optim": vae_optim.state_dict()}, f"/home/hans/modelzoo/maua-sg2/vae-{name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}.pt", ) if th.isnan(loss).any(): print("NaN losses, exiting...") wandb.log({"Total": 27000}) return
) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.path, transform, args.resolution, condition_path=args.condition_path) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, num_workers=args.num_workers, ) train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
def train(args, generator, discriminator): step = int(math.log2(args.max_size)) - 2 #-> 1 resolution = 4 * 2**step batch_size = args.batch.get(resolution, args.batch_default) dataset = MultiResolutionDataset(args.path, transform, resolution=resolution) loader = sample_data(dataset, batch_size, resolution) data_loader = iter(loader) adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) pbar = tqdm(range(3000000)) requires_grad(generator, False) requires_grad(discriminator, True) disc_loss_val = 0 gen_loss_val = 0 grad_loss_val = 0 alpha = 0 used_sample = 0 #-> how many images has been used max_step = int(math.log2(args.max_size)) - 2 #-> log2(1024) - 2 = 8 final_progress = False for i in pbar: discriminator.zero_grad() alpha = min(1, 1 / args.phase * (used_sample + 1)) #-> min(1, (cur+1)/60_0000) #-> when more than 60_0000 sampels is used, alpha will be in const to 1.0 #-> which means we the "skip_rgb" will not be applied if (resolution == args.init_size and args.ckpt is None) or final_progress: alpha = 1 #-> also, if initially, no previous outputs for skip-connection if used_sample > args.phase * 2: #-> if > 1_200_000 ## num_of_epoch_each_phase = args.phase * 2 / training_dataset_size used_sample = 0 step += 1 if step > max_step: step = max_step final_progress = True ckpt_step = step + 1 else: alpha = 0 ckpt_step = step resolution = 4 * 2**step loader = sample_data( dataset, args.batch.get(resolution, args.batch_default), resolution) data_loader = iter(loader) torch.save( { 'generator': generator.module.state_dict(), 'discriminator': discriminator.module.state_dict(), 'g_optimizer': g_optimizer.state_dict(), 'd_optimizer': d_optimizer.state_dict(), 'g_running': g_running.state_dict(), }, r'checkpoint/train_step-{}.model'.format(ckpt_step)) adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) #### update discriminator try: real_image = next(data_loader) except (OSError, StopIteration): data_loader = iter(loader) real_image = next(data_loader) used_sample += real_image.shape[0] b_size = real_image.size(0) coords = coord_base.repeat(b_size, 1) select = np.hstack([[i * b_size + j for i in range(4)] for j in range(b_size)]) real_image = real_image.cuda() if args.loss == 'wgan-gp': real_predict = discriminator(real_image, step=step, alpha=alpha) real_predict = real_predict.mean() - 0.001 * (real_predict** 2).mean() (-real_predict).backward() elif args.loss == 'r1': real_image.requires_grad = True real_scores = discriminator(real_image, step=step, alpha=alpha) real_predict = F.softplus(-real_scores).mean() real_predict.backward(retain_graph=True) grad_real = grad(outputs=real_scores.sum(), inputs=real_image, create_graph=True)[0] grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1)**2).mean() grad_penalty = 10 / 2 * grad_penalty grad_penalty.backward() if i % 10 == 0: grad_loss_val = grad_penalty.item() if args.mixing and random.random() < 0.9: gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( 4, b_size, code_size - 2, device='cuda').chunk(4, 0) gen_in11 = gen_in11.squeeze(0) gen_in11 = torch.cat([gen_in11.repeat(4, 1)[select], coords], dim=1) gen_in12 = gen_in12.squeeze(0) gen_in12 = torch.cat([gen_in12.repeat(4, 1)[select], coords], dim=1) gen_in21 = gen_in21.squeeze(0) gen_in21 = torch.cat([gen_in21.repeat(4, 1)[select], coords], dim=1) gen_in22 = gen_in22.squeeze(0) gen_in22 = torch.cat([gen_in22.repeat(4, 1)[select], coords], dim=1) gen_in1 = [gen_in11, gen_in12] gen_in2 = [gen_in21, gen_in22] else: gen_in1, gen_in2 = torch.randn(2, b_size, code_size - 2, device='cuda').chunk( 2, 0 # 512 ) gen_in1 = gen_in1.squeeze(0) # (B, 254) gen_in2 = gen_in2.squeeze(0) # (B, 254) # repeat and copy gen_in1 = torch.cat([gen_in1.repeat(4, 1)[select], coords], dim=1) gen_in2 = torch.cat([gen_in2.repeat(4, 1)[select], coords], dim=1) fake_image = generator(gen_in1, step=step - 1, alpha=alpha) fake_image_up = torch.cat([fake_image[0::4], fake_image[1::4]], dim=3) fake_image_dn = torch.cat([fake_image[2::4], fake_image[3::4]], dim=3) fake_image = torch.cat([fake_image_up, fake_image_dn], dim=2) fake_predict = discriminator(fake_image, step=step, alpha=alpha) if args.loss == 'wgan-gp': fake_predict = fake_predict.mean() fake_predict.backward() eps = torch.rand(b_size, 1, 1, 1).cuda() x_hat = eps * real_image.data + (1 - eps) * fake_image.data x_hat.requires_grad = True hat_predict = discriminator(x_hat, step=step, alpha=alpha) grad_x_hat = grad(outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0] grad_penalty = ( (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1)**2).mean() grad_penalty = 10 * grad_penalty grad_penalty.backward() if i % 10 == 0: grad_loss_val = grad_penalty.item() disc_loss_val = (real_predict - fake_predict).item() elif args.loss == 'r1': fake_predict = F.softplus(fake_predict).mean() fake_predict.backward() if i % 10 == 0: disc_loss_val = (real_predict + fake_predict).item() d_optimizer.step() #### update generator if (i + 1) % n_critic == 0: generator.zero_grad() requires_grad(generator, True) requires_grad(discriminator, False) fake_image = generator(gen_in2, step=step - 1, alpha=alpha) fake_image_up = torch.cat([fake_image[0::4], fake_image[1::4]], dim=3) fake_image_dn = torch.cat([fake_image[2::4], fake_image[3::4]], dim=3) fake_image = torch.cat([fake_image_up, fake_image_dn], dim=2) predict = discriminator(fake_image, step=step, alpha=alpha) if args.loss == 'wgan-gp': loss = -predict.mean() elif args.loss == 'r1': loss = F.softplus(-predict).mean() if i % 10 == 0: gen_loss_val = loss.item() loss.backward() g_optimizer.step() accumulate(g_running, generator.module) requires_grad(generator, False) requires_grad(discriminator, True) #### validation if (i + 1) % 100 == 0: images = [] gen_i, gen_j = args.gen_sample.get(resolution, (10, 5)) coords = coord_base.repeat(gen_j, 1) select = np.hstack([[i * gen_j + j for i in range(4)] for j in range(gen_j)]) with torch.no_grad(): for ii in range(gen_i): style = torch.randn(gen_j, code_size - 2).cuda().repeat(4, 1)[select] style = torch.cat([style, coords], dim=1) image = g_running(style, step=step - 1, alpha=alpha).data.cpu() image_up = torch.cat([image[0::4], image[1::4]], dim=3) image_dn = torch.cat([image[2::4], image[3::4]], dim=3) image = torch.cat([image_up, image_dn], dim=2) images.append(image) utils.save_image( torch.cat(images, 0), r'sample/%06d.png' % (i + 1), nrow=gen_i, normalize=True, range=(-1, 1), ) if (i + 1) % 10000 == 0: torch.save(g_running.state_dict(), r'checkpoint/%06d.model' % (i + 1)) state_msg = ( r'Size: {}; G: {:.3f}; D: {:.3f}; Grad: {:.3f}; Alpha: {:.5f}'. format(4 * 2**step, gen_loss_val, disc_loss_val, grad_loss_val, alpha)) pbar.set_description(state_msg)
def train(args, generator, discriminator): step = int(math.log2(args.max_size)) - 2 #-> 1 resolution = 4 * 2 ** step batch_size = args.batch.get(resolution, args.batch_default) dataset = MultiResolutionDataset(args.path, transform, resolution=resolution) loader = sample_data( dataset, batch_size, resolution ) data_loader = iter(loader) adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) pbar = tqdm(range(3000000)) requires_grad(generator, False) requires_grad(discriminator, True) disc_loss_val = 0 gen_loss_val = 0 grad_loss_val = 0 alpha = 0 used_sample = 0 #-> how many images has been used max_step = int(math.log2(args.max_size)) - 2 #-> log2(1024) - 2 = 8 final_progress = False for i in pbar: discriminator.zero_grad() alpha = min(1, 1 / args.phase * (used_sample + 1)) #-> min(1, (cur+1)/60_0000) #-> when more than 60_0000 sampels is used, alpha will be in const to 1.0 #-> which means we the "skip_rgb" will not be applied if (resolution == args.init_size and args.ckpt is None) or final_progress: alpha = 1 #-> also, if initially, no previous outputs for skip-connection if used_sample > args.phase * 2: #-> if > 1_200_000 ## num_of_epoch_each_phase = args.phase * 2 / training_dataset_size used_sample = 0 step += 1 if step > max_step: step = max_step final_progress = True ckpt_step = step + 1 else: alpha = 0 ckpt_step = step resolution = 4 * 2 ** step_D loader = sample_data( dataset, args.batch.get(resolution, args.batch_default), resolution ) data_loader = iter(loader) torch.save( { 'generator': generator.module.state_dict(), 'discriminator': discriminator.module.state_dict(), 'g_optimizer': g_optimizer.state_dict(), 'd_optimizer': d_optimizer.state_dict(), 'g_running': g_running.state_dict(), }, r'checkpoint_coco/train_step-{}.model'.format(ckpt_step)) adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) #### update discriminator try: real_image = next(data_loader) except (OSError, StopIteration): data_loader = iter(loader) real_image = next(data_loader) used_sample += real_image.shape[0] real_image = real_image.cuda() b_size = real_image.size(0) select = np.hstack([[i*b_size+j for i in range(num_micro_in_macro)] for j in range(b_size)]) # get sample coords coord_handler.batch_size = b_size patch_handler.batch_size = b_size d_macro_coord_real, g_micro_coord_real, _ = coord_handler._euclidean_sample_coord() d_macro_coord_fake1, g_micro_coord_fake1, _ = coord_handler._euclidean_sample_coord() d_macro_coord_fake2, g_micro_coord_fake2, _ = coord_handler._euclidean_sample_coord() d_macro_coord_real = torch.from_numpy(d_macro_coord_real).float().cuda() d_macro_coord_fake1, g_micro_coord_fake1 = torch.from_numpy(d_macro_coord_fake1).float().cuda(), torch.from_numpy(g_micro_coord_fake1).float().cuda() d_macro_coord_fake2, g_micro_coord_fake2 = torch.from_numpy(d_macro_coord_fake2).float().cuda(), torch.from_numpy(g_micro_coord_fake2).float().cuda() real_macro = micros_to_macro(patch_handler.crop_micro_from_full_gpu(real_image, g_micro_coord_real[:, 1:2], g_micro_coord_real[:, 0:1]), config["data_params"]["ratio_macro_to_micro"]) if args.loss == 'wgan-gp': real_predict, real_H = discriminator(real_macro, d_macro_coord_real, step=step_D, alpha=alpha) real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean() sp_loss_real = criterion_mse(spatial_predictor(real_H), d_macro_coord_real) * coord_loss_w (-real_predict+sp_loss_real).backward() elif args.loss == 'r1': real_macro.requires_grad = True real_scores, real_H = discriminator(real_macro, d_macro_coord_real, step=step_D, alpha=alpha) real_predict = F.softplus(-real_scores).mean() sp_loss_real = criterion_mse(spatial_predictor(real_H), d_macro_coord_real) * coord_loss_w (real_predict+sp_loss_real).backward(retain_graph=True) grad_real = grad( outputs=real_scores.sum(), inputs=real_macro, create_graph=True )[0] grad_penalty = ( grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 ).mean() grad_penalty = 10 / 2 * grad_penalty grad_penalty.backward() if i%10 == 0: grad_loss_val = grad_penalty.item() if args.mixing and random.random() < 0.9: gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( 4, b_size, code_size-2, device='cuda' ).chunk(4, 0) gen_in11 = gen_in11.squeeze(0) gen_in11 = torch.cat([gen_in11.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake1], dim=1) gen_in12 = gen_in12.squeeze(0) gen_in12 = torch.cat([gen_in12.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake1], dim=1) gen_in21 = gen_in21.squeeze(0) gen_in21 = torch.cat([gen_in21.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake2], dim=1) gen_in22 = gen_in22.squeeze(0) gen_in22 = torch.cat([gen_in22.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake2], dim=1) gen_in1 = [gen_in11, gen_in12] gen_in2 = [gen_in21, gen_in22] #print(gen_in11[:16]) else: gen_in1, gen_in2 = torch.randn(2, b_size, code_size-2, device='cuda').chunk( 2, 0 # 512 ) gen_in1 = gen_in1.squeeze(0)# (B, 254) gen_in2 = gen_in2.squeeze(0)# (B, 254) # repeat and copy gen_in1 = torch.cat([gen_in1.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake1], dim=1) gen_in2 = torch.cat([gen_in2.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake2], dim=1) fake_image = generator(gen_in1, step=step_G, alpha=alpha) fake_image = micros_to_macro(fake_image, config["data_params"]["ratio_macro_to_micro"]) fake_predict, fake_H = discriminator(fake_image, d_macro_coord_fake1, step=step_D, alpha=alpha) sp_loss_fake = criterion_mse(spatial_predictor(fake_H), d_macro_coord_fake1) * coord_loss_w if args.loss == 'wgan-gp': fake_predict = fake_predict.mean() (fake_predict+sp_loss_fake).backward() eps = torch.rand(b_size, 1, 1, 1).cuda() x_hat = eps * real_image.data + (1 - eps) * fake_image.data x_hat.requires_grad = True hat_predict = discriminator(x_hat, step=step_D, alpha=alpha) grad_x_hat = grad( outputs=hat_predict.sum(), inputs=x_hat, create_graph=True )[0] grad_penalty = ( (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2 ).mean() grad_penalty = 10 * grad_penalty grad_penalty.backward() if i%10 == 0: grad_loss_val = grad_penalty.item() disc_loss_val = (real_predict - fake_predict).item() elif args.loss == 'r1': fake_predict = F.softplus(fake_predict).mean() (fake_predict+sp_loss_fake).backward() if i%10 == 0: disc_loss_val = (real_predict + fake_predict).item() d_optimizer.step() if i%10 == 0: spatial_loss_D_val = (sp_loss_real.item() + sp_loss_fake.item()) / 2 #### update generator if (i + 1) % n_critic == 0: generator.zero_grad() requires_grad(generator, True) requires_grad(discriminator, False) fake_image = generator(gen_in2, step=step_G, alpha=alpha) fake_image = micros_to_macro(fake_image, config["data_params"]["ratio_macro_to_micro"]) predict, H = discriminator(fake_image, d_macro_coord_fake2, step=step_D, alpha=alpha) spatial_loss = criterion_mse(spatial_predictor(H), d_macro_coord_fake2) * coord_loss_w if args.loss == 'wgan-gp': loss = -predict.mean() elif args.loss == 'r1': loss = F.softplus(-predict).mean() if i%10 == 0: gen_loss_val = loss.item() spatial_loss_G_val = spatial_loss.item() (loss+spatial_loss).backward() g_optimizer.step() accumulate(g_running, generator.module) requires_grad(generator, False) requires_grad(discriminator, True) #### validation if (i + 1) % 100 == 0: images = [] gen_i, gen_j = args.gen_sample.get(resolution, (10, 5)) coord_handler.batch_size = gen_i * gen_j _, g_micro_coord_val, _ = coord_handler._euclidean_sample_coord() g_micro_coord_val = torch.from_numpy(g_micro_coord_val).float().cuda() #print(g_micro_coord_val.shape) select = np.hstack([[i*gen_j+j for i in range(num_micro_in_macro)] for j in range(gen_j)]) with torch.no_grad(): for ii in range(gen_i): style = torch.randn(gen_j, code_size-2).cuda().repeat(num_micro_in_macro, 1)[select] #print(style.size()) coords = g_micro_coord_val[ii*gen_j*num_micro_in_macro:(ii+1)*gen_j*num_micro_in_macro] #print(coords.size()) style = torch.cat([style, coords], dim=1) image = g_running(style, step=step_G, alpha=alpha).data.cpu() image = micros_to_macro(image, config['data_params']['ratio_macro_to_micro']) images.append( image ) utils.save_image( torch.cat(images, 0), r'sample_coco/%06d.png'%(i+1), nrow=gen_i, normalize=True, range=(-1, 1), ) if (i + 1) % 10000 == 0: torch.save( g_running.state_dict(), r'checkpoint_coco/%06d.model'%(i+1) ) state_msg = ( r'Size: {}; G: {:.3f}; D: {:.3f}; Grad: {:.3f}; sp_G: {:.3f}; sp_D: {:.3f}; Alpha: {:.5f}'.format(4 * 2 ** step, gen_loss_val, disc_loss_val, grad_loss_val, spatial_loss_G_val, spatial_loss_D_val, alpha) ) pbar.set_description(state_msg)
def main(args, myargs): code_size = 512 batch_size = 16 n_critic = 1 generator = nn.DataParallel(StyledGenerator(code_size)).cuda() discriminator = nn.DataParallel( Discriminator(from_rgb_activate=not args.no_from_rgb_activate)).cuda() g_running = StyledGenerator(code_size).cuda() g_running.train(False) g_optimizer = optim.Adam(generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)) g_optimizer.add_param_group({ 'params': generator.module.style.parameters(), 'lr': args.lr * 0.01, 'mult': 0.01, }) d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99)) accumulate(g_running, generator.module, 0) if args.ckpt is not None: ckpt = torch.load(args.ckpt) generator.module.load_state_dict(ckpt['generator']) discriminator.module.load_state_dict(ckpt['discriminator']) g_running.load_state_dict(ckpt['g_running']) g_optimizer.load_state_dict(ckpt['g_optimizer']) d_optimizer.load_state_dict(ckpt['d_optimizer']) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.path, transform) if args.sched: args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} args.batch = { 4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32 } else: args.lr = {} args.batch = {} args.gen_sample = {512: (8, 4), 1024: (4, 2)} args.batch_default = 32 train(args, dataset, generator, discriminator, g_optimizer=g_optimizer, d_optimizer=d_optimizer, g_running=g_running, code_size=code_size, n_critic=n_critic, myargs=myargs)
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') args = parser.parse_args() inception = load_patched_inception_v3() inception = nn.DataParallel(inception).eval().to(device) transform = transforms.Compose( [ transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) dset = MultiResolutionDataset(config.data_params.lmdb_path, transform=transform, resolution=config.train_params.full_size) loader = DataLoader(dset, batch_size=config.train_params.batch_size, num_workers=4) features = extract_features(loader, inception, device).numpy() features = features[: params.test_params.n_fid_sample] print(f'extracted {features.shape[0]} features') mean = np.mean(features, 0) cov = np.cov(features, rowvar=False) name = os.path.splitext(os.path.basename(config.data_params.lmdb_path))[0] with open(f'inception_{name}.pkl', 'wb') as f: pickle.dump({'mean': mean, 'cov': cov, 'size': config.train_params.full_size, 'path': config.data_params.lmdb_path}, f)
t_optimizer.load_state_dict(ckpt['t_optimizer']) g_optimizer.load_state_dict(ckpt['g_optimizer']) d_optimizer.load_state_dict(ckpt['d_optimizer']) transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) if not os.path.exists(os.path.join(args.out, 'checkpoint')): os.makedirs(os.path.join(args.out, 'checkpoint')) dataset = MultiResolutionDataset(args.path, transform, max_length=24) inception_score = Inception_score(resize=True, splits=1) if args.sched: args.lr = {4: 1e-3, 8: 1e-3, 16: 5e-4, 32: 1e-4, 64: 1e-4, 128: 1e-4, 256: 1e-4} args.batch = {4: 64, 8: 64, 16: 64, 32: 32, 64: 32, 128: 16, 256: 16} else: args.lr = {} args.batch = {} args.gen_sample = {512: (8, 4), 1024: (4, 2)} args.batch_default = 32
def train(latent_dim, num_repeats, learning_rate, lambda_vgg, lambda_mse): print( f"latent_dim={latent_dim:.4f}", f"num_repeats={num_repeats:.4f}", f"learning_rate={learning_rate:.4f}", f"lambda_vgg={lambda_vgg:.4f}", f"lambda_mse={lambda_mse:.4f}", ) transform = transforms.Compose([ transforms.Resize(128), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) batch_size = 72 data_path = "/home/hans/trainsets/cyphis" name = os.path.splitext(os.path.basename(data_path))[0] dataset = MultiResolutionDataset(data_path, transform, 256) dataloader = data.DataLoader( dataset, batch_size=batch_size, sampler=data.RandomSampler(dataset), num_workers=12, drop_last=True, ) loader = sample_data(dataloader) sample_imgs = next(loader)[:24] wandb.log({ "Real Images": [ wandb.Image( utils.make_grid(sample_imgs, nrow=6, normalize=True, range=(0, 1))) ] }) vae, vae_optim = None, None vae = InceptionVAE(latent_dim=latent_dim, repeat_per_block=num_repeats).to(device) vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate) vgg = VGGLoss() # sample_z = th.randn(size=(24, 512)) scores = [] num_iters = 100_000 pbar = tqdm(range(num_iters), smoothing=0.1) for i in pbar: vae.train() real = next(loader).to(device) fake, mu, log_var = vae(real) bce = F.binary_cross_entropy(fake, real, size_average=False) kld = -0.5 * th.sum(1 + log_var - mu.pow(2) - log_var.exp()) vgg_loss = vgg(fake, real) mse_loss = th.sqrt((fake - real).pow(2).mean()) loss = bce + kld + lambda_vgg * vgg_loss + lambda_mse * mse_loss loss_dict = { "Total": loss, "BCE": bce, "Kullback Leibler Divergence": kld, "MSE": mse_loss, "VGG": vgg_loss, } vae.zero_grad() loss.backward() vae_optim.step() wandb.log(loss_dict) with th.no_grad(): if i % int(num_iters / 100) == 0 or i + 1 == num_iters: vae.eval() sample, _, _ = vae(sample_imgs.to(device)) grid = utils.make_grid(sample, nrow=6, normalize=True, range=(0, 1)) del sample wandb.log({ "Reconstructed Images VAE": [wandb.Image(grid, caption=f"Step {i}")] }) sample = vae.sampling() grid = utils.make_grid(sample, nrow=6, normalize=True, range=(0, 1)) del sample wandb.log({ "Generated Images VAE": [wandb.Image(grid, caption=f"Step {i}")] }) gc.collect() th.cuda.empty_cache() th.save( { "vae": vae.state_dict(), "vae_optim": vae_optim.state_dict() }, f"/home/hans/modelzoo/maua-sg2/vae-{name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}.pt", ) if th.isnan(loss).any() or th.isinf(loss).any(): print("NaN losses, exiting...") print({ "Total": loss, "\nBCE": bce, "\nKullback Leibler Divergence": kld, "\nMSE": mse_loss, "\nVGG": vgg_loss, }) wandb.log({"Total": 27000}) return
metavar='PATH', help='path to datset lmdb file') args = parser.parse_args() inception = load_patched_inception_v3() inception = nn.DataParallel(inception).eval().to(device) transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) loader = DataLoader(dset, batch_size=args.batch, num_workers=4) features = extract_features(loader, inception, device).numpy() features = features[:args.n_sample] print(f'extracted {features.shape[0]} features') mean = np.mean(features, 0) cov = np.cov(features, rowvar=False) name = os.path.splitext(os.path.basename(args.path))[0] with open(f'inception_{name}.pkl', 'wb') as f:
args = parser.parse_args() random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) inception = InceptionV3().cuda() transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(f'./dataset/{args.dataset}_lmdb', transform) loader = sample_data(dataset, args.batch_size, args.image_size) pbar = tqdm(total=len(dataset)) acts = [] for real_index, real_image in loader: real_image = real_image.cuda() with torch.no_grad(): out = inception(real_image) out = out[0].squeeze(-1).squeeze(-1) acts.append(out.cpu().numpy()) pbar.update(len(real_image)) acts = np.concatenate(acts, axis=0) with open(f'dataset/{args.dataset}_acts.pickle', 'wb') as handle:
discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True, ) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = MultiResolutionDataset(args.path) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project='stylegan 2') train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
broadcast_buffers=False, ) drs_discriminator = nn.parallel.DistributedDataParallel( drs_discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.root, transform, args.size) logit_path = f'./exp_results/{args.baseline_exp_name}/logits_netD.pkl' print(f'Use logit from: {logit_path}') logits = pickle.load(open(logit_path, "rb")) window = 5000 score_start_step = (args.p1_step - window) score_end_step = args.p1_step + 1 score_dict = calculate_scores(logits, start_epoch=score_start_step, end_epoch=score_end_step) sample_weights = score_dict[args.resample_score] def print_stats(sw):
### prepare experiments ### random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) ### load dataset ### transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(f'./dataset/{args.dataset}_lmdb', transform, resolution=args.image_size) ### load G and D ### if args.supervised: G_target = nn.DataParallel( StyledGenerator(code_size, dataset_size=len(dataset), embed_dim=code_size)).cuda() G_running_target = StyledGenerator(code_size, dataset_size=len(dataset), embed_dim=code_size).cuda() G_running_target.train(False) accumulate(G_running_target, G_target.module, 0) else:
### load G and D ### gen1, dis1 = load_network(f'./checkpoint/{args.ckpt1}') gen2, dis2 = load_network(f'./checkpoint/{args.ckpt2}') gen3, dis3 = load_network(f'./checkpoint/{args.ckpt3}') ### load dataset ### transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) data1 = MultiResolutionDataset(f'./dataset/{args.data1}_lmdb', transform, resolution=args.image_size) data2 = MultiResolutionDataset(f'./dataset/{args.data2}_lmdb', transform, resolution=args.image_size) data3 = MultiResolutionDataset(f'./dataset/{args.data3}_lmdb', transform, resolution=args.image_size) step = int(math.log2(args.image_size)) - 2 resolution = 4 * 2 ** step batch_size = 10 ### run experiment ### # acc11, threshold11 = test(dis1, data1, gen1) acc11, threshold11 = 77.15, 0.5685 acc12, threshold12 = test(dis1, data2, gen2) acc13, threshold13 = test(dis1, data3, gen3) acc21, threshold21 = test(dis2, data1, gen1) acc22, threshold22 = test(dis2, data2, gen2)
generator.proj.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio / 2, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project='stylegan 2') train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
safe_load_state_dict(g_ema, ckpt["g_ema"]) safe_load_state_dict(g_optim, ckpt["g_optim"]) safe_load_state_dict(d_optim, ckpt["d_optim"]) else: print(" [*] Did not find ckpt, fresh start!") config.var.start_iter = 0 config.var.best_fid = 500 config.var.mean_path_lengths = None """ Dataset """ train_set = MultiResolutionDataset( split="train", config=config, is_training=True) valid_set = None #MultiResolutionDataset( # os.path.join(dataset_root, "valid"), # is_training=False, # config.train_params.full_size) train_set_fid = MultiResolutionDataset( split="train", config=config, is_training=False) loaders = { "train": make_nonstopping(data.DataLoader( train_set, batch_size=config.train_params.batch_size,