def test_n5_d2(self): x = jnp.ones((5, 2)) npt.assert_array_equal( utils.gaussian_potential(x), jnp.repeat( -multivariate_normal.logpdf(x[0], jnp.zeros(x.shape[-1]), 1.), 5)) m = 3. npt.assert_array_equal( utils.gaussian_potential(x, m), jnp.repeat( -multivariate_normal.logpdf(x[0], m * jnp.ones(x.shape[-1]), 1.), 5)) m = jnp.ones(2) * 3. npt.assert_array_equal( utils.gaussian_potential(x, m), jnp.repeat(-multivariate_normal.logpdf(x[0], m, 1.), 5)) with warnings.catch_warnings(): warnings.simplefilter("ignore") sqrt_prec = jnp.array([[5., 0.], [2., 3.]]) npt.assert_array_equal( utils.gaussian_potential(x, sqrt_prec=sqrt_prec), jnp.repeat(0.5 * x[0].T @ sqrt_prec @ sqrt_prec.T @ x[0], 5)) npt.assert_array_equal( utils.gaussian_potential(x, m, sqrt_prec=sqrt_prec), jnp.repeat( 0.5 * (x[0] - m).T @ sqrt_prec @ sqrt_prec.T @ (x[0] - m), 5))
def acceptance_probability( self, scenario: Scenario, reject_state: cdict, reject_extra: cdict, proposed_state: cdict, proposed_extra: cdict) -> Union[float, jnp.ndarray]: pre_min_alpha = jnp.exp( -proposed_state.potential + reject_state.potential - utils.gaussian_potential(proposed_state.momenta) + utils.gaussian_potential(reject_state.momenta)) return jnp.minimum(1., pre_min_alpha)
def transition_potential(self, x_previous: jnp.ndarray, t_previous: float, x_new: jnp.ndarray, t_new: float) -> Union[float, jnp.ndarray]: return gaussian_potential(x_new, x_previous @ self.transition_matrix.T, sqrt_prec=self.transition_precision_sqrt * jnp.sqrt(t_new - t_previous))
def transition_potential(self, x_previous: jnp.ndarray, t_previous: float, x_new: jnp.ndarray, t_new: float) -> Union[float, jnp.ndarray]: return gaussian_potential(x_new, self.transition_function( x_previous, t_previous, t_new), sqrt_prec=self.transition_precision_sqrt, det_prec=self.transition_precision_det)
def prior_potential( self, x: jnp.ndarray, random_key: jnp.ndarray = None) -> Union[float, jnp.ndarray]: return gaussian_potential(x, self.prior_mean, sqrt_prec=self.prior_precision_sqrt, det_prec=self.prior_precision_det)
def likelihood_potential( self, x: jnp.ndarray, random_key: jnp.ndarray = None) -> Union[float, jnp.ndarray]: return gaussian_potential(self.data, x @ self.likelihood_matrix.T, sqrt_prec=self.likelihood_precision_sqrt, det_prec=self.likelihood_precision_det)
def initial_potential(self, x: jnp.ndarray, t: Union[float, None]) -> Union[float, jnp.ndarray]: init_mean = self.get_initial_mean(t) init_prec_sqrt = self.get_initial_precision_sqrt(t) init_prec_det = self.get_initial_precision_det(t) return gaussian_potential(x, init_mean, sqrt_prec=init_prec_sqrt, det_prec=init_prec_det)
def likelihood_potential(self, x: jnp.ndarray, y: jnp.ndarray, t: float) -> Union[float, jnp.ndarray]: likelihood_mat = self.get_likelihood_matrix(t) likelihood_prec_sqrt = self.get_likelihood_precision_sqrt(t) likelihood_prec_det = self.get_likelihood_precision_det(t) return gaussian_potential(y, x @ likelihood_mat.T, sqrt_prec=likelihood_prec_sqrt, det_prec=likelihood_prec_det)
def proposal_potential(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict, proposed_state: cdict, proposed_extra: cdict) -> Union[float, jnp.ndarray]: stepsize = reject_extra.parameters.stepsize return utils.gaussian_potential( proposed_state.value, reject_state.value - stepsize * reject_state.grad_potential, 1. / (2 * stepsize))
def intermediate_log_weight(self, ssm_scenario: NonLinearGaussian, x_previous: jnp.ndarray, t_previous: float, x_new: jnp.ndarray, y_new: jnp.ndarray, t_new: float) -> Union[float, jnp.ndarray]: mx = ssm_scenario.transition_function(x_previous, t_previous, t_new) return -gaussian_potential(y_new, mx @ ssm_scenario.likelihood_matrix.T, sqrt_prec=self.weight_precision_sqrt, det_prec=self.weight_precision_det)
def proposal_potential(self, ssm_scenario: NonLinearGaussian, x_previous: jnp.ndarray, t_previous: float, x_new: jnp.ndarray, y_new: jnp.ndarray, t_new: float) -> Union[float, jnp.ndarray]: mx = ssm_scenario.transition_function(x_previous, t_previous, t_new) conditioned_mean = mx + (y_new - mx @ ssm_scenario.likelihood_matrix.T ) @ self.proposal_kalman_gain.T return gaussian_potential(x_new, conditioned_mean, sqrt_prec=self.proposal_precision_sqrt, det_prec=self.proposal_precision_det)
def initial_potential(self, ssm_scenario: NonLinearGaussian, x: jnp.ndarray, y: jnp.ndarray, t: float) -> Union[float, jnp.ndarray]: initial_conditioned_mean = ssm_scenario.initial_mean \ + self.initial_kalman_gain \ @ (y - ssm_scenario.likelihood_matrix @ ssm_scenario.initial_mean) return gaussian_potential( x, initial_conditioned_mean, sqrt_prec=self.initial_conditioned_precision_sqrt, det_prec=self.initial_conditioned_precision_det)
def transition_potential(self, x_previous: jnp.ndarray, t_previous: float, x_new: jnp.ndarray, t_new: float) -> Union[float, jnp.ndarray]: transition_mat = self.get_transition_matrix(t_previous, t_new) transition_prec_sqrt = self.get_transition_precision_sqrt( t_previous, t_new) transition_prec_det = self.get_transition_precision_det( t_previous, t_new) return gaussian_potential(x_new, x_previous @ transition_mat.T, sqrt_prec=transition_prec_sqrt, det_prec=transition_prec_det)
def propose_and_intermediate_weight_vectorised( self, ssm_scenario: NonLinearGaussian, x_previous: jnp.ndarray, t_previous: float, y_new: jnp.ndarray, t_new: float, random_keys: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: mx = vmap(ssm_scenario.transition_function, (0, None, None))(x_previous, t_previous, t_new) conditioned_mean = mx + (y_new - mx @ ssm_scenario.likelihood_matrix.T ) @ self.proposal_kalman_gain.T x_new = conditioned_mean \ + random.normal(random_keys[0], shape=x_previous.shape) @ self.proposal_covariance_sqrt.T log_weight_new = -gaussian_potential( y_new, mx @ ssm_scenario.likelihood_matrix.T, sqrt_prec=self.weight_precision_sqrt, det_prec=self.weight_precision_det) return x_new, log_weight_new
def test_n1_d1(self): x = jnp.array([7.]) npt.assert_array_equal(utils.gaussian_potential(x), -multivariate_normal.logpdf(x, 0., 1.)) m = jnp.array([1.]) npt.assert_array_equal(utils.gaussian_potential(x, m), -multivariate_normal.logpdf(x, m, 1.)) prec = jnp.array([[2.]]) # test diag npt.assert_array_equal(utils.gaussian_potential(x, prec=prec[0]), -multivariate_normal.logpdf(x, 0, 1 / prec)) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec[0]), -multivariate_normal.logpdf(x, m, 1 / prec), decimal=4) # test full (omits norm constant) with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal(utils.gaussian_potential(x, prec=prec), 0.5 * x**2 * prec) npt.assert_array_equal(utils.gaussian_potential(x, m, prec), 0.5 * (x - m)**2 * prec)
def test_n1_d5(self): x = jnp.ones(5) npt.assert_array_equal( utils.gaussian_potential(x), -multivariate_normal.logpdf(x, jnp.zeros_like(x), 1.)) m = 3. npt.assert_array_equal( utils.gaussian_potential(x, m), -multivariate_normal.logpdf(x, m * jnp.ones_like(x), 1.)) m = jnp.ones(5) * 3. npt.assert_array_equal(utils.gaussian_potential(x, m), -multivariate_normal.logpdf(x, m, 1.)) # diagonal precision prec = jnp.eye(5) * 2 # test diag npt.assert_array_almost_equal( utils.gaussian_potential(x, prec=jnp.diag(prec)), -multivariate_normal.logpdf(x, jnp.zeros_like(x), jnp.linalg.inv(prec)), decimal=5) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec=jnp.diag(prec), det_prec=2**5), -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)), decimal=5) # test full (omits norm constant) with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal(utils.gaussian_potential(x, prec=prec), 0.5 * x.T @ prec @ x) npt.assert_array_equal(utils.gaussian_potential(x, m, prec), 0.5 * (x - m).T @ prec @ (x - m)) # test full with det npt.assert_array_almost_equal( utils.gaussian_potential(x, prec=prec, det_prec=2**5), -multivariate_normal.logpdf(x, jnp.zeros_like(x), jnp.linalg.inv(prec)), decimal=5) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec=prec, det_prec=2**5), -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)), decimal=5) # non-diagonal precision sqrt_prec = jnp.arange(25).reshape(5, 5) / 100 + jnp.eye(5) prec = sqrt_prec @ sqrt_prec.T # test full (omits norm constant) with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal(utils.gaussian_potential(x, prec=prec), 0.5 * x.T @ prec @ x) npt.assert_array_equal(utils.gaussian_potential(x, m, prec), 0.5 * (x - m).T @ prec @ (x - m)) # test full with det npt.assert_array_almost_equal( utils.gaussian_potential(x, prec=prec, det_prec=jnp.linalg.det(prec)), -multivariate_normal.logpdf(x, jnp.zeros_like(x), jnp.linalg.inv(prec)), decimal=5) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec=prec, det_prec=jnp.linalg.det(prec)), -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)), decimal=5)
def initial_potential(self, x: jnp.ndarray, t: Union[float, None]) -> Union[float, jnp.ndarray]: init_mean = self.initial_mean init_prec_sqrt = self.initial_precision_sqrt return gaussian_potential(x, init_mean, sqrt_prec=init_prec_sqrt)
def likelihood_potential(self, x: jnp.ndarray, y: jnp.ndarray, t: float) -> Union[float, jnp.ndarray]: return gaussian_potential( y, 0, sqrt_prec=self.likelihood_precision_diag_sqrt / jnp.exp(0.5 * x))
def initial_potential(self, x: jnp.ndarray, t: float) -> Union[float, jnp.ndarray]: return gaussian_potential(x, self.initial_mean, sqrt_prec=self.initial_precision_sqrt, det_prec=self.initial_precision_det)
def likelihood_potential(self, x: jnp.ndarray, y: jnp.ndarray, t: float) -> Union[float, jnp.ndarray]: return gaussian_potential(y, x @ self.likelihood_matrix.T, sqrt_prec=self.likelihood_precision_sqrt, det_prec=self.likelihood_precision_det)