Ejemplo n.º 1
0
def test_conditional_mean_fwd():
    a, U, V, P, Y, U_star, V_star, inds = get_matrices(vector=True,
                                                       conditional=True)
    d, W = driver.factor(U, P, a, np.copy(V))
    z = driver.solve(U, P, d, W, Y)
    check_basic(
        driver.conditional_mean,
        ops.conditional_mean,
        [U, V, P, z, U_star, V_star, inds],
    )
Ejemplo n.º 2
0
def test_solve(vector):
    a, U, V, P, K, Y = get_matrices(vector=vector, include_dense=True)

    # First compute the expected value
    expect = np.linalg.solve(K, Y)

    # Then solve using celerite
    d, W = driver.factor(U, P, a, V)
    value = driver.solve(U, P, d, W, Y)

    # Make sure that no copy is made if possible
    assert np.allclose(value, Y)

    # Check that the solution is correct
    assert np.allclose(value, expect)
Ejemplo n.º 3
0
def test_conditional_mean():
    a, U, V, P, Y, U_star, V_star, inds = get_matrices(vector=True,
                                                       conditional=True)
    d, W = driver.factor(U, P, a, np.copy(V))
    z = driver.solve(U, P, d, W, Y)

    mu = driver.conditional_mean(U, V, P, z, U_star, V_star, inds,
                                 np.empty(len(inds), dtype=np.float64))

    check_op(
        ops.conditional_mean,
        [U, V, P, z, U_star, V_star, inds],
        [mu],
        grad=False,
    )
Ejemplo n.º 4
0
def test_solve_fwd(vector):
    a, U, V, P, Y = get_matrices(vector=vector)
    d, W = driver.factor(U, P, a, V)

    X0 = driver.solve(U, P, d, W, np.copy(Y))

    X = np.empty_like(Y)
    Z = np.empty_like(Y)
    if vector:
        F = np.empty_like(U)
    else:
        F = np.empty((U.shape[0], U.shape[1] * Y.shape[1]))
    G = np.empty_like(F)

    X, Z, F, G = backprop.solve_fwd(U, P, d, W, Y, X, Z, F, G)
    assert np.allclose(X0, X)
Ejemplo n.º 5
0
def test_solve(vector):
    a, U, V, P, Y = get_matrices(vector=vector)
    d, W = driver.factor(U, P, a, V)
    X = driver.solve(U, P, d, W, np.copy(Y))
    check_op(ops.solve, [U, P, d, W, Y], [X])
    check_op(jit(ops.solve), [U, P, d, W, Y], [X])