예제 #1
0
def log_bernoulli_marginal_estimate(x, x_mu_list, z_list, z_mu, z_logvar):
    r"""Estimate log p(x). NOTE: this is not the objective that
    should be directly optimized.

    @param x: torch.Tensor (batch size x input_dim)
              original observed data
    @param x_mu_list: list of torch.Tensor (batch size x input_dim)
                      reconstructed means on bernoulli
    @param z_list: list of torch.Tensor (batch_size x z dim)
                    samples drawn from variational distribution
    @param z_mu: torch.Tensor (batch_size x # samples x z dim)
                 means of variational distribution
    @param z_logvar: torch.Tensor (batch_size x # samples x z dim)
                     log-variance of variational distribution
    """
    k = len(z_list)
    batch_size = x.size(0)

    log_w = []
    for i in range(k):
        log_p_x_given_z_i = bernoulli_log_pdf(
            x.view(batch_size, -1), x_mu_list[i].view(batch_size, -1))
        log_q_z_given_x_i = gaussian_log_pdf(z_list[i], z_mu, z_logvar)
        log_p_z_i = unit_gaussian_log_pdf(z_list[i])
        log_w_i = log_p_x_given_z_i + log_p_z_i - log_q_z_given_x_i
        log_w.append(log_w_i)

    log_w = torch.stack(log_w).t()  # (batch_size, k)
    # need to compute normalization constant for weights
    # i.e. log ( mean ( exp ( log_weights ) ) )
    log_p_x = log_mean_exp(log_w, dim=1)

    return -torch.mean(log_p_x)
예제 #2
0
    def bernoulli_elbo(self, outputs, reduce=True):
        (c_mu, c_logvar), (q_mu, q_logvar, p_mu, p_logvar), (x, x_mu) = outputs
        batch_size = x.size(0)
        recon_loss = bernoulli_log_pdf(x.view(batch_size, -1),
                                       x_mu.view(batch_size, -1))
        kl_c = -0.5 * (1 + c_logvar - c_mu.pow(2) - c_logvar.exp())
        kl_c = torch.sum(kl_c, dim=1)
        kl_z = 0.5 * (p_logvar - q_logvar + ((q_mu - p_mu)**2 + q_logvar.exp())/p_logvar.exp() - 1)
        kl_z = torch.sum(kl_z, dim=1)

        ELBO = -recon_loss + kl_z + kl_c

        if reduce:
            return torch.mean(ELBO)
        else:
            return ELBO # (n_datasets)
예제 #3
0
    def elbo(self, outputs, reduce=True):
        (c, c_mu, c_logvar), (q_mu, q_logvar, p_mu, p_logvar), (x,
                                                                x_mu) = outputs
        batch_size = x.size(0)
        recon_loss = bernoulli_log_pdf(x.view(batch_size, -1),
                                       x_mu.view(batch_size, -1))
        log_p_c = self.log_p_c(c)
        log_q_c = gaussian_log_pdf(c, c_mu, c_logvar)
        kl_c = -(log_p_c - log_q_c)
        kl_z = 0.5 * (p_logvar - q_logvar +
                      ((q_mu - p_mu)**2 + q_logvar.exp()) / p_logvar.exp() - 1)
        kl_z = torch.sum(kl_z, dim=1)

        ELBO = -recon_loss + kl_z + kl_c

        if reduce:
            return torch.mean(ELBO)
        else:
            return ELBO
예제 #4
0
    def bernoulli_elbo_loss_sets(self, outputs, reduce=True):
        c_outputs, z_outputs, x_outputs = outputs

        # 1. Reconstruction loss
        x, x_mu = x_outputs
        n_datasets = x.size(0)
        batch_size = x.size(1)
        recon_loss = bernoulli_log_pdf(x.view(n_datasets * batch_size, -1),
                                       x_mu.view(n_datasets * batch_size, -1))
        recon_loss = recon_loss.view(n_datasets, batch_size)

        # 2. KL Divergence terms
        # a) Context divergence
        c_mu, c_logvar = c_outputs
        kl_c = -0.5 * (1 + c_logvar - c_mu.pow(2) - c_logvar.exp())
        kl_c = torch.sum(kl_c, dim=-1)  # (n_datasets)

        # b) Latent divergences
        qz_params, pz_params = z_outputs

        # this is kl(q_z||p_z)
        p_mu, p_logvar = pz_params
        q_mu, q_logvar = qz_params

        # the dimensions won't line up, so you'll need to broadcast!
        p_mu = p_mu.unsqueeze(1).expand_as(q_mu)
        p_logvar = p_logvar.unsqueeze(1).expand_as(q_logvar)

        kl_z = 0.5 * (p_logvar - q_logvar +
                      ((q_mu - p_mu)**2 + q_logvar.exp()) / p_logvar.exp() - 1)
        kl_z = torch.sum(kl_z, dim=-1)  # (n_datasets, batch_size)

        ELBO = -recon_loss + kl_z  # these will both be (n_datasets, batch_size)
        ELBO = ELBO.sum(-1) / x.size()[
            1]  # averaging over (batch_size == self.sample_size)
        ELBO = ELBO + kl_c  # now this is (n_datasets,)

        if reduce:
            return torch.mean(ELBO)  # averaging over (n_datasets)
        else:
            return ELBO  # (n_datasets)
예제 #5
0
def bernoulli_elbo_loss(x, x_mu, z, z_mu, z_logvar):
    r"""Lower bound on model evidence (average over multiple samples).

    Closed form solution for KL[p(z|x), p(z)]

    Kingma, Diederik P., and Max Welling. "Auto-encoding
    variational bayes." arXiv preprint arXiv:1312.6114 (2013).
    """
    batch_size = x.size(0)
    x = x.view(batch_size, -1)
    x_mu = x_mu.view(batch_size, -1)
    log_p_x_given_z = -bernoulli_log_pdf(x, x_mu)
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    kl_divergence = -0.5 * (1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
    kl_divergence = torch.sum(kl_divergence, dim=1)

    # lower bound on marginal likelihood
    ELBO = log_p_x_given_z + kl_divergence
    ELBO = torch.mean(ELBO)

    return ELBO