Exemplo n.º 1
0
def test_welford_covariance(jitted, diagonal, regularize):
    with optional(jitted,
                  disable_jit()), optional(jitted,
                                           control_flow_prims_disabled()):
        np.random.seed(0)
        loc = np.random.randn(3)
        a = np.random.randn(3, 3)
        target_cov = np.matmul(a, a.T)
        x = np.random.multivariate_normal(loc, target_cov, size=(2000, ))
        x = device_put(x)

        @jit
        def get_cov(x):
            wc_init, wc_update, wc_final = welford_covariance(
                diagonal=diagonal)
            wc_state = wc_init(3)
            wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val),
                                 wc_state)
            cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize)
            return cov, cov_inv_sqrt

        cov, cov_inv_sqrt = get_cov(x)

        if diagonal:
            diag_cov = jnp.diagonal(target_cov)
            assert_allclose(cov, diag_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.sqrt(jnp.reciprocal(diag_cov)),
                            rtol=0.06)
        else:
            assert_allclose(cov, target_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.linalg.cholesky(jnp.linalg.inv(cov)),
                            rtol=0.06)
Exemplo n.º 2
0
def test_tscan(prims_disabled):
    def f(tree, yz):
        y, z = yz
        return tree_map(lambda x: (x + y) * z, tree)

    Tree = laxtuple("Tree", ["x", "y", "z"])
    a = Tree(np.array([1., 2.]),
             np.array(3., dtype=np.float32),
             np.array(4., dtype=np.float32))
    bs = (np.array([1., 2., 3., 4.]),
          np.array([4., 3., 2., 1.]))

    expected_tree = lax.scan(f, a, bs)
    with optional(prims_disabled, control_flow_prims_disabled()):
        actual_tree = tscan(f, a, bs, fields=(0, 2))
    assert_allclose(actual_tree.x, expected_tree.x)
    assert_allclose(actual_tree.z, expected_tree.z)
    assert actual_tree.y is None
Exemplo n.º 3
0
 def model(data):
     x = numpyro.sample('x', dist.Normal(0, 1))
     with optional(use_context_manager, handlers.scale(scale=10)):
         numpyro.sample('obs', dist.Normal(x, 1), obs=data)