示例#1
0
    def test_n5_d2(self):
        x = jnp.ones((5, 2))
        npt.assert_array_equal(
            utils.gaussian_potential(x),
            jnp.repeat(
                -multivariate_normal.logpdf(x[0], jnp.zeros(x.shape[-1]), 1.),
                5))

        m = 3.
        npt.assert_array_equal(
            utils.gaussian_potential(x, m),
            jnp.repeat(
                -multivariate_normal.logpdf(x[0], m * jnp.ones(x.shape[-1]),
                                            1.), 5))

        m = jnp.ones(2) * 3.
        npt.assert_array_equal(
            utils.gaussian_potential(x, m),
            jnp.repeat(-multivariate_normal.logpdf(x[0], m, 1.), 5))

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sqrt_prec = jnp.array([[5., 0.], [2., 3.]])
            npt.assert_array_equal(
                utils.gaussian_potential(x, sqrt_prec=sqrt_prec),
                jnp.repeat(0.5 * x[0].T @ sqrt_prec @ sqrt_prec.T @ x[0], 5))
            npt.assert_array_equal(
                utils.gaussian_potential(x, m, sqrt_prec=sqrt_prec),
                jnp.repeat(
                    0.5 * (x[0] - m).T @ sqrt_prec @ sqrt_prec.T @ (x[0] - m),
                    5))
示例#2
0
    def acceptance_probability(
            self, scenario: Scenario, reject_state: cdict, reject_extra: cdict,
            proposed_state: cdict,
            proposed_extra: cdict) -> Union[float, jnp.ndarray]:
        pre_min_alpha = jnp.exp(
            -proposed_state.potential + reject_state.potential -
            utils.gaussian_potential(proposed_state.momenta) +
            utils.gaussian_potential(reject_state.momenta))

        return jnp.minimum(1., pre_min_alpha)
 def transition_potential(self, x_previous: jnp.ndarray, t_previous: float,
                          x_new: jnp.ndarray,
                          t_new: float) -> Union[float, jnp.ndarray]:
     return gaussian_potential(x_new,
                               x_previous @ self.transition_matrix.T,
                               sqrt_prec=self.transition_precision_sqrt *
                               jnp.sqrt(t_new - t_previous))
示例#4
0
 def transition_potential(self, x_previous: jnp.ndarray, t_previous: float,
                          x_new: jnp.ndarray,
                          t_new: float) -> Union[float, jnp.ndarray]:
     return gaussian_potential(x_new,
                               self.transition_function(
                                   x_previous, t_previous, t_new),
                               sqrt_prec=self.transition_precision_sqrt,
                               det_prec=self.transition_precision_det)
示例#5
0
 def prior_potential(
         self,
         x: jnp.ndarray,
         random_key: jnp.ndarray = None) -> Union[float, jnp.ndarray]:
     return gaussian_potential(x,
                               self.prior_mean,
                               sqrt_prec=self.prior_precision_sqrt,
                               det_prec=self.prior_precision_det)
示例#6
0
 def likelihood_potential(
         self,
         x: jnp.ndarray,
         random_key: jnp.ndarray = None) -> Union[float, jnp.ndarray]:
     return gaussian_potential(self.data,
                               x @ self.likelihood_matrix.T,
                               sqrt_prec=self.likelihood_precision_sqrt,
                               det_prec=self.likelihood_precision_det)
示例#7
0
 def initial_potential(self, x: jnp.ndarray,
                       t: Union[float, None]) -> Union[float, jnp.ndarray]:
     init_mean = self.get_initial_mean(t)
     init_prec_sqrt = self.get_initial_precision_sqrt(t)
     init_prec_det = self.get_initial_precision_det(t)
     return gaussian_potential(x,
                               init_mean,
                               sqrt_prec=init_prec_sqrt,
                               det_prec=init_prec_det)
示例#8
0
 def likelihood_potential(self, x: jnp.ndarray, y: jnp.ndarray,
                          t: float) -> Union[float, jnp.ndarray]:
     likelihood_mat = self.get_likelihood_matrix(t)
     likelihood_prec_sqrt = self.get_likelihood_precision_sqrt(t)
     likelihood_prec_det = self.get_likelihood_precision_det(t)
     return gaussian_potential(y,
                               x @ likelihood_mat.T,
                               sqrt_prec=likelihood_prec_sqrt,
                               det_prec=likelihood_prec_det)
示例#9
0
    def proposal_potential(self, scenario: Scenario, reject_state: cdict,
                           reject_extra: cdict, proposed_state: cdict,
                           proposed_extra: cdict) -> Union[float, jnp.ndarray]:
        stepsize = reject_extra.parameters.stepsize

        return utils.gaussian_potential(
            proposed_state.value,
            reject_state.value - stepsize * reject_state.grad_potential,
            1. / (2 * stepsize))
示例#10
0
 def intermediate_log_weight(self, ssm_scenario: NonLinearGaussian,
                             x_previous: jnp.ndarray, t_previous: float,
                             x_new: jnp.ndarray, y_new: jnp.ndarray,
                             t_new: float) -> Union[float, jnp.ndarray]:
     mx = ssm_scenario.transition_function(x_previous, t_previous, t_new)
     return -gaussian_potential(y_new,
                                mx @ ssm_scenario.likelihood_matrix.T,
                                sqrt_prec=self.weight_precision_sqrt,
                                det_prec=self.weight_precision_det)
示例#11
0
 def proposal_potential(self, ssm_scenario: NonLinearGaussian,
                        x_previous: jnp.ndarray, t_previous: float,
                        x_new: jnp.ndarray, y_new: jnp.ndarray,
                        t_new: float) -> Union[float, jnp.ndarray]:
     mx = ssm_scenario.transition_function(x_previous, t_previous, t_new)
     conditioned_mean = mx + (y_new - mx @ ssm_scenario.likelihood_matrix.T
                              ) @ self.proposal_kalman_gain.T
     return gaussian_potential(x_new,
                               conditioned_mean,
                               sqrt_prec=self.proposal_precision_sqrt,
                               det_prec=self.proposal_precision_det)
示例#12
0
 def initial_potential(self, ssm_scenario: NonLinearGaussian,
                       x: jnp.ndarray, y: jnp.ndarray,
                       t: float) -> Union[float, jnp.ndarray]:
     initial_conditioned_mean = ssm_scenario.initial_mean \
                                + self.initial_kalman_gain \
                                @ (y - ssm_scenario.likelihood_matrix @ ssm_scenario.initial_mean)
     return gaussian_potential(
         x,
         initial_conditioned_mean,
         sqrt_prec=self.initial_conditioned_precision_sqrt,
         det_prec=self.initial_conditioned_precision_det)
示例#13
0
 def transition_potential(self, x_previous: jnp.ndarray, t_previous: float,
                          x_new: jnp.ndarray,
                          t_new: float) -> Union[float, jnp.ndarray]:
     transition_mat = self.get_transition_matrix(t_previous, t_new)
     transition_prec_sqrt = self.get_transition_precision_sqrt(
         t_previous, t_new)
     transition_prec_det = self.get_transition_precision_det(
         t_previous, t_new)
     return gaussian_potential(x_new,
                               x_previous @ transition_mat.T,
                               sqrt_prec=transition_prec_sqrt,
                               det_prec=transition_prec_det)
示例#14
0
    def propose_and_intermediate_weight_vectorised(
            self, ssm_scenario: NonLinearGaussian, x_previous: jnp.ndarray,
            t_previous: float, y_new: jnp.ndarray, t_new: float,
            random_keys: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        mx = vmap(ssm_scenario.transition_function,
                  (0, None, None))(x_previous, t_previous, t_new)
        conditioned_mean = mx + (y_new - mx @ ssm_scenario.likelihood_matrix.T
                                 ) @ self.proposal_kalman_gain.T
        x_new = conditioned_mean \
                + random.normal(random_keys[0], shape=x_previous.shape) @ self.proposal_covariance_sqrt.T

        log_weight_new = -gaussian_potential(
            y_new,
            mx @ ssm_scenario.likelihood_matrix.T,
            sqrt_prec=self.weight_precision_sqrt,
            det_prec=self.weight_precision_det)
        return x_new, log_weight_new
示例#15
0
    def test_n1_d1(self):
        x = jnp.array([7.])
        npt.assert_array_equal(utils.gaussian_potential(x),
                               -multivariate_normal.logpdf(x, 0., 1.))

        m = jnp.array([1.])
        npt.assert_array_equal(utils.gaussian_potential(x, m),
                               -multivariate_normal.logpdf(x, m, 1.))

        prec = jnp.array([[2.]])
        # test diag
        npt.assert_array_equal(utils.gaussian_potential(x, prec=prec[0]),
                               -multivariate_normal.logpdf(x, 0, 1 / prec))
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, m, prec[0]),
            -multivariate_normal.logpdf(x, m, 1 / prec),
            decimal=4)
        # test full (omits norm constant)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            npt.assert_array_equal(utils.gaussian_potential(x, prec=prec),
                                   0.5 * x**2 * prec)
            npt.assert_array_equal(utils.gaussian_potential(x, m, prec),
                                   0.5 * (x - m)**2 * prec)
示例#16
0
    def test_n1_d5(self):
        x = jnp.ones(5)
        npt.assert_array_equal(
            utils.gaussian_potential(x),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x), 1.))

        m = 3.
        npt.assert_array_equal(
            utils.gaussian_potential(x, m),
            -multivariate_normal.logpdf(x, m * jnp.ones_like(x), 1.))

        m = jnp.ones(5) * 3.
        npt.assert_array_equal(utils.gaussian_potential(x, m),
                               -multivariate_normal.logpdf(x, m, 1.))

        # diagonal precision
        prec = jnp.eye(5) * 2
        # test diag
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, prec=jnp.diag(prec)),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x),
                                        jnp.linalg.inv(prec)),
            decimal=5)
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, m, prec=jnp.diag(prec), det_prec=2**5),
            -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)),
            decimal=5)
        # test full (omits norm constant)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            npt.assert_array_equal(utils.gaussian_potential(x, prec=prec),
                                   0.5 * x.T @ prec @ x)
            npt.assert_array_equal(utils.gaussian_potential(x, m, prec),
                                   0.5 * (x - m).T @ prec @ (x - m))
        # test full with det
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, prec=prec, det_prec=2**5),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x),
                                        jnp.linalg.inv(prec)),
            decimal=5)
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, m, prec=prec, det_prec=2**5),
            -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)),
            decimal=5)

        # non-diagonal precision
        sqrt_prec = jnp.arange(25).reshape(5, 5) / 100 + jnp.eye(5)
        prec = sqrt_prec @ sqrt_prec.T
        # test full (omits norm constant)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            npt.assert_array_equal(utils.gaussian_potential(x, prec=prec),
                                   0.5 * x.T @ prec @ x)
            npt.assert_array_equal(utils.gaussian_potential(x, m, prec),
                                   0.5 * (x - m).T @ prec @ (x - m))
        # test full with det
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x,
                                     prec=prec,
                                     det_prec=jnp.linalg.det(prec)),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x),
                                        jnp.linalg.inv(prec)),
            decimal=5)
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x,
                                     m,
                                     prec=prec,
                                     det_prec=jnp.linalg.det(prec)),
            -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)),
            decimal=5)
示例#17
0
 def initial_potential(self, x: jnp.ndarray,
                       t: Union[float, None]) -> Union[float, jnp.ndarray]:
     init_mean = self.initial_mean
     init_prec_sqrt = self.initial_precision_sqrt
     return gaussian_potential(x, init_mean, sqrt_prec=init_prec_sqrt)
示例#18
0
 def likelihood_potential(self, x: jnp.ndarray, y: jnp.ndarray,
                          t: float) -> Union[float, jnp.ndarray]:
     return gaussian_potential(
         y,
         0,
         sqrt_prec=self.likelihood_precision_diag_sqrt / jnp.exp(0.5 * x))
示例#19
0
 def initial_potential(self, x: jnp.ndarray,
                       t: float) -> Union[float, jnp.ndarray]:
     return gaussian_potential(x,
                               self.initial_mean,
                               sqrt_prec=self.initial_precision_sqrt,
                               det_prec=self.initial_precision_det)
示例#20
0
 def likelihood_potential(self, x: jnp.ndarray, y: jnp.ndarray,
                          t: float) -> Union[float, jnp.ndarray]:
     return gaussian_potential(y,
                               x @ self.likelihood_matrix.T,
                               sqrt_prec=self.likelihood_precision_sqrt,
                               det_prec=self.likelihood_precision_det)