コード例 #1
0
 def update(self,
            likelihood,
            y,
            post_mean,
            post_cov,
            hyp=None,
            site_params=None):
     """
     The update function takes a likelihood as input, and uses CVI to update the site parameters
     """
     if site_params is None:
         _, dE_dm, dE_dv = likelihood.variational_expectation(
             y, post_mean, post_cov, hyp, self.cubature_func)
         dE_dm, dE_dv = np.atleast_2d(dE_dm), np.atleast_2d(dE_dv)
         site_cov = -0.5 * inv_any(dE_dv + 1e-10 * np.eye(dE_dv.shape[0]))
         site_mean = post_mean + site_cov @ dE_dm
         site_cov = ensure_positive_variance(site_cov)
     else:
         site_mean, site_cov = site_params
         log_marg_lik, dE_dm, dE_dv = likelihood.variational_expectation(
             y, post_mean, post_cov, hyp, self.cubature_func)
         dE_dm, dE_dv = np.atleast_2d(dE_dm), np.atleast_2d(dE_dv)
         dE_dv = -ensure_positive_variance(-dE_dv)
         lambda_t_2 = inv_any(site_cov + 1e-10 * np.eye(site_cov.shape[0]))
         lambda_t_1 = lambda_t_2 @ site_mean
         lambda_t_1 = (1 - self.damping) * lambda_t_1 + self.damping * (
             dE_dm - 2 * dE_dv @ post_mean)
         lambda_t_2 = (1 - self.damping) * lambda_t_2 + self.damping * (
             -2 * dE_dv)
         site_cov = inv_any(lambda_t_2 + 1e-10 * np.eye(site_cov.shape[0]))
         site_mean = site_cov @ lambda_t_1
     log_marg_lik, _, _ = likelihood.moment_match(y, post_mean, post_cov,
                                                  hyp, 1.0,
                                                  self.cubature_func)
     return log_marg_lik, site_mean, site_cov
コード例 #2
0
ファイル: likelihoods.py プロジェクト: johannah/kalman-jax
 def moment_match_unstable(self,
                           y,
                           cav_mean,
                           cav_cov,
                           hyp=None,
                           power=1.0,
                           cubature_func=None):
     """
     TODO: Attempt to compute full site covariance, including cross terms. However, this makes things unstable.
     """
     if cubature_func is None:
         x, w = gauss_hermite(1,
                              20)  # Gauss-Hermite sigma points and weights
     else:
         x, w = cubature_func(1)
     lZ = self.log_expected_likelihood(y, x, w, np.squeeze(cav_mean),
                                       np.squeeze(np.diag(cav_cov)), power)
     dlZ = self.dlZ_dm(y, x, w, np.squeeze(cav_mean),
                       np.squeeze(np.diag(cav_cov)), power)[:, None]
     d2lZ = jacrev(self.dlZ_dm, argnums=3)(y, x, w, np.squeeze(cav_mean),
                                           np.squeeze(np.diag(cav_cov)),
                                           power)
     # d2lZ = np.diag(np.diag(d2lZ))  # discard cross terms
     id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0]))
     site_mean = cav_mean - id2lZ @ dlZ  # approx. likelihood (site) mean (see Rasmussen & Williams p75)
     site_cov = -power * (cav_cov + id2lZ
                          )  # approx. likelihood (site) variance
     return lZ, site_mean, site_cov
コード例 #3
0
ファイル: likelihoods.py プロジェクト: johannah/kalman-jax
    def moment_match(self,
                     y,
                     cav_mean,
                     cav_cov,
                     hyp=None,
                     power=1.0,
                     cubature_func=None):
        """
        """
        num_components = int(cav_mean.shape[0] / 2)
        if cubature_func is None:
            x, w = gauss_hermite(num_components,
                                 20)  # Gauss-Hermite sigma points and weights
        else:
            x, w = cubature_func(num_components)

        subband_mean, modulator_mean = cav_mean[:num_components], self.link_fn(
            cav_mean[num_components:])
        subband_cov, modulator_cov = cav_cov[:num_components, :
                                             num_components], cav_cov[
                                                 num_components:,
                                                 num_components:]
        sigma_points = cholesky(modulator_cov) @ x + modulator_mean
        const = power**-0.5 * (2 * pi * hyp)**(0.5 - 0.5 * power)
        mu = (self.link_fn(sigma_points).T @ subband_mean)[:, 0]
        var = hyp / power + (self.link_fn(sigma_points).T**2
                             @ np.diag(subband_cov)[..., None])[:, 0]
        normpdf = const * (2 * pi * var)**-0.5 * np.exp(-0.5 *
                                                        (y - mu)**2 / var)
        Z = np.sum(w * normpdf)
        Zinv = 1. / (Z + 1e-8)
        lZ = np.log(Z + 1e-8)

        dZ1 = np.sum(w * self.link_fn(sigma_points) * (y - mu) / var * normpdf,
                     axis=-1)
        dZ2 = np.sum(w * (sigma_points - modulator_mean) *
                     np.diag(modulator_cov)[..., None]**-1 * normpdf,
                     axis=-1)
        dlZ = Zinv * np.block([dZ1, dZ2])

        d2Z1 = np.sum(w * self.link_fn(sigma_points)**2 *
                      (((y - mu) / var)**2 - var**-1) * normpdf,
                      axis=-1)
        d2Z2 = np.sum(w * (((sigma_points - modulator_mean) *
                            np.diag(modulator_cov)[..., None]**-1)**2 -
                           np.diag(modulator_cov)[..., None]**-1) * normpdf,
                      axis=-1)
        d2lZ = np.diag(-dlZ**2 + Zinv * np.block([d2Z1, d2Z2]))

        id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0]))
        site_mean = cav_mean - id2lZ @ dlZ[
            ...,
            None]  # approx. likelihood (site) mean (see Rasmussen & Williams p75)
        site_cov = -power * (cav_cov + id2lZ
                             )  # approx. likelihood (site) variance
        return lZ, site_mean, site_cov
コード例 #4
0
ファイル: likelihoods.py プロジェクト: johannah/kalman-jax
    def moment_match(self,
                     y,
                     cav_mean,
                     cav_cov,
                     hyp=None,
                     power=1.0,
                     cubature_func=None):
        """
        """
        if cubature_func is None:
            x, w = gauss_hermite(1,
                                 20)  # Gauss-Hermite sigma points and weights
        else:
            x, w = cubature_func(1)
        # sigma_points = np.sqrt(2) * np.sqrt(v) * x + m  # scale locations according to cavity dist.
        sigma_points = np.sqrt(cav_cov[1, 1]) * x + cav_mean[
            1]  # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity

        f2 = self.link_fn(sigma_points)**2. / power
        obs_var = f2 + cav_cov[0, 0]
        const = power**-0.5 * (2 * pi * self.link_fn(sigma_points)**2.)**(
            0.5 - 0.5 * power)
        normpdf = const * (2 * pi * obs_var)**-0.5 * np.exp(
            -0.5 * (y - cav_mean[0, 0])**2 / obs_var)
        Z = np.sum(w * normpdf)
        Zinv = 1. / np.maximum(Z, 1e-8)
        lZ = np.log(np.maximum(Z, 1e-8))

        dZ_integrand1 = (y - cav_mean[0, 0]) / obs_var * normpdf
        dlZ1 = Zinv * np.sum(w * dZ_integrand1)

        dZ_integrand2 = (sigma_points - cav_mean[1, 0]) / cav_cov[1,
                                                                  1] * normpdf
        dlZ2 = Zinv * np.sum(w * dZ_integrand2)

        d2Z_integrand1 = (-(f2 + cav_cov[0, 0])**-1 +
                          ((y - cav_mean[0, 0]) / obs_var)**2) * normpdf
        d2lZ1 = -dlZ1**2 + Zinv * np.sum(w * d2Z_integrand1)

        d2Z_integrand2 = (-cav_cov[1, 1]**-1 + (
            (sigma_points - cav_mean[1, 0]) / cav_cov[1, 1])**2) * normpdf
        d2lZ2 = -dlZ2**2 + Zinv * np.sum(w * d2Z_integrand2)

        dlZ = np.block([[dlZ1], [dlZ2]])
        d2lZ = np.block([[d2lZ1, 0], [0., d2lZ2]])
        id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0]))
        site_mean = cav_mean - id2lZ @ dlZ  # approx. likelihood (site) mean (see Rasmussen & Williams p75)
        site_cov = -power * (cav_cov + id2lZ
                             )  # approx. likelihood (site) variance
        return lZ, site_mean, site_cov
コード例 #5
0
ファイル: likelihoods.py プロジェクト: johannah/kalman-jax
    def moment_match_cubature(self,
                              y,
                              cav_mean,
                              cav_cov,
                              hyp=None,
                              power=1.0,
                              cubature_func=None):
        """
        TODO: N.B. THIS VERSION IS SUPERCEDED BY THE FUNCTION BELOW. HOWEVER THIS ONE MAY BE MORE STABLE.
        Perform moment matching via cubature.
        Moment matching invloves computing the log partition function, logZₙ, and its derivatives w.r.t. the cavity mean
            logZₙ = log ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        with EP power a.
        :param y: observed data (yₙ) [scalar]
        :param cav_mean: cavity mean (mₙ) [scalar]
        :param cav_cov: cavity covariance (cₙ) [scalar]
        :param hyp: likelihood hyperparameter [scalar]
        :param power: EP power / fraction (a) [scalar]
        :param cubature_func: the function to compute sigma points and weights to use during cubature
        :return:
            lZ: the log partition function, logZₙ  [scalar]
            dlZ: first derivative of logZₙ w.r.t. mₙ (if derivatives=True)  [scalar]
            d2lZ: second derivative of logZₙ w.r.t. mₙ (if derivatives=True)  [scalar]
        """
        if cubature_func is None:
            x, w = gauss_hermite(cav_mean.shape[0],
                                 20)  # Gauss-Hermite sigma points and weights
        else:
            x, w = cubature_func(cav_mean.shape[0])
        cav_cho, low = cho_factor(cav_cov)
        # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity dist.
        sigma_points = cav_cho @ np.atleast_2d(x) + cav_mean
        # pre-compute wᵢ pᵃ(yₙ|xᵢ√(2vₙ) + mₙ)
        weighted_likelihood_eval = w * self.evaluate_likelihood(
            y, sigma_points, hyp)**power

        # a different approach, based on the log-likelihood, which can be more stable:
        # ll = self.evaluate_log_likelihood(y, sigma_points)
        # lmax = np.max(ll)
        # weighted_likelihood_eval = np.exp(lmax * power) * w * np.exp(power * (ll - lmax))

        # Compute partition function via cubature:
        # Zₙ = ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        #    ≈ ∑ᵢ wᵢ pᵃ(yₙ|fsigᵢ)
        Z = np.sum(weighted_likelihood_eval, axis=-1)
        lZ = np.log(Z)
        Zinv = 1.0 / Z

        # Compute derivative of partition function via cubature:
        # dZₙ/dmₙ = ∫ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        #         ≈ ∑ᵢ wᵢ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fsigᵢ)
        covinv_f_m = cho_solve((cav_cho, low), sigma_points - cav_mean)
        dZ = np.sum(
            # (sigma_points - cav_mean) / cav_cov
            covinv_f_m * weighted_likelihood_eval,
            axis=-1)
        # dlogZₙ/dmₙ = (dZₙ/dmₙ) / Zₙ
        dlZ = Zinv * dZ

        # Compute second derivative of partition function via cubature:
        # d²Zₙ/dmₙ² = ∫ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        #           ≈ ∑ᵢ wᵢ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fsigᵢ)
        d2Z = np.sum(
            ((sigma_points - cav_mean)**2 / cav_cov**2 - 1.0 / cav_cov) *
            weighted_likelihood_eval)

        # d²logZₙ/dmₙ² = d[(dZₙ/dmₙ) / Zₙ]/dmₙ
        #              = (d²Zₙ/dmₙ² * Zₙ - (dZₙ/dmₙ)²) / Zₙ²
        #              = d²Zₙ/dmₙ² / Zₙ - (dlogZₙ/dmₙ)²
        d2lZ = -dlZ @ dlZ.T + Zinv * d2Z
        site_mean = cav_mean - inv_any(
            d2lZ
        ) @ dlZ  # approx. likelihood (site) mean (see Rasmussen & Williams p75)
        site_cov = -power * (cav_cov + inv_any(d2lZ)
                             )  # approx. likelihood (site) variance
        return lZ, site_mean, site_cov
コード例 #6
0
ファイル: likelihoods.py プロジェクト: johannah/kalman-jax
    def moment_match_cubature(self,
                              y,
                              cav_mean,
                              cav_cov,
                              hyp=None,
                              power=1.0,
                              cubature_func=None):
        """
        TODO: N.B. THIS VERSION ALLOWS MULTI-DIMENSIONAL MOMENT MATCHING, BUT CAN BE UNSTABLE
        Perform moment matching via cubature.
        Moment matching invloves computing the log partition function, logZₙ, and its derivatives w.r.t. the cavity mean
            logZₙ = log ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        with EP power a.
        :param y: observed data (yₙ) [scalar]
        :param cav_mean: cavity mean (mₙ) [scalar]
        :param cav_cov: cavity covariance (cₙ) [scalar]
        :param hyp: likelihood hyperparameter [scalar]
        :param power: EP power / fraction (a) [scalar]
        :param cubature_func: the function to compute sigma points and weights to use during cubature
        :return:
            lZ: the log partition function, logZₙ  [scalar]
            dlZ: first derivative of logZₙ w.r.t. mₙ (if derivatives=True)  [scalar]
            d2lZ: second derivative of logZₙ w.r.t. mₙ (if derivatives=True)  [scalar]
        """
        if cubature_func is None:
            x, w = gauss_hermite(cav_mean.shape[0],
                                 20)  # Gauss-Hermite sigma points and weights
        else:
            x, w = cubature_func(cav_mean.shape[0])
        cav_cho, low = cho_factor(cav_cov)
        # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity dist.
        sigma_points = cav_cho @ np.atleast_2d(x) + cav_mean
        # pre-compute wᵢ pᵃ(yₙ|xᵢ√(2vₙ) + mₙ)
        weighted_likelihood_eval = w * self.evaluate_likelihood(
            y, sigma_points, hyp)**power

        # Compute partition function via cubature:
        # Zₙ = ∫ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        #    ≈ ∑ᵢ wᵢ pᵃ(yₙ|fsigᵢ)
        Z = np.sum(weighted_likelihood_eval, axis=-1)
        lZ = np.log(np.maximum(Z, 1e-8))
        Zinv = 1.0 / np.maximum(Z, 1e-8)

        # Compute derivative of partition function via cubature:
        # dZₙ/dmₙ = ∫ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        #         ≈ ∑ᵢ wᵢ (fₙ-mₙ) vₙ⁻¹ pᵃ(yₙ|fsigᵢ)
        d1 = vmap(gaussian_first_derivative_wrt_mean,
                  (1, None, None, 1))(sigma_points[..., None], cav_mean,
                                      cav_cov, weighted_likelihood_eval)
        dZ = np.sum(d1, axis=0)
        # dlogZₙ/dmₙ = (dZₙ/dmₙ) / Zₙ
        dlZ = Zinv * dZ

        # Compute second derivative of partition function via cubature:
        # d²Zₙ/dmₙ² = ∫ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ
        #           ≈ ∑ᵢ wᵢ [(fₙ-mₙ)² vₙ⁻² - vₙ⁻¹] pᵃ(yₙ|fsigᵢ)
        d2 = vmap(gaussian_second_derivative_wrt_mean,
                  (1, None, None, 1))(sigma_points[..., None], cav_mean,
                                      cav_cov, weighted_likelihood_eval)
        d2Z = np.sum(d2, axis=0)

        # d²logZₙ/dmₙ² = d[(dZₙ/dmₙ) / Zₙ]/dmₙ
        #              = (d²Zₙ/dmₙ² * Zₙ - (dZₙ/dmₙ)²) / Zₙ²
        #              = d²Zₙ/dmₙ² / Zₙ - (dlogZₙ/dmₙ)²
        d2lZ = -dlZ @ dlZ.T + Zinv * d2Z
        id2lZ = inv_any(d2lZ + 1e-10 * np.eye(d2lZ.shape[0]))
        site_mean = cav_mean - id2lZ @ dlZ  # approx. likelihood (site) mean (see Rasmussen & Williams p75)
        site_cov = -power * (cav_cov + id2lZ
                             )  # approx. likelihood (site) variance
        return lZ, site_mean, site_cov