def test_c2r_vjp(comm): pm = ParticleMesh(BoxSize=8.0, Nmesh=[4, 4], comm=comm, dtype='f8') real = pm.generate_whitenoise(1234, type='real', mean=1.0) comp = real.r2c() def objective(comp): real = comp.c2r() obj = (real.value ** 2).sum() return comm.allreduce(obj) grad_real = RealField(pm) grad_real[...] = real[...] * 2 grad_comp = ComplexField(pm) grad_comp = grad_real.c2r_vjp(grad_real) grad_comp.decompress_vjp(grad_comp) ng = [] ag = [] ind = [] dx = 1e-7 for ind1 in numpy.ndindex(*(list(grad_comp.cshape) + [2])): dx1, c1 = perturb(comp, ind1, dx) ng1 = (objective(c1) - objective(comp)) / dx ag1 = grad_comp.cgetitem(ind1) * dx1 / dx comm.barrier() ng.append(ng1) ag.append(ag1) ind.append(ind1) assert_allclose(ng, ag, rtol=1e-5)
def test_c2r_vjp(comm): pm = ParticleMesh(BoxSize=8.0, Nmesh=[4, 4], comm=comm, dtype='f8') real = pm.generate_whitenoise(1234, mode='real') comp = real.r2c() def objective(comp): real = comp.c2r() obj = (real.value ** 2).sum() return comm.allreduce(obj) grad_real = RealField(pm) grad_real[...] = real[...] * 2 grad_comp = ComplexField(pm) grad_comp = grad_real.c2r_vjp(grad_real) grad_comp.decompress_vjp(grad_comp) ng = [] ag = [] ind = [] dx = 1e-7 for ind1 in numpy.ndindex(*(list(grad_comp.cshape) + [2])): dx1, c1 = perturb(comp, ind1, dx) ng1 = (objective(c1) - objective(comp)) / dx ag1 = grad_comp.cgetitem(ind1) * dx1 / dx comm.barrier() ng.append(ng1) ag.append(ag1) ind.append(ind1) assert_allclose(ng, ag, rtol=1e-5)