def create_run(device, temp, pixelsnail_batch, ckpt_epoch, pixelsnail_ckpt_epoch, hier, architecture, num_embeddings, neighborhood, selection_fn, dataset, size, **kwargs): experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) vqvae_checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch) # pixelsnail_checkpoint_name = f'pixelsnail_{experiment_name}_{hier}_{str(pixelsnail_ckpt_epoch + 1).zfill(3)}.pt' pixelsnail_checkpoint_name = 'pixelsnail_vqvae_imagenet_num_embeddings[512]_neighborhood[1]_selectionFN[vanilla]_size[128]_bottom_420.pt' # model_vqvae = load_model('vqvae', vqvae_checkpoint_name, device, architecture, num_embeddings, neighborhood, selection_fn, **kwargs) # model_top = load_model('pixelsnail_top', args.top, device) model_bottom = load_model('pixelsnail_bottom', pixelsnail_checkpoint_name, device, size=size, **kwargs) num_samples = 50000 sampled_directory = os.path.join('sampled_images', pixelsnail_checkpoint_name).replace('.pt', '') if os.path.exists(sampled_directory): shutil.rmtree(sampled_directory) os.mkdir(sampled_directory) for sample_ind in tqdm(range(num_samples), 'Sampling image for: {}'.format(pixelsnail_checkpoint_name)): # top_sample = sample_model(model_top, device, args.batch, [32, 32], args.temp) bottom_sample = sample_model( model_bottom, device, pixelsnail_batch, [size // 4, size // 4], temp, condition=None # model_bottom, device, args.batch, [64, 64], args.temp, condition=top_sample ) # decoded_sample = model_vqvae._modules['module'].decode_code(bottom_sample) # decoded_sample = model_vqvae.decode_code(top_sample, bottom_sample) # decoded_sample = decoded_sample.clamp(-1, 1) # filename = 'sampled_{}.png'.format(sample_ind) # target_path = os.path.join(sampled_directory, filename) save_image(decoded_sample, target_path, normalize=True, range=(-1, 1))
def create_extraction_run(size, device, dataset, data_path, num_workers, num_embeddings, architecture, ckpt_epoch, neighborhood, selection_fn, embed_dim, **kwargs): train_dataset, test_dataset = get_dataset(dataset, data_path, size) print('Creating named datasets') # We don't really use the "Named" part, but I'm keeping it to stay close to the original code repository train_named_dataset = NamedDataset(train_dataset) test_named_dataset = NamedDataset(test_dataset) print('creating data loaders') train_loader = DataLoader(train_named_dataset, batch_size=kwargs['vae_batch'], shuffle=False, num_workers=num_workers) test_loader = DataLoader(test_named_dataset, batch_size=kwargs['vae_batch'], shuffle=False, num_workers=num_workers) # This is still the VQ-VAE experiment name and path experiment_name = util_funcs.create_experiment_name( architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) checkpoint_name = util_funcs.create_checkpoint_name( experiment_name, ckpt_epoch) checkpoint_path = f'checkpoint/{checkpoint_name}' print('Loading model') model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) model.load_state_dict(torch.load(checkpoint_path), strict=False) model = model.to(device) model.eval() print('Creating LMDB DBs') map_size = 100 * 1024 * 1024 * 1024 # This would be the maximum size of the databases db_name = checkpoint_name[:-3] + '_dataset[{}]'.format( dataset ) # This comprises of the experiment name and the epoch the codes are taken from train_env = lmdb.open( os.path.join('codes', 'train_codes', db_name), map_size=map_size) # Will save the encodings of train samples test_env = lmdb.open( os.path.join('codes', 'test_codes', db_name), map_size=map_size) # Will save the encodings of test samples print('Extracting') if architecture == 'vqvae': extract(train_env, train_loader, model, device, 'train') extract(test_env, test_loader, model, device, 'test')
def get_checkpoint_names(architecture, ckpt_epoch, dataset, hier, kwargs, neighborhood, num_embeddings, pixelsnail_ckpt_epoch, selection_fn, size): experiment_name = util_funcs.create_experiment_name( architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) vqvae_checkpoint_name = util_funcs.create_checkpoint_name( experiment_name, ckpt_epoch) pixelsnail_checkpoint_name = f'pixelsnail_{experiment_name}_{hier}_{str(pixelsnail_ckpt_epoch + 1).zfill(3)}.pt' return pixelsnail_checkpoint_name, vqvae_checkpoint_name
def create_run(architecture, dataset, num_embeddings, num_workers, selection_fn, neighborhood, device, embed_dim, size, **kwargs): global args, scheduler # Get VQVAE experiment name experiment_name = util_funcs.create_experiment_name( architecture, dataset, num_embeddings, neighborhood, selection_fn=selection_fn, size=size, **kwargs) # Prepare logger writer = SummaryWriter( os.path.join('runs', 'pixelsnail_' + experiment_name + '2', str(datetime.datetime.now()))) # Load datasets test_loader, train_loader = load_datasets(args, experiment_name, num_workers, dataset) # Create model and optimizer model, optimizer = prepare_model_parts(train_loader) # Get checkpoint path for underlying VQ-VAE model checkpoint_name = util_funcs.create_checkpoint_name( experiment_name, kwargs['ckpt_epoch']) checkpoint_path = f'checkpoint/{checkpoint_name}' # Load underlying VQ-VAE model for logging purposes vqvae_model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) vqvae_model.load_state_dict(torch.load(os.path.join(checkpoint_path)), strict=False) vqvae_model = vqvae_model.to(args.device) vqvae_model.eval() # Train model train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \ test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses, = \ run_train(args, experiment_name, model, optimizer, scheduler, test_loader, train_loader, writer, vqvae_model) return train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \ test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses
def load_datasets(args, experiment_name, num_workers, dataset): db_name = util_funcs.create_checkpoint_name( experiment_name, args.ckpt_epoch)[:-3] + '_dataset[{}]'.format(dataset) train_dataset = LMDBDataset( os.path.join('..', 'codes', 'train_codes', db_name), args.architecture) test_dataset = LMDBDataset( os.path.join('..', 'codes', 'test_codes', db_name), args.architecture) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=True) return test_loader, train_loader
def create_run(architecture, dataset, num_embeddings, num_workers, selection_fn, neighborhood, device, size, ckpt_epoch, embed_dim, **kwargs): global args, scheduler print('creating data loaders') experiment_name = util_funcs.create_experiment_name( architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) checkpoint_name = util_funcs.create_checkpoint_name( experiment_name, ckpt_epoch) checkpoint_path = f'checkpoint/{checkpoint_name}' test_loader, train_loader = load_datasets(args, experiment_name, num_workers, dataset) print('Loading model') model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) model.load_state_dict(torch.load(os.path.join('..', checkpoint_path)), strict=False) model = model.to(device) model.eval() for batch in train_loader: print('decoding') X = model.decode_code(batch[1].to(next(model.parameters()).device)) print('decoded') utils.save_image( torch.cat([X], 0), 'X_img.png', nrow=1, normalize=True, range=(-1, 1), ) a = 5 exit()
def get_PSNR(size, device, dataset, data_path, num_workers, num_embeddings, architecture, ckpt_epoch, neighborhood, selection_fn, embed_dim, **kwargs): print('setting up dataset') _, test_dataset = get_dataset(dataset, data_path, size) print('creating data loaders') test_loader = DataLoader(test_dataset, batch_size=kwargs['vae_batch'], shuffle=True, num_workers=num_workers) experiment_name = util_funcs.create_experiment_name( architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) checkpoint_name = util_funcs.create_checkpoint_name( experiment_name, ckpt_epoch) checkpoint_path = f"{kwargs['checkpoint_path']}/{checkpoint_name}" _validate_args(checkpoint_path) print('Calculating PSNR for: {}'.format(checkpoint_name)) print('Loading model') model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) model.load_state_dict(torch.load(os.path.join('..', checkpoint_path), map_location='cuda:0'), strict=False) model = model.to(device) model.eval() mse = nn.MSELoss() MAX_i = 255 to_MAX = lambda t: (t + 1) * MAX_i / 2 psnr_term = 20 * torch.log10(torch.ones(1, 1) * MAX_i) # calculate PSNR over test set sparsity = 0 top_percentiles = 0 num_zeros = 0 psnrs = 0 done = 0 for batch in tqdm(test_loader, desc='Calculating PSNR'): with torch.no_grad(): img = batch[0].to(device) out, _, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros = model( img) if psnrs == 0 and os.path.isdir( args.sample_save_path): # save the first test batch _save_tensors( experiment_name, img, out, ) cur_psnr = psnr_term.item() - 10 * torch.log10( mse(to_MAX(out), to_MAX(img))) # Gather data psnrs += cur_psnr sparsity += norm_Z.mean() top_percentiles += top_percentile.mean() num_zeros += num_zeros.mean() done += 1 # Dump results print('sparsity: {}'.format(sparsity)) print('done: {}'.format(done)) print('(sparsity/float(done)).item(): {}'.format( (sparsity / float(done)).item())) avg_psnr = (psnrs / float(done)).item() avg_top_percentiles = (top_percentiles / float(done)).item() avg_num_zeros = (num_zeros / float(done)).item() avg_spasity = (sparsity / float(done)).item() print('#' * 30) print('Experiment name: {}'.format(experiment_name)) print('Epoch name: {}'.format(ckpt_epoch)) print('avg_psnr: {}'.format(avg_psnr)) print('is_quantize_coefs: {}'.format(kwargs['is_quantize_coefs'])) # dump params and stats into a JSON file _save_root = os.path.join(args.sample_save_path, experiment_name) os.makedirs(_save_root, exist_ok=True) with open(f'{_save_root}/eval_result_{_NOW}.json', 'w') as fp: json.dump( dict( # --- params dataset=args.dataset, selection_fn=args.selection_fn, embed_dim=args.embed_dim, num_atoms=args.num_embeddings, image_size=args.size, batch_size=args.vae_batch, num_strides=args.num_strides, num_nonzero=args.num_nonzero, normalize_x=args.normalize_x, normalize_d=args.normalize_dict, epoch=args.ckpt_epoch, # --- stats psnr=avg_psnr, compression=None, # todo - calculate this atom_bits=None, # todo - calculate this non_zero_mean=avg_spasity, non_zero_99pct=avg_top_percentiles, # tmp=args.tmp, ), fp, indent=1)