def test_n5_d2(self): x = jnp.ones((5, 2)) npt.assert_array_equal( utils.gaussian_potential(x), jnp.repeat( -multivariate_normal.logpdf(x[0], jnp.zeros(x.shape[-1]), 1.), 5)) m = 3. npt.assert_array_equal( utils.gaussian_potential(x, m), jnp.repeat( -multivariate_normal.logpdf(x[0], m * jnp.ones(x.shape[-1]), 1.), 5)) m = jnp.ones(2) * 3. npt.assert_array_equal( utils.gaussian_potential(x, m), jnp.repeat(-multivariate_normal.logpdf(x[0], m, 1.), 5)) with warnings.catch_warnings(): warnings.simplefilter("ignore") sqrt_prec = jnp.array([[5., 0.], [2., 3.]]) npt.assert_array_equal( utils.gaussian_potential(x, sqrt_prec=sqrt_prec), jnp.repeat(0.5 * x[0].T @ sqrt_prec @ sqrt_prec.T @ x[0], 5)) npt.assert_array_equal( utils.gaussian_potential(x, m, sqrt_prec=sqrt_prec), jnp.repeat( 0.5 * (x[0] - m).T @ sqrt_prec @ sqrt_prec.T @ (x[0] - m), 5))
def loss_one_pair_with_prior(mu_i, mu_j, s_i, s_j, D, n_components): log_prior = 0.0 log_prior = log_prior + multivariate_normal.logpdf( mu_i, mean=zeros2, cov=1.0) log_prior = log_prior + multivariate_normal.logpdf( mu_j, mean=zeros2, cov=1.0) return loss_one_pair(mu_i, mu_j, s_i, s_j, D, n_components) - log_prior
def log_driven_Langevin_kernel(x_tm1, x_t, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter): """ compute the log of the driven langevin kernel transition probability arguments x_tm1 : jnp.array(N) previous iteration position (of dimension N) x_t : jnp.array(N) current iteration positions (of dimension N) potential : function potential function dt : float time increment A_function : function variance_controlling function b_function : function mean_controlling function potential_parameter : jnp.array(Q) second argument to potential function A_parameter : jnp.array(R) argument to A_function b_parameter : jnp.array(S) argument to b_function return logp : float log probability """ from jax.scipy.stats import multivariate_normal as mvn A, b, f, theta = driven_Langevin_parameters(x_tm1, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter) mu, cov = driven_mu_cov(b, f, theta, dt) logp = mvn.logpdf(x_t, mu, cov) return logp
def log_pdf(params, inputs): cluster_lls = [] for log_weight, mean, cov in zip(np.log(weights), means, covariances): cluster_lls.append( log_weight + multivariate_normal.logpdf(inputs, mean, cov)) return logsumexp(np.vstack(cluster_lls), axis=0)
def logpdf(self, z): """Compute the logpdf from sample z.""" capital_phi = norm.logcdf(jnp.matmul(self.alpha, (z - self.loc).T)) small_phi = mvn.logpdf(z - self.loc, mean=jnp.zeros(shape=(self.k), ), cov=self.cov) return 2 + small_phi + capital_phi
def loss_MAP(mu, tau_unc, D, i0, i1, mu0, beta=1.0, gamma_shape=1.0, gamma_rate=1.0, alpha=1.0): mu_i, mu_j = mu[i0], mu[i1] tau = EPSILON + jax.nn.softplus(SCALE * tau_unc) tau_i, tau_j = tau[i0], tau[i1] tau_ij_inv = tau_i * tau_j / (tau_i + tau_j) log_tau_ij_inv = jnp.log(tau_i) + jnp.log(tau_j) - jnp.log(tau_i + tau_j) d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1) log_llh = (jnp.log(D) + log_tau_ij_inv - 0.5 * tau_ij_inv * (D - d)**2 + jnp.log(i0e(tau_ij_inv * D * d))) # index of points in prior log_mu = multivariate_normal.logpdf(mu, mean=mu0, cov=beta * jnp.eye(2)) log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate) return jnp.sum(log_llh) + jnp.sum(log_mu) + jnp.sum(log_tau)
def ll(sigma, theta, y): p = y.shape[-1] sc = jnp.sqrt(jnp.diag(sigma)) al = jnp.einsum('i,i->i', 1 / sc, theta) capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T))) small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma)) return -(2 + small_phi + capital_phi)
def log_normal_batch(W, sigma, Vs): mu = W.ravel() Vs_ = Vs.reshape((-1, Vs.shape[-1])).T # Vs_ is of shape M, N assert mu.shape[0] == Vs_.shape[1] d = mu.shape[0] cov = sigma**2 * jnp.eye(d) return logpdf(Vs_, mu, cov)
def update(observation_function: Callable, observation_covariance: jnp.ndarray, predicted_state: MVNormalParameters, observation: jnp.ndarray, linearization_state: MVNormalParameters) -> MVNormalParameters: """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t` Parameters ---------- observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` observation function of the state space model observation_covariance: (K,K) array observation_error :math:`\Sigma` fed to observation_function predicted_state: MVNormalParameters predicted approximate mv normal parameters of the filter :math:`x` observation: (K) array Observation :math:`y` linearization_state: MVNormalParameters state for the linearization of the update Returns ------- updated_mvn_parameters: MVNormalParameters filtered state """ if linearization_state is None: linearization_state = predicted_state sigma_points = get_sigma_points(linearization_state) obs_points = observation_function(sigma_points.points) obs_sigma_points = SigmaPoints(obs_points, sigma_points.wm, sigma_points.wc) obs_state = get_mv_normal_parameters(obs_sigma_points) cross_covariance = covariance_sigma_points(sigma_points, linearization_state.mean, obs_sigma_points, obs_state.mean) H = jlinalg.solve(linearization_state.cov, cross_covariance, sym_pos=True).T # linearized observation function d = obs_state.mean - jnp.dot( H, linearization_state.mean) # linearized observation offset residual_cov = H @ (predicted_state.cov - linearization_state.cov) @ H.T + \ observation_covariance + obs_state.cov gain = jlinalg.solve(residual_cov, H @ predicted_state.cov).T predicted_observation = H @ predicted_state.mean + d residual = observation - predicted_observation mean = predicted_state.mean + gain @ residual cov = predicted_state.cov - gain @ residual_cov @ gain.T loglikelihood = multivariate_normal.logpdf(residual, jnp.zeros_like(residual), residual_cov) return loglikelihood, MVNormalParameters(mean, 0.5 * (cov + cov.T))
def ll_chol(pars, y): p = y.shape[-1] X, theta = pars[:-p], pars[-p:] sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T sigma = jnp.matmul(sigma, sigma.T) sc = jnp.sqrt(jnp.diag(sigma)) al = jnp.einsum('i,i->i', 1 / sc, theta) capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T))) small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma)) return -(2 + small_phi + capital_phi)
def log_normal_gamma_prior(mu, tau, mu0=0.0, beta=1.0, gamma_shape=1.0, gamma_rate=1.0): log_mu = multivariate_normal.logpdf(mu, mean=0.0, cov=beta).sum() # sum of 2 dimensions log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate) # print("[DEBUG] Log prior: ", log_mu.shape, log_tau.shape) return log_mu + log_tau
def update( observation_function: Callable[[jnp.ndarray], jnp.ndarray], observation_covariance: jnp.ndarray, predicted: MVNormalParameters, observation: jnp.ndarray, linearization_point: jnp.ndarray) -> Tuple[float, MVNormalParameters]: """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t` Parameters ---------- observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` observation function of the state space model observation_covariance: (K,K) array observation_error :math:`\Sigma` fed to observation_function predicted: MVNormalParameters predicted state of the filter :math:`x` observation: (K) array Observation :math:`y` linearization_point: jnp.ndarray Where to compute the Jacobian Returns ------- loglikelihood: float Log-likelihood increment for observation updated_state: MVNormalParameters filtered state """ if linearization_point is None: linearization_point = predicted.mean jac_x = jacfwd(observation_function, 0)(linearization_point) obs_mean = observation_function(linearization_point) + jnp.dot( jac_x, predicted.mean - linearization_point) residual = observation - obs_mean residual_covariance = jnp.dot(jac_x, jnp.dot(predicted.cov, jac_x.T)) residual_covariance = residual_covariance + observation_covariance gain = jnp.dot(predicted.cov, jlag.solve(residual_covariance, jac_x, sym_pos=True).T) mean = predicted.mean + jnp.dot(gain, residual) cov = predicted.cov - jnp.dot(gain, jnp.dot(residual_covariance, gain.T)) updated_state = MVNormalParameters(mean, 0.5 * (cov + cov.T)) loglikelihood = multivariate_normal.logpdf(residual, jnp.zeros_like(residual), residual_covariance) return loglikelihood, updated_state
def test_n1_d1(self): x = jnp.array([7.]) npt.assert_array_equal(utils.gaussian_potential(x), -multivariate_normal.logpdf(x, 0., 1.)) m = jnp.array([1.]) npt.assert_array_equal(utils.gaussian_potential(x, m), -multivariate_normal.logpdf(x, m, 1.)) prec = jnp.array([[2.]]) # test diag npt.assert_array_equal(utils.gaussian_potential(x, prec=prec[0]), -multivariate_normal.logpdf(x, 0, 1 / prec)) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec[0]), -multivariate_normal.logpdf(x, m, 1 / prec), decimal=4) # test full (omits norm constant) with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal(utils.gaussian_potential(x, prec=prec), 0.5 * x**2 * prec) npt.assert_array_equal(utils.gaussian_potential(x, m, prec), 0.5 * (x - m)**2 * prec)
def log_prob(self, inputs: np.ndarray) -> np.ndarray: """Calculates log probability density of inputs. Parameters ---------- inputs : np.ndarray Input data for which log probability density is calculated. Returns ------- np.ndarray Device array of shape (inputs.shape[0],). """ return multivariate_normal.logpdf( x=inputs, mean=np.zeros(self.input_dim), cov=np.identity(self.input_dim), )
def test_normal(self): inputs = random.uniform(random.PRNGKey(0), (20, 2), minval=-3.0, maxval=3.0) input_dim = inputs.shape[1] init_key, sample_key = random.split(random.PRNGKey(0)) init_fun = flows.Normal() params, log_pdf, sample = init_fun(init_key, input_dim) log_pdfs = log_pdf(params, inputs) mean = np.zeros(input_dim) covariance = np.eye(input_dim) true_log_pdfs = multivariate_normal.logpdf(inputs, mean, covariance) self.assertTrue(np.allclose(log_pdfs, true_log_pdfs)) for test in (returns_correct_shape, ): test(self, flows.Normal())
def log_Euler_Maruyma_kernel(x_tm1, x_t, potential, potential_parameters, dt): """ the log kernel probability of the transition arguments x_tm1 : jnp.array(N) previous iteration position (of dimension N) x_t : jnp.array(N) current iteration positions (of dimension N) potential : function potential function potential_parameters : jnp.array(Q) second argument to potential function dt : float time increment returns logp : float the log probability of the kernel """ from jax.scipy.stats import multivariate_normal as mvn mu, cov = EL_mu_sigma(x_tm1, potential, dt, potential_parameters) logp = mvn.logpdf(x_t, mu, cov) return logp
def log_prior_mu(mu, mu0, sigma0): return multivariate_normal.logpdf(mu, mean=mu0, cov=sigma0).sum()
def calculate_prior(x, mu, cov_mat, theta): return multivariate_normal.logpdf(x, mu, cov_mat)
def _log_prior_ss(ss): return -multivariate_normal.logpdf(ss, mean=zeros2, cov=1.0)
def log_normal(W, sigma, V): mu = W.ravel() V_ = V.ravel() d = mu.shape[0] cov = sigma**2 * jnp.eye(d) return logpdf(V_, mu, cov)
def log_pdf(params, inputs): return multivariate_normal.logpdf(inputs, mean, covariance)
def test_n1_d5(self): x = jnp.ones(5) npt.assert_array_equal( utils.gaussian_potential(x), -multivariate_normal.logpdf(x, jnp.zeros_like(x), 1.)) m = 3. npt.assert_array_equal( utils.gaussian_potential(x, m), -multivariate_normal.logpdf(x, m * jnp.ones_like(x), 1.)) m = jnp.ones(5) * 3. npt.assert_array_equal(utils.gaussian_potential(x, m), -multivariate_normal.logpdf(x, m, 1.)) # diagonal precision prec = jnp.eye(5) * 2 # test diag npt.assert_array_almost_equal( utils.gaussian_potential(x, prec=jnp.diag(prec)), -multivariate_normal.logpdf(x, jnp.zeros_like(x), jnp.linalg.inv(prec)), decimal=5) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec=jnp.diag(prec), det_prec=2**5), -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)), decimal=5) # test full (omits norm constant) with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal(utils.gaussian_potential(x, prec=prec), 0.5 * x.T @ prec @ x) npt.assert_array_equal(utils.gaussian_potential(x, m, prec), 0.5 * (x - m).T @ prec @ (x - m)) # test full with det npt.assert_array_almost_equal( utils.gaussian_potential(x, prec=prec, det_prec=2**5), -multivariate_normal.logpdf(x, jnp.zeros_like(x), jnp.linalg.inv(prec)), decimal=5) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec=prec, det_prec=2**5), -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)), decimal=5) # non-diagonal precision sqrt_prec = jnp.arange(25).reshape(5, 5) / 100 + jnp.eye(5) prec = sqrt_prec @ sqrt_prec.T # test full (omits norm constant) with warnings.catch_warnings(): warnings.simplefilter("ignore") npt.assert_array_equal(utils.gaussian_potential(x, prec=prec), 0.5 * x.T @ prec @ x) npt.assert_array_equal(utils.gaussian_potential(x, m, prec), 0.5 * (x - m).T @ prec @ (x - m)) # test full with det npt.assert_array_almost_equal( utils.gaussian_potential(x, prec=prec, det_prec=jnp.linalg.det(prec)), -multivariate_normal.logpdf(x, jnp.zeros_like(x), jnp.linalg.inv(prec)), decimal=5) npt.assert_array_almost_equal( utils.gaussian_potential(x, m, prec=prec, det_prec=jnp.linalg.det(prec)), -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)), decimal=5)
def _log_prior_mu(mu): return -multivariate_normal.logpdf(mu, mean=zeros2, cov=1.0)