def update(self, likelihood, y, post_mean, post_cov, hyp=None, site_params=None): """ The update function takes a likelihood as input, and uses CVI to update the site parameters """ if site_params is None: _, dE_dm, dE_dv = likelihood.variational_expectation( y, post_mean, post_cov, hyp, self.cubature_func) dE_dm, dE_dv = np.atleast_2d(dE_dm), np.atleast_2d(dE_dv) site_cov = -0.5 * inv_any(dE_dv + 1e-10 * np.eye(dE_dv.shape[0])) site_mean = post_mean + site_cov @ dE_dm site_cov = ensure_positive_variance(site_cov) else: site_mean, site_cov = site_params log_marg_lik, dE_dm, dE_dv = likelihood.variational_expectation( y, post_mean, post_cov, hyp, self.cubature_func) dE_dm, dE_dv = np.atleast_2d(dE_dm), np.atleast_2d(dE_dv) dE_dv = -ensure_positive_variance(-dE_dv) lambda_t_2 = inv_any(site_cov + 1e-10 * np.eye(site_cov.shape[0])) lambda_t_1 = lambda_t_2 @ site_mean lambda_t_1 = (1 - self.damping) * lambda_t_1 + self.damping * ( dE_dm - 2 * dE_dv @ post_mean) lambda_t_2 = (1 - self.damping) * lambda_t_2 + self.damping * ( -2 * dE_dv) site_cov = inv_any(lambda_t_2 + 1e-10 * np.eye(site_cov.shape[0])) site_mean = site_cov @ lambda_t_1 log_marg_lik, _, _ = likelihood.moment_match(y, post_mean, post_cov, hyp, 1.0, self.cubature_func) return log_marg_lik, site_mean, site_cov
def moment_match_unstable(self, y, cav_mean, cav_cov, hyp=None, power=1.0, cubature_func=None): """ TODO: Attempt to compute full site covariance, including cross terms. However, this makes things unstable. """ if cubature_func is None: x, w = gauss_hermite(1, 20) # Gauss-Hermite sigma points and weights else: x, w = cubature_func(1) lZ = self.log_expected_likelihood(y, x, w, np.squeeze(cav_mean), np.squeeze(np.diag(cav_cov)), power) dlZ = self.dlZ_dm(y, x, w, np.squeeze(cav_mean), np.squeeze(np.diag(cav_cov)), power)[:, None] d2lZ = jacrev(self.dlZ_dm, argnums=3)(y, x, w, np.squeeze(cav_mean), np.squeeze(np.diag(cav_cov)), power) # d2lZ = np.diag(np.diag(d2lZ)) # discard cross terms id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0])) site_mean = cav_mean - id2lZ @ dlZ # approx. likelihood (site) mean (see Rasmussen & Williams p75) site_cov = -power * (cav_cov + id2lZ ) # approx. likelihood (site) variance return lZ, site_mean, site_cov
def moment_match(self, y, cav_mean, cav_cov, hyp=None, power=1.0, cubature_func=None): """ """ num_components = int(cav_mean.shape[0] / 2) if cubature_func is None: x, w = gauss_hermite(num_components, 20) # Gauss-Hermite sigma points and weights else: x, w = cubature_func(num_components) subband_mean, modulator_mean = cav_mean[:num_components], self.link_fn( cav_mean[num_components:]) subband_cov, modulator_cov = cav_cov[:num_components, : num_components], cav_cov[ num_components:, num_components:] sigma_points = cholesky(modulator_cov) @ x + modulator_mean const = power**-0.5 * (2 * pi * hyp)**(0.5 - 0.5 * power) mu = (self.link_fn(sigma_points).T @ subband_mean)[:, 0] var = hyp / power + (self.link_fn(sigma_points).T**2 @ np.diag(subband_cov)[..., None])[:, 0] normpdf = const * (2 * pi * var)**-0.5 * np.exp(-0.5 * (y - mu)**2 / var) Z = np.sum(w * normpdf) Zinv = 1. / (Z + 1e-8) lZ = np.log(Z + 1e-8) dZ1 = np.sum(w * self.link_fn(sigma_points) * (y - mu) / var * normpdf, axis=-1) dZ2 = np.sum(w * (sigma_points - modulator_mean) * np.diag(modulator_cov)[..., None]**-1 * normpdf, axis=-1) dlZ = Zinv * np.block([dZ1, dZ2]) d2Z1 = np.sum(w * self.link_fn(sigma_points)**2 * (((y - mu) / var)**2 - var**-1) * normpdf, axis=-1) d2Z2 = np.sum(w * (((sigma_points - modulator_mean) * np.diag(modulator_cov)[..., None]**-1)**2 - np.diag(modulator_cov)[..., None]**-1) * normpdf, axis=-1) d2lZ = np.diag(-dlZ**2 + Zinv * np.block([d2Z1, d2Z2])) id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0])) site_mean = cav_mean - id2lZ @ dlZ[ ..., None] # approx. likelihood (site) mean (see Rasmussen & Williams p75) site_cov = -power * (cav_cov + id2lZ ) # approx. likelihood (site) variance return lZ, site_mean, site_cov
def moment_match(self, y, cav_mean, cav_cov, hyp=None, power=1.0, cubature_func=None): """ """ if cubature_func is None: x, w = gauss_hermite(1, 20) # Gauss-Hermite sigma points and weights else: x, w = cubature_func(1) # sigma_points = np.sqrt(2) * np.sqrt(v) * x + m # scale locations according to cavity dist. sigma_points = np.sqrt(cav_cov[1, 1]) * x + cav_mean[ 1] # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity f2 = self.link_fn(sigma_points)**2. / power obs_var = f2 + cav_cov[0, 0] const = power**-0.5 * (2 * pi * self.link_fn(sigma_points)**2.)**( 0.5 - 0.5 * power) normpdf = const * (2 * pi * obs_var)**-0.5 * np.exp( -0.5 * (y - cav_mean[0, 0])**2 / obs_var) Z = np.sum(w * normpdf) Zinv = 1. / np.maximum(Z, 1e-8) lZ = np.log(np.maximum(Z, 1e-8)) dZ_integrand1 = (y - cav_mean[0, 0]) / obs_var * normpdf dlZ1 = Zinv * np.sum(w * dZ_integrand1) dZ_integrand2 = (sigma_points - cav_mean[1, 0]) / cav_cov[1, 1] * normpdf dlZ2 = Zinv * np.sum(w * dZ_integrand2) d2Z_integrand1 = (-(f2 + cav_cov[0, 0])**-1 + ((y - cav_mean[0, 0]) / obs_var)**2) * normpdf d2lZ1 = -dlZ1**2 + Zinv * np.sum(w * d2Z_integrand1) d2Z_integrand2 = (-cav_cov[1, 1]**-1 + ( (sigma_points - cav_mean[1, 0]) / cav_cov[1, 1])**2) * normpdf d2lZ2 = -dlZ2**2 + Zinv * np.sum(w * d2Z_integrand2) dlZ = np.block([[dlZ1], [dlZ2]]) d2lZ = np.block([[d2lZ1, 0], [0., d2lZ2]]) id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0])) site_mean = cav_mean - id2lZ @ dlZ # approx. likelihood (site) mean (see Rasmussen & Williams p75) site_cov = -power * (cav_cov + id2lZ ) # approx. likelihood (site) variance return lZ, site_mean, site_cov
def moment_match_cubature(self, y, cav_mean, cav_cov, hyp=None, power=1.0, cubature_func=None): """ TODO: N.B. THIS VERSION IS SUPERCEDED BY THE FUNCTION BELOW. HOWEVER THIS ONE MAY BE MORE STABLE. Perform moment matching via cubature. Moment matching invloves computing the log partition function, logZₙ, and its derivatives w.r.t. the cavity mean logZₙ = log ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ with EP power a. :param y: observed data (yₙ) [scalar] :param cav_mean: cavity mean (mₙ) [scalar] :param cav_cov: cavity covariance (cₙ) [scalar] :param hyp: likelihood hyperparameter [scalar] :param power: EP power / fraction (a) [scalar] :param cubature_func: the function to compute sigma points and weights to use during cubature :return: lZ: the log partition function, logZₙ [scalar] dlZ: first derivative of logZₙ w.r.t. mₙ (if derivatives=True) [scalar] d2lZ: second derivative of logZₙ w.r.t. mₙ (if derivatives=True) [scalar] """ if cubature_func is None: x, w = gauss_hermite(cav_mean.shape[0], 20) # Gauss-Hermite sigma points and weights else: x, w = cubature_func(cav_mean.shape[0]) cav_cho, low = cho_factor(cav_cov) # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity dist. sigma_points = cav_cho @ np.atleast_2d(x) + cav_mean # pre-compute wᵢ pᵃ(yₙ|xᵢ√(2vₙ) + mₙ) weighted_likelihood_eval = w * self.evaluate_likelihood( y, sigma_points, hyp)**power # a different approach, based on the log-likelihood, which can be more stable: # ll = self.evaluate_log_likelihood(y, sigma_points) # lmax = np.max(ll) # weighted_likelihood_eval = np.exp(lmax * power) * w * np.exp(power * (ll - lmax)) # Compute partition function via cubature: # Zₙ = ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ pᵃ(yₙ|fsigᵢ) Z = np.sum(weighted_likelihood_eval, axis=-1) lZ = np.log(Z) Zinv = 1.0 / Z # Compute derivative of partition function via cubature: # dZₙ/dmₙ = ∫ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fsigᵢ) covinv_f_m = cho_solve((cav_cho, low), sigma_points - cav_mean) dZ = np.sum( # (sigma_points - cav_mean) / cav_cov covinv_f_m * weighted_likelihood_eval, axis=-1) # dlogZₙ/dmₙ = (dZₙ/dmₙ) / Zₙ dlZ = Zinv * dZ # Compute second derivative of partition function via cubature: # d²Zₙ/dmₙ² = ∫ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fsigᵢ) d2Z = np.sum( ((sigma_points - cav_mean)**2 / cav_cov**2 - 1.0 / cav_cov) * weighted_likelihood_eval) # d²logZₙ/dmₙ² = d[(dZₙ/dmₙ) / Zₙ]/dmₙ # = (d²Zₙ/dmₙ² * Zₙ - (dZₙ/dmₙ)²) / Zₙ² # = d²Zₙ/dmₙ² / Zₙ - (dlogZₙ/dmₙ)² d2lZ = -dlZ @ dlZ.T + Zinv * d2Z site_mean = cav_mean - inv_any( d2lZ ) @ dlZ # approx. likelihood (site) mean (see Rasmussen & Williams p75) site_cov = -power * (cav_cov + inv_any(d2lZ) ) # approx. likelihood (site) variance return lZ, site_mean, site_cov
def moment_match_cubature(self, y, cav_mean, cav_cov, hyp=None, power=1.0, cubature_func=None): """ TODO: N.B. THIS VERSION ALLOWS MULTI-DIMENSIONAL MOMENT MATCHING, BUT CAN BE UNSTABLE Perform moment matching via cubature. Moment matching invloves computing the log partition function, logZₙ, and its derivatives w.r.t. the cavity mean logZₙ = log ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ with EP power a. :param y: observed data (yₙ) [scalar] :param cav_mean: cavity mean (mₙ) [scalar] :param cav_cov: cavity covariance (cₙ) [scalar] :param hyp: likelihood hyperparameter [scalar] :param power: EP power / fraction (a) [scalar] :param cubature_func: the function to compute sigma points and weights to use during cubature :return: lZ: the log partition function, logZₙ [scalar] dlZ: first derivative of logZₙ w.r.t. mₙ (if derivatives=True) [scalar] d2lZ: second derivative of logZₙ w.r.t. mₙ (if derivatives=True) [scalar] """ if cubature_func is None: x, w = gauss_hermite(cav_mean.shape[0], 20) # Gauss-Hermite sigma points and weights else: x, w = cubature_func(cav_mean.shape[0]) cav_cho, low = cho_factor(cav_cov) # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity dist. sigma_points = cav_cho @ np.atleast_2d(x) + cav_mean # pre-compute wᵢ pᵃ(yₙ|xᵢ√(2vₙ) + mₙ) weighted_likelihood_eval = w * self.evaluate_likelihood( y, sigma_points, hyp)**power # Compute partition function via cubature: # Zₙ = ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ pᵃ(yₙ|fsigᵢ) Z = np.sum(weighted_likelihood_eval, axis=-1) lZ = np.log(np.maximum(Z, 1e-8)) Zinv = 1.0 / np.maximum(Z, 1e-8) # Compute derivative of partition function via cubature: # dZₙ/dmₙ = ∫ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fsigᵢ) d1 = vmap(gaussian_first_derivative_wrt_mean, (1, None, None, 1))(sigma_points[..., None], cav_mean, cav_cov, weighted_likelihood_eval) dZ = np.sum(d1, axis=0) # dlogZₙ/dmₙ = (dZₙ/dmₙ) / Zₙ dlZ = Zinv * dZ # Compute second derivative of partition function via cubature: # d²Zₙ/dmₙ² = ∫ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fsigᵢ) d2 = vmap(gaussian_second_derivative_wrt_mean, (1, None, None, 1))(sigma_points[..., None], cav_mean, cav_cov, weighted_likelihood_eval) d2Z = np.sum(d2, axis=0) # d²logZₙ/dmₙ² = d[(dZₙ/dmₙ) / Zₙ]/dmₙ # = (d²Zₙ/dmₙ² * Zₙ - (dZₙ/dmₙ)²) / Zₙ² # = d²Zₙ/dmₙ² / Zₙ - (dlogZₙ/dmₙ)² d2lZ = -dlZ @ dlZ.T + Zinv * d2Z id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0])) site_mean = cav_mean - id2lZ @ dlZ # approx. likelihood (site) mean (see Rasmussen & Williams p75) site_cov = -power * (cav_cov + id2lZ ) # approx. likelihood (site) variance return lZ, site_mean, site_cov