def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None): eps = Variable(torch.FloatTensor(mu.size()).normal_().type(self.dtype)) z = eps.mul(logvar.mul(0.5).exp_()).add_(mu) logqz = log_normal(z, mu, logvar) if self.has_flow: z, logprob = self.q_dist.forward(z, grad_fn, x_info) logqz += logprob zeros = Variable(torch.zeros(z.size()).type(self.dtype)) logpz = log_normal(z, zeros, zeros) return z, logpz, logqz
def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None): # grad_fn default is identity, i.e. don't use grad info eps = Variable(torch.randn(mu.size()).type(self.dtype)) z = eps.mul(logvar.mul(0.5).exp()).add(mu) logqz = log_normal(z, mu, logvar) if self.has_flow: z, logprob = self.q_dist.forward(z, grad_fn, x_info) logqz += logprob zeros = Variable(torch.zeros(z.size()).type(self.dtype)) logpz = log_normal(z, zeros, zeros) return z, logpz, logqz
def log_joint(x_logits, x, z): """log p(x,z)""" zeros = Variable(torch.zeros(z.size()).type(model.dtype)) logpz = log_normal(z, zeros, zeros) logpx = log_likelihood(x_logits, x) return logpx + logpz
def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli): """Unnormalized density for intermediate distribution `f_i`: f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t => log f_i = log p(z) + t * log p(x|z) """ zeros = Variable(torch.zeros(B, z_size).type(mdtype)) log_prior = log_normal(z, zeros, zeros) log_likelihood = log_likelihood_fn(model.decode(z), data) return log_prior + log_likelihood.mul_(t)
def _sample(self, z0, grad_fn, x_info): B = z0.size(0) z_size = self.z_size act_func = F.elu qv_weights, rv_weights, params = self.qv_weights, self.rv_weights, self.params out = torch.cat((z0, x_info), dim=1) for i in range(len(qv_weights) - 1): out = act_func(qv_weights[i](out)) out = qv_weights[-1](out) mean_v0, logvar_v0 = out[:, :z_size], out[:, z_size:] eps = Variable(torch.randn(B, z_size).type(type(out.data))) v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0 logqv0 = log_normal(v0, mean_v0, logvar_v0) zT, vT = z0, v0 logdetsum = 0. for i in range(self.n_flows): zT, vT, logdet = self._norm_flow(params[i], zT, vT, grad_fn, x_info) logdetsum += logdet # reverse model, r(vT|x,zT) out = torch.cat((zT, x_info), dim=1) for i in range(len(rv_weights) - 1): out = act_func(rv_weights[i](out)) out = rv_weights[-1](out) mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:] logrvT = log_normal(vT, mean_vT, logvar_vT) assert logqv0.size() == (B, ) assert logdetsum.size() == (B, ) assert logrvT.size() == (B, ) logprob = logqv0 - logdetsum - logrvT return zT, logprob
def sample(mean_v0, logvar_v0): B = mean_v0.size()[0] eps = Variable( torch.FloatTensor(B, z_size).normal_().type(model.dtype)) v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0 logqv0 = log_normal(v0, mean_v0, logvar_v0) out = v0 for i in range(len(qz_weights) - 1): out = act_func(qz_weights[i](out)) out = qz_weights[-1](out) mean_z0, logvar_z0 = out[:, :z_size], out[:, z_size:] eps = Variable( torch.FloatTensor(B, z_size).normal_().type(model.dtype)) z0 = eps.mul(logvar_z0.mul(0.5).exp_()) + mean_z0 logqz0 = log_normal(z0, mean_z0, logvar_z0) zT, vT = z0, v0 logdetsum = 0. for i in range(n_flows): zT, vT, logdet = norm_flow(params[i], zT, vT) logdetsum += logdet # reverse model, r(vT|x,zT) out = zT for i in range(len(rv_weights) - 1): out = act_func(rv_weights[i](out)) out = rv_weights[-1](out) mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:] logrvT = log_normal(vT, mean_vT, logvar_vT) logq = logqz0 + logqv0 - logdetsum - logrvT return zT, logq
def U(z): logpx = log_bernoulli(self.decode(z), x) logpz = log_normal(z) return -logpx - logpz # energy as -log p(x, z)
def optimize_local_gaussian(log_likelihood, model, data_var, k=100, check_every=100, sentinel_thres=10, debug=False): """data_var should be (cuda) variable.""" B = data_var.size()[0] z_size = model.z_size data_var = safe_repeat(data_var, k) zeros = Variable(torch.zeros(B * k, z_size).type(model.dtype)) mean = Variable(torch.zeros(B * k, z_size).type(model.dtype), requires_grad=True) logvar = Variable(torch.zeros(B * k, z_size).type(model.dtype), requires_grad=True) optimizer = optim.Adam([mean, logvar], lr=1e-3) best_avg, sentinel, prev_seq = 999999, 0, [] # perform local opt time_ = time.time() for epoch in range(1, 999999): eps = Variable( torch.FloatTensor(mean.size()).normal_().type(model.dtype)) z = eps.mul(logvar.mul(0.5).exp_()).add_(mean) x_logits = model.decode(z) logpz = log_normal(z, zeros, zeros) logqz = log_normal(z, mean, logvar) logpx = log_likelihood(x_logits, data_var) optimizer.zero_grad() loss = -torch.mean(logpx + logpz - logqz) loss_np = loss.data.cpu().numpy() loss.backward() optimizer.step() prev_seq.append(loss_np) if epoch % check_every == 0: last_avg = np.mean(prev_seq) if debug: # debugging helper sys.stderr.write( 'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \ (epoch, time.time()-time_, -last_avg, -best_avg) ) if last_avg < best_avg: sentinel, best_avg = 0, last_avg else: sentinel += 1 if sentinel > sentinel_thres: break prev_seq = [] time_ = time.time() # evaluation eps = Variable( torch.FloatTensor(B * k, z_size).normal_().type(model.dtype)) z = eps.mul(logvar.mul(0.5).exp_()).add_(mean) logpz = log_normal(z, zeros, zeros) logqz = log_normal(z, mean, logvar) logpx = log_likelihood(model.decode(z), data_var) elbo = logpx + logpz - logqz vae_elbo = torch.mean(elbo) iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1))) return vae_elbo.data[0], iwae_elbo.data[0]
def normalized_kinetic(v): zeros = Variable(torch.zeros(B, z_size).type(mdtype)) # this is superior to the unnormalized version return -log_normal(v, zeros, zeros)