示例#1
0
def log_logistic_volume_flow_marginal_estimate(recon_x_mu, recon_x_logvar, x,
                                               zk, z0, z0_mu, z0_logvar):
    batch_size, n_samples, z_dim = z0.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z0_2d = z0.view(batch_size * n_samples, z_dim)
    zk_2d = zk.view(batch_size * n_samples, z_dim)
    z0_mu_2d = z0_mu.view(batch_size * n_samples, z_dim)
    z0_logvar_2d = z0_logvar.view(batch_size * n_samples, z_dim)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    recon_x_logvar_2d = recon_x_logvar.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_zk_2d = logistic_256_log_pdf(x_2d, recon_x_mu_2d,
                                               recon_x_logvar_2d)
    log_q_z0_given_x_2d = gaussian_log_pdf(z0_2d, z0_mu_2d, z0_logvar_2d)
    log_q_zk_given_x_2d = log_q_z0_given_x_2d  # diff
    log_p_zk_2d = unit_gaussian_log_pdf(zk_2d)

    log_weight_2d = log_p_x_given_zk_2d + log_p_zk_2d - log_q_zk_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)
示例#2
0
def weighted_bernoulli_elbo_loss(recon_x_mu, x, z, z_mu, z_logvar):
    r"""Importance weighted evidence lower bound.

    @param recon_x_mu: torch.Tensor (batch size x # samples x |input_dim|)
                       reconstructed means on bernoulli
    @param x: torch.Tensor (batch size x |input_dim|)
                 original observed data
    @param z: torch.Tensor (batch_size x # samples x z dim)
              samples drawn from variational distribution
    @param z_mu: torch.Tensor (batch_size x # samples x z dim)
                 means of variational distribution
    @param z_logvar: torch.Tensor (batch_size x # samples x z dim)
                     log-variance of variational distribution
    """
    batch_size = recon_x_mu.size(0)
    n_samples = recon_x_mu.size(1)

    log_ws = []
    for i in xrange(n_samples):
        log_p_x_given_z = bernoulli_log_pdf(x, recon_x_mu[:, i])
        log_q_z_given_x = gaussian_log_pdf(z[:, i], z_mu[:, i], z_logvar[:, i])
        log_p_z = unit_gaussian_log_pdf(z[:, i])

        log_ws_i = log_p_x_given_z + log_p_z - log_q_z_given_x
        log_ws.append(log_ws_i.unsqueeze(1))

    log_ws = torch.cat(log_ws, dim=1)
    log_ws = log_mean_exp(log_ws, dim=1)
    BOUND = -torch.mean(log_ws)

    return BOUND
示例#3
0
def log_bernoulli_norm_flow_marginal_estimate(recon_x_mu, x, zk, z0, z0_mu,
                                              z0_logvar, log_abs_det_jacobian):
    batch_size, n_samples, z_dim = z0.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z0_2d = z0.view(batch_size * n_samples, z_dim)
    zk_2d = zk.view(batch_size * n_samples, z_dim)
    z0_mu_2d = z0_mu.view(batch_size * n_samples, z_dim)
    z0_logvar_2d = z0_logvar.view(batch_size * n_samples, z_dim)
    log_abs_det_jacobian_2d = \
        log_abs_det_jacobian.view(batch_size * n_samples)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_zk_2d = bernoulli_log_pdf(x_2d, recon_x_mu_2d)
    log_q_z0_given_x_2d = gaussian_log_pdf(z0_2d, z0_mu_2d, z0_logvar_2d)
    log_q_zk_given_x_2d = log_q_z0_given_x_2d - log_abs_det_jacobian_2d
    log_p_zk_2d = unit_gaussian_log_pdf(zk_2d)

    log_weight_2d = log_p_x_given_zk_2d + log_p_zk_2d - log_q_zk_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)
示例#4
0
def gaussian_free_energy_bound(recon_x_mu,
                               recon_x_logvar,
                               x,
                               zk,
                               z0,
                               z_mu,
                               z_logvar,
                               log_abs_det_jacobian=None,
                               beta=1.):
    assert z0.size() == zk.size()
    n_samples = recon_x_mu.size(1)

    BOUND = 0
    for i in xrange(n_samples):
        log_p_x_given_z = logistic_256_log_pdf(x, recon_x_mu[:, i],
                                               recon_x_logvar[:, i])
        log_q_z0_given_x = gaussian_log_pdf(z0[:, i], z_mu[:, i], z_logvar[:,
                                                                           i])
        log_p_zk = unit_gaussian_log_pdf(zk[:, i])

        if log_abs_det_jacobian is not None:
            log_q_zk_given_x = log_q_z0_given_x - log_abs_det_jacobian[:, i]
        else:
            log_q_zk_given_x = log_q_z0_given_x

        BOUND_i = log_p_x_given_z + beta * (log_p_zk - log_q_zk_given_x)
        BOUND += BOUND_i

    BOUND = BOUND / float(n_samples)
    BOUND = -BOUND
    BOUND = torch.mean(BOUND)

    return BOUND
示例#5
0
def weighted_gaussian_elbo_loss(recon_x_mu, recon_x_logvar, x, z, z_mu,
                                z_logvar):
    n_samples = recon_x_mu.size(1)

    log_ws = []
    for i in xrange(n_samples):
        log_p_x_given_z = logistic_256_log_pdf(x, recon_x_mu[:, i],
                                               recon_x_logvar[:, i])
        log_q_z_given_x = gaussian_log_pdf(z[:, i], z_mu[:, i], z_logvar[:, i])
        log_p_z = unit_gaussian_log_pdf(z[:, i])

        log_ws_i = log_p_x_given_z + log_p_z - log_q_z_given_x
        log_ws.append(log_ws_i.unsqueeze(1))

    log_ws = torch.cat(log_ws, dim=1)
    log_ws = log_mean_exp(log_ws, dim=1)
    BOUND = -torch.mean(log_ws)

    return BOUND
示例#6
0
def bernoulli_free_energy_bound(recon_x_mu,
                                x,
                                zk,
                                z0,
                                z_mu,
                                z_logvar,
                                log_abs_det_jacobian=None,
                                beta=1.):
    r"""Lower bound on approximate posterior distribution transformed by
    many normalizing flows.

    This uses the closed form solution for ELBO. See <closed_form_elbo_loss>
    for more details.

    See https://github.com/Lyusungwon/generative_models_pytorch/blob/master/vae_nf/main.py

    For volume preserving transformatinos, keep log_abs_det_jacobian as None.
    """
    assert z0.size() == zk.size()
    n_samples = recon_x_mu.size(1)

    BOUND = 0
    for i in xrange(n_samples):
        log_p_x_given_z = bernoulli_log_pdf(x, recon_x_mu[:, i])
        log_q_z0_given_x = gaussian_log_pdf(z0[:, i], z_mu[:, i], z_logvar[:,
                                                                           i])
        log_p_zk = unit_gaussian_log_pdf(zk[:, i])

        if log_abs_det_jacobian is not None:
            log_q_zk_given_x = log_q_z0_given_x - log_abs_det_jacobian[:, i]
        else:
            log_q_zk_given_x = log_q_z0_given_x

        BOUND_i = log_p_x_given_z + beta * (log_p_zk - log_q_zk_given_x)
        BOUND += BOUND_i

    BOUND = BOUND / float(n_samples)
    BOUND = -BOUND
    BOUND = torch.mean(BOUND)

    return BOUND
示例#7
0
def log_bernoulli_marginal_estimate(recon_x_mu, x, z, z_mu, z_logvar):
    r"""Estimate log p(x). NOTE: this is not the objective that
    should be directly optimized.

    @param recon_x_mu: torch.Tensor (batch size x # samples x input_dim)
                       reconstructed means on bernoulli
    @param x: torch.Tensor (batch size x input_dim)
              original observed data
    @param z: torch.Tensor (batch_size x # samples x z dim)
              samples drawn from variational distribution
    @param z_mu: torch.Tensor (batch_size x # samples x z dim)
                 means of variational distribution
    @param z_logvar: torch.Tensor (batch_size x # samples x z dim)
                     log-variance of variational distribution
    """
    batch_size, n_samples, z_dim = z.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z_2d = z.view(batch_size * n_samples, z_dim)
    z_mu_2d = z_mu.view(batch_size * n_samples, z_dim)
    z_logvar_2d = z_logvar.view(batch_size * n_samples, z_dim)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_z_2d = bernoulli_log_pdf(x_2d, recon_x_mu_2d)
    log_q_z_given_x_2d = gaussian_log_pdf(z_2d, z_mu_2d, z_logvar_2d)
    log_p_z_2d = unit_gaussian_log_pdf(z_2d)

    log_weight_2d = log_p_x_given_z_2d + log_p_z_2d - log_q_z_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    # need to compute normalization constant for weights
    # i.e. log ( mean ( exp ( log_weights ) ) )
    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)