def log_bernoulli_marginal_estimate(x, x_mu_list, z_list, z_mu, z_logvar): r"""Estimate log p(x). NOTE: this is not the objective that should be directly optimized. @param x: torch.Tensor (batch size x input_dim) original observed data @param x_mu_list: list of torch.Tensor (batch size x input_dim) reconstructed means on bernoulli @param z_list: list of torch.Tensor (batch_size x z dim) samples drawn from variational distribution @param z_mu: torch.Tensor (batch_size x # samples x z dim) means of variational distribution @param z_logvar: torch.Tensor (batch_size x # samples x z dim) log-variance of variational distribution """ k = len(z_list) batch_size = x.size(0) log_w = [] for i in range(k): log_p_x_given_z_i = bernoulli_log_pdf( x.view(batch_size, -1), x_mu_list[i].view(batch_size, -1)) log_q_z_given_x_i = gaussian_log_pdf(z_list[i], z_mu, z_logvar) log_p_z_i = unit_gaussian_log_pdf(z_list[i]) log_w_i = log_p_x_given_z_i + log_p_z_i - log_q_z_given_x_i log_w.append(log_w_i) log_w = torch.stack(log_w).t() # (batch_size, k) # need to compute normalization constant for weights # i.e. log ( mean ( exp ( log_weights ) ) ) log_p_x = log_mean_exp(log_w, dim=1) return -torch.mean(log_p_x)
def gaussian_elbo_loss(x, x_mu, x_logvar, z, z_mu, z_logvar): log_p_x_given_z = -gaussian_log_pdf(x, x_mu, x_logvar) kl_divergence = -0.5 * (1 + z_logvar - z_mu.pow(2) - z_logvar.exp()) kl_divergence = torch.sum(kl_divergence, dim=1) elbo = log_p_x_given_z + kl_divergence elbo = torch.mean(elbo) return elbo
def log_p_c(self, c): x_flat = self.means(self.idle_input) x_dset = x_flat.view(self.num_components, self.pseudoinputs_samples, 1, self.image_size, self.image_size) h_dset = self.encoder_net(x_dset) h_dset = h_dset.view(self.num_components, self.pseudoinputs_samples, 256*4*4) c_p_mean, c_p_logvar = self.statistic_net(h_dset) c_expand = c.unsqueeze(1) means = c_p_mean.unsqueeze(0) logvars = c_p_logvar.unsqueeze(0) a = gaussian_log_pdf(c_expand, means, logvars) - math.log(self.num_components) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 log_prior = a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 return log_prior
def compiled_inference_objective(z, z_mu, z_logvar): r"""NOTE: (x,z) are sampled from p(x,z), a known graphical model Compiled inference uses a different objective: https://arxiv.org/pdf/1610.09900.pdf Proof of objective: loss_func = E_{p(x)}[KL[p(z|x) || q_\phi(z|x)]] = \int_x p(x) \int_z p(z|x) log(p(z|x)/q_\phi(z|x)) dz dx = \int_x \int_z p(x,z) log(p(z|x)/q_\phi(z|x)) dz dx = E_{p(x,z)}[log(p(z|x)/q_\phi(z|x))] \propto E_{p(x,z)}[-log q_\phi(z|x)] """ log_q_z_given_x = gaussian_log_pdf(z, z_mu, z_logvar) return -torch.mean(log_q_z_given_x)
def log_p_c(self, c): # this is a function now thanks to a learned prior x_flat = self.means(self.idle_input) x_dset = x_flat.view(self.num_components, self.pseudoinputs_samples, self.input_dim) c_p_mean, c_p_logvar = self.statistic_net(x_dset) c_expand = c.unsqueeze(1) means = c_p_mean.unsqueeze(0) logvars = c_p_logvar.unsqueeze(0) a = gaussian_log_pdf(c_expand, means, logvars) - math.log( self.num_components) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 log_prior = a_max + torch.log( torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 return log_prior
def bernoulli_elbo(self, outputs, reduce=True): (c, c_mu, c_logvar), (q_mu, q_logvar, p_mu, p_logvar), (x, x_mu) = outputs batch_size = x.size(0) recon_loss = bernoulli_log_pdf(x.view(batch_size, -1), x_mu.view(batch_size, -1)) log_p_c = self.log_p_c(c) log_q_c = gaussian_log_pdf(c, c_mu, c_logvar) kl_c = -(log_p_c - log_q_c) kl_z = 0.5 * (p_logvar - q_logvar + ((q_mu - p_mu)**2 + q_logvar.exp())/p_logvar.exp() - 1) kl_z = torch.sum(kl_z, dim=1) ELBO = -recon_loss + kl_z + kl_c if reduce: return torch.mean(ELBO) else: return ELBO # (n_datasets)