Ejemplo n.º 1
0
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('model')
    parser.add_argument('-r', '--reconstruction', action='store_true')
    parser.add_argument('-u', '--uniform-sample', action='store_true')
    parser.add_argument('-p', '--pixelcnn', type=str, default=None)
    parser.add_argument('-g', '--gpu', action='store_true', default=None)
    parser.add_argument('--no-gpu', action='store_false', default=None)
    parser.add_argument('-n', '--num-samples', type=int, default=16)
    parser.add_argument('--plot_path', type=str, default=None)
    args = parser.parse_args()

    recon_path = None
    sample_path = None
    pixelcnn_sample_path = None
    if args.plot_path:
        os.makedirs(args.plot_path, exist_ok=True)

        recon_path = os.path.join(args.plot_path, 'recon.png')
        sample_path = os.path.join(args.plot_path, 'sample.png')
        pixelcnn_sample_path = os.path.join(args.plot_path,
                                            'pixelcnn_sample.png')

    use_gpu = args.gpu
    if use_gpu is None:
        use_gpu = torch.cuda.is_available()
    device = torch.device('cuda' if use_gpu else 'cpu')

    vqvae, config = load_vqvae(args.model, device)
    params = config["hyperparameters"]

    print(f"Loaded model {args.model}")
    data = None

    if args.reconstruction:
        if data is None:
            _, data, _, _, _ = load_data_and_data_loaders(
                params['dataset'], params['batch_size'])
        reconstruct(vqvae, [data[i][0] for i in range(args.num_samples)],
                    device,
                    plot_path=recon_path)

    if args.uniform_sample:
        uniform_sample(vqvae, args.num_samples, device, plot_path=sample_path)

    if args.pixelcnn:
        ckpt = torch.load(args.pixelcnn)
        pixelcnn_state = ckpt['model']
        cfg = ckpt['config']
        pixelcnn = PixelCNN(cfg).to(device)
        pixelcnn.load_state_dict(pixelcnn_state)
        code_shape = vqvae.encode(torch.zeros((1, 3, 32, 32),
                                              device=device)).shape
        code = pixelcnn.sample(code_shape, args.num_samples, device=device)
        if not pixelcnn_sample_path:
            plt.title('PixelCNN decode')
        decode(vqvae, code, plot_path=pixelcnn_sample_path)
Ejemplo n.º 2
0
def encode(model_path: Union[str, Path], output_path: Union[str, Path]):
    model_path = Path(model_path)
    output_path = Path(output_path)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vqvae, config = load_vqvae(model_path, device)
    params = config["hyperparameters"]
    dataset = params['dataset']

    _, _, training_loader, validation_loader, _ = load_data_and_data_loaders(
        dataset, 128)

    print(f"Encoding {dataset} test...")
    encode_from_loader(vqvae, validation_loader,
                       output_path / 'test' / f'encoded_{dataset}.npz', device)

    print(f"Encoding {dataset} train...")
    encode_from_loader(vqvae, training_loader,
                       output_path / 'train' / f'encoded_{dataset}.npz',
                       device)
Ejemplo n.º 3
0
# whether or not to save model
parser.add_argument("-save", action="store_true")
parser.add_argument("--filename", type=str, default=timestamp)

args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.save:
    print('Results will be saved in ./results/vqvae_' + args.filename + '.pth')
"""
Load data and define batch data loaders
"""

training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders(
    args.dataset, args.batch_size)
"""
Set up VQ-VAE model with components defined in ./models/ folder
"""

model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers,
              args.n_embeddings, args.embedding_dim, args.beta).to(device)
"""
Set up optimizer and training loop
"""
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)

model.train()

results = {
    'n_updates': 0,
Ejemplo n.º 4
0
                    help='1 for grayscale 3 for rgb')
parser.add_argument("--n_embeddings",
                    type=int,
                    default=512,
                    help='number of embeddings from VQ VAE')
parser.add_argument("--n_layers", type=int, default=15)
parser.add_argument("--learning_rate", type=float, default=3e-4)

args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"""
data loaders
"""
if args.dataset == 'LATENT_BLOCK':
    _, _, train_loader, test_loader, _ = utils.load_data_and_data_loaders(
        'LATENT_BLOCK', args.batch_size)
else:
    train_loader = torch.utils.data.DataLoader(
        eval('datasets.' + args.dataset)(
            '../data/{}/'.format(args.dataset),
            train=True,
            download=True,
            transform=transforms.ToTensor(),
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(eval('datasets.' + args.dataset)(
        '../data/{}/'.format(args.dataset),
        train=False,