def test_log_prob_LKJCholesky_uniform(dimension): # When concentration=1, the distribution of correlation matrices is uniform. # We will test that fact here. d = dist.LKJCholesky(dimension=dimension, concentration=1) N = 5 corr_log_prob = [] for i in range(N): sample = d.sample(random.PRNGKey(i)) log_prob = d.log_prob(sample) sample_tril = matrix_to_tril_vec(sample, diagonal=-1) cholesky_to_corr_jac = onp.linalg.slogdet( jax.jacobian(_tril_cholesky_to_tril_corr)(sample_tril))[1] corr_log_prob.append(log_prob - cholesky_to_corr_jac) corr_log_prob = np.array(corr_log_prob) # test if they are constant assert_allclose(corr_log_prob, np.broadcast_to(corr_log_prob[0], corr_log_prob.shape), rtol=1e-6) if dimension == 2: # when concentration = 1, LKJ gives a uniform distribution over correlation matrix, # hence for the case dimension = 2, # density of a correlation matrix will be Uniform(-1, 1) = 0.5. # In addition, jacobian of the transformation from cholesky -> corr is 1 (hence its # log value is 0) because the off-diagonal lower triangular element does not change # in the transform. # So target_log_prob = log(0.5) assert_allclose(corr_log_prob[0], np.log(0.5), rtol=1e-6)
def test_log_prob_LKJCholesky(dimension, concentration): # We will test against the fact that LKJCorrCholesky can be seen as a # TransformedDistribution with base distribution is a distribution of partial # correlations in C-vine method (modulo an affine transform to change domain from (0, 1) # to (1, 0)) and transform is a signed stick-breaking process. d = dist.LKJCholesky(dimension, concentration, sample_method="cvine") beta_sample = d._beta.sample(random.PRNGKey(0)) beta_log_prob = np.sum(d._beta.log_prob(beta_sample)) partial_correlation = 2 * beta_sample - 1 affine_logdet = beta_sample.shape[-1] * np.log(2) sample = signed_stick_breaking_tril(partial_correlation) # compute signed stick breaking logdet inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2 # noqa: E731 inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation))) unconstrained = inv_tanh(partial_correlation) corr_cholesky_logdet = biject_to( constraints.corr_cholesky).log_abs_det_jacobian( unconstrained, sample, ) signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet actual_log_prob = d.log_prob(sample) expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5) assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
def twoh_c_kf(T=None, T_forecast=15, obs=None): """Define Kalman Filter with two hidden variates.""" T = len(obs) if T is None else T # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) #W = numpyro.sample(name="W", fn=dist.Normal(loc=jnp.zeros((2,4)), scale=jnp.ones((2,4)))) beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.array([0.,0.]), scale=jnp.ones(2))) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=.1)) z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.)) Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast,)) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T+T_forecast) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) c = numpyro.sample(name="c", fn=dist.Normal(loc=jnp.array([[0.], [0.]]), scale=jnp.ones((2,1)))) obs_mean = jnp.dot(z_collection[:T,:], c).squeeze() pred_mean = jnp.dot(z_collection[T:,:], c).squeeze() # Sample the observed y (y_obs) numpyro.sample(name="y_obs", fn=dist.Normal(loc=obs_mean, scale=sigma), obs=obs) numpyro.sample(name="y_pred", fn=dist.Normal(loc=pred_mean, scale=sigma), obs=None)
def multivariate_kf(T=None, T_forecast=15, obs=None): """Define Kalman Filter in a multivariate fashion. The "time-series" are correlated. To define these relationships in a efficient manner, the covarianze matrix of h_t (or, equivalently, the noises) is drown from a Cholesky decomposed matrix. Parameters ---------- T: int T_forecast: int obs: np.array observed variable (infected, deaths...) """ T = len(obs) if T is None else T beta = numpyro.sample( name="beta", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)) ) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample( name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)) ) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0)) Sigma_lower = jnp.matmul( jnp.diag(jnp.sqrt(tau)), L_Omega ) # lower cholesky factor of the covariance matrix noises = numpyro.sample( "noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T + T_forecast - 1,), ) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) # Sample the observed y (y_obs) and missing y (y_mis) numpyro.sample( name="y_obs1", fn=dist.Normal(loc=z_collection[:T, 0], scale=sigma), obs=obs[:, 0], ) numpyro.sample( name="y_pred1", fn=dist.Normal(loc=z_collection[T:, 0], scale=sigma), obs=None ) numpyro.sample( name="y_obs2", fn=dist.Normal(loc=z_collection[:T, 1], scale=sigma), obs=obs[:, 1], ) numpyro.sample( name="y_pred2", fn=dist.Normal(loc=z_collection[T:, 1], scale=sigma), obs=None )
def glmm(dept, male, applications, admit): v_mu = numpyro.sample('v_mu', dist.Normal(0, np.array([4., 1.]))) sigma = numpyro.sample('sigma', dist.HalfNormal(np.ones(2))) L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2)) scale_tril = sigma[..., np.newaxis] * L_Rho # non-centered parameterization num_dept = len(onp.unique(dept)) z = numpyro.sample('z', dist.Normal(np.zeros((num_dept, 2)), 1)) v = np.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def multih_kf(T=None, T_forecast=15, hidden=4, obs=None): """Define Kalman Filter: multiple hidden variables; just one time series. Parameters ---------- T: int T_forecast: int Times to forecast ahead. hidden: int number of variables in the latent space obs: np.array observed variable (infected, deaths...) """ # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) T = len(obs) if T is None else T beta = numpyro.sample( name="beta", fn=dist.Normal(loc=jnp.zeros(hidden), scale=jnp.ones(hidden)) ) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample( name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)) ) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0)) Sigma_lower = jnp.matmul( jnp.diag(jnp.sqrt(tau)), L_Omega ) # lower cholesky factor of the covariance matrix noises = numpyro.sample( "noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T + T_forecast - 1,), ) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) # Sample the observed y (y_obs) and missing y (y_mis) numpyro.sample( name="y_obs", fn=dist.Normal(loc=z_collection[:T, :].sum(axis=1), scale=sigma), obs=obs[:, 0], ) numpyro.sample( name="y_pred", fn=dist.Normal(loc=z_collection[T:, :].sum(axis=1), scale=sigma), obs=None )
def glmm(dept, male, applications, admit): v_mu = sample('v_mu', dist.Normal(0, np.array([4., 1.]))) sigma = sample('sigma', dist.HalfNormal(np.ones(2))) L_Rho = sample('L_Rho', dist.LKJCholesky(2)) scale_tril = np.expand_dims(sigma, axis=-1) * L_Rho # non-centered parameterization num_dept = len(onp.unique(dept)) z = sample('z', dist.Normal(np.zeros((num_dept, 2)), 1)) v = np.squeeze(np.matmul(np.expand_dims(scale_tril, axis=-3), np.expand_dims(z, axis=-1)), axis=-1) logits = v_mu[..., :1] + v[..., dept, 0] + (v_mu[..., 1:] + v[..., dept, 1]) * male sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def model(self, *views: np.ndarray): n = views[0].shape[0] p = [view.shape[1] for view in views] # mean of column in each view of data (p_1,) mu = [ numpyro.sample("mu_" + str(i), dist.MultivariateNormal(0., 10 * jnp.eye(p_))) for i, p_ in enumerate(p) ] """ Generates cholesky factors of correlation matrices using an LKJ prior. The expected use is to combine it with a vector of variances and pass it to the scale_tril parameter of a multivariate distribution such as MultivariateNormal. E.g., if theta is a (positive) vector of covariances with the same dimensionality as this distribution, and Omega is sampled from this distribution, scale_tril=torch.mm(torch.diag(sqrt(theta)), Omega) """ psi = [ numpyro.sample("psi_" + str(i), dist.LKJCholesky(p_)) for i, p_ in enumerate(p) ] # sample weights to get from latent to data space (k,p) with numpyro.plate("plate_views", self.latent_dims): self.weights_list = [ numpyro.sample( "W_" + str(i), dist.MultivariateNormal(0., jnp.diag(jnp.ones(p_)))) for i, p_ in enumerate(p) ] with numpyro.plate("plate_i", n): # sample from latent z - normally disributed (n,k) z = numpyro.sample( "z", dist.MultivariateNormal(0., jnp.diag(jnp.ones(self.latent_dims)))) # sample from multivariate normal and observe data [ numpyro.sample("obs" + str(i), dist.MultivariateNormal((z @ W_) + mu_, scale_tril=psi_), obs=X_) for i, ( X_, psi_, mu_, W_) in enumerate(zip(views, psi, mu, self.weights_list)) ]
def glmm(dept, male, applications, admit=None): v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.]))) sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2))) L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2)) scale_tril = sigma[..., jnp.newaxis] * L_Rho # non-centered parameterization num_dept = len(np.unique(dept)) z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1)) v = jnp.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male if admit is None: # we use a Delta site to record probs for predictive distribution probs = expit(logits) numpyro.sample('probs', dist.Delta(probs), obs=probs) numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def model_w_c(T, T_forecast, x, obs=None): # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) W = numpyro.sample(name="W", fn=dist.Normal(loc=jnp.zeros((2, 4)), scale=jnp.ones((2, 4)))) beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.array([0.0, 0.0]), scale=jnp.ones(2))) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0)) Sigma_lower = jnp.matmul( jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix noises = numpyro.sample( "noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T + T_forecast, ), ) # Propagate the dynamics forward using jax.lax.scan carry = (W, beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, (x, noises), T + T_forecast) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) c = numpyro.sample(name="c", fn=dist.Normal(loc=jnp.array([[0.0], [0.0]]), scale=jnp.ones((2, 1)))) obs_mean = jnp.dot(z_collection[:T, :], c).squeeze() pred_mean = jnp.dot(z_collection[T:, :], c).squeeze() # Sample the observed y (y_obs) numpyro.sample(name="y_obs", fn=dist.Normal(loc=obs_mean, scale=sigma), obs=obs) numpyro.sample(name="y_pred", fn=dist.Normal(loc=pred_mean, scale=sigma), obs=None)
def _model(self, views: Iterable[np.ndarray]): n = views[0].shape[0] p = [view.shape[1] for view in views] # parameter representing the mean of column in each view of data mu = [ numpyro.sample( "mu_" + str(i), dist.MultivariateNormal(0.0, 10 * jnp.eye(p_)) ) for i, p_ in enumerate(p) ] # parameter representing the within view variance for each view of data psi = [ numpyro.sample("psi_" + str(i), dist.LKJCholesky(p_)) for i, p_ in enumerate(p) ] # parameter representing weights applied to latent variables with numpyro.plate("plate_views", self.latent_dims): self.weights_list = [ numpyro.sample( "W_" + str(i), dist.MultivariateNormal(0.0, 10 * jnp.diag(jnp.ones(p_))), ) for i, p_ in enumerate(p) ] with numpyro.plate("plate_i", n): # sample from latent z: the latent variables of the model z = numpyro.sample( "z", dist.MultivariateNormal(0.0, jnp.diag(jnp.ones(self.latent_dims))) ) # sample from multivariate normal and observe data [ numpyro.sample( "obs" + str(i), dist.MultivariateNormal((z @ W_) + mu_, scale_tril=psi_), obs=X_, ) for i, (X_, psi_, mu_, W_) in enumerate( zip(views, psi, mu, self.weights_list) ) ]