def test_hankel_01_complex(ctx_factory, ref_src): ctx = ctx_factory() queue = cl.CommandQueue(ctx) if not has_double_support(ctx.devices[0]): from pytest import skip skip("no double precision support--cannot test complex bessel function") n = 10**6 np.random.seed(11) z = ( np.logspace(-5, 2, n) * np.exp(1j * 2 * np.pi * np.random.rand(n))) def get_err(check, ref): return np.max(np.abs(check-ref)) / np.max(np.abs(ref)) if ref_src == "pyfmmlib": pyfmmlib = pytest.importorskip("pyfmmlib") h0_ref, h1_ref = pyfmmlib.hank103_vec(z, ifexpon=1) elif ref_src == "scipy": spec = pytest.importorskip("scipy.special") h0_ref = spec.hankel1(0, z) h1_ref = spec.hankel1(1, z) else: raise ValueError("ref_src") z_dev = cl_array.to_device(queue, z) h0_dev, h1_dev = clmath.hankel_01(z_dev) rel_err_h0 = np.abs(h0_dev.get() - h0_ref)/np.abs(h0_ref) rel_err_h1 = np.abs(h1_dev.get() - h1_ref)/np.abs(h1_ref) max_rel_err_h0 = np.max(rel_err_h0) max_rel_err_h1 = np.max(rel_err_h1) print("H0", max_rel_err_h0) print("H1", max_rel_err_h1) assert max_rel_err_h0 < 4e-13 assert max_rel_err_h1 < 2e-13 if 0: import matplotlib.pyplot as pt pt.loglog(np.abs(z), rel_err_h0) pt.loglog(np.abs(z), rel_err_h1) pt.show()
def test_hankel_01_complex(ctx_factory, ref_src): ctx = ctx_factory() queue = cl.CommandQueue(ctx) if not has_double_support(ctx.devices[0]): from pytest import skip skip( "no double precision support--cannot test complex bessel function") n = 10**6 np.random.seed(11) z = (np.logspace(-5, 2, n) * np.exp(1j * 2 * np.pi * np.random.rand(n))) def get_err(check, ref): return np.max(np.abs(check - ref)) / np.max(np.abs(ref)) if ref_src == "pyfmmlib": pyfmmlib = pytest.importorskip("pyfmmlib") h0_ref, h1_ref = pyfmmlib.hank103_vec(z, ifexpon=1) elif ref_src == "scipy": spec = pytest.importorskip("scipy.special") h0_ref = spec.hankel1(0, z) h1_ref = spec.hankel1(1, z) else: raise ValueError("ref_src") z_dev = cl_array.to_device(queue, z) h0_dev, h1_dev = clmath.hankel_01(z_dev) rel_err_h0 = np.abs(h0_dev.get() - h0_ref) / np.abs(h0_ref) rel_err_h1 = np.abs(h1_dev.get() - h1_ref) / np.abs(h1_ref) max_rel_err_h0 = np.max(rel_err_h0) max_rel_err_h1 = np.max(rel_err_h1) print("H0", max_rel_err_h0) print("H1", max_rel_err_h1) assert max_rel_err_h0 < 4e-13 assert max_rel_err_h1 < 2e-13 if 0: import matplotlib.pyplot as pt pt.loglog(np.abs(z), rel_err_h0) pt.loglog(np.abs(z), rel_err_h1) pt.show()