Exemple #1
0
    def test_n5_d2(self):
        x = jnp.ones((5, 2))
        npt.assert_array_equal(
            utils.gaussian_potential(x),
            jnp.repeat(
                -multivariate_normal.logpdf(x[0], jnp.zeros(x.shape[-1]), 1.),
                5))

        m = 3.
        npt.assert_array_equal(
            utils.gaussian_potential(x, m),
            jnp.repeat(
                -multivariate_normal.logpdf(x[0], m * jnp.ones(x.shape[-1]),
                                            1.), 5))

        m = jnp.ones(2) * 3.
        npt.assert_array_equal(
            utils.gaussian_potential(x, m),
            jnp.repeat(-multivariate_normal.logpdf(x[0], m, 1.), 5))

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sqrt_prec = jnp.array([[5., 0.], [2., 3.]])
            npt.assert_array_equal(
                utils.gaussian_potential(x, sqrt_prec=sqrt_prec),
                jnp.repeat(0.5 * x[0].T @ sqrt_prec @ sqrt_prec.T @ x[0], 5))
            npt.assert_array_equal(
                utils.gaussian_potential(x, m, sqrt_prec=sqrt_prec),
                jnp.repeat(
                    0.5 * (x[0] - m).T @ sqrt_prec @ sqrt_prec.T @ (x[0] - m),
                    5))
Exemple #2
0
def loss_one_pair_with_prior(mu_i, mu_j, s_i, s_j, D, n_components):
    log_prior = 0.0
    log_prior = log_prior + multivariate_normal.logpdf(
        mu_i, mean=zeros2, cov=1.0)
    log_prior = log_prior + multivariate_normal.logpdf(
        mu_j, mean=zeros2, cov=1.0)
    return loss_one_pair(mu_i, mu_j, s_i, s_j, D, n_components) - log_prior
Exemple #3
0
def log_driven_Langevin_kernel(x_tm1, x_t, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter):
    """
    compute the log of the driven langevin kernel transition probability
    arguments
        x_tm1 : jnp.array(N)
            previous iteration position (of dimension N)
        x_t : jnp.array(N)
            current iteration positions (of dimension N)
        potential : function
            potential function
        dt : float
            time increment
        A_function : function
            variance_controlling function
        b_function : function
            mean_controlling function
        potential_parameter : jnp.array(Q)
            second argument to potential function
        A_parameter :  jnp.array(R)
            argument to A_function
        b_parameter : jnp.array(S)
            argument to b_function

    return
        logp : float
            log probability
    """
    from jax.scipy.stats import multivariate_normal as mvn
    A, b, f, theta = driven_Langevin_parameters(x_tm1, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter)
    mu, cov = driven_mu_cov(b, f, theta, dt)
    logp = mvn.logpdf(x_t, mu, cov)
    return logp
 def log_pdf(params, inputs):
     cluster_lls = []
     for log_weight, mean, cov in zip(np.log(weights), means,
                                      covariances):
         cluster_lls.append(
             log_weight + multivariate_normal.logpdf(inputs, mean, cov))
     return logsumexp(np.vstack(cluster_lls), axis=0)
Exemple #5
0
 def logpdf(self, z):
     """Compute the logpdf from sample z."""
     capital_phi = norm.logcdf(jnp.matmul(self.alpha, (z - self.loc).T))
     small_phi = mvn.logpdf(z - self.loc,
                            mean=jnp.zeros(shape=(self.k), ),
                            cov=self.cov)
     return 2 + small_phi + capital_phi
Exemple #6
0
def loss_MAP(mu,
             tau_unc,
             D,
             i0,
             i1,
             mu0,
             beta=1.0,
             gamma_shape=1.0,
             gamma_rate=1.0,
             alpha=1.0):
    mu_i, mu_j = mu[i0], mu[i1]
    tau = EPSILON + jax.nn.softplus(SCALE * tau_unc)
    tau_i, tau_j = tau[i0], tau[i1]

    tau_ij_inv = tau_i * tau_j / (tau_i + tau_j)
    log_tau_ij_inv = jnp.log(tau_i) + jnp.log(tau_j) - jnp.log(tau_i + tau_j)

    d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1)

    log_llh = (jnp.log(D) + log_tau_ij_inv - 0.5 * tau_ij_inv * (D - d)**2 +
               jnp.log(i0e(tau_ij_inv * D * d)))

    # index of points in prior
    log_mu = multivariate_normal.logpdf(mu, mean=mu0, cov=beta * jnp.eye(2))
    log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate)

    return jnp.sum(log_llh) + jnp.sum(log_mu) + jnp.sum(log_tau)
Exemple #7
0
def ll(sigma, theta, y):
    p = y.shape[-1]
    sc = jnp.sqrt(jnp.diag(sigma))
    al = jnp.einsum('i,i->i', 1 / sc, theta)
    capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T)))
    small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma))
    return -(2 + small_phi + capital_phi)
Exemple #8
0
def log_normal_batch(W, sigma, Vs):
    mu = W.ravel()
    Vs_ = Vs.reshape((-1, Vs.shape[-1])).T
    # Vs_ is of shape M, N
    assert mu.shape[0] == Vs_.shape[1]
    d = mu.shape[0]
    cov = sigma**2 * jnp.eye(d)
    return logpdf(Vs_, mu, cov)
Exemple #9
0
def update(observation_function: Callable, observation_covariance: jnp.ndarray,
           predicted_state: MVNormalParameters, observation: jnp.ndarray,
           linearization_state: MVNormalParameters) -> MVNormalParameters:
    """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t`

    Parameters
    ----------
    observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t`
        observation function of the state space model
    observation_covariance: (K,K) array
        observation_error :math:`\Sigma` fed to observation_function
    predicted_state: MVNormalParameters
        predicted approximate mv normal parameters of the filter :math:`x`
    observation: (K) array
        Observation :math:`y`
    linearization_state: MVNormalParameters
        state for the linearization of the update

    Returns
    -------
    updated_mvn_parameters: MVNormalParameters
        filtered state
    """
    if linearization_state is None:
        linearization_state = predicted_state
    sigma_points = get_sigma_points(linearization_state)
    obs_points = observation_function(sigma_points.points)
    obs_sigma_points = SigmaPoints(obs_points, sigma_points.wm,
                                   sigma_points.wc)

    obs_state = get_mv_normal_parameters(obs_sigma_points)
    cross_covariance = covariance_sigma_points(sigma_points,
                                               linearization_state.mean,
                                               obs_sigma_points,
                                               obs_state.mean)

    H = jlinalg.solve(linearization_state.cov, cross_covariance,
                      sym_pos=True).T  # linearized observation function

    d = obs_state.mean - jnp.dot(
        H, linearization_state.mean)  # linearized observation offset

    residual_cov = H @ (predicted_state.cov - linearization_state.cov) @ H.T + \
                   observation_covariance + obs_state.cov

    gain = jlinalg.solve(residual_cov, H @ predicted_state.cov).T

    predicted_observation = H @ predicted_state.mean + d

    residual = observation - predicted_observation
    mean = predicted_state.mean + gain @ residual
    cov = predicted_state.cov - gain @ residual_cov @ gain.T
    loglikelihood = multivariate_normal.logpdf(residual,
                                               jnp.zeros_like(residual),
                                               residual_cov)

    return loglikelihood, MVNormalParameters(mean, 0.5 * (cov + cov.T))
def ll_chol(pars, y):
    p = y.shape[-1]
    X, theta = pars[:-p], pars[-p:]
    sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T
    sigma = jnp.matmul(sigma, sigma.T)
    sc = jnp.sqrt(jnp.diag(sigma))
    al = jnp.einsum('i,i->i', 1 / sc, theta)
    capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T)))
    small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma))
    return -(2 + small_phi + capital_phi)
Exemple #11
0
def log_normal_gamma_prior(mu,
                           tau,
                           mu0=0.0,
                           beta=1.0,
                           gamma_shape=1.0,
                           gamma_rate=1.0):
    log_mu = multivariate_normal.logpdf(mu, mean=0.0,
                                        cov=beta).sum()  # sum of 2 dimensions
    log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate)
    # print("[DEBUG] Log prior: ", log_mu.shape, log_tau.shape)
    return log_mu + log_tau
Exemple #12
0
def update(
        observation_function: Callable[[jnp.ndarray], jnp.ndarray],
        observation_covariance: jnp.ndarray, predicted: MVNormalParameters,
        observation: jnp.ndarray,
        linearization_point: jnp.ndarray) -> Tuple[float, MVNormalParameters]:
    """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t`

    Parameters
    ----------
    observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t`
        observation function of the state space model
    observation_covariance: (K,K) array
        observation_error :math:`\Sigma` fed to observation_function
    predicted: MVNormalParameters
        predicted state of the filter :math:`x`
    observation: (K) array
        Observation :math:`y`
    linearization_point: jnp.ndarray
        Where to compute the Jacobian

    Returns
    -------
    loglikelihood: float
        Log-likelihood increment for observation
    updated_state: MVNormalParameters
        filtered state
    """
    if linearization_point is None:
        linearization_point = predicted.mean
    jac_x = jacfwd(observation_function, 0)(linearization_point)

    obs_mean = observation_function(linearization_point) + jnp.dot(
        jac_x, predicted.mean - linearization_point)

    residual = observation - obs_mean
    residual_covariance = jnp.dot(jac_x, jnp.dot(predicted.cov, jac_x.T))
    residual_covariance = residual_covariance + observation_covariance

    gain = jnp.dot(predicted.cov,
                   jlag.solve(residual_covariance, jac_x, sym_pos=True).T)

    mean = predicted.mean + jnp.dot(gain, residual)
    cov = predicted.cov - jnp.dot(gain, jnp.dot(residual_covariance, gain.T))
    updated_state = MVNormalParameters(mean, 0.5 * (cov + cov.T))

    loglikelihood = multivariate_normal.logpdf(residual,
                                               jnp.zeros_like(residual),
                                               residual_covariance)
    return loglikelihood, updated_state
Exemple #13
0
    def test_n1_d1(self):
        x = jnp.array([7.])
        npt.assert_array_equal(utils.gaussian_potential(x),
                               -multivariate_normal.logpdf(x, 0., 1.))

        m = jnp.array([1.])
        npt.assert_array_equal(utils.gaussian_potential(x, m),
                               -multivariate_normal.logpdf(x, m, 1.))

        prec = jnp.array([[2.]])
        # test diag
        npt.assert_array_equal(utils.gaussian_potential(x, prec=prec[0]),
                               -multivariate_normal.logpdf(x, 0, 1 / prec))
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, m, prec[0]),
            -multivariate_normal.logpdf(x, m, 1 / prec),
            decimal=4)
        # test full (omits norm constant)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            npt.assert_array_equal(utils.gaussian_potential(x, prec=prec),
                                   0.5 * x**2 * prec)
            npt.assert_array_equal(utils.gaussian_potential(x, m, prec),
                                   0.5 * (x - m)**2 * prec)
Exemple #14
0
    def log_prob(self, inputs: np.ndarray) -> np.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        inputs : np.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        np.ndarray
            Device array of shape (inputs.shape[0],).
        """
        return multivariate_normal.logpdf(
            x=inputs,
            mean=np.zeros(self.input_dim),
            cov=np.identity(self.input_dim),
        )
Exemple #15
0
    def test_normal(self):
        inputs = random.uniform(random.PRNGKey(0), (20, 2),
                                minval=-3.0,
                                maxval=3.0)
        input_dim = inputs.shape[1]
        init_key, sample_key = random.split(random.PRNGKey(0))

        init_fun = flows.Normal()
        params, log_pdf, sample = init_fun(init_key, input_dim)
        log_pdfs = log_pdf(params, inputs)

        mean = np.zeros(input_dim)
        covariance = np.eye(input_dim)
        true_log_pdfs = multivariate_normal.logpdf(inputs, mean, covariance)

        self.assertTrue(np.allclose(log_pdfs, true_log_pdfs))

        for test in (returns_correct_shape, ):
            test(self, flows.Normal())
Exemple #16
0
def log_Euler_Maruyma_kernel(x_tm1, x_t, potential, potential_parameters, dt):
    """
    the log kernel probability of the transition
    arguments
        x_tm1 : jnp.array(N)
            previous iteration position (of dimension N)
        x_t : jnp.array(N)
            current iteration positions (of dimension N)
        potential : function
            potential function
        potential_parameters : jnp.array(Q)
            second argument to potential function
        dt : float
            time increment

    returns
        logp : float
            the log probability of the kernel
    """
    from jax.scipy.stats import multivariate_normal as mvn
    mu, cov = EL_mu_sigma(x_tm1, potential, dt, potential_parameters)
    logp = mvn.logpdf(x_t, mu, cov)
    return logp
Exemple #17
0
def log_prior_mu(mu, mu0, sigma0):
    return multivariate_normal.logpdf(mu, mean=mu0, cov=sigma0).sum()
Exemple #18
0
def calculate_prior(x, mu, cov_mat, theta):

    return multivariate_normal.logpdf(x, mu, cov_mat)
Exemple #19
0
def _log_prior_ss(ss):
    return -multivariate_normal.logpdf(ss, mean=zeros2, cov=1.0)
Exemple #20
0
def log_normal(W, sigma, V):
    mu = W.ravel()
    V_ = V.ravel()
    d = mu.shape[0]
    cov = sigma**2 * jnp.eye(d)
    return logpdf(V_, mu, cov)
 def log_pdf(params, inputs):
     return multivariate_normal.logpdf(inputs, mean, covariance)
Exemple #22
0
    def test_n1_d5(self):
        x = jnp.ones(5)
        npt.assert_array_equal(
            utils.gaussian_potential(x),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x), 1.))

        m = 3.
        npt.assert_array_equal(
            utils.gaussian_potential(x, m),
            -multivariate_normal.logpdf(x, m * jnp.ones_like(x), 1.))

        m = jnp.ones(5) * 3.
        npt.assert_array_equal(utils.gaussian_potential(x, m),
                               -multivariate_normal.logpdf(x, m, 1.))

        # diagonal precision
        prec = jnp.eye(5) * 2
        # test diag
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, prec=jnp.diag(prec)),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x),
                                        jnp.linalg.inv(prec)),
            decimal=5)
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, m, prec=jnp.diag(prec), det_prec=2**5),
            -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)),
            decimal=5)
        # test full (omits norm constant)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            npt.assert_array_equal(utils.gaussian_potential(x, prec=prec),
                                   0.5 * x.T @ prec @ x)
            npt.assert_array_equal(utils.gaussian_potential(x, m, prec),
                                   0.5 * (x - m).T @ prec @ (x - m))
        # test full with det
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, prec=prec, det_prec=2**5),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x),
                                        jnp.linalg.inv(prec)),
            decimal=5)
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x, m, prec=prec, det_prec=2**5),
            -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)),
            decimal=5)

        # non-diagonal precision
        sqrt_prec = jnp.arange(25).reshape(5, 5) / 100 + jnp.eye(5)
        prec = sqrt_prec @ sqrt_prec.T
        # test full (omits norm constant)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            npt.assert_array_equal(utils.gaussian_potential(x, prec=prec),
                                   0.5 * x.T @ prec @ x)
            npt.assert_array_equal(utils.gaussian_potential(x, m, prec),
                                   0.5 * (x - m).T @ prec @ (x - m))
        # test full with det
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x,
                                     prec=prec,
                                     det_prec=jnp.linalg.det(prec)),
            -multivariate_normal.logpdf(x, jnp.zeros_like(x),
                                        jnp.linalg.inv(prec)),
            decimal=5)
        npt.assert_array_almost_equal(
            utils.gaussian_potential(x,
                                     m,
                                     prec=prec,
                                     det_prec=jnp.linalg.det(prec)),
            -multivariate_normal.logpdf(x, m, jnp.linalg.inv(prec)),
            decimal=5)
Exemple #23
0
def _log_prior_mu(mu):
    return -multivariate_normal.logpdf(mu, mean=zeros2, cov=1.0)