Exemplo n.º 1
0
def test_c2r_vjp(comm):
    pm = ParticleMesh(BoxSize=8.0, Nmesh=[4, 4], comm=comm, dtype='f8')

    real = pm.generate_whitenoise(1234, type='real', mean=1.0)
    comp = real.r2c()

    def objective(comp):
        real = comp.c2r()
        obj = (real.value ** 2).sum()
        return comm.allreduce(obj)

    grad_real = RealField(pm)
    grad_real[...] = real[...] * 2
    grad_comp = ComplexField(pm)
    grad_comp = grad_real.c2r_vjp(grad_real)
    grad_comp.decompress_vjp(grad_comp)

    ng = []
    ag = []
    ind = []
    dx = 1e-7
    for ind1 in numpy.ndindex(*(list(grad_comp.cshape) + [2])):
        dx1, c1 = perturb(comp, ind1, dx)
        ng1 = (objective(c1) - objective(comp)) / dx
        ag1 = grad_comp.cgetitem(ind1) * dx1 / dx
        comm.barrier()
        ng.append(ng1)
        ag.append(ag1)
        ind.append(ind1)

    assert_allclose(ng, ag, rtol=1e-5)
Exemplo n.º 2
0
def test_c2r_vjp(comm):
    pm = ParticleMesh(BoxSize=8.0, Nmesh=[4, 4], comm=comm, dtype='f8')

    real = pm.generate_whitenoise(1234, mode='real')
    comp = real.r2c()

    def objective(comp):
        real = comp.c2r()
        obj = (real.value ** 2).sum()
        return comm.allreduce(obj)

    grad_real = RealField(pm)
    grad_real[...] = real[...] * 2
    grad_comp = ComplexField(pm)
    grad_comp = grad_real.c2r_vjp(grad_real)
    grad_comp.decompress_vjp(grad_comp)

    ng = []
    ag = []
    ind = []
    dx = 1e-7
    for ind1 in numpy.ndindex(*(list(grad_comp.cshape) + [2])):
        dx1, c1 = perturb(comp, ind1, dx)
        ng1 = (objective(c1) - objective(comp)) / dx
        ag1 = grad_comp.cgetitem(ind1) * dx1 / dx
        comm.barrier()
        ng.append(ng1)
        ag.append(ag1)
        ind.append(ind1)

    assert_allclose(ng, ag, rtol=1e-5)