def sample(cfg, logger): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(2) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.to(device) model.eval() model.requires_grad_(False) decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 model_dict = { 'discriminator_s': encoder, 'generator_s': decoder, 'mapping_tl_s': mapping_tl, 'mapping_fl_s': mapping_fl, 'dlatent_avg': dlatent_avg } checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False) checkpointer.load() model.eval() path = './make_figures/output' os.makedirs(path, exist_ok=True) os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) with torch.no_grad(): generate_samples(cfg, model, path, 5, device=device)
def sample(cfg, logger): torch.cuda.set_device(0) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.cuda(0) model.eval() model.requires_grad_(False) decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 model_dict = { 'discriminator_s': encoder, 'generator_s': decoder, 'mapping_tl_s': mapping_tl, 'mapping_fl_s': mapping_fl, 'dlatent_avg': dlatent_avg } checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False) extra_checkpoint_data = checkpointer.load() model.eval() layer_count = cfg.MODEL.LAYER_COUNT def encode(x): z, mu, _ = model.encode(x, layer_count - 1, 1) styles = model.mapping_fl(mu) return styles def decode(x): return model.decoder(x, layer_count - 1, 1, noise=True) path = cfg.DATASET.SAMPLES_PATH im_size = 2**(cfg.MODEL.LAYER_COUNT + 1) paths = list(os.listdir(path)) paths = sorted(paths) random.seed(1) random.shuffle(paths) def make(paths): canvas = [] with torch.no_grad(): for filename in paths: img = np.asarray(Image.open(path + '/' + filename)) if img.shape[2] == 4: img = img[:, :, :3] im = img.transpose((2, 0, 1)) x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] factor = x.shape[2] // im_size if factor != 1: x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] assert x.shape[2] == im_size latents = encode(x[None, ...].cuda()) f = decode(latents) r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3) canvas.append(r) return canvas def chunker_list(seq, n): return [seq[i * n:(i + 1) * n] for i in range((len(seq) + n - 1) // n)] paths = chunker_list(paths, 8 * 3) path = './make_figures/output' os.makedirs(path, exist_ok=True) os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) for i, chunk in enumerate(paths): canvas = make(chunk) canvas = torch.cat(canvas, dim=0) save_path = './make_figures/output/%s/reconstructions_%d.png' % ( cfg.NAME, i) os.makedirs(os.path.dirname(save_path), exist_ok=True) save_image(canvas * 0.5 + 0.5, save_path, nrow=3, pad_value=1.0)
def train(cfg, logger, local_rank, world_size, distributed): torch.cuda.set_device(local_rank) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.cuda(local_rank) model.train() if local_rank == 0: model_s = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model_s.cuda(local_rank) model_s.eval() model_s.requires_grad_(False) if distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], broadcast_buffers=False, bucket_cap_mb=25, find_unused_parameters=True) model.device_ids = None decoder = model.module.decoder encoder = model.module.encoder mapping_tl = model.module.mapping_tl mapping_fl = model.module.mapping_fl dlatent_avg = model.module.dlatent_avg else: decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg count_parameters.print = lambda a: logger.info(a) num_vae_epochs = cfg.TRAIN.NUM_VAE logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 decoder_optimizer = LREQAdam([{ 'params': decoder.parameters() }, { 'params': mapping_fl.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) encoder_optimizer = LREQAdam([{ 'params': encoder.parameters() }, { 'params': mapping_tl.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers={ 'encoder_optimizer': encoder_optimizer, 'decoder_optimizer': decoder_optimizer }, milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, gamma=cfg.TRAIN.LEARNING_DECAY_RATE, reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) model_dict = { 'discriminator': encoder, 'generator': decoder, 'mapping_fl': mapping_fl, 'mapping_tl': mapping_tl, 'dlatent_avg': dlatent_avg } if local_rank == 0: model_dict['discriminator_s'] = model_s.encoder model_dict['generator_s'] = model_s.decoder model_dict['mapping_fl_s'] = model_s.mapping_fl model_dict['mapping_tl_s'] = model_s.mapping_tl tracker = LossTracker(cfg.OUTPUT_DIR) checkpointer = Checkpointer(cfg, model_dict, { 'encoder_optimizer': encoder_optimizer, 'decoder_optimizer': decoder_optimizer, 'scheduler': scheduler, 'tracker': tracker }, logger=logger, save=local_rank == 0) extra_checkpoint_data = checkpointer.load() logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) arguments.update(extra_checkpoint_data) layer_to_resolution = decoder.layer_to_resolution dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS) rnd = np.random.RandomState(3456) latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) samplez = torch.tensor(latents).float().cuda() lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) if cfg.DATASET.SAMPLES_PATH: path = cfg.DATASET.SAMPLES_PATH src = [] with torch.no_grad(): for filename in list(os.listdir(path))[:32]: img = np.asarray(Image.open(os.path.join(path, filename))) if img.shape[2] == 4: img = img[:, :, :3] im = img.transpose((2, 0, 1)) x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] src.append(x) sample = torch.stack(src) else: dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32) sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) sample = (sample / 127.5 - 1.) lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, decoder_optimizer]) kls_real = [] kls_fake = [] rec_errs = [] best_fid = None # best_fid = 20.0 for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): new_beta_neg = cfg.MODEL.BETA_NEG[lod2batch.lod] if distributed: if model.module.beta_neg != new_beta_neg: model.module.beta_neg = new_beta_neg print("beta negative changed to:", new_beta_neg) else: if model.beta_neg != new_beta_neg: model.beta_neg = new_beta_neg print("beta negative changed to:", new_beta_neg) if (epoch > cfg.TRAIN.EPOCHS_PER_LOD * (cfg.MODEL.LAYER_COUNT - 1)) and (epoch % 10 == 0) and (local_rank == 0): print("calculating fid...") fid = calc_fid_from_dataset_generate( cfg, dataset, model_s, batch_size=50, cuda=1, dims=2048, device=torch.device( "cuda:0" if torch.cuda.is_available() else "cpu"), num_images=50000) print("epoch: {}, fid: {}".format(epoch, fid)) if best_fid is None: best_fid = fid elif fid < best_fid: print("best fid updated: {} -> {}".format(best_fid, fid)) best_fid = fid checkpointer.save("model_tmp_lod{}_fid_{}".format( lod2batch.lod, fid)) lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer]) logger.info( "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2**lod2batch.get_lod_power2(), 2** lod2batch.get_lod_power2(), lod2batch.get_blend_factor(), len(dataset) * world_size)) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) model.train() need_permute = False epoch_start_time = time.time() i = 0 diff_kls = [] batch_kls_real = [] batch_kls_fake = [] batch_rec_errs = [] for x_orig in tqdm(batches): i += 1 with torch.no_grad(): if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): continue if need_permute: x_orig = x_orig.permute(0, 3, 1, 2) x_orig = (x_orig / 127.5 - 1.) blend_factor = lod2batch.get_blend_factor() needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) x_prev_2x = F.interpolate(x_prev, needed_resolution) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) x.requires_grad = True if epoch < num_vae_epochs: encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() loss = model(x, lod2batch.lod, blend_factor, d_train=False, e_train=False) tracker.update(dict(loss_e=loss)) tracker.update(dict(loss_d=loss)) loss.backward() encoder_optimizer.step() decoder_optimizer.step() else: # ------------- Update Encoder ------------- # encoder_optimizer.zero_grad() loss_e = model(x, lod2batch.lod, blend_factor, d_train=False, e_train=True) tracker.update(dict(loss_e=loss_e)) loss_e.backward() encoder_optimizer.step() # ------------- Update Decoder ------------- # decoder_optimizer.zero_grad() loss_d = model(x, lod2batch.lod, blend_factor, d_train=True, e_train=False) loss_d.backward() tracker.update(dict(loss_d=loss_d)) decoder_optimizer.step() # ------------- Update Statistics ------------- # if distributed: tracker.update(dict(rec_loss=model.module.last_rec_loss)) tracker.update(dict(real_kl=model.module.last_real_kl)) tracker.update(dict(fake_kl=model.module.last_fake_kl)) tracker.update(dict(kl_diff=model.module.last_kl_diff)) tracker.update(dict(expelbo_f=model.module.last_expelbo_fake)) tracker.update(dict(expelbo_r=model.module.last_expelbo_rec)) diff_kls.append(model.module.last_kl_diff.data.cpu()) batch_kls_real.append(model.module.last_real_kl) batch_kls_fake.append(model.module.last_fake_kl) batch_rec_errs.append(model.module.last_rec_loss) else: tracker.update(dict(rec_loss=model.last_rec_loss)) tracker.update(dict(real_kl=model.last_real_kl)) tracker.update(dict(fake_kl=model.last_fake_kl)) tracker.update(dict(kl_diff=model.last_kl_diff)) tracker.update(dict(expelbo_f=model.last_expelbo_fake)) tracker.update(dict(expelbo_r=model.last_expelbo_rec)) diff_kls.append(model.last_kl_diff.data.cpu()) batch_kls_real.append(model.last_real_kl) batch_kls_fake.append(model.last_fake_kl) batch_rec_errs.append(model.last_rec_loss) if local_rank == 0: betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0)) model_s.lerp(model, betta) epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time lod_for_saving_model = lod2batch.lod lod2batch.step() if local_rank == 0: if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model) if lod2batch.is_time_to_report(): save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) scheduler.step() mean_diff_kl = np.mean(diff_kls) print("mean diff kl: ", mean_diff_kl) if epoch > num_vae_epochs - 1: kls_real.append(np.mean(batch_kls_real)) kls_fake.append(np.mean(batch_kls_fake)) rec_errs.append(np.mean(batch_rec_errs)) if local_rank == 0: checkpointer.save("model_tmp_lod%d" % lod_for_saving_model) save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) logger.info("Training finish!... save training results") if local_rank == 0: checkpointer.save("model_final").wait()
def sample(cfg, logger): torch.cuda.set_device(0) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE ) model.cuda(0) model.eval() model.requires_grad_(False) decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 model_dict = { 'discriminator_s': encoder, 'generator_s': decoder, 'mapping_tl_s': mapping_tl, 'mapping_fl_s': mapping_fl, 'dlatent_avg': dlatent_avg } checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False) extra_checkpoint_data = checkpointer.load() model.eval() layer_count = cfg.MODEL.LAYER_COUNT def encode(x): z, mu, _ = model.encode(x, layer_count - 1, 1) styles = model.mapping_fl(mu) return styles def decode(x): return model.decoder(x, layer_count - 1, 1, noise=True) rnd = np.random.RandomState(5) dataset = TFRecordsDataset(cfg, logger, rank=0, world_size=1, buffer_size_mb=10, channels=cfg.MODEL.CHANNELS, train=False) dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 10) b = iter(make_dataloader(cfg, logger, dataset, 10, 0, numpy=True)) def make(sample): canvas = [] with torch.no_grad(): for img in sample: x = torch.tensor(np.asarray(img, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] latents = encode(x[None, ...].cuda()) f = decode(latents) r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3) canvas.append(r) return canvas sample = next(b) canvas = make(sample) canvas = torch.cat(canvas, dim=0) save_image(canvas * 0.5 + 0.5, './make_figures/reconstructions_ffhq_real_1.png', nrow=2, pad_value=1.0) sample = next(b) canvas = make(sample) canvas = torch.cat(canvas, dim=0) save_image(canvas * 0.5 + 0.5, './make_figures/reconstructions_ffhq_real_2.png', nrow=2, pad_value=1.0)
def sample(cfg, logger): torch.cuda.set_device(0) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.cuda(0) model.eval() model.requires_grad_(False) decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 model_dict = { 'discriminator_s': encoder, 'generator_s': decoder, 'mapping_tl_s': mapping_tl, 'mapping_fl_s': mapping_fl, 'dlatent_avg': dlatent_avg } checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False) checkpointer.load() model.eval() im_size = 2**(cfg.MODEL.LAYER_COUNT + 1) seed = np.random.randint(0, 999999) print("seed:", seed) with torch.no_grad(): path = './make_figures/output' os.makedirs(path, exist_ok=True) os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) draw_uncurated_result_figure( cfg, './make_figures/output/%s/generations.jpg' % cfg.NAME, model, cx=0, cy=0, cw=im_size, ch=im_size, rows=6, lods=[0, 0, 0, 1, 1, 2], seed=seed)
def sample(cfg, logger): torch.cuda.set_device(0) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.cuda(0) model.eval() model.requires_grad_(False) decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 model_dict = { 'discriminator_s': encoder, 'generator_s': decoder, 'mapping_tl_s': mapping_tl, 'mapping_fl_s': mapping_fl, 'dlatent_avg': dlatent_avg } checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False) extra_checkpoint_data = checkpointer.load() model.eval() layer_count = cfg.MODEL.LAYER_COUNT def encode(x): z, mu, _ = model.encode(x, layer_count - 1, 1) styles = model.mapping_fl(mu) return styles def decode(x): return model.decoder(x, layer_count - 1, 1, noise=True) rnd = np.random.RandomState(4) path = cfg.DATASET.SAMPLES_PATH im_size = 2**(cfg.MODEL.LAYER_COUNT + 1) pathA = '17460.jpg' pathB = '02973.jpg' def open_image(filename): img = np.asarray(Image.open(path + '/' + filename)) if img.shape[2] == 4: img = img[:, :, :3] im = img.transpose((2, 0, 1)) x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] factor = x.shape[2] // im_size if factor != 1: x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] assert x.shape[2] == im_size _latents = encode(x[None, ...].cuda()) latents = _latents[0, 0] return latents def make(w): with torch.no_grad(): w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) x_rec = decode(w) return x_rec wa = open_image(pathA) wb = open_image(pathB) width = 7 images = [] for j in range(width): kh = j / (width - 1.0) ka = (1.0 - kh) kb = kh w = ka * wa + kb * wb interpolated = make(w) images.append(interpolated) images = torch.cat(images) path = './make_figures/output' os.makedirs(path, exist_ok=True) os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) save_image(images * 0.5 + 0.5, './make_figures/output/%s/interpolations.png' % cfg.NAME, nrow=width) save_image(images * 0.5 + 0.5, './make_figures/output/%s/interpolations.jpg' % cfg.NAME, nrow=width)
def sample(cfg, logger): torch.cuda.set_device(0) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.cuda(0) model.eval() model.requires_grad_(False) decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 model_dict = { 'discriminator_s': encoder, 'generator_s': decoder, 'mapping_tl_s': mapping_tl, 'mapping_fl_s': mapping_fl, 'dlatent_avg': dlatent_avg } checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False) extra_checkpoint_data = checkpointer.load() model.eval() layer_count = cfg.MODEL.LAYER_COUNT def encode(x): z, mu, _ = model.encode(x, layer_count - 1, 1) styles = model.mapping_fl(mu) return styles def decode(x): return model.decoder(x, layer_count - 1, 1, noise=True) path = cfg.DATASET.SAMPLES_PATH im_size = 2**(cfg.MODEL.LAYER_COUNT + 1) paths = list(os.listdir(path)) paths = sorted(paths) random.seed(5) random.shuffle(paths) def move_to(list, item, new_index): list.remove(item) list.insert(new_index, item) def make(paths): src = [] for filename in paths: img = np.asarray(Image.open(path + '/' + filename)) if img.shape[2] == 4: img = img[:, :, :3] im = img.transpose((2, 0, 1)) x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] factor = x.shape[2] // im_size if factor != 1: x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] assert x.shape[2] == im_size src.append(x) with torch.no_grad(): reconstructions = [] for s in src: latents = encode(s[None, ...]) reconstructions.append(decode(latents).cpu().detach().numpy()) return src, reconstructions def chunker_list(seq, size): return list((seq[i::size] for i in range(size))) final = chunker_list(paths, 4) path0, path1, path2, path3 = final path0.reverse() path1.reverse() path2.reverse() path3.reverse() src0, rec0 = make(path0) src1, rec1 = make(path1) src2, rec2 = make(path2) src3, rec3 = make(path3) initial_resolution = im_size lods_down = 1 padding_step = 4 width = 0 height = 0 current_padding = 0 final_resolution = initial_resolution for _ in range(lods_down): final_resolution /= 2 for i in range(lods_down + 1): width += current_padding * 2**(lods_down - i) height += current_padding * 2**(lods_down - i) current_padding += padding_step width += 2**(lods_down + 1) * final_resolution height += (lods_down + 1) * initial_resolution width = int(width) height = int(height) def make_part(current_padding, src, rec): canvas = np.ones([3, height + 20, width + 10]) padd = 0 initial_padding = current_padding height_padding = 0 for i in range(lods_down + 1): for x in range(2**i): for y in range(2**i): try: ims = src.pop() imr = rec.pop()[0] ims = ims.cpu().detach().numpy() imr = imr res = int(initial_resolution / 2**i) ims = resize(ims, (3, initial_resolution / 2**i, initial_resolution / 2**i)) imr = resize(imr, (3, initial_resolution / 2**i, initial_resolution / 2**i)) place( canvas, ims, current_padding + x * (2 * res + current_padding), i * initial_resolution + height_padding + y * (res + current_padding)) place( canvas, imr, current_padding + res + x * (2 * res + current_padding), i * initial_resolution + height_padding + y * (res + current_padding)) except IndexError: return canvas height_padding += initial_padding * 2 current_padding -= padding_step padd += padding_step return canvas canvas = [ make_part(current_padding, src0, rec0), make_part(current_padding, src1, rec1), make_part(current_padding, src2, rec2), make_part(current_padding, src3, rec3) ] canvas = np.concatenate(canvas, axis=2) path = './make_figures/output' os.makedirs(path, exist_ok=True) os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True) print('Saving image') save_path = './make_figures/output/%s/reconstructions_multiresolution.png' % cfg.NAME os.makedirs(os.path.dirname(save_path), exist_ok=True) save_image(torch.Tensor(canvas), save_path)