Esempio n. 1
0
def test_welford_covariance(jitted, diagonal, regularize):
    with optional(jitted,
                  disable_jit()), optional(jitted,
                                           control_flow_prims_disabled()):
        np.random.seed(0)
        loc = np.random.randn(3)
        a = np.random.randn(3, 3)
        target_cov = np.matmul(a, a.T)
        x = np.random.multivariate_normal(loc, target_cov, size=(2000, ))
        x = device_put(x)

        @jit
        def get_cov(x):
            wc_init, wc_update, wc_final = welford_covariance(
                diagonal=diagonal)
            wc_state = wc_init(3)
            wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val),
                                 wc_state)
            cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize)
            return cov, cov_inv_sqrt

        cov, cov_inv_sqrt = get_cov(x)

        if diagonal:
            diag_cov = jnp.diagonal(target_cov)
            assert_allclose(cov, diag_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.sqrt(jnp.reciprocal(diag_cov)),
                            rtol=0.06)
        else:
            assert_allclose(cov, target_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.linalg.cholesky(jnp.linalg.inv(cov)),
                            rtol=0.06)
Esempio n. 2
0
def test_scan_prims_disabled():
    def f(tree, yz):
        y, z = yz
        return tree_map(lambda x: (x + y) * z, tree)

    Tree = laxtuple("Tree", ["x", "y", "z"])
    a = Tree(np.array([1., 2.]),
             np.array(3., dtype=np.float32),
             np.array(4., dtype=np.float32))
    bs = (np.array([1., 2., 3., 4.]),
          np.array([4., 3., 2., 1.]))

    expected_tree = lax.scan(f, a, bs)
    with control_flow_prims_disabled():
        actual_tree = scan(f, a, bs)
    assert_allclose(actual_tree.x, expected_tree.x)
    assert_allclose(actual_tree.y, expected_tree.y)
    assert_allclose(actual_tree.z, expected_tree.z)