def test_ZIP_log_prob(rate): # if gate is 0 ZIP is Poisson zip_ = dist.ZeroInflatedPoisson(0., rate) pois = dist.Poisson(rate) s = zip_.sample(random.PRNGKey(0), (20,)) zip_prob = zip_.log_prob(s) pois_prob = pois.log_prob(s) assert_allclose(zip_prob, pois_prob) # if gate is 1 ZIP is Delta(0) zip_ = dist.ZeroInflatedPoisson(1., rate) delta = dist.Delta(0.) s = np.array([0., 1.]) zip_prob = zip_.log_prob(s) delta_prob = delta.log_prob(s) assert_allclose(zip_prob, delta_prob)
def model(X, Y): D_X = X.shape[1] b1 = numpyro.sample("b1", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1)) b2 = numpyro.sample("b2", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1)) q = jsp.special.expit(jnp.dot(X, b1[:, None])).reshape(-1) lam = jnp.exp(jnp.dot(X, b2[:, None]).reshape(-1)) with numpyro.plate("obs", X.shape[0]): numpyro.sample("Y", dist.ZeroInflatedPoisson(gate=q, rate=lam), obs=Y)