def log_bernoulli_marginal_estimate(x, x_mu_list, z_list, z_mu, z_logvar): r"""Estimate log p(x). NOTE: this is not the objective that should be directly optimized. @param x: torch.Tensor (batch size x input_dim) original observed data @param x_mu_list: list of torch.Tensor (batch size x input_dim) reconstructed means on bernoulli @param z_list: list of torch.Tensor (batch_size x z dim) samples drawn from variational distribution @param z_mu: torch.Tensor (batch_size x # samples x z dim) means of variational distribution @param z_logvar: torch.Tensor (batch_size x # samples x z dim) log-variance of variational distribution """ k = len(z_list) batch_size = x.size(0) log_w = [] for i in range(k): log_p_x_given_z_i = bernoulli_log_pdf( x.view(batch_size, -1), x_mu_list[i].view(batch_size, -1)) log_q_z_given_x_i = gaussian_log_pdf(z_list[i], z_mu, z_logvar) log_p_z_i = unit_gaussian_log_pdf(z_list[i]) log_w_i = log_p_x_given_z_i + log_p_z_i - log_q_z_given_x_i log_w.append(log_w_i) log_w = torch.stack(log_w).t() # (batch_size, k) # need to compute normalization constant for weights # i.e. log ( mean ( exp ( log_weights ) ) ) log_p_x = log_mean_exp(log_w, dim=1) return -torch.mean(log_p_x)
def bernoulli_elbo(self, outputs, reduce=True): (c_mu, c_logvar), (q_mu, q_logvar, p_mu, p_logvar), (x, x_mu) = outputs batch_size = x.size(0) recon_loss = bernoulli_log_pdf(x.view(batch_size, -1), x_mu.view(batch_size, -1)) kl_c = -0.5 * (1 + c_logvar - c_mu.pow(2) - c_logvar.exp()) kl_c = torch.sum(kl_c, dim=1) kl_z = 0.5 * (p_logvar - q_logvar + ((q_mu - p_mu)**2 + q_logvar.exp())/p_logvar.exp() - 1) kl_z = torch.sum(kl_z, dim=1) ELBO = -recon_loss + kl_z + kl_c if reduce: return torch.mean(ELBO) else: return ELBO # (n_datasets)
def elbo(self, outputs, reduce=True): (c, c_mu, c_logvar), (q_mu, q_logvar, p_mu, p_logvar), (x, x_mu) = outputs batch_size = x.size(0) recon_loss = bernoulli_log_pdf(x.view(batch_size, -1), x_mu.view(batch_size, -1)) log_p_c = self.log_p_c(c) log_q_c = gaussian_log_pdf(c, c_mu, c_logvar) kl_c = -(log_p_c - log_q_c) kl_z = 0.5 * (p_logvar - q_logvar + ((q_mu - p_mu)**2 + q_logvar.exp()) / p_logvar.exp() - 1) kl_z = torch.sum(kl_z, dim=1) ELBO = -recon_loss + kl_z + kl_c if reduce: return torch.mean(ELBO) else: return ELBO
def bernoulli_elbo_loss_sets(self, outputs, reduce=True): c_outputs, z_outputs, x_outputs = outputs # 1. Reconstruction loss x, x_mu = x_outputs n_datasets = x.size(0) batch_size = x.size(1) recon_loss = bernoulli_log_pdf(x.view(n_datasets * batch_size, -1), x_mu.view(n_datasets * batch_size, -1)) recon_loss = recon_loss.view(n_datasets, batch_size) # 2. KL Divergence terms # a) Context divergence c_mu, c_logvar = c_outputs kl_c = -0.5 * (1 + c_logvar - c_mu.pow(2) - c_logvar.exp()) kl_c = torch.sum(kl_c, dim=-1) # (n_datasets) # b) Latent divergences qz_params, pz_params = z_outputs # this is kl(q_z||p_z) p_mu, p_logvar = pz_params q_mu, q_logvar = qz_params # the dimensions won't line up, so you'll need to broadcast! p_mu = p_mu.unsqueeze(1).expand_as(q_mu) p_logvar = p_logvar.unsqueeze(1).expand_as(q_logvar) kl_z = 0.5 * (p_logvar - q_logvar + ((q_mu - p_mu)**2 + q_logvar.exp()) / p_logvar.exp() - 1) kl_z = torch.sum(kl_z, dim=-1) # (n_datasets, batch_size) ELBO = -recon_loss + kl_z # these will both be (n_datasets, batch_size) ELBO = ELBO.sum(-1) / x.size()[ 1] # averaging over (batch_size == self.sample_size) ELBO = ELBO + kl_c # now this is (n_datasets,) if reduce: return torch.mean(ELBO) # averaging over (n_datasets) else: return ELBO # (n_datasets)
def bernoulli_elbo_loss(x, x_mu, z, z_mu, z_logvar): r"""Lower bound on model evidence (average over multiple samples). Closed form solution for KL[p(z|x), p(z)] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013). """ batch_size = x.size(0) x = x.view(batch_size, -1) x_mu = x_mu.view(batch_size, -1) log_p_x_given_z = -bernoulli_log_pdf(x, x_mu) # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 kl_divergence = -0.5 * (1 + z_logvar - z_mu.pow(2) - z_logvar.exp()) kl_divergence = torch.sum(kl_divergence, dim=1) # lower bound on marginal likelihood ELBO = log_p_x_given_z + kl_divergence ELBO = torch.mean(ELBO) return ELBO