Example #1
0
def main(args):
    torch.cuda.empty_cache()
    pl.trainer.seed_everything(seed=42)

    datamodule = LMDBDataModule(
        path=args.dataset_path,
        embedding_id=args.level,
        batch_size=args.batch_size,
        num_workers=5,
    )

    datamodule.setup()
    args.num_embeddings = datamodule.num_embeddings

    if args.use_model == 'pixelcnn':
        model = PixelCNN(args)
    elif args.use_model == 'pixelsnail':
        model = PixelSNAIL(args)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=1,
                                                       save_last=True,
                                                       monitor='val_loss_mean')
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback])
    trainer.fit(model, datamodule=datamodule)
Example #2
0
def main():
    parser = argparse.ArgumentParser(description='PixelCNN')

    parser.add_argument('--causal-ksize', type=int, default=7,
                        help='Kernel size of causal convolution')
    parser.add_argument('--hidden-ksize', type=int, default=7,
                        help='Kernel size of hidden layers convolutions')

    parser.add_argument('--color-levels', type=int, default=2,
                        help='Number of levels to quantisize value of each channel of each pixel into')

    parser.add_argument('--hidden-fmaps', type=int, default=30,
                        help='Number of feature maps in hidden layer')
    parser.add_argument('--out-hidden-fmaps', type=int, default=10,
                        help='Number of feature maps in outer hidden layer')
    parser.add_argument('--hidden-layers', type=int, default=6,
                        help='Number of layers of gated convolutions with mask of type "B"')

    parser.add_argument('--cuda', type=str2bool, default=True,
                        help='Flag indicating whether CUDA should be used')
    parser.add_argument('--model-path', '-m',
                        help="Path to model's saved parameters")
    parser.add_argument('--output-fname', type=str, default='samples.png',
                        help='Name of output file (.png format)')

    parser.add_argument('--label', '--l', type=int, default=-1,
                        help='Label of sampled images. -1 indicates random labels.')

    parser.add_argument('--count', '-c', type=int, default=64,
                        help='Number of images to generate')
    parser.add_argument('--height', type=int, default=28, help='Output image height')
    parser.add_argument('--width', type=int, default=28, help='Output image width')

    cfg = parser.parse_args()
    OUTPUT_FILENAME = cfg.output_fname

    model = PixelCNN(cfg=cfg)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
    model.to(device)

    model.load_state_dict(torch.load(cfg.model_path))

    label = None if cfg.label == -1 else cfg.label
    samples = model.sample((3, cfg.height, cfg.width), cfg.count, label=label, device=device)
    save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME)
    def __init__(self, params):
        super(PixelVAE, self).__init__()
        self.model_str = 'PixelVAE'
        self.is_cuda = False

        self.latent_dim = latent_dim = params.get('latent_dim', 2)
        self.hdim = hdim = params.get('hdim', 400)
        self.batchnorm = params.get('batchnorm', True)

        # encoder
        self.fc1 = fc(784, hdim)
        if self.batchnorm:
            self.bn_1 = BN(hdim, momentum=.9)
        self.fc_mu = fc(hdim, latent_dim)  # output the mean of z
        if self.batchnorm:
            self.bn_mu = BN(latent_dim, momentum=.9)
        self.fc_logvar = fc(hdim,
                            latent_dim)  # output the log of the variance of z
        if self.batchnorm:
            self.bn_logvar = BN(latent_dim, momentum=.9)

        # decoder
        self.pixelcnn = PixelCNN(params)
Example #4
0
def parse_arguments():
    parser = ArgumentParser()
    parser.add_argument("--use-model",
                        type=str,
                        choices=['pixelcnn', 'pixelsnail'],
                        default='pixelcnn')
    use_model = parser.parse_known_args()[0].use_model

    if use_model == 'pixelcnn':
        parser = PixelCNN.add_model_specific_args(parser)
    elif use_model == 'pixelsnail':
        parser = PixelSNAIL.add_model_specific_args(parser)

    parser = pl.Trainer.add_argparse_args(parser)

    parser.add_argument("dataset_path", type=Path)
    parser.add_argument("level",
                        type=int,
                        help="Which PixelCNN hierarchy level to train")
    parser.add_argument("--batch-size", type=int)

    parser.set_defaults(gpus="-1",
                        distributed_backend='ddp',
                        benchmark=True,
                        num_sanity_val_steps=0,
                        precision=16,
                        log_every_n_steps=50,
                        val_check_interval=0.5,
                        flush_logs_every_n_steps=100,
                        weights_summary='full',
                        max_epochs=int(5e4))

    args = parser.parse_args()
    args.use_model = use_model

    assert args.dataset_path.resolve().exists()
    args.dataset_path = str(args.dataset_path.resolve())

    return args
Example #5
0
def main():
    start_time = time.time()

    init_out_dir()
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'made':
        net = MADE(**vars(args))
    elif args.net == 'pixelcnn':
        net = PixelCNN(**vars(args))
    elif args.net == 'bernoulli':
        net = BernoulliMixture(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6)

    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(args.batch_size)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        log_prob = net.log_prob(sample)
        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step)
        with torch.no_grad():
            energy = ising.energy(sample, args.ham, args.lattice,
                                  args.boundary)
            loss = log_prob + beta * energy
        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / args.beta / args.L**2
            free_energy_std = loss.std() / args.beta / args.L**2
            entropy_mean = -log_prob.mean() / args.L**2
            energy_mean = energy.mean() / args.L**2
            mag = sample.mean(dim=0)
            mag_mean = mag.mean()
            mag_sqr_mean = (mag**2).mean()
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            my_log(
                'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    entropy_mean.item(),
                    energy_mean.item(),
                    mag_mean.item(),
                    mag_sqr_mean.item(),
                    optimizer.param_groups[0]['lr'],
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            sample_time = 0
            train_time = 0

            if args.save_sample:
                state = {
                    'sample': sample,
                    'x_hat': x_hat,
                    'log_prob': log_prob,
                    'energy': energy,
                    'loss': loss,
                }
                torch.save(state, '{}_save/{}.sample'.format(
                    args.out_filename, step))

        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state, '{}_save/{}.state'.format(
                args.out_filename, step))

        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            torchvision.utils.save_image(
                sample,
                '{}_img/{}.png'.format(args.out_filename, step),
                nrow=int(sqrt(sample.shape[0])),
                padding=0,
                normalize=True)

            if args.print_sample:
                x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy()
                x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2)

                x_hat_cov = np.cov(x_hat_np.T)
                x_hat_cov_diag = np.diag(x_hat_cov)
                x_hat_corr = x_hat_cov / (
                    sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) +
                    args.epsilon)
                x_hat_corr = np.tril(x_hat_corr, -1)
                x_hat_corr = np.max(np.abs(x_hat_corr), axis=1)
                x_hat_corr = x_hat_corr.reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std,
                        x_hat_corr,
                        energy_count,
                    ))

            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
Example #6
0
ham_args, features = get_ham_args_features()
state_filename = '{state_dir}/{ham_args}/{features}/out{args.out_infix}_save/10000.state'.format(
    **locals())

target_layer = 1
num_channel = 1
out_dir = '../support/fig/filters/{ham_args}/{features}/layer{target_layer}'.format(
    **locals())

if __name__ == '__main__':
    ensure_dir(out_dir + '/')

    if args.net == 'made':
        net = MADE(**vars(args))
    elif args.net == 'pixelcnn':
        net = PixelCNN(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    print('{}\n'.format(net))

    print(state_filename)
    state = torch.load(state_filename, map_location=args.device)
    net.load_state_dict(state['net'])

    sample = torch.zeros([num_channel, 1, args.L, args.L], requires_grad=True)
    nn.init.normal_(sample)

    optimizer = torch.optim.Adam([sample], lr=1e-3, weight_decay=1)

    start_time = time.time()
class PixelVAE(nn.Module):
    """A simple VAE using BN"""
    def __init__(self, params):
        super(PixelVAE, self).__init__()
        self.model_str = 'PixelVAE'
        self.is_cuda = False

        self.latent_dim = latent_dim = params.get('latent_dim', 2)
        self.hdim = hdim = params.get('hdim', 400)
        self.batchnorm = params.get('batchnorm', True)

        # encoder
        self.fc1 = fc(784, hdim)
        if self.batchnorm:
            self.bn_1 = BN(hdim, momentum=.9)
        self.fc_mu = fc(hdim, latent_dim)  # output the mean of z
        if self.batchnorm:
            self.bn_mu = BN(latent_dim, momentum=.9)
        self.fc_logvar = fc(hdim,
                            latent_dim)  # output the log of the variance of z
        if self.batchnorm:
            self.bn_logvar = BN(latent_dim, momentum=.9)

        # decoder
        self.pixelcnn = PixelCNN(params)

    def encode(self, x, **kwargs):
        x = x.view(x.size(0), -1)
        h1 = self.fc1(x)
        if self.batchnorm:
            h1 = relu(self.bn_1(h1))
        else:
            h1 = relu(h1)

        mu = self.fc_mu(h1)
        if self.batchnorm:
            mu = self.bn_mu(mu)

        logvar = self.fc_logvar(h1)
        if self.batchnorm:
            logvar = self.bn_logvar(logvar)

        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = t.exp(.5 * logvar)
            eps = variable(np.random.normal(0, 1, (len(mu), self.latent_dim)),
                           cuda=self.is_cuda)
            return mu + std * eps
        else:
            return mu

    def decode(self, x, z, **kwargs):
        x = x.view(x.size(0), 1, 28, 28)
        return self.pixelcnn.forward(x, z, **kwargs)

    def forward(self, x, **kwargs):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(x, z, **kwargs), mu, logvar

    def generate(self, x, z, **kwargs):
        generated_pic = x * 1.

        # autoregressive generation using your own outputs as inputs
        for i in range(28):
            for j in range(28):
                xx = self.pixelcnn.forward(x, z, **kwargs)
                generated_pic[:, 0, i, j] = xx[:, 0, i,
                                               j]  # take only the ij-th pixel
        return generated_pic
Example #8
0
def main():
    parser = argparse.ArgumentParser(description='PixelCNN')

    parser.add_argument('--epochs',
                        type=int,
                        default=25,
                        help='Number of epochs to train model for')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        help='Number of images per mini-batch')
    parser.add_argument(
        '--dataset',
        type=str,
        default='mnist',
        help='Dataset to train model on. Either mnist, fashionmnist or cifar.')

    parser.add_argument('--causal-ksize',
                        type=int,
                        default=7,
                        help='Kernel size of causal convolution')
    parser.add_argument('--hidden-ksize',
                        type=int,
                        default=7,
                        help='Kernel size of hidden layers convolutions')

    parser.add_argument(
        '--color-levels',
        type=int,
        default=2,
        help=
        'Number of levels to quantisize value of each channel of each pixel into'
    )

    parser.add_argument(
        '--hidden-fmaps',
        type=int,
        default=30,
        help='Number of feature maps in hidden layer (must be divisible by 3)')
    parser.add_argument('--out-hidden-fmaps',
                        type=int,
                        default=10,
                        help='Number of feature maps in outer hidden layer')
    parser.add_argument(
        '--hidden-layers',
        type=int,
        default=6,
        help='Number of layers of gated convolutions with mask of type "B"')

    parser.add_argument('--learning-rate',
                        '--lr',
                        type=float,
                        default=0.0001,
                        help='Learning rate of optimizer')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0001,
                        help='Weight decay rate of optimizer')
    parser.add_argument('--max-norm',
                        type=float,
                        default=1.,
                        help='Max norm of the gradients after clipping')

    parser.add_argument('--epoch-samples',
                        type=int,
                        default=25,
                        help='Number of images to sample each epoch')

    parser.add_argument('--cuda',
                        type=str2bool,
                        default=True,
                        help='Flag indicating whether CUDA should be used')

    cfg = parser.parse_args()

    wandb.init(project="PixelCNN")
    wandb.config.update(cfg)
    torch.manual_seed(42)

    EPOCHS = cfg.epochs

    model = PixelCNN(cfg=cfg)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
    model.to(device)

    train_loader, test_loader, HEIGHT, WIDTH = get_loaders(
        cfg.dataset, cfg.batch_size, cfg.color_levels, TRAIN_DATASET_ROOT,
        TEST_DATASET_ROOT)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                            cfg.learning_rate,
                                            10 * cfg.learning_rate,
                                            cycle_momentum=False)

    wandb.watch(model)

    losses = []
    params = []

    for epoch in range(EPOCHS):
        train(cfg, model, device, train_loader, optimizer, scheduler, epoch)
        test_and_sample(cfg, model, device, test_loader, HEIGHT, WIDTH, losses,
                        params, epoch)

    print('\nBest test loss: {}'.format(np.amin(np.array(losses))))
    print('Best epoch: {}'.format(np.argmin(np.array(losses)) + 1))
    best_params = params[np.argmin(np.array(losses))]

    if not os.path.exists(MODEL_PARAMS_OUTPUT_DIR):
        os.mkdir(MODEL_PARAMS_OUTPUT_DIR)
    MODEL_PARAMS_OUTPUT_FILENAME = '{}_cks{}hks{}cl{}hfm{}ohfm{}hl{}_params.pth'\
        .format(cfg.dataset, cfg.causal_ksize, cfg.hidden_ksize, cfg.color_levels, cfg.hidden_fmaps, cfg.out_hidden_fmaps, cfg.hidden_layers)
    torch.save(
        best_params,
        os.path.join(MODEL_PARAMS_OUTPUT_DIR, MODEL_PARAMS_OUTPUT_FILENAME))
Example #9
0
            print(X.shape)
        X_noncausal_graph = NonCausal(conf, data).get_test_samples_graph()
        tf.reset_default_graph()
        get_metric(X, X_noncausal_graph)
        tf.reset_default_graph()
    elif conf.model == 'evaluate':
        data = Dataset(conf)
        test_data = data.get_plain_test_values()
        with tf.Session() as sess:
            samples = []
            for _ in range(data.total_test_batches):
                X, _ = sess.run(test_data)
                samples.append(X)
            X = np.concatenate(samples)
            print(X.shape)
        X_denoising = PixelCNN(conf, data, True).get_test_samples()
        tf.reset_default_graph()
        X_noncausal = NonCausal(conf, data).get_test_samples()
        tf.reset_default_graph()
        X_pixelcnn = PixelCNN(conf, data, False).get_test_samples()
        tf.reset_default_graph()
        get_metric(X, [X_pixelcnn, X_denoising, X_noncausal])
    else:
        data = Dataset(conf)
        model = NonCausal(conf,
                          data) if conf.model == 'noncausal' else PixelCNN(
                              conf, data, conf.model == 'denoising')

        if conf.test:
            model.run_tests()
        elif conf.samples:
Example #10
0
                               transforms.Resize(img_size),
                               transforms.CenterCrop(img_size),
                               transforms.ToTensor()
                           ]))

# dataset = dset.MNIST('../data', train=True, download=True,
#                    transform=transforms.Compose([
#                        transforms.ToTensor()
#                    ]))

loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=bsize,
                                     shuffle=True,
                                     num_workers=4)

pcnn = PixelCNN()
pcnn.cuda()
criterion = nn.NLLLoss2d()

optimizer = optim.Adam(pcnn.parameters(), lr=0.0002, betas=(0.5, 0.999))
# train
for epoch in range(100):
    for i, (data, _) in enumerate(loader, 0):
        data = data.mean(1, keepdim=True)
        bsize_now, _, h, w = data.size()

        ids = (255 * data).long()

        label = torch.FloatTensor(bsize_now, 256, h,
                                  w).scatter_(1, ids,
                                              torch.ones(ids.size())).cuda()