def test_welford_covariance(jitted, diagonal, regularize): with optional(jitted, disable_jit()), optional(jitted, control_flow_prims_disabled()): np.random.seed(0) loc = np.random.randn(3) a = np.random.randn(3, 3) target_cov = np.matmul(a, a.T) x = np.random.multivariate_normal(loc, target_cov, size=(2000, )) x = device_put(x) @jit def get_cov(x): wc_init, wc_update, wc_final = welford_covariance( diagonal=diagonal) wc_state = wc_init(3) wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val), wc_state) cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize) return cov, cov_inv_sqrt cov, cov_inv_sqrt = get_cov(x) if diagonal: diag_cov = jnp.diagonal(target_cov) assert_allclose(cov, diag_cov, rtol=0.06) assert_allclose(cov_inv_sqrt, jnp.sqrt(jnp.reciprocal(diag_cov)), rtol=0.06) else: assert_allclose(cov, target_cov, rtol=0.06) assert_allclose(cov_inv_sqrt, jnp.linalg.cholesky(jnp.linalg.inv(cov)), rtol=0.06)
def test_tscan(prims_disabled): def f(tree, yz): y, z = yz return tree_map(lambda x: (x + y) * z, tree) Tree = laxtuple("Tree", ["x", "y", "z"]) a = Tree(np.array([1., 2.]), np.array(3., dtype=np.float32), np.array(4., dtype=np.float32)) bs = (np.array([1., 2., 3., 4.]), np.array([4., 3., 2., 1.])) expected_tree = lax.scan(f, a, bs) with optional(prims_disabled, control_flow_prims_disabled()): actual_tree = tscan(f, a, bs, fields=(0, 2)) assert_allclose(actual_tree.x, expected_tree.x) assert_allclose(actual_tree.z, expected_tree.z) assert actual_tree.y is None
def model(data): x = numpyro.sample('x', dist.Normal(0, 1)) with optional(use_context_manager, handlers.scale(scale=10)): numpyro.sample('obs', dist.Normal(x, 1), obs=data)