Esempio n. 1
0
def test_2d_on_2d_r2c(comm):
    if comm.size == 1:
        procmesh = pfft.ProcMesh(np=[1, 1], comm=comm)
    else:
        procmesh = pfft.ProcMesh(np=[2, 2], comm=comm)
    N = (8, 8)

    data = numpy.arange(numpy.prod(N), dtype='f8').reshape(N)

    correct = numpy.fft.rfftn(data.copy())
    result = numpy.zeros_like(correct)

    partition = pfft.Partition(
        pfft.Type.PFFT_R2C,
        N,
        procmesh,
        flags=pfft.Flags.PFFT_ESTIMATE
        | pfft.Flags.PFFT_TRANSPOSED_OUT
        | pfft.Flags.PFFT_DESTROY_INPUT
        #          | pfft.Flags.PADDED_R2C # doesn't work yet
    )

    buffer1 = pfft.LocalBuffer(partition)
    buffer2 = pfft.LocalBuffer(partition)

    plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)

    buffer1.view_input()[:] = data[partition.local_i_slice]
    plan.execute(buffer1, buffer2)

    result[partition.local_o_slice] = buffer2.view_output()
    result = comm.allreduce(result)
    assert_almost_equal(correct, result)
Esempio n. 2
0
def test_world():
    world = MPI.COMM_WORLD

    procmesh = pfft.ProcMesh(np=[
        world.size,
    ], comm=world)
    assert procmesh.comm == world
    procmesh = pfft.ProcMesh(np=[
        world.size,
    ], comm=None)
    assert procmesh.comm == world

    assert_array_equal(pfft.ProcMesh.split(2, None),
                       pfft.ProcMesh.split(2, world))
    assert_array_equal(pfft.ProcMesh.split(1, None),
                       pfft.ProcMesh.split(1, world))
Esempio n. 3
0
def test_correct_multi(comm):
    procmesh = pfft.ProcMesh(np=[
        comm.size,
    ], comm=comm)
    N = (2, 3)
    data = numpy.arange(numpy.prod(N), dtype='complex128').reshape(N)
    correct = numpy.fft.fftn(data)
    result = numpy.zeros_like(data)

    partition = pfft.Partition(pfft.Type.PFFT_C2C,
                               N,
                               procmesh,
                               flags=pfft.Flags.PFFT_ESTIMATE)

    buffer1 = pfft.LocalBuffer(partition)
    buffer2 = pfft.LocalBuffer(partition)

    plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)

    buffer1.view_input()[:] = data[partition.local_i_slice]
    plan.execute(buffer1, buffer2)

    result[partition.local_o_slice] = buffer2.view_output()
    result = comm.allreduce(result)
    assert_almost_equal(correct, result)
Esempio n. 4
0
def test_reuse_local_buffer(comm):
    procmesh = pfft.ProcMesh(np=[1], comm=comm)

    partition1 = pfft.Partition(pfft.Type.PFFT_R2C, [8, 8],
                                procmesh,
                                flags=pfft.Flags.PFFT_ESTIMATE
                                | pfft.Flags.PFFT_TRANSPOSED_OUT)

    partition2 = pfft.Partition(pfft.Type.PFFT_R2C, [8, 8],
                                procmesh,
                                flags=pfft.Flags.PFFT_ESTIMATE)

    buffer1 = pfft.LocalBuffer(partition1)
    buffer2 = pfft.LocalBuffer(partition2, base=buffer1)
    buffer3 = pfft.LocalBuffer(partition1)

    assert buffer1 is not buffer2
    assert buffer1.address == buffer2.address

    assert buffer1 in buffer2
    assert buffer2 in buffer1

    assert buffer1 not in buffer3
    assert buffer3 not in buffer1
    assert buffer2 not in buffer3
    assert buffer3 not in buffer2
Esempio n. 5
0
def test_nino(comm):
    procmesh = pfft.ProcMesh(np=[
        comm.size,
    ], comm=comm)

    partition = pfft.Partition(pfft.Type.PFFT_C2C, [4, 8], procmesh,
                               pfft.Flags.PFFT_TRANSPOSED_OUT)

    assert_array_equal(partition.ni, [4, 8])
    assert_array_equal(partition.no, [4, 8])
Esempio n. 6
0
def test_raw(comm):
    procmesh = pfft.ProcMesh(np=[1], comm=comm)

    partition = pfft.Partition(pfft.Type.PFFT_R2C, [8, 8],
                               procmesh,
                               flags=pfft.Flags.PFFT_ESTIMATE
                               | pfft.Flags.PFFT_TRANSPOSED_OUT)

    buffer1 = pfft.LocalBuffer(partition)
    assert buffer1.view_raw().size == 2 * partition.alloc_local
Esempio n. 7
0
def test_leak(comm):
    for i in range(1024):
        procmesh = pfft.ProcMesh(np=[1, 1], comm=comm)

        partition = pfft.Partition(pfft.Type.PFFT_C2C, [128, 128, 128],
                                   procmesh, pfft.Flags.PFFT_TRANSPOSED_OUT)

        buffer = pfft.LocalBuffer(partition)
        #FIXME: check with @mpip if this is correct.
        i = buffer.view_input()
Esempio n. 8
0
def main():
    comm = MPI.COMM_WORLD
    # this must run with comm.size == 3
    assert comm.size == 3
    procmesh = pfft.ProcMesh(np=[
        3,
    ])
    partition = pfft.Partition(pfft.Type.PFFT_C2C, [4, 4], procmesh,
                               pfft.Flags.PFFT_TRANSPOSED_OUT)

    assert_array_equal(partition.i_edges[0], [0, 2, 4, 4])
    assert_array_equal(partition.i_edges[1], [0, 4])
Esempio n. 9
0
def test_edges(comm):
    procmesh = pfft.ProcMesh(np=[
        comm.size,
    ], comm=comm)

    partition = pfft.Partition(pfft.Type.PFFT_C2C, [4, 4], procmesh,
                               pfft.Flags.PFFT_TRANSPOSED_OUT)

    assert_array_equal(partition.i_edges[0], [0, 2, 4, 4])
    assert_array_equal(partition.i_edges[1], [0, 4])

    assert_array_equal(partition.o_edges[1], [0, 2, 4, 4])
    assert_array_equal(partition.o_edges[0], [0, 4])
Esempio n. 10
0
def test_edges_padded(comm):
    procmesh = pfft.ProcMesh(np=[
        comm.size,
    ], comm=comm)

    partition = pfft.Partition(
        pfft.Type.PFFT_R2C, [16, 8], procmesh,
        pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_PADDED_R2C)

    assert_array_equal(partition.i_edges[0], [0, 16])
    assert_array_equal(partition.i_edges[1], [0, 8])

    assert_array_equal(partition.o_edges[0], [0, 16])
    assert_array_equal(partition.o_edges[1], [0, 5])
Esempio n. 11
0
def test_transpose_2d_decom(comm):
    procmesh = pfft.ProcMesh(np=[1, 1], comm=comm)
    N = (1, 2, 3, 4)

    partition = pfft.Partition(pfft.Type.PFFT_C2C,
                               N,
                               procmesh,
                               flags=pfft.Flags.PFFT_ESTIMATE
                               | pfft.Flags.PFFT_TRANSPOSED_OUT)

    buffer = pfft.LocalBuffer(partition)
    i = buffer.view_input()
    assert_array_equal(i.strides, [384, 192, 64, 16])
    o = buffer.view_output()
    assert_array_equal(o.strides, [64, 192, 64, 16])
Esempio n. 12
0
def test_correct_single(comm):
    procmesh = pfft.ProcMesh(np=[1], comm=comm)

    partition = pfft.Partition(pfft.Type.PFFT_C2C, [2, 2],
                               procmesh,
                               flags=pfft.Flags.PFFT_ESTIMATE)

    buffer1 = pfft.LocalBuffer(partition)
    buffer2 = pfft.LocalBuffer(partition)

    plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
    buffer1.view_input()[:] = numpy.arange(4).reshape(2, 2)
    correct = numpy.fft.fftn(buffer1.view_input())
    plan.execute(buffer1, buffer2)

    assert_array_equal(correct, buffer2.view_output())
Esempio n. 13
0
def test_transpose_3d_decom(comm):
    procmesh = pfft.ProcMesh(np=[1, 1, 1], comm=comm)
    N = (1, 2, 3, 4, 5)

    partition = pfft.Partition(pfft.Type.PFFT_C2C,
                               N,
                               procmesh,
                               flags=pfft.Flags.PFFT_ESTIMATE
                               | pfft.Flags.PFFT_TRANSPOSED_OUT)

    buffer = pfft.LocalBuffer(partition)
    #FIXME: check with @mpip if this is correct.
    i = buffer.view_input()
    assert_array_equal(i.strides, [1920, 960, 320, 80, 16])
    o = buffer.view_output()
    assert_array_equal(o.strides, [80, 960, 320, 80, 16])
Esempio n. 14
0
def test_plan_backward(comm):
    procmesh = pfft.ProcMesh(np=[1], comm=comm)

    partition = pfft.Partition(pfft.Type.PFFT_R2C, [2, 2],
                               procmesh,
                               flags=pfft.Flags.PFFT_ESTIMATE
                               | pfft.Flags.PFFT_TRANSPOSED_OUT)

    buffer1 = pfft.LocalBuffer(partition)
    buffer2 = pfft.LocalBuffer(partition)

    plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer1, buffer2)
    assert plan.flags & pfft.Flags.PFFT_TRANSPOSED_OUT
    assert plan.type == pfft.Type.PFFT_R2C

    plan = pfft.Plan(partition, pfft.Direction.PFFT_BACKWARD, buffer1, buffer2)
    assert plan.flags & pfft.Flags.PFFT_TRANSPOSED_IN
    assert plan.type == pfft.Type.PFFT_C2R
Esempio n. 15
0
def test_transposed(comm):
    procmesh = pfft.ProcMesh(np=[
        1,
    ], comm=comm)

    partition = pfft.Partition(pfft.Type.PFFT_C2C, [4, 8], procmesh,
                               pfft.Flags.PFFT_TRANSPOSED_OUT)

    buffer = pfft.LocalBuffer(partition)
    o = buffer.view_output()
    i = buffer.view_input()

    assert_array_equal(i.shape, (4, 8))
    assert_array_equal(i.strides, (128, 16))
    assert_array_equal(o.shape, (4, 8))
    assert_array_equal(o.strides, (16, 64))

    assert o.dtype == numpy.dtype('complex128')
    assert i.dtype == numpy.dtype('complex128')
Esempio n. 16
0
def test_padded(comm):
    procmesh = pfft.ProcMesh(np=[
        1,
    ], comm=comm)

    partition = pfft.Partition(
        pfft.Type.PFFT_R2C, [4, 8], procmesh,
        pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_PADDED_R2C)

    buffer = pfft.LocalBuffer(partition)
    i = buffer.view_input()
    o = buffer.view_output()

    assert_array_equal(i.shape, (4, 8))
    assert_array_equal(i.strides, (80, 8))
    assert_array_equal(o.shape, (4, 5))
    assert_array_equal(o.strides, (16, 64))

    assert i.dtype == numpy.dtype('float64')
    assert o.dtype == numpy.dtype('complex128')
Esempio n. 17
0
    def __init__(self, BoxSize, Nmesh, paintbrush='cic', comm=None, np=None, verbose=False, dtype='f8'):
        """ create a PM object.  """
        # this weird sequence to intialize comm is because
        # we want to be compatible with None comm == MPI.COMM_WORLD
        # while not relying on pfft's full mpi4py compatibility
        # (passing None through to pfft)
        if comm is None:
            self.comm = MPI.COMM_WORLD
        else:
            self.comm = comm
        if np is None:
                np = pfft.split_size_2d(self.comm.size)

        dtype = numpy.dtype(dtype)
        if dtype == numpy.dtype('f8'):
            forward = pfft.Type.PFFT_R2C
            backward = pfft.Type.PFFT_C2R
        elif dtype == numpy.dtype('f4'):
            forward = pfft.Type.PFFTF_R2C
            backward = pfft.Type.PFFTF_C2R
        else:
            raise ValueError("dtype must be f8 or f4")

        self.procmesh = pfft.ProcMesh(np, comm=comm)
        self.Nmesh = Nmesh
        self.BoxSize = numpy.empty(3, dtype='f8')
        self.BoxSize[:] = BoxSize
        self.partition = pfft.Partition(forward,
            [Nmesh, Nmesh, Nmesh], 
            self.procmesh,
            pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_DESTROY_INPUT)

        buffer = pfft.LocalBuffer(self.partition)
        self.real = buffer.view_input()
        self.real[:] = 0

        self.complex = buffer.view_output()

        self.T = Timers(self.comm)
        with self.T['Plan']:
            self.forward = pfft.Plan(self.partition, pfft.Direction.PFFT_FORWARD,
                    self.real.base, self.complex.base, forward,
                    pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_DESTROY_INPUT)
            self.backward = pfft.Plan(self.partition, pfft.Direction.PFFT_BACKWARD,
                    self.complex.base, self.real.base, backward, 
                    pfft.Flags.PFFT_ESTIMATE | pfft.Flags.PFFT_TRANSPOSED_IN | pfft.Flags.PFFT_DESTROY_INPUT)

        self.domain = domain.GridND(self.partition.i_edges, comm=self.comm)
        self.verbose = verbose
        self.stack = []

        k = []
        x = []
        w = []
        r = []

        for d in range(self.partition.ndim):
            t = numpy.ones(self.partition.ndim, dtype='intp')
            s = numpy.ones(self.partition.ndim, dtype='intp')
            t[d] = self.partition.local_ni[d]
            s[d] = self.partition.local_no[d]
            wi = numpy.arange(s[d], dtype='f4') + self.partition.local_o_start[d] 
            ri = numpy.arange(t[d], dtype='f4') + self.partition.local_i_start[d] 

            wi[wi >= self.Nmesh // 2] -= self.Nmesh
            ri[ri >= self.Nmesh // 2] -= self.Nmesh

            wi *= (2 * numpy.pi / self.Nmesh)
            ki = wi * self.Nmesh / self.BoxSize[d]
            xi = ri * self.BoxSize[d] / self.Nmesh

            w.append(wi.reshape(s))
            r.append(ri.reshape(t))
            k.append(ki.reshape(s))
            x.append(xi.reshape(t))

        self.w = w
        self.r = r
        self.k = k
        self.x = x
        
        # set the painter
        self.paintbrush = paintbrush.lower()
        if paintbrush == 'cic':
            self.painter = cic.paint
        elif paintbrush == 'tsc':
            self.painter = tsc.paint
        else:
            raise ValueError("valid `painter` values are: ['cic', 'tsc']")
Esempio n. 18
0
from mpi4py import MPI
import numpy
import pfft

if MPI.COMM_WORLD.rank == 0:
    print \
        """
This example performs a in-place transform, with a naive slab decomposition.

In place transform is achieved by providing a single buffer object to pfft.Plan.
Consequently, calls to plan.execute we also provide only a single buffer object.
"""

procmesh = pfft.ProcMesh([4], comm=MPI.COMM_WORLD)
partition = pfft.Partition(
    pfft.Type.PFFT_C2C, [8, 8], procmesh,
    pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_DESTROY_INPUT)
for irank in range(4):
    MPI.COMM_WORLD.barrier()
    if irank != procmesh.rank:
        continue
    print 'My rank is', procmesh.this
    print 'local_i_start', partition.local_i_start
    print 'local_o_start', partition.local_o_start
    print 'i_edges', partition.i_edges
    print 'o_edges', partition.o_edges

buffer = pfft.LocalBuffer(partition)

plan = pfft.Plan(partition, pfft.Direction.PFFT_FORWARD, buffer)
iplan = pfft.Plan(
Esempio n. 19
0
    def __init__(self, Nmesh, BoxSize=1.0, comm=None, np=None, dtype='f8', plan_method='estimate', resampler='cic'):
        """ create a PM object.  """
        if comm is None:
            comm = MPI.COMM_WORLD

        self.comm = comm

        if np is None:
            if len(Nmesh) >= 3:
                np = pfft.split_size_2d(self.comm.size)
            elif len(Nmesh) == 2:
                np = [self.comm.size]
            elif len(Nmesh) == 1:
                np = []

        dtype = numpy.dtype(dtype)
        self.dtype = dtype

        if dtype == numpy.dtype('f8'):
            forward = pfft.Type.PFFT_R2C
            backward = pfft.Type.PFFT_C2R
        elif dtype == numpy.dtype('f4'):
            forward = pfft.Type.PFFTF_R2C
            backward = pfft.Type.PFFTF_C2R
        else:
            raise ValueError("dtype must be f8 or f4")

        self.procmesh = pfft.ProcMesh(np, comm=comm)
        self.Nmesh = numpy.array(Nmesh, dtype='i8')
        self.ndim = len(self.Nmesh)
        self.BoxSize = numpy.empty(len(Nmesh), dtype='f8')
        self.BoxSize[:] = BoxSize
        self.partition = pfft.Partition(forward,
            self.Nmesh,
            self.procmesh,
            pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_PADDED_R2C)

        bufferin = pfft.LocalBuffer(self.partition)
        bufferout = pfft.LocalBuffer(self.partition)

        plan_method = {
            "estimate": pfft.Flags.PFFT_ESTIMATE,
            "measure": pfft.Flags.PFFT_MEASURE,
            "exhaustive": pfft.Flags.PFFT_EXHAUSTIVE,
            } [plan_method]

        self.forward = pfft.Plan(self.partition, pfft.Direction.PFFT_FORWARD,
                bufferin, bufferout, forward,
                plan_method | pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_TUNE | pfft.Flags.PFFT_PADDED_R2C)
        self.backward = pfft.Plan(self.partition, pfft.Direction.PFFT_BACKWARD,
                bufferout, bufferin, backward,
                plan_method | pfft.Flags.PFFT_TRANSPOSED_IN | pfft.Flags.PFFT_TUNE | pfft.Flags.PFFT_PADDED_C2R)

        self.ipforward = pfft.Plan(self.partition, pfft.Direction.PFFT_FORWARD,
                bufferin, bufferin, forward,
                plan_method | pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.PFFT_TUNE | pfft.Flags.PFFT_PADDED_R2C)
        self.ipbackward = pfft.Plan(self.partition, pfft.Direction.PFFT_BACKWARD,
                bufferout, bufferout, backward,
                plan_method | pfft.Flags.PFFT_TRANSPOSED_IN | pfft.Flags.PFFT_TUNE | pfft.Flags.PFFT_PADDED_C2R)

        self.domain = domain.GridND(self.partition.i_edges, comm=self.comm)

        k = []
        x = []
        w = []
        r = []
        o_ind = []
        i_ind = []

        for d in range(self.partition.ndim):
            t = numpy.ones(self.partition.ndim, dtype='intp')
            s = numpy.ones(self.partition.ndim, dtype='intp')
            t[d] = self.partition.local_i_shape[d]
            s[d] = self.partition.local_o_shape[d]

            i_indi = numpy.arange(t[d], dtype='intp') + self.partition.local_i_start[d]
            o_indi = numpy.arange(s[d], dtype='intp') + self.partition.local_o_start[d]

            wi = numpy.arange(s[d], dtype='f4') + self.partition.local_o_start[d]
            ri = numpy.arange(t[d], dtype='f4') + self.partition.local_i_start[d]

            wi[wi >= self.Nmesh[d] // 2] -= self.Nmesh[d]
            ri[ri >= self.Nmesh[d] // 2] -= self.Nmesh[d]

            wi *= (2 * numpy.pi / self.Nmesh[d])
            ki = wi * self.Nmesh[d] / self.BoxSize[d]
            xi = ri * self.BoxSize[d] / self.Nmesh[d]

            o_ind.append(o_indi.reshape(s))
            i_ind.append(i_indi.reshape(t))
            w.append(wi.reshape(s))
            r.append(ri.reshape(t))
            k.append(ki.reshape(s))
            x.append(xi.reshape(t))

        self.i_ind = i_ind
        self.o_ind = o_ind
        self.w = w
        self.r = r
        self.k = k
        self.x = x

        # Transform from simulation unit to local grid unit.
        self.affine = Affine(self.partition.ndim,
                    translate=-self.partition.local_i_start,
                    scale=1.0 * self.Nmesh / self.BoxSize,
                    period = self.Nmesh)

        # Transform from global grid unit to local grid unit.
        self.affine_grid = Affine(self.partition.ndim,
                    translate=-self.partition.local_i_start,
                    scale=1.0,
                    period = self.Nmesh)

        self.resampler = FindResampler(resampler)
Esempio n. 20
0
def main(comm):
    Nmesh = [8, 8]

    if len(Nmesh) == 3:
        procmesh = pfft.ProcMesh(pfft.split_size_2d(comm.size), comm=comm)
    else:
        procmesh = pfft.ProcMesh((comm.size, ), comm=comm)

    partition = pfft.Partition(
        pfft.Type.R2C, Nmesh, procmesh, pfft.Flags.PADDED_R2C
        | pfft.Flags.PFFT_TRANSPOSED_OUT | pfft.Flags.DESTROY_INPUT)

    # generate the coordinate support.

    k = [None] * partition.ndim
    x = [None] * partition.ndim
    for d in range(partition.ndim):
        k[d] = numpy.arange(partition.no[d])[partition.local_o_slice[d]]
        k[d][k[d] >= partition.n[d] // 2] -= partition.n[d]
        # set to the right numpy broadcast shape
        k[d] = k[d].reshape(
            [-1 if i == d else 1 for i in range(partition.ndim)])

        x[d] = numpy.arange(partition.ni[d])[partition.local_i_slice[d]]
        # set to the right numpy broadcast shape
        x[d] = x[d].reshape(
            [-1 if i == d else 1 for i in range(partition.ndim)])

    # allocate memory
    buffer1 = pfft.LocalBuffer(partition)
    phi_disp = buffer1.view_input()

    buffer2 = pfft.LocalBuffer(partition)
    phi_spec = buffer2.view_output()

    # forward plan
    disp_to_spec_inplace = pfft.Plan(
        partition,
        pfft.Direction.PFFT_FORWARD,
        buffer2,
        buffer2,
        # the two lines below not needed after version 0.1.21
        # type=pfft.Type.R2C,
        # flags=pfft.Flags.TRANSPOSED_OUT | pfft.Flags.DESTROY_INPUT | pfft.Flags.PADDED_R2C
    )

    buffer3 = pfft.LocalBuffer(partition)
    grad_spec = buffer3.view_output()

    buffer4 = pfft.LocalBuffer(partition)
    grad_disp = buffer4.view_input()

    # backward plan
    spec_to_disp = pfft.Plan(
        partition,
        pfft.Direction.PFFT_BACKWARD,
        buffer3,
        buffer4,
        # the two lines below not needed after version 0.1.21
        # type=pfft.Type.C2R,
        # flags=pfft.Flags.TRANSPOSED_IN | pfft.Flags.DESTROY_INPUT | pfft.Flags.PADDED_C2R
    )

    # to do : fill in initial value
    dx = x[0] - Nmesh[0] * 0.5 + 0.5
    dy = x[1] - Nmesh[1] * 0.5 + 0.5
    phi_disp[...] = dx**2 + dx * dy + dy**2

    cprint('phi =', gather(partition, phi_disp).round(2), comm=comm)

    # copy in to the buffer for inplace transform
    # this preserves value of phi_disp
    phi_spec.base.view_input()[...] = phi_disp
    disp_to_spec_inplace.execute(phi_spec.base, phi_spec.base)

    all_grad_disp = numpy.zeros([partition.ndim] + list(phi_disp.shape),
                                dtype=grad_disp.dtype)

    #    cprint('phi_k =', gather(partition, phi_spec, mode='output').round(2), comm=comm)

    for d in range(partition.ndim):
        grad_spec[...] = phi_spec[...] * (k[d] * 1j)
        spec_to_disp.execute(grad_spec.base, grad_disp.base)
        # copy the gradient along d th direction
        all_grad_disp[d] = grad_disp

    # now do your thing.

    for d in range(partition.ndim):
        cprint('dim =',
               gather(partition, all_grad_disp[d]).round(2),
               comm=comm)