def eval_log_model_posterior(self, x, grid_z): """perform grid search to calculate the true posterior this function computes p(z|x) Args: grid_z: tensor different z points that will be evaluated, with shape (k^2, nz), where k=(zmax - zmin)/pace Returns: Tensor Tensor: the log posterior distribution log p(z|x) with shape [batch_size, K^2] """ try: batch_size = x.size(0) except: batch_size = x[0].size(0) # (batch_size, k^2, nz) grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous() # (batch_size, k^2) log_comp = self.eval_complete_ll(x, grid_z) # normalize to posterior log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True) return log_posterior
def loss_iw(self, x0, x1, nsamples=50, ns=1): """ Args: x: if the data is constant-length, x is the data tensor with shape (batch, *). Otherwise x is a tuple that contains the data tensor and length list Returns: Tensor1, Tensor2, Tensor3 Tensor1: total loss [batch] Tensor2: reconstruction loss shape [batch] Tensor3: KL loss shape [batch] """ # encoding into bert features bert_fea = self.encoder(x0)[1] # (batch_size, nz) mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1) ################## # compute KL ################## # pdb.set_trace() KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) # mu, logvar = mu.squeeze(0), logvar.squeeze(0) ll_tmp, rc_tmp = [], [] for _ in range(int(nsamples / ns)): # (batch, nsamples, nz) z = self.reparameterize(mu, logvar, ns) # past = self.decoder.linear(z) past = z # [batch, nsamples] log_prior = self.eval_prior_dist(z) log_gen = self.eval_cond_ll(x1, past) log_infer = self.eval_inference_dist(z, (mu, logvar)) # pdb.set_trace() log_gen = log_gen.unsqueeze(0).contiguous().view(z.shape[0], -1) # pdb.set_trace() rc_tmp.append(log_gen) ll_tmp.append(log_gen + log_prior - log_infer) log_prob_iw = log_sum_exp(torch.cat(ll_tmp, dim=-1), dim=-1) - math.log(nsamples) log_gen_iw = torch.mean(torch.cat(rc_tmp, dim=-1), dim=-1) return log_prob_iw, log_gen_iw, KL
def nll_iw(self, x0, x1, nsamples, ns=1): """compute the importance weighting estimate of the log-likelihood Args: x0, x1: two different tokenization results of x, where x is the data tensor with shape (batch, *). nsamples: Int the number of samples required to estimate marginal data likelihood Returns: Tensor1 Tensor1: the estimate of log p(x), shape [batch] """ # compute iw every ns samples to address the memory issue # nsamples = 500, ns = 100 # nsamples = 500, ns = 10 # TODO: note that x is forwarded twice in self.encoder.sample(x, ns) and self.eval_inference_dist(x, z, param) #. this problem is to be solved in order to speed up tmp = [] for _ in range(int(nsamples / ns)): # [batch, ns, nz] # Chunyuan: # encoding into bert features pooled_hidden_fea = self.encoder(x0)[1] # param is the parameters required to evaluate q(z|x) z, param = self.encoder_sample(pooled_hidden_fea, ns) # [batch, ns] log_comp_ll = self.eval_complete_ll(x1, z) log_infer_ll = self.eval_inference_dist(z, param) tmp.append(log_comp_ll - log_infer_ll) ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) return ll_iw
def calc_mi(model, test_data_batch): # calc_mi_v3 import math from modules.utils import log_sum_exp mi = 0 num_examples = 0 mu_batch_list, logvar_batch_list = [], [] neg_entropy = 0. for batch_data in test_data_batch: mu, logvar = model.encoder.forward(batch_data) x_batch, nz = mu.size() ##print(x_batch, end=' ') num_examples += x_batch # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) neg_entropy += (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).sum().item() mu_batch_list += [mu.cpu()] logvar_batch_list += [logvar.cpu()] neg_entropy = neg_entropy / num_examples ##print() num_examples = 0 log_qz = 0. for i in range(len(mu_batch_list)): ############### # get z_samples ############### mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() # [z_batch, 1, nz] if hasattr(model.encoder, 'reparameterize'): z_samples = model.encoder.reparameterize(mu, logvar, 1) else: z_samples = model.encoder.gaussian_enc.reparameterize(mu, logvar, 1) z_samples = z_samples.view(-1, 1, nz) num_examples += z_samples.size(0) ############### # compute density ############### # [1, x_batch, nz] #mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() #indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i] indices = np.arange(len(mu_batch_list)) mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda() logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda() x_batch, nz = mu.size() mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) var = logvar.exp() # (z_batch, x_batch, nz) dev = z_samples - mu # (z_batch, x_batch) log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) # log q(z): aggregate posterior # [z_batch] log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1) log_qz /= num_examples mi = neg_entropy - log_qz return mi
def calc_mi(self, test_data_batch, args): # calc_mi_v3 import math from modules.utils import log_sum_exp mi = 0 num_examples = 0 mu_batch_list, logvar_batch_list = [], [] neg_entropy = 0. for batch_data in test_data_batch: x0, _, _ = batch_data x0 = x0.to(args.device) # encoding into bert features bert_fea = self.encoder(x0)[1] (batch_size, nz) mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1) x_batch, nz = mu.size() #print(x_batch, end=' ') num_examples += x_batch # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) neg_entropy += (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).sum().item() mu_batch_list += [mu.cpu()] logvar_batch_list += [logvar.cpu()] pdb.set_trace() neg_entropy = neg_entropy / num_examples ##print() num_examples = 0 log_qz = 0. for i in range(len(mu_batch_list)): ############### # get z_samples ############### mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() # [z_batch, 1, nz] z_samples = self.reparameterize(mu, logvar, 1) z_samples = z_samples.view(-1, 1, nz) num_examples += z_samples.size(0) ############### # compute density ############### # [1, x_batch, nz] #mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() #indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i] indices = np.arange(len(mu_batch_list)) mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda() logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda() x_batch, nz = mu.size() mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) var = logvar.exp() # (z_batch, x_batch, nz) dev = z_samples - mu # (z_batch, x_batch) log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) # log q(z): aggregate posterior # [z_batch] log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1) log_qz /= num_examples mi = neg_entropy - log_qz return mi