def make_hermitian(data, fft): from pystella.fourier import gDFT if isinstance(fft, gDFT): from pystella.fourier.rayleigh import make_hermitian data = make_hermitian(data) data = fft.zero_corner_modes(data) return data
def make_data(queue, fft): kshape = fft.shape(True) data = np.random.rand(*kshape) + 1j * np.random.rand(*kshape) if isinstance(fft, gDFT): from pystella.fourier.rayleigh import make_hermitian data = make_hermitian(data).astype(np.complex128) data = fft.zero_corner_modes(data) return cla.to_device(queue, data)
def test_make_hermitian(ctx_factory, grid_shape, proc_shape, dtype): if proc_shape != (1, 1, 1): pytest.skip("test make_hermitian only on one rank") kshape = (grid_shape[0], grid_shape[1], grid_shape[2]//2 + 1) data = np.random.rand(*kshape) + 1j * np.random.rand(*kshape) from pystella.fourier.rayleigh import make_hermitian data = make_hermitian(data) assert is_hermitian(data), "data is not hermitian"
def test_pol_spectra(ctx_factory, grid_shape, proc_shape, dtype, timing=False): ctx = ctx_factory() if np.dtype(dtype).kind != "f": dtype = "float64" queue = cl.CommandQueue(ctx) h = 1 mpi = ps.DomainDecomposition(proc_shape, h, grid_shape=grid_shape) rank_shape, _ = mpi.get_rank_shape_start(grid_shape) fft = ps.DFT(mpi, ctx, queue, grid_shape, dtype) L = (10, 8, 7) dk = tuple(2 * np.pi / Li for Li in L) dx = tuple(Li / Ni for Li, Ni in zip(L, grid_shape)) cdtype = fft.cdtype spec = ps.PowerSpectra(mpi, fft, dk, np.product(L)) k_power = 2. fk = make_data(*fft.shape(True)).astype(cdtype) fk = make_hermitian(fk, fft).astype(cdtype) plus = cla.to_device(queue, fk) fk = make_data(*fft.shape(True)).astype(cdtype) fk = make_hermitian(fk, fft).astype(cdtype) minus = cla.to_device(queue, fk) plus_ps_1 = spec.bin_power(plus, queue=queue, k_power=k_power) minus_ps_1 = spec.bin_power(minus, queue=queue, k_power=k_power) project = ps.Projector(fft, h, dk, dx) vector = cla.empty(queue, (3, ) + fft.shape(True), cdtype) project.pol_to_vec(queue, plus, minus, vector) project.vec_to_pol(queue, plus, minus, vector) plus_ps_2 = spec.bin_power(plus, k_power=k_power) minus_ps_2 = spec.bin_power(minus, k_power=k_power) max_rtol = 1e-8 if dtype == np.float64 else 1e-2 avg_rtol = 1e-11 if dtype == np.float64 else 1e-4 max_err, avg_err = get_errs(plus_ps_1[1:-2], plus_ps_2[1:-2]) assert max_err < max_rtol and avg_err < avg_rtol, \ f"plus power spectrum inaccurate for {grid_shape=}: {max_err=}, {avg_err=}" max_err, avg_err = get_errs(minus_ps_1[1:-2], minus_ps_2[1:-2]) assert max_err < max_rtol and avg_err < avg_rtol, \ f"minus power spectrum inaccurate for {grid_shape=}: {max_err=}, {avg_err=}" vec_sum = sum( spec.bin_power(vector[mu], k_power=k_power) for mu in range(3)) pol_sum = plus_ps_1 + minus_ps_1 max_err, avg_err = get_errs(vec_sum[1:-2], pol_sum[1:-2]) assert max_err < max_rtol and avg_err < avg_rtol, \ f"polarization power spectrum inaccurate for {grid_shape=}" \ f": {max_err=}, {avg_err=}" # reset for mu in range(3): fk = make_data(*fft.shape(True)).astype(cdtype) fk = make_hermitian(fk, fft).astype(cdtype) vector[mu].set(fk) long = cla.zeros_like(plus) project.decompose_vector(queue, vector, plus, minus, long, times_abs_k=True) plus_ps = spec.bin_power(plus, k_power=k_power) minus_ps = spec.bin_power(minus, k_power=k_power) long_ps = spec.bin_power(long, k_power=k_power) vec_sum = sum( spec.bin_power(vector[mu], k_power=k_power) for mu in range(3)) dec_sum = plus_ps + minus_ps + long_ps max_err, avg_err = get_errs(vec_sum[1:-2], dec_sum[1:-2]) assert max_err < max_rtol and avg_err < avg_rtol, \ f"decomp power spectrum inaccurate for {grid_shape=}: {max_err=}, {avg_err=}" hij = cl.clrandom.rand(queue, (6, ) + rank_shape, dtype) gw_spec = spec.gw(hij, project, 1.3) gw_pol_spec = spec.gw_polarization(hij, project, 1.3) max_rtol = 1e-14 if dtype == np.float64 else 1e-2 avg_rtol = 1e-11 if dtype == np.float64 else 1e-4 pol_sum = gw_pol_spec[0] + gw_pol_spec[1] max_err, avg_err = get_errs(gw_spec[1:-2], pol_sum[1:-2]) assert max_err < max_rtol and avg_err < avg_rtol, \ f"gw pol don't add up to gw for {grid_shape=}: {max_err=}, {avg_err=}"