def test_fixedpoint_vmap(): def elem(y): def fixed_point(x): return np.concatenate([np.array([1.]), 2 * lax.slice(x + y, [0], [3])]) return lazy_eval_fixed_point(fixed_point, np.zeros(4))[:] ys = np.array([1, 2, 3, 4]) expected = np.array([elem(y) for y in ys]) actual = vmap(elem)(ys) jtu.check_close(expected, actual)
def test_odeint_linearize_fwrap(): def odeint_fwrap(y0, ts, fargs): return odeint(y0, ts, func=f, fargs=fargs) _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs), (y0, ts, fargs)) # when break this is why y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs)) out_tangent_2 = f_jvp(*(y0, ts, fargs)) # print(make_jaxpr(f_jvp)((y0,t0,t1,fargs),)) check_close(out_tangent, out_tangent_2)
def test_jitted_update_fn(): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optim.Adam(0.05) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) expected = svi.get_params(svi.update(svi_state, data)[0]) actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0]) check_close(actual, expected, atol=1e-5)
def test_mcmc_progbar(): true_mean, true_std = 1., 2. num_warmup, num_samples = 10, 10 def model(data): mean = numpyro.param('mean', 0.) std = numpyro.param('std', 1., constraint=constraints.positive) return numpyro.sample('obs', dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.warmup(random.PRNGKey(2), data) mcmc.run(random.PRNGKey(3), data) mcmc1 = MCMC(kernel, num_warmup, num_samples, progress_bar=False) mcmc1.run(random.PRNGKey(2), data) with pytest.raises(AssertionError): check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) mcmc1.warmup(random.PRNGKey(2), data) mcmc1.run(random.PRNGKey(3), data) check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) check_close(mcmc1._warmup_state, mcmc._warmup_state, atol=1e-4, rtol=1e-4)
def test_mcmc_progbar(): true_mean, true_std = 1.0, 2.0 num_warmup, num_samples = 10, 10 def model(data): mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False)) std = numpyro.sample("std", dist.LogNormal(0, 1).mask(False)) return numpyro.sample("obs", dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data) mcmc.run(random.PRNGKey(3), data) mcmc1 = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc1.run(random.PRNGKey(2), data) with pytest.raises(AssertionError): check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) mcmc1.warmup(random.PRNGKey(2), data) mcmc1.run(random.PRNGKey(3), data) check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) check_close(mcmc1.post_warmup_state, mcmc.post_warmup_state, atol=1e-4, rtol=1e-4)
def test_richardson_solve(self): """Check that richardson_solve produces correct outputs and derivatives.""" # Ensure we converge to the fixed point matrix = jax.random.normal(jax.random.PRNGKey(0), (50, 50)) matrix = jnp.eye(50) - 0.9 * matrix / jnp.sum( jnp.abs(matrix), axis=0, keepdims=True) b = jax.random.normal(jax.random.PRNGKey(1), (50, )) def iter_solve(matrix, b): return linear_solvers.richardson_solve(lambda x: matrix @ x, b, iterations=100) # Correct output jtu.check_close(iter_solve(matrix, b), jax.scipy.linalg.solve(matrix, b), rtol=1e-4) # Correct jvp dmatrix = jax.random.normal(jax.random.PRNGKey(2), (50, 50)) db = jax.random.normal(jax.random.PRNGKey(3), (50, )) jtu.check_close(jax.jvp(iter_solve, (matrix, b), (dmatrix, db)), jax.jvp(jax.scipy.linalg.solve, (matrix, b), (dmatrix, db)), rtol=1e-4) # Correct vjp co_x = jax.random.normal(jax.random.PRNGKey(3), (50, )) jtu.check_close(jax.vjp(iter_solve, matrix, b)[1](co_x), jax.vjp(jax.scipy.linalg.solve, matrix, b)[1](co_x), rtol=1e-4)
def test_richardson_solve_structured(self): """Check that richardson_solve works on pytrees.""" def structured_matvec(m1, m2, m3, xs): x1, x2 = xs b1 = m1 @ x1 b2 = { "foo": m2 @ x2["foo"], "bar": m3 @ x2["bar"], } return b1, b2 def structured_direct_solve(m1, m2, m3, b1, b2): x1 = jax.scipy.linalg.solve(m1, b1) x2 = { "foo": jax.scipy.linalg.solve(m2, b2["foo"]), "bar": jax.scipy.linalg.solve(m3, b2["bar"]), } return x1, x2 def structured_iter_solve(m1, m2, m3, b1, b2): return linear_solvers.richardson_solve(functools.partial( structured_matvec, m1, m2, m3), (b1, b2), iterations=200) def mk_mat(key): matrix = jax.random.normal(jax.random.PRNGKey(key), (50, 50)) return jnp.eye(50) - 0.9 * matrix / jnp.sum( jnp.abs(matrix), axis=0, keepdims=True) m1 = mk_mat(0) m2 = mk_mat(1) m3 = mk_mat(2) b1 = jax.random.normal(jax.random.PRNGKey(3), (50, )) b2 = { "foo": jax.random.normal(jax.random.PRNGKey(4), (50, )), "bar": jax.random.normal(jax.random.PRNGKey(5), (50, )), } jtu.check_close(structured_iter_solve(m1, m2, m3, b1, b2), structured_direct_solve(m1, m2, m3, b1, b2))
def test_jit(self, f, args): jtu.check_close(jit(f)(*args), f(*args))
def odeint2(y0, ts, fargs): return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8) odeint2_prim = custom_transforms(odeint2).primitive def odeint2_jvp((y0, ts, fargs), (tan_y, tan_ts, tan_fargs)): return jvp_odeint(f, (y0, ts, fargs), (tan_y, tan_ts, tan_fargs)) ad.defjvp(odeint2_prim, odeint2_jvp) _, out_tangent = jvp(odeint2, (y0, ts, fargs), (y0, ts, fargs)) # when break this is why y, f_jvp = linearize(odeint2, *(y0, ts, fargs)) out_tangent_2 = f_jvp(*(y0, ts, fargs)) # print(make_jaxpr(f_jvp)(y0,t0,t1,fargs)) check_close(out_tangent, out_tangent_2) def test_odeint_linearize_fwrap(): def odeint_fwrap(y0, ts, fargs): return odeint(y0, ts, func=f, fargs=fargs) _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs), (y0, ts, fargs)) # when break this is why y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs)) out_tangent_2 = f_jvp(*(y0, ts, fargs)) # print(make_jaxpr(f_jvp)((y0,t0,t1,fargs),)) check_close(out_tangent, out_tangent_2)
def test_pickle_hmcecs(): mcmc = MCMC(HMCECS(NUTS(logistic_regression)), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())
def test_pickle_discrete_hmc(kernel): mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())
def test_pickle_hmc(kernel): mcmc = MCMC(kernel(normal_model), 10, 10) mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())