Example #1
0

def make_group_generator(group_output_dim):
    # Note that this Variable is NOT going to show up in `net.parameters()` and
    # therefore it is implicitly free from the ridge penalty/p(theta) prior.
    log_sigma = Variable(torch.log(1e-2 * torch.ones(group_output_dim).type(
        torch.cuda.FloatTensor if use_cuda else torch.FloatTensor)),
                         requires_grad=True)
    return NormalNet(mu_net=torch.nn.Linear(group_input_dim, group_output_dim),
                     sigma_net=Lambda(lambda x, log_sigma: torch.exp(
                         log_sigma.expand(x.size(0), -1)) + 1e-3,
                                      extra_args=(log_sigma, )))


generative_net = BayesianGroupLassoGenerator(
    group_generators=[make_group_generator(gs) for gs in group_output_dims],
    group_input_dim=group_input_dim,
    dim_z=dim_z)

prior_z = Normal(Variable(torch.zeros(1, dim_z)),
                 Variable(torch.ones(1, dim_z)))

if use_cuda:
    inference_net.cuda()
    generative_net.cuda()
    prior_z.mu = prior_z.mu.cuda()
    prior_z.sigma = prior_z.sigma.cuda()

lr = 1e-3
optimizer = torch.optim.Adam([
    {
        'params': inference_net.parameters(),
Example #2
0
def make_group_generator():
    # Note that this Variable is NOT going to show up in `net.parameters()` and
    # therefore it is implicitly free from the ridge penalty/p(theta) prior.
    log_sigma = Variable(torch.log(1e-2 * torch.ones(image_size)),
                         requires_grad=True)
    return NormalNet(mu_net=torch.nn.Sequential(
        torch.nn.Tanh(), torch.nn.Linear(group_input_dim, image_size)),
                     sigma_net=Lambda(lambda x, log_sigma: torch.exp(
                         log_sigma.expand(x.size(0), -1)) + 1e-3,
                                      extra_args=(log_sigma, )))


generative_net = BayesianGroupLassoGenerator(
    seq_len=seq_len,
    group_generators=[make_group_generator() for _ in range(image_size)],
    group_input_dim=group_input_dim,
    dim_z=dim_z,
    dim_h=dim_h)


def debug(count):
    """Create a plot showing the first `count` training samples along with their
  mean z value, x mean, x standard deviation, and a sample from the full model
  (sample z and then sample x)."""
    fig, ax = plt.subplots(5, count, figsize=(12, 4))

    # True images
    for i in range(count):
        ax[0, i].imshow(X[7][i].view(image_size, image_size).numpy())
        ax[0, i].axes.xaxis.set_ticks([])
        ax[0, i].axes.yaxis.set_ticks([])
Example #3
0
                         requires_grad=True)
    return NormalNet(
        mu_net=torch.nn.Sequential(
            # torch.nn.Linear(group_input_dim, 16),
            torch.nn.Tanh(),
            torch.nn.Linear(group_input_dim, output_dim)),
        sigma_net=Lambda(lambda x, log_sigma: torch.exp(
            log_sigma.expand(x.size(0), -1)) + 1e-3,
                         extra_args=(log_sigma, )))


seq_len = 32

generative_net = BayesianGroupLassoGenerator(
    seq_len=seq_len,
    group_generators=[make_group_generator(dim) for dim in group_dims],
    group_input_dim=group_input_dim,
    dim_z=dim_z,
    dim_h=dim_h)


def debug_z_by_group_matrix(t):
    # dim_z x groups
    # fig, ax = plt.subplots()
    # W_col_norms = torch.sqrt(
    #   torch.sum(torch.pow(generative_net.Ws.data, 2), dim=2)
    # )
    # ax.imshow(W_col_norms.t().numpy(), aspect='equal')
    # ax.set_ylabel('dimensions of z')
    # ax.set_xlabel('group generative nets')
    # ax.xaxis.tick_top()
    # ax.xaxis.set_label_position('top')