예제 #1
0
파일: tests.py 프로젝트: alewis/jax_vumps
def do_arnoldi_tests_identity(N=3, thresh=1E-5):
    """
    Runs tests for orthonormality of V and the form of H with A the 
    identity matrix.
    """
    print("Testing Arnoldi on the identity with N=", N)
    n_kry = N
    A = jnp.eye(N)
    shapes = [(N, )]
    v0, = random_tensors(shapes)
    mv = ops.matrix_matvec
    orth_pass, orth_err = test_arnoldi_orthonormality(mv, [A],
                                                      n_kry,
                                                      v0,
                                                      thresh=thresh)
    allpass = True
    if not orth_pass:
        allpass = False
        print("Orthonormality failed at N, n_kry", N, n_kry, "by ", orth_err)

    fp_pass, fp_err = test_arnoldi_fixed_point(mv, [A],
                                               n_kry,
                                               v0,
                                               thresh=thresh)
    if not fp_pass:
        allpass = False
        print("Fixed point failed at N, n_kry", N, n_kry, "by ", fp_err)
    print("Done!")
    return allpass
예제 #2
0
파일: tests.py 프로젝트: alewis/jax_vumps
def do_arnoldi_tests_random_dense_matrices(thresh=1E-5):
    """
    Runs tests for orthonormality of V and the form of H with A set to various
    random dense matrices.
    """
    print("Testing Arnoldi on dense matrices.")
    mv = ops.matrix_matvec
    allpass = True
    Ns = np.arange(5, 25, 5)
    for N in Ns:
        n_krys = np.arange(1, N - 1, 3)
        for n_kry in n_krys:
            shapes = [(N, N), (N, )]
            A, v0 = random_tensors(shapes)
            orth_pass, orth_err = test_arnoldi_orthonormality(mv, [A],
                                                              n_kry,
                                                              v0,
                                                              thresh=thresh,
                                                              verbose=True)
            if not orth_pass:
                print("Orthonormality failed at N, n_kry = ", N, n_kry, "by ",
                      orth_err)

            fp_pass, fp_err = test_arnoldi_fixed_point(mv, [A],
                                                       n_kry,
                                                       v0,
                                                       thresh=thresh)
            if not fp_pass:
                print("Fixed point failed at N=", N, "n_kry=", n_kry, "by ",
                      fp_err)
            allpass = allpass and orth_pass and fp_pass
    return allpass
예제 #3
0
파일: tests.py 프로젝트: alewis/jax_vumps
def test_matrix_matvec(m, n, thresh=1E-5, verbose=False):
    A, v = random_tensors([(m, n), (n, )])
    mv = matrix_matvec
    Av_dense = A @ v
    Av_sparse = mv(A, v)
    error = jnp.linalg.norm(jnp.abs(Av_dense - Av_sparse))
    passed = True
    if error > thresh:
        passed = False
        if verbose:
            print("Test failed by : ", error)
    return (passed, error)
예제 #4
0
def test_solve_for_RH(d, chi, dtype=np.float32, thresh=1E-4):
    print("Testing the RH linear solve: ")
    shapes = [ (d, chi, chi), (d, d, d, d), (chi, chi) ]
    A_R, H, rL = np_linalg.random_tensors(shapes, dtype)
    params = vumps.krylov_params()
    delta = thresh
    npRH = np_env.solve_for_RH(A_R, H, rL, params, delta)

    A_Rj, Hj, rLj = [jnp.array(x) for x in [A_R, H, rL]]
    jaxRH = jax_env.solve_for_RH(A_Rj, Hj, rLj, params, delta)
    err = np.linalg.norm(np.abs(jaxRH - npRH))/jaxRH.size
    if err > thresh or jnp.any(jnp.isnan(jaxRH)):
        print("FAILED with err: ", err)
    else:
        print("Passed!")
예제 #5
0
def test_RH_matvec(d, chi, dtype=np.float32, thresh=1E-6):
    print("Testing the RH matvec operation: ")
    shapes = [ (d, chi, chi), (d, d, d, d), (chi, chi), (chi, chi) ]
    A_R, H, rL, x0 = np_linalg.random_tensors(shapes, dtype)
    x0 = np.ones((chi, chi))
    op = np_env.RH_linear_operator(A_R, rL)
    vn = op.matvec(x0.flatten())

    A_Rj, Hj, rLj, x0j = [jnp.array(x) for x in [A_R, H, rL, x0]]
    vj = jax_env.RH_matvec(rLj, A_Rj, x0j.flatten())

    err = np.linalg.norm(np.abs(vn - vj))/vn.size
    if err > thresh or jnp.any(jnp.isnan(vj)):
        print("FAILED with err: ", err)
    else:
        print("Passed!")