コード例 #1
0
ファイル: test_clmath.py プロジェクト: ssouyris/pyopencl
def test_bessel(ctx_factory):
    try:
        import scipy.special as spec
    except ImportError:
        from pytest import skip
        skip("scipy not present--cannot test Bessel function")

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    if not has_double_support(ctx.devices[0]):
        from pytest import skip
        skip("no double precision support--cannot test bessel function")

    nterms = 30

    try:
        from pyfmmlib import jfuns2d, hank103_vec
    except ImportError:
        use_pyfmmlib = False
    else:
        use_pyfmmlib = True

    print("PYFMMLIB", use_pyfmmlib)

    if use_pyfmmlib:
        a = np.logspace(-3, 3, 10**6)
    else:
        a = np.logspace(-5, 5, 10**6)

    for which_func, cl_func, scipy_func, is_rel in [
        ("j", clmath.bessel_jn, spec.jn, False),
        ("y", clmath.bessel_yn, spec.yn, True)
    ]:
        if is_rel:

            def get_err(check, ref):
                return np.max(np.abs(check - ref)) / np.max(np.abs(ref))
        else:

            def get_err(check, ref):
                return np.max(np.abs(check - ref))

        if use_pyfmmlib:
            pfymm_result = np.empty((len(a), nterms), dtype=np.complex128)
            if which_func == "j":
                for i, a_i in enumerate(a):
                    if i % 100000 == 0:
                        print("%.1f %%" % (100 * i / len(a)))
                    ier, fjs, _, _ = jfuns2d(nterms, a_i, 1, 0, 10000)
                    pfymm_result[i] = fjs[:nterms]
                assert ier == 0
            elif which_func == "y":
                h0, h1 = hank103_vec(a, ifexpon=1)
                pfymm_result[:, 0] = h0.imag
                pfymm_result[:, 1] = h1.imag

        a_dev = cl_array.to_device(queue, a)

        for n in range(0, nterms):
            cl_bessel = cl_func(n, a_dev).get()
            scipy_bessel = scipy_func(n, a)

            error_scipy = get_err(cl_bessel, scipy_bessel)
            assert error_scipy < 1e-10, error_scipy

            if use_pyfmmlib and (which_func == "j" or
                                 (which_func == "y" and n in [0, 1])):
                pyfmm_bessel = pfymm_result[:, n]
                error_pyfmm = get_err(cl_bessel, pyfmm_bessel)
                assert error_pyfmm < 1e-10, error_pyfmm
                error_pyfmm_scipy = get_err(scipy_bessel, pyfmm_bessel)
                print(which_func, n, error_scipy, error_pyfmm,
                      error_pyfmm_scipy)
            else:
                print(which_func, n, error_scipy)

            assert not np.isnan(cl_bessel).any()

            if 0 and n == 15:
                import matplotlib.pyplot as pt
                #pt.plot(scipy_bessel)
                #pt.plot(cl_bessel)

                pt.loglog(a,
                          np.abs(cl_bessel - scipy_bessel),
                          label="vs scipy")
                if use_pyfmmlib:
                    pt.loglog(a,
                              np.abs(cl_bessel - pyfmm_bessel),
                              label="vs pyfmmlib")
                pt.legend()
                pt.show()
コード例 #2
0
ファイル: test_clmath.py プロジェクト: romanstingler/pyopencl
def test_bessel(ctx_factory):
    try:
        import scipy.special as spec
    except ImportError:
        from pytest import skip

        skip("scipy not present--cannot test Bessel function")

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    if not has_double_support(ctx.devices[0]):
        from pytest import skip

        skip("no double precision support--cannot test bessel function")

    nterms = 30

    try:
        from pyfmmlib import jfuns2d, hank103_vec
    except ImportError:
        use_pyfmmlib = False
    else:
        use_pyfmmlib = True

    print("PYFMMLIB", use_pyfmmlib)

    if use_pyfmmlib:
        a = np.logspace(-3, 3, 10 ** 6)
    else:
        a = np.logspace(-5, 5, 10 ** 6)

    for which_func, cl_func, scipy_func, is_rel in [
        ("j", clmath.bessel_jn, spec.jn, False),
        ("y", clmath.bessel_yn, spec.yn, True),
    ]:
        if is_rel:

            def get_err(check, ref):
                return np.max(np.abs(check - ref)) / np.max(np.abs(ref))

        else:

            def get_err(check, ref):
                return np.max(np.abs(check - ref))

        if use_pyfmmlib:
            pfymm_result = np.empty((len(a), nterms), dtype=np.complex128)
            if which_func == "j":
                for i, a_i in enumerate(a):
                    if i % 100000 == 0:
                        print("%.1f %%" % (100 * i / len(a)))
                    ier, fjs, _, _ = jfuns2d(nterms, a_i, 1, 0, 10000)
                    pfymm_result[i] = fjs[:nterms]
                assert ier == 0
            elif which_func == "y":
                h0, h1 = hank103_vec(a, ifexpon=1)
                pfymm_result[:, 0] = h0.imag
                pfymm_result[:, 1] = h1.imag

        a_dev = cl_array.to_device(queue, a)

        for n in range(0, nterms):
            cl_bessel = cl_func(n, a_dev).get()
            scipy_bessel = scipy_func(n, a)

            error_scipy = get_err(cl_bessel, scipy_bessel)
            assert error_scipy < 1e-10, error_scipy

            if use_pyfmmlib and (which_func == "j" or (which_func == "y" and n in [0, 1])):
                pyfmm_bessel = pfymm_result[:, n]
                error_pyfmm = get_err(cl_bessel, pyfmm_bessel)
                assert error_pyfmm < 1e-10, error_pyfmm
                error_pyfmm_scipy = get_err(scipy_bessel, pyfmm_bessel)
                print(which_func, n, error_scipy, error_pyfmm, error_pyfmm_scipy)
            else:
                print(which_func, n, error_scipy)

            assert not np.isnan(cl_bessel).any()

            if 0 and n == 15:
                import matplotlib.pyplot as pt

                # pt.plot(scipy_bessel)
                # pt.plot(cl_bessel)

                pt.loglog(a, np.abs(cl_bessel - scipy_bessel), label="vs scipy")
                if use_pyfmmlib:
                    pt.loglog(a, np.abs(cl_bessel - hk_bessel), label="vs pyfmmlib")
                pt.legend()
                pt.show()