示例#1
0
def test_log_prob_LKJCholesky_uniform(dimension):
    # When concentration=1, the distribution of correlation matrices is uniform.
    # We will test that fact here.
    d = dist.LKJCholesky(dimension=dimension, concentration=1)
    N = 5
    corr_log_prob = []
    for i in range(N):
        sample = d.sample(random.PRNGKey(i))
        log_prob = d.log_prob(sample)
        sample_tril = matrix_to_tril_vec(sample, diagonal=-1)
        cholesky_to_corr_jac = onp.linalg.slogdet(
            jax.jacobian(_tril_cholesky_to_tril_corr)(sample_tril))[1]
        corr_log_prob.append(log_prob - cholesky_to_corr_jac)

    corr_log_prob = np.array(corr_log_prob)
    # test if they are constant
    assert_allclose(corr_log_prob,
                    np.broadcast_to(corr_log_prob[0], corr_log_prob.shape),
                    rtol=1e-6)

    if dimension == 2:
        # when concentration = 1, LKJ gives a uniform distribution over correlation matrix,
        # hence for the case dimension = 2,
        # density of a correlation matrix will be Uniform(-1, 1) = 0.5.
        # In addition, jacobian of the transformation from cholesky -> corr is 1 (hence its
        # log value is 0) because the off-diagonal lower triangular element does not change
        # in the transform.
        # So target_log_prob = log(0.5)
        assert_allclose(corr_log_prob[0], np.log(0.5), rtol=1e-6)
示例#2
0
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = np.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * np.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(
        constraints.corr_cholesky).log_abs_det_jacobian(
            unconstrained,
            sample,
        )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
示例#3
0
def twoh_c_kf(T=None, T_forecast=15, obs=None):
    """Define Kalman Filter with two hidden variates."""
    T = len(obs) if T is None else T
    
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    #W = numpyro.sample(name="W", fn=dist.Normal(loc=jnp.zeros((2,4)), scale=jnp.ones((2,4))))
    beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.array([0.,0.]), scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=.1))
    z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.))
    Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix
    noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast,))
    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T+T_forecast)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    c = numpyro.sample(name="c", fn=dist.Normal(loc=jnp.array([[0.], [0.]]), scale=jnp.ones((2,1))))
    obs_mean = jnp.dot(z_collection[:T,:], c).squeeze()
    pred_mean = jnp.dot(z_collection[T:,:], c).squeeze()

    # Sample the observed y (y_obs)
    numpyro.sample(name="y_obs", fn=dist.Normal(loc=obs_mean, scale=sigma), obs=obs)
    numpyro.sample(name="y_pred", fn=dist.Normal(loc=pred_mean, scale=sigma), obs=None)
示例#4
0
def multivariate_kf(T=None, T_forecast=15, obs=None):
    """Define Kalman Filter in a multivariate fashion.

    The "time-series" are correlated. To define these relationships in
    a efficient manner, the covarianze matrix of h_t (or, equivalently, the
    noises) is drown from a Cholesky decomposed matrix.

    Parameters
    ----------
    T:  int
    T_forecast: int
    obs: np.array
       observed variable (infected, deaths...)

    """
    T = len(obs) if T is None else T
    beta = numpyro.sample(
        name="beta", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))
    )
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(
        name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))
    )
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)), L_Omega
    )  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast - 1,),
    )

    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    # Sample the observed y (y_obs) and missing y (y_mis)
    numpyro.sample(
        name="y_obs1",
        fn=dist.Normal(loc=z_collection[:T, 0], scale=sigma),
        obs=obs[:, 0],
    )
    numpyro.sample(
        name="y_pred1", fn=dist.Normal(loc=z_collection[T:, 0], scale=sigma), obs=None
    )
    numpyro.sample(
        name="y_obs2",
        fn=dist.Normal(loc=z_collection[:T, 1], scale=sigma),
        obs=obs[:, 1],
    )
    numpyro.sample(
        name="y_pred2", fn=dist.Normal(loc=z_collection[T:, 1], scale=sigma), obs=None
    )
示例#5
0
def glmm(dept, male, applications, admit):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, np.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(np.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2))
    scale_tril = sigma[..., np.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(onp.unique(dept))
    z = numpyro.sample('z', dist.Normal(np.zeros((num_dept, 2)), 1))
    v = np.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
示例#6
0
def multih_kf(T=None, T_forecast=15, hidden=4, obs=None):
    """Define Kalman Filter: multiple hidden variables; just one time series.

    Parameters
    ----------
    T:  int
    T_forecast: int
        Times to forecast ahead.
    hidden: int
        number of variables in the latent space
    obs: np.array
        observed variable (infected, deaths...)

    """
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    T = len(obs) if T is None else T
    beta = numpyro.sample(
        name="beta", fn=dist.Normal(loc=jnp.zeros(hidden), scale=jnp.ones(hidden))
    )
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(
        name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))
    )
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)), L_Omega
    )  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast - 1,),
    )

    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    # Sample the observed y (y_obs) and missing y (y_mis)
    numpyro.sample(
        name="y_obs",
        fn=dist.Normal(loc=z_collection[:T, :].sum(axis=1), scale=sigma),
        obs=obs[:, 0],
    )
    numpyro.sample(
        name="y_pred", fn=dist.Normal(loc=z_collection[T:, :].sum(axis=1), scale=sigma), obs=None
    )
示例#7
0
def glmm(dept, male, applications, admit):
    v_mu = sample('v_mu', dist.Normal(0, np.array([4., 1.])))

    sigma = sample('sigma', dist.HalfNormal(np.ones(2)))
    L_Rho = sample('L_Rho', dist.LKJCholesky(2))
    scale_tril = np.expand_dims(sigma, axis=-1) * L_Rho
    # non-centered parameterization
    num_dept = len(onp.unique(dept))
    z = sample('z', dist.Normal(np.zeros((num_dept, 2)), 1))
    v = np.squeeze(np.matmul(np.expand_dims(scale_tril, axis=-3), np.expand_dims(z, axis=-1)),
                   axis=-1)

    logits = v_mu[..., :1] + v[..., dept, 0] + (v_mu[..., 1:] + v[..., dept, 1]) * male
    sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
示例#8
0
    def model(self, *views: np.ndarray):
        n = views[0].shape[0]
        p = [view.shape[1] for view in views]
        # mean of column in each view of data (p_1,)
        mu = [
            numpyro.sample("mu_" + str(i),
                           dist.MultivariateNormal(0., 10 * jnp.eye(p_)))
            for i, p_ in enumerate(p)
        ]
        """
        Generates cholesky factors of correlation matrices using an LKJ prior.

        The expected use is to combine it with a vector of variances and pass it
        to the scale_tril parameter of a multivariate distribution such as MultivariateNormal.

        E.g., if theta is a (positive) vector of covariances with the same dimensionality
        as this distribution, and Omega is sampled from this distribution,
        scale_tril=torch.mm(torch.diag(sqrt(theta)), Omega)
        """
        psi = [
            numpyro.sample("psi_" + str(i), dist.LKJCholesky(p_))
            for i, p_ in enumerate(p)
        ]
        # sample weights to get from latent to data space (k,p)
        with numpyro.plate("plate_views", self.latent_dims):
            self.weights_list = [
                numpyro.sample(
                    "W_" + str(i),
                    dist.MultivariateNormal(0., jnp.diag(jnp.ones(p_))))
                for i, p_ in enumerate(p)
            ]
        with numpyro.plate("plate_i", n):
            # sample from latent z - normally disributed (n,k)
            z = numpyro.sample(
                "z",
                dist.MultivariateNormal(0.,
                                        jnp.diag(jnp.ones(self.latent_dims))))
            # sample from multivariate normal and observe data
            [
                numpyro.sample("obs" + str(i),
                               dist.MultivariateNormal((z @ W_) + mu_,
                                                       scale_tril=psi_),
                               obs=X_)
                for i, (
                    X_, psi_, mu_,
                    W_) in enumerate(zip(views, psi, mu, self.weights_list))
            ]
示例#9
0
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(np.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
示例#10
0
def model_w_c(T, T_forecast, x, obs=None):
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    W = numpyro.sample(name="W",
                       fn=dist.Normal(loc=jnp.zeros((2, 4)),
                                      scale=jnp.ones((2, 4))))
    beta = numpyro.sample(name="beta",
                          fn=dist.Normal(loc=jnp.array([0.0, 0.0]),
                                         scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(name="z_1",
                            fn=dist.Normal(loc=jnp.zeros(2),
                                           scale=jnp.ones(2)))
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)),
        L_Omega)  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast, ),
    )
    # Propagate the dynamics forward using jax.lax.scan
    carry = (W, beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, (x, noises), T + T_forecast)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    c = numpyro.sample(name="c",
                       fn=dist.Normal(loc=jnp.array([[0.0], [0.0]]),
                                      scale=jnp.ones((2, 1))))
    obs_mean = jnp.dot(z_collection[:T, :], c).squeeze()
    pred_mean = jnp.dot(z_collection[T:, :], c).squeeze()

    # Sample the observed y (y_obs)
    numpyro.sample(name="y_obs",
                   fn=dist.Normal(loc=obs_mean, scale=sigma),
                   obs=obs)
    numpyro.sample(name="y_pred",
                   fn=dist.Normal(loc=pred_mean, scale=sigma),
                   obs=None)
示例#11
0
 def _model(self, views: Iterable[np.ndarray]):
     n = views[0].shape[0]
     p = [view.shape[1] for view in views]
     # parameter representing the mean of column in each view of data
     mu = [
         numpyro.sample(
             "mu_" + str(i), dist.MultivariateNormal(0.0, 10 * jnp.eye(p_))
         )
         for i, p_ in enumerate(p)
     ]
     # parameter representing the within view variance for each view of data
     psi = [
         numpyro.sample("psi_" + str(i), dist.LKJCholesky(p_))
         for i, p_ in enumerate(p)
     ]
     # parameter representing weights applied to latent variables
     with numpyro.plate("plate_views", self.latent_dims):
         self.weights_list = [
             numpyro.sample(
                 "W_" + str(i),
                 dist.MultivariateNormal(0.0, 10 * jnp.diag(jnp.ones(p_))),
             )
             for i, p_ in enumerate(p)
         ]
     with numpyro.plate("plate_i", n):
         # sample from latent z: the latent variables of the model
         z = numpyro.sample(
             "z", dist.MultivariateNormal(0.0, jnp.diag(jnp.ones(self.latent_dims)))
         )
         # sample from multivariate normal and observe data
         [
             numpyro.sample(
                 "obs" + str(i),
                 dist.MultivariateNormal((z @ W_) + mu_, scale_tril=psi_),
                 obs=X_,
             )
             for i, (X_, psi_, mu_, W_) in enumerate(
                 zip(views, psi, mu, self.weights_list)
             )
         ]