def _means_posterior( self, covariances_p: Tensor, means_pp: Tensor, precisions_pp: Tensor, means_d: Tensor, precisions_d: Tensor, ): r""" Return the means of the MoG posterior. $m_k^\prime = S_k^\prime ( S_k^{-1} m_k - S_0^{-1} m_0 )$ (see eq (24) in Appendix C of [1]) Args: covariances_post: Covariance matrices of the MoG posterior. means_pp: Means of the proposal prior. precisions_pp: Precision matrices of the proposal prior. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Means of the MoG posterior. """ num_comps_pp = precisions_pp.shape[1] num_comps_d = precisions_d.shape[1] # Compute the products P_k * m_k and P_0 * m_0. prec_m_prod_pp = utils.batched_mixture_mv(precisions_pp, means_pp) prec_m_prod_d = utils.batched_mixture_mv(precisions_d, means_d) # Repeat them to allow for matrix operations: same trick as for the precisions. prec_m_prod_pp_rep = prec_m_prod_pp.repeat_interleave(num_comps_d, dim=1) prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_pp, 1) # Compute the means P_k^prime * (P_k * m_k - P_0 * m_0). summed_cov_m_prod_rep = prec_m_prod_d_rep - prec_m_prod_pp_rep if isinstance(self._maybe_z_scored_prior, MultivariateNormal): summed_cov_m_prod_rep += self.prec_m_prod_prior means_p = utils.batched_mixture_mv(covariances_p, summed_cov_m_prod_rep) return means_p
def _means_proposal_posterior( self, covariances_pp: Tensor, means_p: Tensor, precisions_p: Tensor, means_d: Tensor, precisions_d: Tensor, ): """ Return the means of the proposal posterior. means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). Args: covariances_pp: Covariance matrices of the proposal posterior. means_p: Means of the proposal distribution. precisions_p: Precision matrices of the proposal distribution. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Means of the proposal posterior. L*K terms. """ num_comps_p = precisions_p.shape[1] num_comps_d = precisions_d.shape[1] # First, compute the product P_i * m_i and P_j * m_j prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) # Repeat them to allow for matrix operations: same trick as for the precisions. prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep if isinstance(self._maybe_z_scored_prior, MultivariateNormal): summed_cov_m_prod_rep -= self.prec_m_prod_prior means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) return means_pp