Ejemplo n.º 1
0
flow = flows.Flow(transform, distribution).to(device)

n_params = utils.get_num_parameters(flow)
print('There are {} trainable parameters in this model.'.format(n_params))

# create optimizer
optimizer = optim.Adam(flow.parameters(), lr=args.learning_rate)
if args.anneal_learning_rate:
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     args.num_training_steps,
                                                     0)
else:
    scheduler = None

# create summary writer and write to log directory
timestamp = cutils.get_timestamp()
if cutils.on_cluster():
    timestamp += '||{}'.format(os.environ['SLURM_JOB_ID'])
log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
while True:
    try:
        writer = SummaryWriter(log_dir=log_dir, max_queue=20)
        break
    except FileExistsError:
        sleep(5)
filename = os.path.join(log_dir, 'config.json')
with open(filename, 'w') as file:
    json.dump(vars(args), file)

tbar = tqdm(range(args.num_training_steps))
best_val_score = -1e10
Ejemplo n.º 2
0
def run(seed):

    assert torch.cuda.is_available()
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    np.random.seed(seed)
    torch.manual_seed(seed)

    # Create training data.
    data_transform = tvtransforms.Compose(
        [tvtransforms.ToTensor(),
         tvtransforms.Lambda(torch.bernoulli)])

    if args.dataset_name == 'mnist':
        dataset = datasets.MNIST(root=os.path.join(utils.get_data_root(),
                                                   'mnist'),
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        test_dataset = datasets.MNIST(root=os.path.join(
            utils.get_data_root(), 'mnist'),
                                      train=False,
                                      download=True,
                                      transform=data_transform)
    elif args.dataset_name == 'fashion-mnist':
        dataset = datasets.FashionMNIST(root=os.path.join(
            utils.get_data_root(), 'fashion-mnist'),
                                        train=True,
                                        download=True,
                                        transform=data_transform)
        test_dataset = datasets.FashionMNIST(root=os.path.join(
            utils.get_data_root(), 'fashion-mnist'),
                                             train=False,
                                             download=True,
                                             transform=data_transform)
    elif args.dataset_name == 'omniglot':
        dataset = data_.OmniglotDataset(split='train',
                                        transform=data_transform)
        test_dataset = data_.OmniglotDataset(split='test',
                                             transform=data_transform)
    elif args.dataset_name == 'emnist':
        rotate = partial(tvF.rotate, angle=-90)
        hflip = tvF.hflip
        data_transform = tvtransforms.Compose([
            tvtransforms.Lambda(rotate),
            tvtransforms.Lambda(hflip),
            tvtransforms.ToTensor(),
            tvtransforms.Lambda(torch.bernoulli)
        ])
        dataset = datasets.EMNIST(root=os.path.join(utils.get_data_root(),
                                                    'emnist'),
                                  split='letters',
                                  train=True,
                                  transform=data_transform,
                                  download=True)
        test_dataset = datasets.EMNIST(root=os.path.join(
            utils.get_data_root(), 'emnist'),
                                       split='letters',
                                       train=False,
                                       transform=data_transform,
                                       download=True)
    else:
        raise ValueError

    if args.dataset_name == 'omniglot':
        split = -1345
    elif args.dataset_name == 'emnist':
        split = -20000
    else:
        split = -10000
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[:split], indices[split:]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    train_loader = data.DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=4 if args.dataset_name == 'emnist' else 0)
    train_generator = data_.batch_generator(train_loader)
    val_loader = data.DataLoader(dataset=dataset,
                                 batch_size=1024,
                                 sampler=val_sampler,
                                 shuffle=False,
                                 drop_last=False)
    val_batch = next(iter(val_loader))[0]
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
    )

    # from matplotlib import pyplot as plt
    # from experiments import cutils
    # from torchvision.utils import make_grid
    # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    # cutils.gridimshow(make_grid(val_batch[:64], nrow=8), ax)
    # plt.show()
    # quit()

    def create_linear_transform():
        if args.linear_type == 'lu':
            return transforms.CompositeTransform([
                transforms.RandomPermutation(args.latent_features),
                transforms.LULinear(args.latent_features, identity_init=True)
            ])
        elif args.linear_type == 'svd':
            return transforms.SVDLinear(args.latent_features,
                                        num_householder=4,
                                        identity_init=True)
        elif args.linear_type == 'perm':
            return transforms.RandomPermutation(args.latent_features)
        else:
            raise ValueError

    def create_base_transform(i, context_features=None):
        if args.prior_type == 'affine-coupling':
            return transforms.AffineCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm))
        elif args.prior_type == 'rq-coupling':
            return transforms.PiecewiseRationalQuadraticCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm),
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                apply_unconditional_transform=args.
                apply_unconditional_transform,
            )
        elif args.prior_type == 'affine-autoregressive':
            return transforms.MaskedAffineAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        elif args.prior_type == 'rq-autoregressive':
            return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        else:
            raise ValueError

    # ---------------
    # prior
    # ---------------
    def create_prior():
        if args.prior_type == 'standard-normal':
            prior = distributions_.StandardNormal((args.latent_features, ))

        else:
            distribution = distributions_.StandardNormal(
                (args.latent_features, ))
            transform = transforms.CompositeTransform([
                transforms.CompositeTransform(
                    [create_linear_transform(),
                     create_base_transform(i)])
                for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            prior = flows.Flow(transform, distribution)

        return prior

    # ---------------
    # inputs encoder
    # ---------------
    def create_inputs_encoder():
        if args.approximate_posterior_type == 'diagonal-normal':
            inputs_encoder = None
        else:
            inputs_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
        return inputs_encoder

    # ---------------
    # approximate posterior
    # ---------------
    def create_approximate_posterior():
        if args.approximate_posterior_type == 'diagonal-normal':
            context_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
            approximate_posterior = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

        else:
            context_encoder = nn.Linear(args.context_features,
                                        2 * args.latent_features)
            distribution = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

            transform = transforms.CompositeTransform([
                transforms.CompositeTransform([
                    create_linear_transform(),
                    create_base_transform(
                        i, context_features=args.context_features)
                ]) for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            approximate_posterior = flows.Flow(
                transforms.InverseTransform(transform), distribution)

        return approximate_posterior

    # ---------------
    # likelihood
    # ---------------
    def create_likelihood():
        latent_decoder = nn_.ConvDecoder(
            latent_features=args.latent_features,
            channels_multiplier=16,
            dropout_probability=args.dropout_probability_encoder_decoder)

        likelihood = distributions_.ConditionalIndependentBernoulli(
            shape=[1, 28, 28], context_encoder=latent_decoder)

        return likelihood

    prior = create_prior()
    approximate_posterior = create_approximate_posterior()
    likelihood = create_likelihood()
    inputs_encoder = create_inputs_encoder()

    model = vae.VariationalAutoencoder(
        prior=prior,
        approximate_posterior=approximate_posterior,
        likelihood=likelihood,
        inputs_encoder=inputs_encoder)

    # with torch.no_grad():
    #     # elbo = model.stochastic_elbo(val_batch[:16].to(device)).mean()
    #     # print(elbo)
    #     elbo = model.stochastic_elbo(val_batch[:16].to(device), num_samples=100).mean()
    #     print(elbo)
    #     log_prob = model.log_prob_lower_bound(val_batch[:16].to(device), num_samples=1200).mean()
    #     print(log_prob)
    # quit()

    n_params = utils.get_num_parameters(model)
    print('There are {} trainable parameters in this model.'.format(n_params))

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=args.num_training_steps, eta_min=0)

    def get_kl_multiplier(step):
        if args.kl_multiplier_schedule == 'constant':
            return args.kl_multiplier_initial
        elif args.kl_multiplier_schedule == 'linear':
            multiplier = min(
                step / (args.num_training_steps * args.kl_warmup_fraction), 1.)
            return args.kl_multiplier_initial * (1. + multiplier)

    # create summary writer and write to log directory
    timestamp = cutils.get_timestamp()
    if cutils.on_cluster():
        timestamp += '||{}'.format(os.environ['SLURM_JOB_ID'])
    log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
    while True:
        try:
            writer = SummaryWriter(log_dir=log_dir, max_queue=20)
            break
        except FileExistsError:
            sleep(5)
    filename = os.path.join(log_dir, 'config.json')
    with open(filename, 'w') as file:
        json.dump(vars(args), file)

    best_val_elbo = -np.inf
    tbar = tqdm(range(args.num_training_steps))
    for step in tbar:
        model.train()
        optimizer.zero_grad()
        scheduler.step(step)

        batch = next(train_generator)[0].to(device)
        elbo = model.stochastic_elbo(batch,
                                     kl_multiplier=get_kl_multiplier(step))
        loss = -torch.mean(elbo)
        loss.backward()
        optimizer.step()

        if (step + 1) % args.monitor_interval == 0:
            model.eval()
            with torch.no_grad():
                elbo = model.stochastic_elbo(val_batch.to(device))
                mean_val_elbo = elbo.mean()

            if mean_val_elbo > best_val_elbo:
                best_val_elbo = mean_val_elbo
                path = os.path.join(
                    cutils.get_checkpoint_root(),
                    '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
                torch.save(model.state_dict(), path)

            writer.add_scalar(tag='val-elbo',
                              scalar_value=mean_val_elbo,
                              global_step=step)

            writer.add_scalar(tag='best-val-elbo',
                              scalar_value=best_val_elbo,
                              global_step=step)

            with torch.no_grad():
                samples = model.sample(64)
            fig, ax = plt.subplots(figsize=(10, 10))
            cutils.gridimshow(make_grid(samples.view(64, 1, 28, 28), nrow=8),
                              ax)
            writer.add_figure(tag='vae-samples', figure=fig, global_step=step)
            plt.close()

    # load best val model
    path = os.path.join(
        cutils.get_checkpoint_root(),
        '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
    model.load_state_dict(torch.load(path))
    model.eval()

    np.random.seed(5)
    torch.manual_seed(5)

    # compute elbo on test set
    with torch.no_grad():
        elbo = torch.Tensor([])
        log_prob_lower_bound = torch.Tensor([])
        for batch in tqdm(test_loader):
            elbo_ = model.stochastic_elbo(batch[0].to(device))
            elbo = torch.cat([elbo, elbo_])
            log_prob_lower_bound_ = model.log_prob_lower_bound(
                batch[0].to(device), num_samples=1000)
            log_prob_lower_bound = torch.cat(
                [log_prob_lower_bound, log_prob_lower_bound_])
    path = os.path.join(
        log_dir, '{}-prior-{}-posterior-{}-elbo.npy'.format(
            args.dataset_name, args.prior_type,
            args.approximate_posterior_type))
    np.save(path, utils.tensor2numpy(elbo))
    path = os.path.join(
        log_dir, '{}-prior-{}-posterior-{}-log-prob-lower-bound.npy'.format(
            args.dataset_name, args.prior_type,
            args.approximate_posterior_type))
    np.save(path, utils.tensor2numpy(log_prob_lower_bound))

    # save elbo and log prob lower bound
    mean_elbo = elbo.mean()
    std_elbo = elbo.std()
    mean_log_prob_lower_bound = log_prob_lower_bound.mean()
    std_log_prob_lower_bound = log_prob_lower_bound.std()
    s = 'ELBO: {:.2f} +- {:.2f}, LOG PROB LOWER BOUND: {:.2f} +- {:.2f}'.format(
        mean_elbo.item(), 2 * std_elbo.item() / np.sqrt(len(test_dataset)),
        mean_log_prob_lower_bound.item(),
        2 * std_log_prob_lower_bound.item() / np.sqrt(len(test_dataset)))
    filename = os.path.join(log_dir, 'test-results.txt')
    with open(filename, 'w') as file:
        file.write(s)