def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('model') parser.add_argument('-r', '--reconstruction', action='store_true') parser.add_argument('-u', '--uniform-sample', action='store_true') parser.add_argument('-p', '--pixelcnn', type=str, default=None) parser.add_argument('-g', '--gpu', action='store_true', default=None) parser.add_argument('--no-gpu', action='store_false', default=None) parser.add_argument('-n', '--num-samples', type=int, default=16) parser.add_argument('--plot_path', type=str, default=None) args = parser.parse_args() recon_path = None sample_path = None pixelcnn_sample_path = None if args.plot_path: os.makedirs(args.plot_path, exist_ok=True) recon_path = os.path.join(args.plot_path, 'recon.png') sample_path = os.path.join(args.plot_path, 'sample.png') pixelcnn_sample_path = os.path.join(args.plot_path, 'pixelcnn_sample.png') use_gpu = args.gpu if use_gpu is None: use_gpu = torch.cuda.is_available() device = torch.device('cuda' if use_gpu else 'cpu') vqvae, config = load_vqvae(args.model, device) params = config["hyperparameters"] print(f"Loaded model {args.model}") data = None if args.reconstruction: if data is None: _, data, _, _, _ = load_data_and_data_loaders( params['dataset'], params['batch_size']) reconstruct(vqvae, [data[i][0] for i in range(args.num_samples)], device, plot_path=recon_path) if args.uniform_sample: uniform_sample(vqvae, args.num_samples, device, plot_path=sample_path) if args.pixelcnn: ckpt = torch.load(args.pixelcnn) pixelcnn_state = ckpt['model'] cfg = ckpt['config'] pixelcnn = PixelCNN(cfg).to(device) pixelcnn.load_state_dict(pixelcnn_state) code_shape = vqvae.encode(torch.zeros((1, 3, 32, 32), device=device)).shape code = pixelcnn.sample(code_shape, args.num_samples, device=device) if not pixelcnn_sample_path: plt.title('PixelCNN decode') decode(vqvae, code, plot_path=pixelcnn_sample_path)
def encode(model_path: Union[str, Path], output_path: Union[str, Path]): model_path = Path(model_path) output_path = Path(output_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vqvae, config = load_vqvae(model_path, device) params = config["hyperparameters"] dataset = params['dataset'] _, _, training_loader, validation_loader, _ = load_data_and_data_loaders( dataset, 128) print(f"Encoding {dataset} test...") encode_from_loader(vqvae, validation_loader, output_path / 'test' / f'encoded_{dataset}.npz', device) print(f"Encoding {dataset} train...") encode_from_loader(vqvae, training_loader, output_path / 'train' / f'encoded_{dataset}.npz', device)
# whether or not to save model parser.add_argument("-save", action="store_true") parser.add_argument("--filename", type=str, default=timestamp) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.save: print('Results will be saved in ./results/vqvae_' + args.filename + '.pth') """ Load data and define batch data loaders """ training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders( args.dataset, args.batch_size) """ Set up VQ-VAE model with components defined in ./models/ folder """ model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).to(device) """ Set up optimizer and training loop """ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True) model.train() results = { 'n_updates': 0,
help='1 for grayscale 3 for rgb') parser.add_argument("--n_embeddings", type=int, default=512, help='number of embeddings from VQ VAE') parser.add_argument("--n_layers", type=int, default=15) parser.add_argument("--learning_rate", type=float, default=3e-4) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ data loaders """ if args.dataset == 'LATENT_BLOCK': _, _, train_loader, test_loader, _ = utils.load_data_and_data_loaders( 'LATENT_BLOCK', args.batch_size) else: train_loader = torch.utils.data.DataLoader( eval('datasets.' + args.dataset)( '../data/{}/'.format(args.dataset), train=True, download=True, transform=transforms.ToTensor(), ), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(eval('datasets.' + args.dataset)( '../data/{}/'.format(args.dataset), train=False,