def reset_gen(): if args.model in ['iagan_began_cs']: gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif args.model in ['iagan_dcgan_cs']: gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 elif args.model in ['iagan_vanilla_vae_cs']: gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 else: raise NotImplementedError() return gen, img_size
def reset_gen(model): if model == 'began': gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif model == 'vae': gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 elif model == 'dcgan': gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 return gen, img_size
def cut_training(n_cols): began_settings = { 1: { 'batch_size': 32, 'z_lr': 3e-5, 'path': ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=1.z_lr=3e-5/gen_ckpt.24.pt') }, 2: { 'batch_size': 32, 'z_lr': 8e-5, 'path': ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=1.z_lr=8e-5/gen_ckpt.19.pt') }, 3: { 'batch_size': 64, 'z_lr': 1e-4, 'path': ('./checkpoints/celeba_began.withskips.bs64.cosine.min=0.25' '.n_cuts=1.z_lr=1e-4/gen_ckpt.19.pt') }, 4: { 'batch_size': 64, 'z_lr': 3e-5, 'path': ('./checkpoints/celeba_began.withskips.bs64.cosine.min=0.25' '.n_cuts=1.z_lr=3e-5/gen_ckpt.24.pt') }, 5: { 'batch_size': 64, 'z_lr': 8e-5, 'path': ('./checkpoints/celeba_began.withskips.bs64.cosine.min=0.25' '.n_cuts=1.z_lr=8e-5/gen_ckpt.24.pt') } } fig, ax = plt.subplots(len(began_settings.items()), n_cols, figsize=(n_cols, len(began_settings.items()))) fig.suptitle('BEGAN (cuts=1)', fontsize=16) for i, settings in began_settings.items(): g = Generator128(64).to('cuda') g = load_trained_net(g, settings['path']) input_shapes = g.input_shapes[1] z1_shape = input_shapes[0] z2_shape = input_shapes[1] for col in range(n_cols): z1 = torch.randn(1, *z1_shape).clamp(-1, 1).to('cuda') if len(z2_shape) == 0: z2 = None else: z2 = torch.randn(1, *z2_shape).clamp(-1, 1).to('cuda') img = g.forward( z1, z2, n_cuts=1).detach().cpu().squeeze(0).numpy().transpose( [1, 2, 0]) ax[i - 1, col].imshow(np.clip(img, 0, 1), aspect='auto') ax[i - 1, col].set_xticks([]) ax[i - 1, col].set_yticks([]) ax[i - 1, col].set_frame_on(False) fig.subplots_adjust(0, 0, 1, 0.93, 0, 0) os.makedirs('./figures/cut_training/', exist_ok=True) plt.savefig(f'./figures/cut_training/began_cut_training.pdf', bbox_inches='tight', dpi=300) dcgan_settings = { 1: { 'z_lr': 5e-5, 'b1': 0.5, 'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1' '.bs_64.b1_0.5.lr_5e-05.pt') }, 2: { 'z_lr': 1e-4, 'b1': 0.5, 'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1' '.bs_64.b1_0.5.lr_0.0001.pt') }, 3: { 'z_lr': 2e-4, 'b1': 0.5, 'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1' '.bs_64.b1_0.5.lr_0.0002.pt') }, 4: { 'z_lr': 5e-5, 'b1': 0.9, 'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1' '.bs_64.b1_0.9.lr_5e-05.pt') }, 5: { 'z_lr': 2e-4, 'b1': 0.9, 'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1' '.bs_64.b1_0.9.lr_0.0002.pt') }, } fig, ax = plt.subplots(len(dcgan_settings.items()), n_cols, figsize=(n_cols, len(dcgan_settings.items()))) fig.suptitle('DCGAN (cuts=1)', fontsize=16) for i, settings in dcgan_settings.items(): g = dcgan_generator().to('cuda') g.load_state_dict(torch.load(settings['path'])) input_shapes = g.input_shapes[1] z1_shape = input_shapes[0] z2_shape = input_shapes[1] for col in range(n_cols): z1 = torch.randn(1, *z1_shape).clamp(-1, 1).to('cuda') if len(z2_shape) == 0: z2 = None else: z2 = torch.randn(1, *z2_shape).clamp(-1, 1).to('cuda') img = g.forward( z1, z2, n_cuts=1).detach().cpu().squeeze(0).numpy().transpose( [1, 2, 0]) # Rescale from [-1, 1] to [0, 1] img = (img + 1) / 2 ax[i - 1, col].imshow(np.clip(img, 0, 1), aspect='auto') ax[i - 1, col].set_xticks([]) ax[i - 1, col].set_yticks([]) ax[i - 1, col].set_frame_on(False) fig.subplots_adjust(0, 0, 1, 0.93, 0, 0) os.makedirs('./figures/cut_training/', exist_ok=True) plt.savefig(f'./figures/cut_training/dcgan_cut_training.pdf', bbox_inches='tight', dpi=300)
def generator_samples(model): if model == 'began': g = Generator128(64).to('cuda:0') g = load_trained_net( g, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) elif model == 'vae': g = VAE().to('cuda:0') g.load_state_dict( torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')) g = g.decoder elif model == 'biggan': g = BigGanSkip().to('cuda:0') elif model == 'dcgan': g = dcgan_generator().to('cuda:0') g.load_state_dict( torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt'))) else: raise NotImplementedError nseed = 10 n_cuts_list = [0, 1, 2, 3, 4, 5] fig, ax = plt.subplots(len(n_cuts_list), nseed, figsize=(10, len(n_cuts_list))) for row, n_cuts in enumerate(n_cuts_list): input_shapes = g.input_shapes[n_cuts] z1_shape = input_shapes[0] z2_shape = input_shapes[1] for col in range(nseed): torch.manual_seed(col) np.random.seed(col) if n_cuts == 0 and model == 'biggan': class_vector = torch.tensor( 949, dtype=torch.long).to('cuda:0').unsqueeze( 0) # 949 = strawberry embed = g.biggan.embeddings( torch.nn.functional.one_hot( class_vector, num_classes=1000).to(torch.float)) cond_vector = torch.cat( (torch.randn(1, 128).to('cuda:0'), embed), dim=1) img = orig_biggan_forward( g.biggan.generator, cond_vector, truncation=1.0).detach().cpu().squeeze( 0).numpy().transpose([1, 2, 0]) elif n_cuts > 0 and model == 'biggan': z1 = torch.randn(1, *z1_shape).to('cuda:0') class_vector = torch.tensor( 949, dtype=torch.long).to('cuda:0').unsqueeze( 0) # 949 = strawberry embed = g.biggan.embeddings( torch.nn.functional.one_hot( class_vector, num_classes=1000).to(torch.float)) cond_vector = torch.cat( (torch.randn(1, 128).to('cuda:0'), embed), dim=1) z2 = cond_vector img = g( z1, z2, truncation=1.0, n_cuts=n_cuts).detach().cpu().squeeze(0).numpy().transpose( [1, 2, 0]) else: z1 = torch.randn(1, *z1_shape).to('cuda:0') if len(z2_shape) == 0: z2 = None else: z2 = torch.randn(1, *z2_shape).to('cuda:0') img = g( z1, z2, n_cuts=n_cuts).detach().cpu().squeeze(0).numpy().transpose( [1, 2, 0]) if g.rescale: img = (img + 1) / 2 ax[row, col].imshow(np.clip(img, 0, 1), aspect='auto') ax[row, col].set_xticks([]) ax[row, col].set_yticks([]) ax[row, col].set_frame_on(False) if col == 0: ax[row, col].set_ylabel(f'{n_cuts}') fig.subplots_adjust(0, 0, 1, 1, 0, 0) os.makedirs('./figures/generator_samples', exist_ok=True) plt.savefig((f'./figures/generator_samples/' f'model={model}.pdf'), dpi=300, bbox_inches='tight')
def main(args): checkpoint_path = f"checkpoints/{args.dataset}_{args.run_name}" tensorboard_path = f"tensorboard_logs/{args.dataset}_{args.run_name}" torch.backends.cudnn.benchmark = True device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') writer = SummaryWriter(tensorboard_path) dataloader, _ = get_dataloader(args.dataset_dir, args.batch_size, args.n_train, True) gen = Generator128(args.latent_dim).to(device) disc = Discriminator128(args.latent_dim).to(device) # Get latent_shape for x1 only latent_shape = gen.input_shapes[args.n_cuts][0] if torch.cuda.device_count() > 1: gen = torch.nn.DataParallel(gen) disc = torch.nn.DataParallel(disc) gen_optimizer = torch.optim.Adam(gen.parameters(), args.lr) gen_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( gen_optimizer, len(dataloader) * args.epochs, 0.25 * args.lr) disc_optimizer = torch.optim.Adam(disc.parameters(), args.lr) disc_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( disc_optimizer, len(dataloader) * args.epochs, 0.25 * args.lr) current_checkpoint = 0 if (not os.path.exists(checkpoint_path)): os.makedirs(checkpoint_path) else: print("Restoring from checkpoint...") paths = os.listdir(checkpoint_path) try: available = sorted(set([int(x.split(".")[1]) for x in paths])) # Find a checkpoint that both gen AND disc have reached # Reaching zero will cause IndexError during pop() while True: latest_idx = available.pop() latest_disc = os.path.join(checkpoint_path, f"disc_ckpt.{latest_idx}.pt") latest_gen = os.path.join(checkpoint_path, f"gen_ckpt.{latest_idx}.pt") if os.path.exists(latest_disc) and os.path.exists(latest_gen): break current_checkpoint = latest_idx disc_epoch = load(latest_disc, disc, disc_optimizer, disc_scheduler) gen_epoch = load(latest_gen, gen, gen_optimizer, gen_scheduler) assert disc_epoch == gen_epoch, \ 'Checkpoint contents are mismatched!' print(f"Loaded checkpoint {current_checkpoint}") except Exception as e: print(e) print("Unable to load from checkpoint.") k = 0 # Uniform from -1 to 1 const_sample = get_z_vector((args.batch_size, *latent_shape), mode='uniform', dtype=torch.float, device=device) n_gen_param = sum([x.numel() for x in gen.parameters() if x.requires_grad]) n_disc_param = sum( [x.numel() for x in disc.parameters() if x.requires_grad]) print(f"{n_gen_param + n_disc_param} Trainable Parameters") if current_checkpoint < args.epochs - 1: for e in trange(current_checkpoint, args.epochs, initial=current_checkpoint, desc='Epoch', leave=True, disable=args.disable_tqdm): for i, img_batch in tqdm(enumerate(dataloader), total=len(dataloader), leave=False, disable=args.disable_tqdm): disc_optimizer.zero_grad() gen_optimizer.zero_grad() img_batch = img_batch.to(device) # Uniform from -1 to 1 d_latent_sample = get_z_vector( (args.batch_size, *latent_shape), mode='uniform', dtype=torch.float, device=device) g_latent_sample = get_z_vector( (args.batch_size, *latent_shape), mode='uniform', dtype=torch.float, device=device) batch_ac_loss = ac_loss(img_batch, disc) d_fake_ac_loss = ac_loss( gen.forward(d_latent_sample, x2=None, n_cuts=args.n_cuts).detach(), disc) g_fake_ac_loss = ac_loss( gen.forward(g_latent_sample, x2=None, n_cuts=args.n_cuts), disc) def d_loss(): loss = batch_ac_loss - k * d_fake_ac_loss loss.backward() return loss def g_loss(): loss = g_fake_ac_loss loss.backward() return loss disc_optimizer.step(d_loss) gen_optimizer.step(g_loss) disc_scheduler.step() gen_scheduler.step() k = k + args.prop_gain * \ (args.gamma * batch_ac_loss.item() - g_fake_ac_loss.item()) m = ac_loss(img_batch, disc) + \ torch.abs(args.gamma * batch_ac_loss - g_fake_ac_loss) writer.add_scalar("Convergence", m, len(dataloader) * e + i) if (i % args.log_every == 0): ex_img = gen.forward(g_latent_sample, x2=None, n_cuts=args.n_cuts)[0] writer.add_image("Random/Raw", ex_img, len(dataloader) * e + i) writer.add_image("Random/Clamp", ex_img.clamp(0, 1), len(dataloader) * e + i) writer.add_image("Random/Normalize", normalize(ex_img), len(dataloader) * e + i) ex_img_const = gen.forward(const_sample, x2=None, n_cuts=args.n_cuts)[0] writer.add_image("Constant/Raw", ex_img_const, len(dataloader) * e + i) writer.add_image("Constant/Clamp", ex_img_const.clamp(0, 1), len(dataloader) * e + i) writer.add_image("Constant/Normalize", normalize(ex_img_const), len(dataloader) * e + i) save(os.path.join(checkpoint_path, f"gen_ckpt.{e}.pt"), e, gen, gen_optimizer, gen_scheduler) save(os.path.join(checkpoint_path, f"disc_ckpt.{e}.pt"), e, disc, disc_optimizer, disc_scheduler)
lasso_est = Lasso(alpha=gamma) lasso_est.fit(A_val.T, y_val.reshape(-1)) x_hat = lasso_est.coef_ x_hat = np.reshape(x_hat, [-1]) return x_hat if __name__ == '__main__': DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' a = argparse.ArgumentParser() a.add_argument('--img_dir', required=True) a.add_argument('--disable_tqdm', default=False) args = a.parse_args() gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) n_cuts = 3 img_size = 128 img_shape = (3, img_size, img_size) forward_model = GaussianCompressiveSensing(n_measure=2500, img_shape=img_shape) # forward_model = NoOp() for img_name in tqdm(os.listdir(args.img_dir),
def mgan_images(args): if args.set_seed: torch.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.makedirs(BASE_DIR, exist_ok=True) if args.model in ['mgan_began_cs']: gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif args.model in ['mgan_vanilla_vae_cs']: gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 elif args.model in ['mgan_dcgan_cs']: gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 else: raise NotImplementedError() img_shape = (3, img_size, img_size) metadata = recovery_settings[args.model] n_cuts_list = metadata['n_cuts_list'] del (metadata['n_cuts_list']) z_init_mode_list = metadata['z_init_mode'] limit_list = metadata['limit'] assert len(z_init_mode_list) == len(limit_list) del (metadata['z_init_mode']) del (metadata['limit']) forwards = forward_models[args.model] data_split = Path(args.img_dir).name for img_name in tqdm(sorted(os.listdir(args.img_dir)), desc='Images', leave=True, disable=args.disable_tqdm): # Load image and get filename without extension orig_img = load_target_image(os.path.join(args.img_dir, img_name), img_size).to(DEVICE) img_basename, _ = os.path.splitext(img_name) for n_cuts in tqdm(n_cuts_list, desc='N_cuts', leave=False, disable=args.disable_tqdm): metadata['n_cuts'] = n_cuts for i, (f, f_args_list) in enumerate( tqdm(forwards.items(), desc='Forwards', leave=False, disable=args.disable_tqdm)): for f_args in tqdm(f_args_list, desc=f'{f} Args', leave=False, disable=args.disable_tqdm): f_args['img_shape'] = img_shape forward_model = get_forward_model(f, **f_args) for z_init_mode, limit in zip( tqdm(z_init_mode_list, desc='z_init_mode', leave=False), limit_list): metadata['z_init_mode'] = z_init_mode metadata['limit'] = limit # Before doing recovery, check if results already exist # and possibly skip recovered_name = 'recovered.pt' results_folder = get_results_folder( image_name=img_basename, model=args.model, n_cuts=n_cuts, split=data_split, forward_model=forward_model, recovery_params=dict_to_str(metadata), base_dir=BASE_DIR) os.makedirs(results_folder, exist_ok=True) recovered_path = results_folder / recovered_name if os.path.exists( recovered_path) and not args.overwrite: print( f'{recovered_path} already exists, skipping...' ) continue if args.run_name is not None: current_run_name = ( f'{img_basename}.{forward_model}' f'.{dict_to_str(metadata)}' f'.{args.run_name}') else: current_run_name = None recovered_img, distorted_img, _ = mgan_recover( orig_img, gen, n_cuts, forward_model, metadata['optimizer'], z_init_mode, limit, metadata['z_lr'], metadata['n_steps'], metadata['z_number'], metadata['restarts'], args.run_dir, current_run_name, args.disable_tqdm) # Make images folder img_folder = get_images_folder(split=data_split, image_name=img_basename, img_size=img_size, base_dir=BASE_DIR) os.makedirs(img_folder, exist_ok=True) # Save original image if needed original_img_path = img_folder / 'original.pt' if not os.path.exists(original_img_path): torch.save(orig_img, original_img_path) # Save distorted image if needed if forward_model.viewable: distorted_img_path = img_folder / f'{forward_model}.pt' if not os.path.exists(distorted_img_path): torch.save(distorted_img, distorted_img_path) # Save recovered image and metadata torch.save(recovered_img, recovered_path) pickle.dump( metadata, open(results_folder / 'metadata.pkl', 'wb')) p = psnr(recovered_img, orig_img) pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))