def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True): batch_size, hidden_dim = latent_sample.shape # calculate log q(z|x) log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1) # calculate log p(z) # mean and log var is 0 zeros = torch.zeros_like(latent_sample) log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1) mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist) if is_mss: # use stratification log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to( latent_sample.device) mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1) log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False) log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1) return log_pz, log_qz, log_prod_qzi, log_q_zCx
def _estimate_latent_entropies(self, samples_zCx, params_zCX, n_samples=10000): r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]` using the emperical distribution of :math:`p(x)`. Note ---- - the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`. - we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`. - computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`. Parameters ---------- samples_zCx: torch.tensor Tensor of shape (len_dataset, latent_dim) containing a sample of q(z|x) for every x in the dataset. params_zCX: tuple of torch.Tensor Sufficient statistics q(z|x) for each training example. E.g. for gaussian (mean, log_var) each of shape : (len_dataset, latent_dim). n_samples: int, optional Number of samples to use to estimate the entropies. Return ------ H_z: torch.Tensor Tensor of shape (latent_dim) containing the marginal entropies H(z_j) """ len_dataset, latent_dim = samples_zCx.shape device = samples_zCx.device H_z = torch.zeros(latent_dim, device=device) # sample from p(x) samples_x = torch.randperm(len_dataset, device=device)[:n_samples] # sample from p(z|x) samples_zCx = samples_zCx.index_select(0, samples_x).view( latent_dim, n_samples) mini_batch_size = 10 samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples) mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples) log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples) log_N = math.log(len_dataset) with trange(n_samples, leave=False, disable=self.is_progress_bar) as t: for k in range(0, n_samples, mini_batch_size): # log q(z_j|x) for n_samples idcs = slice(k, k + mini_batch_size) log_q_zCx = log_density_gaussian(samples_zCx[..., idcs], mean[..., idcs], log_var[..., idcs]) # numerically stable log q(z_j) for n_samples: # log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n) # As we don't know q(z) we appoximate it with the monte carlo # expectation of q(z_j|x_n) over x. => fix a single z and look at # proba for every x to generate it. n_samples is not used here ! log_q_z = -log_N + torch.logsumexp( log_q_zCx, dim=0, keepdim=False) # H(z_j) = E_{z_j}[- log q(z_j)] # mean over n_samples (i.e. dimesnion 1 because already summed over 0). H_z += (-log_q_z).sum(1) t.update(mini_batch_size) H_z /= n_samples return H_z