Esempio n. 1
0
class hwconv:
    def __init__(self,
                 Nx,
                 Ny,
                 padx=1.5,
                 pady=1.5,
                 num_in=6,
                 num_out=2,
                 fmultin=mult_in,
                 fmultout=mult_out62,
                 comm=MPI.COMM_WORLD):
        self.comm = comm
        self.num_in = num_in
        self.num_out = num_out
        self.nar = max(num_in, num_out)
        self.ffti = PFFT(comm,
                         shape=(self.num_in, Nx, Ny),
                         axes=(1, 2),
                         grid=[1, -1, 1],
                         padding=[1, 1.5, 1.5],
                         collapse=False)
        self.ffto = PFFT(comm,
                         shape=(self.num_out, Nx, Ny),
                         axes=(1, 2),
                         grid=[1, -1, 1],
                         padding=[1, 1.5, 1.5],
                         collapse=False)
        self.datk = newDistArray(self.ffti, forward_output=True)
        self.dat = newDistArray(self.ffti, forward_output=False)
        lkx = np.r_[0:int(Nx / 2), -int(Nx / 2):0]
        lky = np.r_[0:int(Ny / 2 + 1)]
        self.kx = DistArray((Nx, int(Ny / 2 + 1)),
                            subcomm=(1, 0),
                            dtype=float,
                            alignment=0)
        self.ky = DistArray((Nx, int(Ny / 2 + 1)),
                            subcomm=(1, 0),
                            dtype=float,
                            alignment=0)
        self.kx[:], self.ky[:] = np.meshgrid(lkx[self.kx.local_slice()[0]],
                                             lky[self.ky.local_slice()[1]],
                                             indexing='ij')
        self.ksqr = self.kx**2 + self.ky**2
        self.fmultin = fmultin
        self.fmultout = fmultout

    def convolve(self, u):
        hermitian_symmetrize(u)
        if (u.local_slice()[2].stop == u.global_shape[2]):
            u[:, :, -1] = 0
        u[:, int(Nx / 2), :] = 0
        self.fmultin(u, self.datk, self.kx, self.ky, self.ksqr)
        self.ffti.backward(self.datk, self.dat)
        self.fmultout(self.dat)
        self.ffto.forward(self.dat[:self.num_out, ],
                          self.datk[:self.num_out, ])
        if (self.datk.local_slice()[2].stop == self.datk.global_shape[2]):
            self.datk[:, :, -1] = 0
        self.datk[:, int(Nx / 2), :] = 0
        return self.datk[:self.num_out, ]
Esempio n. 2
0
        U_pad[j] = FFT_pad.backward(U_hat[j], U_pad[j])

    curl_pad[:] = compute_curl(U_hat, curl_pad)
    rhs = cross(U_pad, curl_pad, rhs)
    P_hat[:] = np.sum(rhs * K_over_K2, 0, out=P_hat)
    rhs -= P_hat * K
    rhs -= nu * K2 * U_hat
    return rhs


# Initialize a Taylor Green vortex
U[0] = np.sin(X[0]) * np.cos(X[1]) * np.cos(X[2])
U[1] = -np.cos(X[0]) * np.sin(X[1]) * np.cos(X[2])
U[2] = 0
for i in range(3):
    U_hat[i] = FFT.forward(U[i], U_hat[i])

# Integrate using a 4th order Rung-Kutta method
t = 0.0
tstep = 0
t0 = time()
while t < T - 1e-8:
    t += dt
    tstep += 1
    U_hat1[:] = U_hat0[:] = U_hat
    for rk in range(4):
        dU = compute_rhs(dU)
        if rk < 3:
            U_hat[:] = U_hat0 + b[rk] * dt * dU
        U_hat1[:] += a[rk] * dt * dU
    U_hat[:] = U_hat1[:]
Esempio n. 3
0
           grid=(-1, ),
           transforms=transforms)
pfft = PFFT(MPI.COMM_WORLD,
            N,
            axes=((0, ), (1, 2)),
            grid=(-1, ),
            padding=[1.5, 1.0, 1.0],
            transforms=transforms)

assert fft.axes == pfft.axes

u = newDistArray(fft, forward_output=False)
u[:] = np.random.random(u.shape).astype(u.dtype)

u_hat = newDistArray(fft, forward_output=True)
u_hat = fft.forward(u, u_hat)
uj = np.zeros_like(u)
uj = fft.backward(u_hat, uj)
assert np.allclose(uj, u)

u_padded = newDistArray(pfft, forward_output=False)
uc = u_hat.copy()
u_padded = pfft.backward(u_hat, u_padded)
u_hat = pfft.forward(u_padded, u_hat)
assert np.allclose(u_hat, uc)

#cfft = PFFT(MPI.COMM_WORLD, N, dtype=complex, padding=[1.5, 1.5, 1.5])
cfft = PFFT(MPI.COMM_WORLD, N, dtype=complex)

uc = np.random.random(cfft.backward.input_array.shape).astype(complex)
u2 = cfft.backward(uc)
Esempio n. 4
0
          grid=[1, -1, 1],
          padding=[1, 1.5, 1.5],
          collapse=False)
u = newDistArray(pf, forward_output=False)
uk = newDistArray(pf, forward_output=True)

n, x, y = np.meshgrid(np.arange(0, howmany),
                      np.linspace(-1, 1, Npx),
                      np.linspace(-1, 1, Npy),
                      indexing='ij')
nl, xl, yl = n[u.local_slice()], x[u.local_slice()], y[u.local_slice()]
u[:] = np.sin(4 * np.pi * (xl + 2 * yl)) * np.exp(-xl**2 / 2 / 0.04 -
                                                  yl**2 / 2 / 0.08) * (nl - 3)
u0 = u.copy()

pf.forward(u, uk)
pf.backward(uk, u)

plt.figure()
plt.pcolormesh(xl[0, ].T,
               yl[0, ].T,
               u0[0, ].T - u[0, ].T,
               cmap='twilight_shifted',
               rasterized=True)
plt.colorbar()
plt.axis('square')
plt.axis([-1, 1, -1, 1])

u1 = u.copy()

pf.forward(u, uk)
Esempio n. 5
0
class allencahn_imex(ptype):
    """
    Example implementing Allen-Cahn equation in 2-3D using mpi4py-fft for solving linear parts, IMEX time-stepping

    mpi4py-fft: https://mpi4py-fft.readthedocs.io/en/latest/

    Attributes:
        fft: fft object
        X: grid coordinates in real space
        K2: Laplace operator in spectral space
        dx: mesh width in x direction
        dy: mesh width in y direction
    """
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 1.0
        if 'init_type' not in problem_params:
            problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = None
        if 'dw' not in problem_params:
            problem_params['dw'] = 0.0

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'eps', 'L', 'radius', 'dw', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        ndim = len(problem_params['nvars'])
        axes = tuple(range(ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.float,
                        collapse=True)

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # invoke super init, passing the communicator and the local dimensions as init
        super(allencahn_imex,
              self).__init__(init=(tmp_u.shape, problem_params['comm'],
                                   tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        L = np.array([self.params.L] * ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = (X[i] * L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N[:-1]]
        k.append(np.fft.rfftfreq(N[-1], 1. / N[-1]).astype(int))
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / L
        for i in range(ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]

    def eval_f(self, u, t):
        """
        Routine to evaluate the RHS

        Args:
            u (dtype_u): current values
            t (float): current time

        Returns:
            dtype_f: the RHS
        """

        f = self.dtype_f(self.init)

        if self.params.spectral:

            f.impl = -self.K2 * u

            if self.params.eps > 0:
                tmp = self.fft.backward(u)
                tmpf = - 2.0 / self.params.eps ** 2 * tmp * (1.0 - tmp) * (1.0 - 2.0 * tmp) - \
                    6.0 * self.params.dw * tmp * (1.0 - tmp)
                f.expl[:] = self.fft.forward(tmpf)

        else:

            u_hat = self.fft.forward(u)
            lap_u_hat = -self.K2 * u_hat
            f.impl[:] = self.fft.backward(lap_u_hat, f.impl)

            if self.params.eps > 0:
                f.expl = - 2.0 / self.params.eps ** 2 * u * (1.0 - u) * (1.0 - 2.0 * u) - \
                    6.0 * self.params.dw * u * (1.0 - u)

        return f

    def solve_system(self, rhs, factor, u0, t):
        """
        Simple FFT solver for the diffusion part

        Args:
            rhs (dtype_f): right-hand side for the linear system
            factor (float) : abbrev. for the node-to-node stepsize (or any other factor required)
            u0 (dtype_u): initial guess for the iterative solver (not used here so far)
            t (float): current time (e.g. for time-dependent BCs)

        Returns:
            dtype_u: solution as mesh
        """

        if self.params.spectral:

            me = rhs / (1.0 + factor * self.K2)

        else:

            me = self.dtype_u(self.init)
            rhs_hat = self.fft.forward(rhs)
            rhs_hat /= (1.0 + factor * self.K2)
            me[:] = self.fft.backward(rhs_hat)

        return me

    def u_exact(self, t):
        """
        Routine to compute the exact solution at time t

        Args:
            t (float): current time

        Returns:
            dtype_u: exact solution
        """

        assert t == 0, 'ERROR: u_exact only valid for t=0'
        me = self.dtype_u(self.init, val=0.0)
        if self.params.init_type == 'circle':
            r2 = (self.X[0] - 0.5)**2 + (self.X[1] - 0.5)**2
            if self.params.spectral:
                tmp = 0.5 * (1.0 + np.tanh((self.params.radius - np.sqrt(r2)) /
                                           (np.sqrt(2) * self.params.eps)))
                me[:] = self.fft.forward(tmp)
            else:
                me[:] = 0.5 * (1.0 + np.tanh(
                    (self.params.radius - np.sqrt(r2)) /
                    (np.sqrt(2) * self.params.eps)))
        elif self.params.init_type == 'circle_rand':
            ndim = len(me.shape)
            L = int(self.params.L)
            # get random radii for circles/spheres
            np.random.seed(1)
            lbound = 3.0 * self.params.eps
            ubound = 0.5 - self.params.eps
            rand_radii = (ubound - lbound) * np.random.random_sample(
                size=tuple([L] * ndim)) + lbound
            # distribute circles/spheres
            tmp = newDistArray(self.fft, False)
            if ndim == 2:
                for i in range(0, L):
                    for j in range(0, L):
                        # build radius
                        r2 = (self.X[0] + i - L + 0.5)**2 + (self.X[1] + j -
                                                             L + 0.5)**2
                        # add this blob, shifted by 1 to avoid issues with adding up negative contributions
                        tmp += np.tanh((rand_radii[i, j] - np.sqrt(r2)) /
                                       (np.sqrt(2) * self.params.eps)) + 1
            # normalize to [0,1]
            tmp *= 0.5
            assert np.all(tmp <= 1.0)
            if self.params.spectral:
                me[:] = self.fft.forward(tmp)
            else:
                me[:] = tmp[:]
        else:
            raise NotImplementedError(
                'type of initial value not implemented, got %s' %
                self.params.init_type)

        return me
Esempio n. 6
0
X = get_local_mesh(fft, L)
K = get_local_wavenumbermesh(fft, L)
K = np.array(K).astype(float)
K2 = np.sum(K * K, 0, dtype=float)

u = newDistArray(fft, False)
print(type(u))
print(u.subcomm)
uex = newDistArray(fft, False)

u[:] = np.sin(2 * np.pi * X[0]) * np.sin(2 * np.pi * X[1])
print(u.shape, X[0].shape)
# exit()
uex[:] = -2.0 * (2.0 * np.pi)**2 * np.sin(2 * np.pi * X[0]) * np.sin(
    2 * np.pi * X[1])
u_hat = fft.forward(u)

lap_u_hat = -K2 * u_hat

lap_u = np.zeros_like(u)
lap_u = fft.backward(lap_u_hat, lap_u)
local_error = np.amax(abs(lap_u - uex))
err = MPI.COMM_WORLD.allreduce(local_error, MPI.MAX)
print('Laplace error:', err)

ratio = 2
Nc = np.array([nvars // ratio] * ndim, dtype=int)
fftc = PFFT(MPI.COMM_WORLD, Nc, axes=axes, dtype=np.float, slab=True)
print(Nc, fftc.global_shape())
Xc = get_local_mesh(fftc, L)
Esempio n. 7
0
class nonlinearschroedinger_imex(ptype):
    """
    Example implementing the nonlinear Schrödinger equation in 2-3D using mpi4py-fft for solving linear parts,
    IMEX time-stepping

    mpi4py-fft: https://mpi4py-fft.readthedocs.io/en/latest/

    Attributes:
        fft: fft object
        X: grid coordinates in real space
        K2: Laplace operator in spectral space
    """
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 2.0 * np.pi
        # if 'init_type' not in problem_params:
        #     problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = MPI.COMM_WORLD
        if 'c' not in problem_params:
            problem_params['c'] = 1.0

        if not problem_params['L'] == 2.0 * np.pi:
            raise ProblemError(
                f'Setup not implemented, L has to be 2pi, got {problem_params["L"]}'
            )

        if not (problem_params['c'] == 0.0 or problem_params['c'] == 1.0):
            raise ProblemError(
                f'Setup not implemented, c has to be 0 or 1, got {problem_params["c"]}'
            )

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'c', 'L', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        self.ndim = len(problem_params['nvars'])
        axes = tuple(range(self.ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.complex128,
                        collapse=True)

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # invoke super init, passing the communicator and the local dimensions as init
        super(nonlinearschroedinger_imex,
              self).__init__(init=(tmp_u.shape, problem_params['comm'],
                                   tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        self.L = np.array([self.params.L] * self.ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = (X[i] * self.L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N]
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / self.L
        for i in range(self.ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]

    def eval_f(self, u, t):
        """
        Routine to evaluate the RHS

        Args:
            u (dtype_u): current values
            t (float): current time

        Returns:
            dtype_f: the RHS
        """

        f = self.dtype_f(self.init)

        if self.params.spectral:

            f.impl = -self.K2 * 1j * u
            tmp = self.fft.backward(u)
            tmpf = self.ndim * self.params.c * 2j * np.absolute(tmp)**2 * tmp
            f.expl[:] = self.fft.forward(tmpf)

        else:

            u_hat = self.fft.forward(u)
            lap_u_hat = -self.K2 * 1j * u_hat
            f.impl[:] = self.fft.backward(lap_u_hat, f.impl)
            f.expl = self.ndim * self.params.c * 2j * np.absolute(u)**2 * u

        return f

    def solve_system(self, rhs, factor, u0, t):
        """
        Simple FFT solver for the diffusion part

        Args:
            rhs (dtype_f): right-hand side for the linear system
            factor (float) : abbrev. for the node-to-node stepsize (or any other factor required)
            u0 (dtype_u): initial guess for the iterative solver (not used here so far)
            t (float): current time (e.g. for time-dependent BCs)

        Returns:
            dtype_u: solution as mesh
        """

        if self.params.spectral:

            me = rhs / (1.0 + factor * self.K2 * 1j)

        else:

            me = self.dtype_u(self.init)
            rhs_hat = self.fft.forward(rhs)
            rhs_hat /= (1.0 + factor * self.K2 * 1j)
            me[:] = self.fft.backward(rhs_hat)

        return me

    def u_exact(self, t):
        """
        Routine to compute the exact solution at time t, see (1.3) https://arxiv.org/pdf/nlin/0702010.pdf for details

        Args:
            t (float): current time

        Returns:
            dtype_u: exact solution
        """
        def nls_exact_1D(t, x, c):

            ae = 1.0 / np.sqrt(2.0) * np.exp(1j * t)
            if c != 0:
                u = ae * ((np.cosh(t) + 1j * np.sinh(t)) /
                          (np.cosh(t) - 1.0 / np.sqrt(2.0) * np.cos(x)) - 1.0)
            else:
                u = np.sin(x) * np.exp(-t * 1j)

            return u

        me = self.dtype_u(self.init, val=0.0)

        if self.params.spectral:
            tmp = nls_exact_1D(self.ndim * t, sum(self.X), self.params.c)
            me[:] = self.fft.forward(tmp)
        else:
            me[:] = nls_exact_1D(self.ndim * t, sum(self.X), self.params.c)

        return me
Esempio n. 8
0
class pDFT(BaseDFT):
    """
    A wrapper to :class:`mpi4py_fft.mpifft.PFFT` to compute distributed Fast Fourier
    transforms.

    See :class:`pystella.fourier.dft.BaseDFT`.

    :arg decomp: A :class:`pystella.DomainDecomposition`.
        The shape of the MPI processor grid is determined by
        the ``proc_shape`` attribute of this object.

    :arg queue: A :class:`pyopencl.CommandQueue`.

    :arg grid_shape: A 3-:class:`tuple` specifying the shape of position-space
        arrays to be transformed.

    :arg dtype: The datatype of position-space arrays to be transformed.
        The complex datatype for momentum-space arrays is chosen to have
        the same precision.

    Any keyword arguments are passed to :class:`mpi4py_fft.mpifft.PFFT`.

    .. versionchanged:: 2020.1

        Support for complex-to-complex transforms.
    """
    def __init__(self, decomp, queue, grid_shape, dtype, **kwargs):
        self.decomp = decomp
        self.grid_shape = grid_shape
        self.proc_shape = decomp.proc_shape
        self.dtype = np.dtype(dtype)
        self.is_real = self.dtype.kind == "f"

        from pystella.fourier import get_complex_dtype_with_matching_prec
        self.cdtype = get_complex_dtype_with_matching_prec(self.dtype)
        from pystella.fourier import get_real_dtype_with_matching_prec
        self.rdtype = get_real_dtype_with_matching_prec(self.dtype)

        if self.proc_shape[0] > 1 and self.proc_shape[1] == 1:
            slab = True
        else:
            slab = False

        from mpi4py_fft.pencil import Subcomm
        default_kwargs = dict(
            axes=([0], [1], [2]),
            threads=16,
            backend="fftw",
            collapse=True,
        )
        default_kwargs.update(kwargs)
        comm = decomp.comm if slab else Subcomm(decomp.comm, self.proc_shape)

        from mpi4py_fft import PFFT
        self.fft = PFFT(comm,
                        grid_shape,
                        dtype=dtype,
                        slab=slab,
                        **default_kwargs)

        self.fx = self.fft.forward.input_array
        self.fk = self.fft.forward.output_array

        slc = self.fft.local_slice(True)
        self.sub_k = get_sliced_momenta(grid_shape, self.dtype, slc, queue)

    @property
    def proc_permutation(self):
        axes = list(a for b in self.fft.axes for a in b)
        for t in self.fft.transfer:
            axes[t.axisA], axes[t.axisB] = axes[t.axisB], axes[t.axisA]
        return axes

    def shape(self, forward_output=True):
        return self.fft.shape(forward_output=forward_output)

    def forward_transform(self, fx, fk, **kwargs):
        kwargs["normalize"] = kwargs.get("normalize", False)
        return self.fft.forward(input_array=fx, output_array=fk, **kwargs)

    def backward_transform(self, fk, fx, **kwargs):
        return self.fft.backward(input_array=fk, output_array=fx, **kwargs)
Esempio n. 9
0
class grayscott_imex_diffusion(ptype):
    """
    Example implementing the Gray-Scott equation in 2-3D using mpi4py-fft for solving linear parts,
    IMEX time-stepping (implicit diffusion, explicit reaction)

    mpi4py-fft: https://mpi4py-fft.readthedocs.io/en/latest/

    Attributes:
        fft: fft object
        X: grid coordinates in real space
        ndim: number of spatial dimensions
        Ku: Laplace operator in spectral space (u component)
        Kv: Laplace operator in spectral space (v component)
    """
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 2.0
        # if 'init_type' not in problem_params:
        #     problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = MPI.COMM_WORLD

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'Du', 'Dv', 'A', 'B', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        self.ndim = len(problem_params['nvars'])
        axes = tuple(range(self.ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.float64,
                        collapse=True,
                        backend='fftw')

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # add two components to contain field and temperature
        self.ncomp = 2
        sizes = tmp_u.shape + (self.ncomp, )

        # invoke super init, passing the communicator and the local dimensions as init
        super(grayscott_imex_diffusion,
              self).__init__(init=(sizes, problem_params['comm'], tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        L = np.array([self.params.L] * self.ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = -L[i] / 2 + (X[i] * L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N[:-1]]
        k.append(np.fft.rfftfreq(N[-1], 1. / N[-1]).astype(int))
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / L
        for i in range(self.ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)
        self.Ku = -self.K2 * self.params.Du
        self.Kv = -self.K2 * self.params.Dv

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]

    def eval_f(self, u, t):
        """
        Routine to evaluate the RHS

        Args:
            u (dtype_u): current values
            t (float): current time

        Returns:
            dtype_f: the RHS
        """

        f = self.dtype_f(self.init)

        if self.params.spectral:

            f.impl[..., 0] = self.Ku * u[..., 0]
            f.impl[..., 1] = self.Kv * u[..., 1]
            tmpu = newDistArray(self.fft, False)
            tmpv = newDistArray(self.fft, False)
            tmpu[:] = self.fft.backward(u[..., 0], tmpu)
            tmpv[:] = self.fft.backward(u[..., 1], tmpv)
            tmpfu = -tmpu * tmpv**2 + self.params.A * (1 - tmpu)
            tmpfv = tmpu * tmpv**2 - self.params.B * tmpv
            f.expl[..., 0] = self.fft.forward(tmpfu)
            f.expl[..., 1] = self.fft.forward(tmpfv)

        else:

            u_hat = self.fft.forward(u[..., 0])
            lap_u_hat = self.Ku * u_hat
            f.impl[..., 0] = self.fft.backward(lap_u_hat, f.impl[..., 0])
            u_hat = self.fft.forward(u[..., 1])
            lap_u_hat = self.Kv * u_hat
            f.impl[..., 1] = self.fft.backward(lap_u_hat, f.impl[..., 1])
            f.expl[...,
                   0] = -u[..., 0] * u[...,
                                       1]**2 + self.params.A * (1 - u[..., 0])
            f.expl[...,
                   1] = u[..., 0] * u[..., 1]**2 - self.params.B * u[..., 1]

        return f

    def solve_system(self, rhs, factor, u0, t):
        """
        Simple FFT solver for the diffusion part

        Args:
            rhs (dtype_f): right-hand side for the linear system
            factor (float) : abbrev. for the node-to-node stepsize (or any other factor required)
            u0 (dtype_u): initial guess for the iterative solver (not used here so far)
            t (float): current time (e.g. for time-dependent BCs)

        Returns:
            dtype_u: solution as mesh
        """

        me = self.dtype_u(self.init)
        if self.params.spectral:

            me[..., 0] = rhs[..., 0] / (1.0 - factor * self.Ku)
            me[..., 1] = rhs[..., 1] / (1.0 - factor * self.Kv)

        else:

            rhs_hat = self.fft.forward(rhs[..., 0])
            rhs_hat /= (1.0 - factor * self.Ku)
            me[..., 0] = self.fft.backward(rhs_hat, me[..., 0])
            rhs_hat = self.fft.forward(rhs[..., 1])
            rhs_hat /= (1.0 - factor * self.Kv)
            me[..., 1] = self.fft.backward(rhs_hat, me[..., 1])

        return me

    def u_exact(self, t):
        """
        Routine to compute the exact solution at time t=0, see https://www.chebfun.org/examples/pde/GrayScott.html

        Args:
            t (float): current time

        Returns:
            dtype_u: exact solution
        """
        assert t == 0.0, 'Exact solution only valid as initial condition'
        assert self.ndim == 2, 'The initial conditions are 2D for now..'

        me = self.dtype_u(self.init, val=0.0)

        # This assumes that the box is [-L/2, L/2]^2
        if self.params.spectral:
            tmp = 1.0 - np.exp(-80.0 * ((self.X[0] + 0.05)**2 +
                                        (self.X[1] + 0.02)**2))
            me[..., 0] = self.fft.forward(tmp)
            tmp = np.exp(-80.0 * ((self.X[0] - 0.05)**2 +
                                  (self.X[1] - 0.02)**2))
            me[..., 1] = self.fft.forward(tmp)
        else:
            me[..., 0] = 1.0 - np.exp(-80.0 * ((self.X[0] + 0.05)**2 +
                                               (self.X[1] + 0.02)**2))
            me[..., 1] = np.exp(-80.0 * ((self.X[0] - 0.05)**2 +
                                         (self.X[1] - 0.02)**2))

        # tmpu = np.load('data/u_0001.npy')
        # tmpv = np.load('data/v_0001.npy')
        #
        # me[..., 0] = self.fft.forward(tmpu)
        # me[..., 1] = self.fft.forward(tmpv)

        return me