def _mog_log_prob( theta: Tensor, logits_pp: Tensor, means_pp: Tensor, precisions_pp: Tensor, ) -> Tensor: r""" Returns the log-probability of parameter sets $\theta$ under a mixture of Gaussians. Note that the mixture can have different logits, means, covariances for any theta in the batch. This is because these values were computed from a batch of $x$ (and the $x$ in the batch are not the same). This code is similar to the code of mdn.py in pyknos, but it does not use log(det(Cov)) = -2*sum(log(diag(L))), L being Cholesky of Precision. Instead, it just computes log(det(Cov)). Also, it uses the above-defined helper `_batched_vmv()`. Args: theta: Parameters at which to evaluate the mixture. logits_pp: (Unnormalized) mixture components. means_pp: Means of all mixture components. Shape (batch_dim, num_components, theta_dim). precisions_pp: Precisions of all mixtures. Shape (batch_dim, num_components, theta_dim, theta_dim). Returns: The log-probability. """ _, _, output_dim = means_pp.size() theta = theta.view(-1, 1, output_dim) # Split up evaluation into parts. weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True) constant = -(output_dim / 2.0) * torch.log( torch.tensor([2 * pi], device=torch.device('cuda'))) log_det = 0.5 * torch.log(torch.det(precisions_pp)) theta_minus_mean = theta.expand_as(means_pp) - means_pp exponent = -0.5 * batched_mixture_vmv(precisions_pp, theta_minus_mean) return torch.logsumexp(weights + constant + log_det + exponent, dim=-1)
def _logits_proposal_posterior( means_pp: Tensor, precisions_pp: Tensor, covariances_pp: Tensor, logits_p: Tensor, means_p: Tensor, precisions_p: Tensor, logits_d: Tensor, means_d: Tensor, precisions_d: Tensor, ): """ Return the component weights (i.e. logits) of the proposal posterior. Args: means_pp: Means of the proposal posterior. precisions_pp: Precision matrices of the proposal posterior. covariances_pp: Covariance matrices of the proposal posterior. logits_p: Component weights (i.e. logits) of the proposal distribution. means_p: Means of the proposal distribution. precisions_p: Precision matrices of the proposal distribution. logits_d: Component weights (i.e. logits) of the density estimator. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Component weights of the proposal posterior. L*K terms. """ num_comps_p = precisions_p.shape[1] num_comps_d = precisions_d.shape[1] # Compute log(alpha_i * beta_j) logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) logits_d_rep = logits_d.repeat(1, num_comps_p) logit_factors = logits_p_rep + logits_d_rep # Compute sqrt(det()/(det()*det())) logdet_covariances_pp = torch.logdet(covariances_pp) logdet_covariances_p = -torch.logdet(precisions_p) logdet_covariances_d = -torch.logdet(precisions_d) # Repeat the proposal and density estimator terms such that there are LK terms. # Same trick as has been used above. logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( num_comps_d, dim=1 ) logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) log_sqrt_det_ratio = 0.5 * ( logdet_covariances_pp - (logdet_covariances_p_rep + logdet_covariances_d_rep) ) # Compute for proposal, density estimator, and proposal posterior: # mu_i.T * P_i * mu_i exponent_p = batched_mixture_vmv(precisions_p, means_p) exponent_d = batched_mixture_vmv(precisions_d, means_d) exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) # Extend proposal and density estimator exponents to get LK terms. exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) exponent_d_rep = exponent_d.repeat(1, num_comps_p) exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) logits_pp = logit_factors + log_sqrt_det_ratio + exponent return logits_pp
def _logits_posterior( means_post: Tensor, precisions_post: Tensor, covariances_post: Tensor, logits_pp: Tensor, means_pp: Tensor, precisions_pp: Tensor, logits_d: Tensor, means_d: Tensor, precisions_d: Tensor, ): r""" Return the component weights (i.e. logits) of the MoG posterior. $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 c_j) } $ with $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ (see eqs. (25, 26) in Appendix C of [1]) Args: means_post: Means of the posterior. precisions_post: Precision matrices of the posterior. covariances_post: Covariance matrices of the posterior. logits_pp: Component weights (i.e. logits) of the proposal prior. means_pp: Means of the proposal prior. precisions_pp: Precision matrices of the proposal prior. logits_d: Component weights (i.e. logits) of the density estimator. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Component weights of the proposal posterior. """ num_comps_pp = precisions_pp.shape[1] num_comps_d = precisions_d.shape[1] # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1) logits_d_rep = logits_d.repeat(1, num_comps_pp) logit_factors = logits_d_rep - logits_pp_rep # Compute the log-determinants logdet_covariances_post = torch.logdet(covariances_post) logdet_covariances_pp = -torch.logdet(precisions_pp) logdet_covariances_d = -torch.logdet(precisions_d) # Repeat the proposal and density estimator terms such that there are LK terms. # Same trick as has been used above. logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave( num_comps_d, dim=1 ) logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp) log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] logdet_covariances_post + logdet_covariances_pp_rep - logdet_covariances_d_rep ) # Compute for proposal, density estimator, and proposal posterior: exponent_pp = utils.batched_mixture_vmv( precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1] ) exponent_d = utils.batched_mixture_vmv( precisions_d, means_d # m_k in eq (26) in Appendix C of [1] ) exponent_post = utils.batched_mixture_vmv( precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] ) # Extend proposal and density estimator exponents to get LK terms. exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1) exponent_d_rep = exponent_d.repeat(1, num_comps_pp) exponent = -0.5 * ( exponent_d_rep - exponent_pp_rep - exponent_post # eq (26) in [1] ) logits_post = logit_factors + log_sqrt_det_ratio + exponent return logits_post