Exemplo n.º 1
0
class hwconv:
    def __init__(self,
                 Nx,
                 Ny,
                 padx=1.5,
                 pady=1.5,
                 num_in=6,
                 num_out=2,
                 fmultin=mult_in,
                 fmultout=mult_out62,
                 comm=MPI.COMM_WORLD):
        self.comm = comm
        self.num_in = num_in
        self.num_out = num_out
        self.nar = max(num_in, num_out)
        self.ffti = PFFT(comm,
                         shape=(self.num_in, Nx, Ny),
                         axes=(1, 2),
                         grid=[1, -1, 1],
                         padding=[1, 1.5, 1.5],
                         collapse=False)
        self.ffto = PFFT(comm,
                         shape=(self.num_out, Nx, Ny),
                         axes=(1, 2),
                         grid=[1, -1, 1],
                         padding=[1, 1.5, 1.5],
                         collapse=False)
        self.datk = newDistArray(self.ffti, forward_output=True)
        self.dat = newDistArray(self.ffti, forward_output=False)
        lkx = np.r_[0:int(Nx / 2), -int(Nx / 2):0]
        lky = np.r_[0:int(Ny / 2 + 1)]
        self.kx = DistArray((Nx, int(Ny / 2 + 1)),
                            subcomm=(1, 0),
                            dtype=float,
                            alignment=0)
        self.ky = DistArray((Nx, int(Ny / 2 + 1)),
                            subcomm=(1, 0),
                            dtype=float,
                            alignment=0)
        self.kx[:], self.ky[:] = np.meshgrid(lkx[self.kx.local_slice()[0]],
                                             lky[self.ky.local_slice()[1]],
                                             indexing='ij')
        self.ksqr = self.kx**2 + self.ky**2
        self.fmultin = fmultin
        self.fmultout = fmultout

    def convolve(self, u):
        hermitian_symmetrize(u)
        if (u.local_slice()[2].stop == u.global_shape[2]):
            u[:, :, -1] = 0
        u[:, int(Nx / 2), :] = 0
        self.fmultin(u, self.datk, self.kx, self.ky, self.ksqr)
        self.ffti.backward(self.datk, self.dat)
        self.fmultout(self.dat)
        self.ffto.forward(self.dat[:self.num_out, ],
                          self.datk[:self.num_out, ])
        if (self.datk.local_slice()[2].stop == self.datk.global_shape[2]):
            self.datk[:, :, -1] = 0
        self.datk[:, int(Nx / 2), :] = 0
        return self.datk[:self.num_out, ]
Exemplo n.º 2
0
def test_2D(backend, forward_output):
    if backend == 'netcdf4':
        assert forward_output is False
    T = PFFT(comm, (N[0], N[1]))
    for i, domain in enumerate([
            None, ((0, np.pi), (0, 2 * np.pi)),
        (np.arange(N[0], dtype=float) * 1 * np.pi / N[0],
         np.arange(N[1], dtype=float) * 2 * np.pi / N[1])
    ]):
        for rank in range(3):
            filename = "".join(
                ('test2D_{}{}{}'.format(ex[i == 0], ex[forward_output],
                                        rank), ending[backend]))
            if backend == 'netcdf4':
                remove_if_exists(filename)
            u = newDistArray(T,
                             forward_output=forward_output,
                             val=1,
                             rank=rank)
            hfile = writer[backend](filename, domain=domain)
            assert hfile.backend() == backend
            hfile.write(0, {'u': [u]})
            hfile.write(1, {'u': [u]})
            u.write(hfile, 'u', 2)
            if rank > 0:
                hfile.write(0, {'u': [u]}, as_scalar=True)
                hfile.write(1, {'u': [u]}, as_scalar=True)
                u.write(hfile, 'u', 2, as_scalar=True)
            u.write('t' + filename, 'u', 0)
            u.write('t' + filename, 'u', 0, [slice(None), 3])

            if not forward_output and backend == 'hdf5' and comm.Get_rank(
            ) == 0:
                generate_xdmf(filename)
                generate_xdmf(filename, order='visit')

            u0 = newDistArray(T, forward_output=forward_output, rank=rank)
            read = reader[backend](filename)
            read.read(u0, 'u', step=0)
            u0.read(filename, 'u', 2)
            u0.read(read, 'u', 2)
            assert np.allclose(u0, u)
            if backend == 'netcdf4':  # Test opening file in mode 'a' when not existing
                remove_if_exists('nctesta.nc')
                _ = NCFile('nctesta.nc', domain=domain, mode='a')
    T.destroy()
Exemplo n.º 3
0
 def __init__(self,
              Nx,
              Ny,
              padx=1.5,
              pady=1.5,
              num_in=6,
              num_out=2,
              fmultin=mult_in,
              fmultout=mult_out62,
              comm=MPI.COMM_WORLD):
     self.comm = comm
     self.num_in = num_in
     self.num_out = num_out
     self.nar = max(num_in, num_out)
     self.ffti = PFFT(comm,
                      shape=(self.num_in, Nx, Ny),
                      axes=(1, 2),
                      grid=[1, -1, 1],
                      padding=[1, 1.5, 1.5],
                      collapse=False)
     self.ffto = PFFT(comm,
                      shape=(self.num_out, Nx, Ny),
                      axes=(1, 2),
                      grid=[1, -1, 1],
                      padding=[1, 1.5, 1.5],
                      collapse=False)
     self.datk = newDistArray(self.ffti, forward_output=True)
     self.dat = newDistArray(self.ffti, forward_output=False)
     lkx = np.r_[0:int(Nx / 2), -int(Nx / 2):0]
     lky = np.r_[0:int(Ny / 2 + 1)]
     self.kx = DistArray((Nx, int(Ny / 2 + 1)),
                         subcomm=(1, 0),
                         dtype=float,
                         alignment=0)
     self.ky = DistArray((Nx, int(Ny / 2 + 1)),
                         subcomm=(1, 0),
                         dtype=float,
                         alignment=0)
     self.kx[:], self.ky[:] = np.meshgrid(lkx[self.kx.local_slice()[0]],
                                          lky[self.ky.local_slice()[1]],
                                          indexing='ij')
     self.ksqr = self.kx**2 + self.ky**2
     self.fmultin = fmultin
     self.fmultout = fmultout
Exemplo n.º 4
0
def test_4D(backend, forward_output):
    if backend == 'netcdf4':
        assert forward_output is False
    T = PFFT(comm, (N[0], N[1], N[2], N[3]))
    d0 = ((0, np.pi), (0, 2 * np.pi), (0, 3 * np.pi), (0, 4 * np.pi))
    d1 = (np.arange(N[0], dtype=float) * 1 * np.pi / N[0],
          np.arange(N[1], dtype=float) * 2 * np.pi / N[1],
          np.arange(N[2], dtype=float) * 3 * np.pi / N[2],
          np.arange(N[3], dtype=float) * 4 * np.pi / N[3])
    for i, domain in enumerate([None, d0, d1]):
        for rank in range(3):
            filename = "".join(
                ('h5test4_{}{}{}'.format(ex[i == 0], ex[forward_output],
                                         rank), ending[backend]))
            if backend == 'netcdf4':
                remove_if_exists('uv' + filename)
            u = newDistArray(T, forward_output=forward_output, rank=rank)
            v = newDistArray(T, forward_output=forward_output, rank=rank)
            h0file = writer[backend]('uv' + filename, domain=domain)
            u[:] = np.random.random(u.shape)
            v[:] = 2
            for k in range(3):
                h0file.write(
                    k, {
                        'u':
                        [u, (u, [slice(None), 4,
                                 slice(None),
                                 slice(None)])],
                        'v': [v,
                              (v, [slice(None), slice(None), 5, 6])]
                    })

            if not forward_output and backend == 'hdf5' and comm.Get_rank(
            ) == 0:
                generate_xdmf('uv' + filename)

            u0 = newDistArray(T, forward_output=forward_output, rank=rank)
            read = reader[backend]('uv' + filename)
            read.read(u0, 'u', step=0)
            assert np.allclose(u0, u)
            read.read(u0, 'v', step=0)
            assert np.allclose(u0, v)
    T.destroy()
Exemplo n.º 5
0
    def __init__(self, decomp, queue, grid_shape, dtype, **kwargs):
        self.decomp = decomp
        self.grid_shape = grid_shape
        self.proc_shape = decomp.proc_shape
        self.dtype = np.dtype(dtype)
        self.is_real = self.dtype.kind == "f"

        from pystella.fourier import get_complex_dtype_with_matching_prec
        self.cdtype = get_complex_dtype_with_matching_prec(self.dtype)
        from pystella.fourier import get_real_dtype_with_matching_prec
        self.rdtype = get_real_dtype_with_matching_prec(self.dtype)

        if self.proc_shape[0] > 1 and self.proc_shape[1] == 1:
            slab = True
        else:
            slab = False

        from mpi4py_fft.pencil import Subcomm
        default_kwargs = dict(
            axes=([0], [1], [2]),
            threads=16,
            backend="fftw",
            collapse=True,
        )
        default_kwargs.update(kwargs)
        comm = decomp.comm if slab else Subcomm(decomp.comm, self.proc_shape)

        from mpi4py_fft import PFFT
        self.fft = PFFT(comm,
                        grid_shape,
                        dtype=dtype,
                        slab=slab,
                        **default_kwargs)

        self.fx = self.fft.forward.input_array
        self.fk = self.fft.forward.output_array

        slc = self.fft.local_slice(True)
        self.sub_k = get_sliced_momenta(grid_shape, self.dtype, slc, queue)
Exemplo n.º 6
0
    def __init__(self, fine_prob, coarse_prob, params):
        """
        Initialization routine

        Args:
            fine_prob: fine problem
            coarse_prob: coarse problem
            params: parameters for the transfer operators
        """
        # invoke super initialization
        super(fft_to_fft, self).__init__(fine_prob, coarse_prob, params)

        assert self.fine_prob.params.spectral == self.coarse_prob.params.spectral

        self.spectral = self.fine_prob.params.spectral

        Nf = list(self.fine_prob.fft.global_shape())
        Nc = list(self.coarse_prob.fft.global_shape())
        self.ratio = [int(nf / nc) for nf, nc in zip(Nf, Nc)]
        axes = tuple(range(len(Nf)))

        self.fft_pad = PFFT(self.coarse_prob.params.comm, Nc, padding=self.ratio, axes=axes,
                            dtype=self.coarse_prob.fft.dtype(False),
                            slab=True)
Exemplo n.º 7
0
class allencahn_imex(ptype):
    """
    Example implementing Allen-Cahn equation in 2-3D using mpi4py-fft for solving linear parts, IMEX time-stepping

    mpi4py-fft: https://mpi4py-fft.readthedocs.io/en/latest/

    Attributes:
        fft: fft object
        X: grid coordinates in real space
        K2: Laplace operator in spectral space
        dx: mesh width in x direction
        dy: mesh width in y direction
    """
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 1.0
        if 'init_type' not in problem_params:
            problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = None
        if 'dw' not in problem_params:
            problem_params['dw'] = 0.0

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'eps', 'L', 'radius', 'dw', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        ndim = len(problem_params['nvars'])
        axes = tuple(range(ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.float,
                        collapse=True)

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # invoke super init, passing the communicator and the local dimensions as init
        super(allencahn_imex,
              self).__init__(init=(tmp_u.shape, problem_params['comm'],
                                   tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        L = np.array([self.params.L] * ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = (X[i] * L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N[:-1]]
        k.append(np.fft.rfftfreq(N[-1], 1. / N[-1]).astype(int))
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / L
        for i in range(ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]

    def eval_f(self, u, t):
        """
        Routine to evaluate the RHS

        Args:
            u (dtype_u): current values
            t (float): current time

        Returns:
            dtype_f: the RHS
        """

        f = self.dtype_f(self.init)

        if self.params.spectral:

            f.impl = -self.K2 * u

            if self.params.eps > 0:
                tmp = self.fft.backward(u)
                tmpf = - 2.0 / self.params.eps ** 2 * tmp * (1.0 - tmp) * (1.0 - 2.0 * tmp) - \
                    6.0 * self.params.dw * tmp * (1.0 - tmp)
                f.expl[:] = self.fft.forward(tmpf)

        else:

            u_hat = self.fft.forward(u)
            lap_u_hat = -self.K2 * u_hat
            f.impl[:] = self.fft.backward(lap_u_hat, f.impl)

            if self.params.eps > 0:
                f.expl = - 2.0 / self.params.eps ** 2 * u * (1.0 - u) * (1.0 - 2.0 * u) - \
                    6.0 * self.params.dw * u * (1.0 - u)

        return f

    def solve_system(self, rhs, factor, u0, t):
        """
        Simple FFT solver for the diffusion part

        Args:
            rhs (dtype_f): right-hand side for the linear system
            factor (float) : abbrev. for the node-to-node stepsize (or any other factor required)
            u0 (dtype_u): initial guess for the iterative solver (not used here so far)
            t (float): current time (e.g. for time-dependent BCs)

        Returns:
            dtype_u: solution as mesh
        """

        if self.params.spectral:

            me = rhs / (1.0 + factor * self.K2)

        else:

            me = self.dtype_u(self.init)
            rhs_hat = self.fft.forward(rhs)
            rhs_hat /= (1.0 + factor * self.K2)
            me[:] = self.fft.backward(rhs_hat)

        return me

    def u_exact(self, t):
        """
        Routine to compute the exact solution at time t

        Args:
            t (float): current time

        Returns:
            dtype_u: exact solution
        """

        assert t == 0, 'ERROR: u_exact only valid for t=0'
        me = self.dtype_u(self.init, val=0.0)
        if self.params.init_type == 'circle':
            r2 = (self.X[0] - 0.5)**2 + (self.X[1] - 0.5)**2
            if self.params.spectral:
                tmp = 0.5 * (1.0 + np.tanh((self.params.radius - np.sqrt(r2)) /
                                           (np.sqrt(2) * self.params.eps)))
                me[:] = self.fft.forward(tmp)
            else:
                me[:] = 0.5 * (1.0 + np.tanh(
                    (self.params.radius - np.sqrt(r2)) /
                    (np.sqrt(2) * self.params.eps)))
        elif self.params.init_type == 'circle_rand':
            ndim = len(me.shape)
            L = int(self.params.L)
            # get random radii for circles/spheres
            np.random.seed(1)
            lbound = 3.0 * self.params.eps
            ubound = 0.5 - self.params.eps
            rand_radii = (ubound - lbound) * np.random.random_sample(
                size=tuple([L] * ndim)) + lbound
            # distribute circles/spheres
            tmp = newDistArray(self.fft, False)
            if ndim == 2:
                for i in range(0, L):
                    for j in range(0, L):
                        # build radius
                        r2 = (self.X[0] + i - L + 0.5)**2 + (self.X[1] + j -
                                                             L + 0.5)**2
                        # add this blob, shifted by 1 to avoid issues with adding up negative contributions
                        tmp += np.tanh((rand_radii[i, j] - np.sqrt(r2)) /
                                       (np.sqrt(2) * self.params.eps)) + 1
            # normalize to [0,1]
            tmp *= 0.5
            assert np.all(tmp <= 1.0)
            if self.params.spectral:
                me[:] = self.fft.forward(tmp)
            else:
                me[:] = tmp[:]
        else:
            raise NotImplementedError(
                'type of initial value not implemented, got %s' %
                self.params.init_type)

        return me
Exemplo n.º 8
0
    Ks = np.meshgrid(*K, indexing='ij', sparse=True)
    Lp = 2 * np.pi / L
    for i in range(ndim):
        Ks[i] = (Ks[i] * Lp[i]).astype(float)
    return [np.broadcast_to(k, FFT.shape(True)) for k in Ks]


comm = MPI.COMM_WORLD
subcomm = comm.Split()
print(subcomm)
nvars = 8
ndim = 2
axes = tuple(range(ndim))
N = np.array([nvars] * ndim, dtype=int)
print(N, axes)
fft = PFFT(subcomm, N, axes=axes, dtype=np.float, slab=True)
# L = np.array([2*np.pi] * ndim, dtype=float)
L = np.array([1] * ndim, dtype=float)

print(fft.subcomm)

X = get_local_mesh(fft, L)
K = get_local_wavenumbermesh(fft, L)
K = np.array(K).astype(float)
K2 = np.sum(K * K, 0, dtype=float)

u = newDistArray(fft, False)
print(type(u))
print(u.subcomm)
uex = newDistArray(fft, False)
Exemplo n.º 9
0
class nonlinearschroedinger_imex(ptype):
    """
    Example implementing the nonlinear Schrödinger equation in 2-3D using mpi4py-fft for solving linear parts,
    IMEX time-stepping

    mpi4py-fft: https://mpi4py-fft.readthedocs.io/en/latest/

    Attributes:
        fft: fft object
        X: grid coordinates in real space
        K2: Laplace operator in spectral space
    """
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 2.0 * np.pi
        # if 'init_type' not in problem_params:
        #     problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = MPI.COMM_WORLD
        if 'c' not in problem_params:
            problem_params['c'] = 1.0

        if not problem_params['L'] == 2.0 * np.pi:
            raise ProblemError(
                f'Setup not implemented, L has to be 2pi, got {problem_params["L"]}'
            )

        if not (problem_params['c'] == 0.0 or problem_params['c'] == 1.0):
            raise ProblemError(
                f'Setup not implemented, c has to be 0 or 1, got {problem_params["c"]}'
            )

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'c', 'L', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        self.ndim = len(problem_params['nvars'])
        axes = tuple(range(self.ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.complex128,
                        collapse=True)

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # invoke super init, passing the communicator and the local dimensions as init
        super(nonlinearschroedinger_imex,
              self).__init__(init=(tmp_u.shape, problem_params['comm'],
                                   tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        self.L = np.array([self.params.L] * self.ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = (X[i] * self.L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N]
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / self.L
        for i in range(self.ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]

    def eval_f(self, u, t):
        """
        Routine to evaluate the RHS

        Args:
            u (dtype_u): current values
            t (float): current time

        Returns:
            dtype_f: the RHS
        """

        f = self.dtype_f(self.init)

        if self.params.spectral:

            f.impl = -self.K2 * 1j * u
            tmp = self.fft.backward(u)
            tmpf = self.ndim * self.params.c * 2j * np.absolute(tmp)**2 * tmp
            f.expl[:] = self.fft.forward(tmpf)

        else:

            u_hat = self.fft.forward(u)
            lap_u_hat = -self.K2 * 1j * u_hat
            f.impl[:] = self.fft.backward(lap_u_hat, f.impl)
            f.expl = self.ndim * self.params.c * 2j * np.absolute(u)**2 * u

        return f

    def solve_system(self, rhs, factor, u0, t):
        """
        Simple FFT solver for the diffusion part

        Args:
            rhs (dtype_f): right-hand side for the linear system
            factor (float) : abbrev. for the node-to-node stepsize (or any other factor required)
            u0 (dtype_u): initial guess for the iterative solver (not used here so far)
            t (float): current time (e.g. for time-dependent BCs)

        Returns:
            dtype_u: solution as mesh
        """

        if self.params.spectral:

            me = rhs / (1.0 + factor * self.K2 * 1j)

        else:

            me = self.dtype_u(self.init)
            rhs_hat = self.fft.forward(rhs)
            rhs_hat /= (1.0 + factor * self.K2 * 1j)
            me[:] = self.fft.backward(rhs_hat)

        return me

    def u_exact(self, t):
        """
        Routine to compute the exact solution at time t, see (1.3) https://arxiv.org/pdf/nlin/0702010.pdf for details

        Args:
            t (float): current time

        Returns:
            dtype_u: exact solution
        """
        def nls_exact_1D(t, x, c):

            ae = 1.0 / np.sqrt(2.0) * np.exp(1j * t)
            if c != 0:
                u = ae * ((np.cosh(t) + 1j * np.sinh(t)) /
                          (np.cosh(t) - 1.0 / np.sqrt(2.0) * np.cos(x)) - 1.0)
            else:
                u = np.sin(x) * np.exp(-t * 1j)

            return u

        me = self.dtype_u(self.init, val=0.0)

        if self.params.spectral:
            tmp = nls_exact_1D(self.ndim * t, sum(self.X), self.params.c)
            me[:] = self.fft.forward(tmp)
        else:
            me[:] = nls_exact_1D(self.ndim * t, sum(self.X), self.params.c)

        return me
Exemplo n.º 10
0
o = mytype(m)
m[0, 0] = 2
assert o[0, 0] == -1
assert o is not m
exit()

print(type(m))
print(type(m+n))
print(abs(m))
print(type(abs(m)))

comm = MPI.COMM_WORLD
subcomm = comm.Split()
# print(subcomm)
nvars = 8
ndim = 2
axes = tuple(range(ndim))
N = np.array([nvars] * ndim, dtype=int)
# print(N, axes)
fft = PFFT(subcomm, N, axes=axes, dtype=np.float, slab=True)

init = (fft, False)
m = fft_datatype(init)
m[:] = comm.Get_rank()

print(type(m))
print(m.subcomm)
print(abs(m), type(abs(m)))


Exemplo n.º 11
0
import numpy as np
from mpi4py import MPI
from mpi4py_fft import PFFT, newDistArray
from mpi4py_fft.fftw import dctn, idctn

# Set global size of the computational box
N = np.array([18, 18, 18], dtype=int)

dct = functools.partial(dctn, type=3)
idct = functools.partial(idctn, type=3)

transforms = {(1, 2): (dct, idct)}

fft = PFFT(MPI.COMM_WORLD,
           N,
           axes=None,
           collapse=True,
           grid=(-1, ),
           transforms=transforms)
pfft = PFFT(MPI.COMM_WORLD,
            N,
            axes=((0, ), (1, 2)),
            grid=(-1, ),
            padding=[1.5, 1.0, 1.0],
            transforms=transforms)

assert fft.axes == pfft.axes

u = newDistArray(fft, forward_output=False)
u[:] = np.random.random(u.shape).astype(u.dtype)

u_hat = newDistArray(fft, forward_output=True)
Exemplo n.º 12
0
def test_newDistArray():
    N = (8, 8, 8)
    pfft = PFFT(MPI.COMM_WORLD, N)
    for forward_output in (True, False):
        for view in (True, False):
            for rank in (0, 1, 2):
                a = newDistArray(pfft,
                                 forward_output=forward_output,
                                 rank=rank,
                                 view=view)
                if view is False:
                    assert isinstance(a, DistArray)
                    assert a.rank == rank
                    if rank == 0:
                        qfft = PFFT(MPI.COMM_WORLD, darray=a)
                    elif rank == 1:
                        qfft = PFFT(MPI.COMM_WORLD, darray=a[0])
                    else:
                        qfft = PFFT(MPI.COMM_WORLD, darray=a[0, 0])
                    qfft.destroy()

                else:
                    assert isinstance(a, np.ndarray)
                    assert a.base.rank == rank
    pfft.destroy()
Exemplo n.º 13
0
class pDFT(BaseDFT):
    """
    A wrapper to :class:`mpi4py_fft.mpifft.PFFT` to compute distributed Fast Fourier
    transforms.

    See :class:`pystella.fourier.dft.BaseDFT`.

    :arg decomp: A :class:`pystella.DomainDecomposition`.
        The shape of the MPI processor grid is determined by
        the ``proc_shape`` attribute of this object.

    :arg queue: A :class:`pyopencl.CommandQueue`.

    :arg grid_shape: A 3-:class:`tuple` specifying the shape of position-space
        arrays to be transformed.

    :arg dtype: The datatype of position-space arrays to be transformed.
        The complex datatype for momentum-space arrays is chosen to have
        the same precision.

    Any keyword arguments are passed to :class:`mpi4py_fft.mpifft.PFFT`.

    .. versionchanged:: 2020.1

        Support for complex-to-complex transforms.
    """
    def __init__(self, decomp, queue, grid_shape, dtype, **kwargs):
        self.decomp = decomp
        self.grid_shape = grid_shape
        self.proc_shape = decomp.proc_shape
        self.dtype = np.dtype(dtype)
        self.is_real = self.dtype.kind == "f"

        from pystella.fourier import get_complex_dtype_with_matching_prec
        self.cdtype = get_complex_dtype_with_matching_prec(self.dtype)
        from pystella.fourier import get_real_dtype_with_matching_prec
        self.rdtype = get_real_dtype_with_matching_prec(self.dtype)

        if self.proc_shape[0] > 1 and self.proc_shape[1] == 1:
            slab = True
        else:
            slab = False

        from mpi4py_fft.pencil import Subcomm
        default_kwargs = dict(
            axes=([0], [1], [2]),
            threads=16,
            backend="fftw",
            collapse=True,
        )
        default_kwargs.update(kwargs)
        comm = decomp.comm if slab else Subcomm(decomp.comm, self.proc_shape)

        from mpi4py_fft import PFFT
        self.fft = PFFT(comm,
                        grid_shape,
                        dtype=dtype,
                        slab=slab,
                        **default_kwargs)

        self.fx = self.fft.forward.input_array
        self.fk = self.fft.forward.output_array

        slc = self.fft.local_slice(True)
        self.sub_k = get_sliced_momenta(grid_shape, self.dtype, slc, queue)

    @property
    def proc_permutation(self):
        axes = list(a for b in self.fft.axes for a in b)
        for t in self.fft.transfer:
            axes[t.axisA], axes[t.axisB] = axes[t.axisB], axes[t.axisA]
        return axes

    def shape(self, forward_output=True):
        return self.fft.shape(forward_output=forward_output)

    def forward_transform(self, fx, fk, **kwargs):
        kwargs["normalize"] = kwargs.get("normalize", False)
        return self.fft.forward(input_array=fx, output_array=fk, **kwargs)

    def backward_transform(self, fk, fx, **kwargs):
        return self.fft.backward(input_array=fk, output_array=fx, **kwargs)
Exemplo n.º 14
0
class fft_to_fft(space_transfer):
    """
    Custon base_transfer class, implements Transfer.py

    This implementation can restrict and prolong between PMESH datatypes meshes with FFT for periodic boundaries

    """

    def __init__(self, fine_prob, coarse_prob, params):
        """
        Initialization routine

        Args:
            fine_prob: fine problem
            coarse_prob: coarse problem
            params: parameters for the transfer operators
        """
        # invoke super initialization
        super(fft_to_fft, self).__init__(fine_prob, coarse_prob, params)

        assert self.fine_prob.params.spectral == self.coarse_prob.params.spectral

        self.spectral = self.fine_prob.params.spectral

        Nf = list(self.fine_prob.fft.global_shape())
        Nc = list(self.coarse_prob.fft.global_shape())
        self.ratio = [int(nf / nc) for nf, nc in zip(Nf, Nc)]
        axes = tuple(range(len(Nf)))

        self.fft_pad = PFFT(self.coarse_prob.params.comm, Nc, padding=self.ratio, axes=axes,
                            dtype=self.coarse_prob.fft.dtype(False),
                            slab=True)

    def restrict(self, F):
        """
        Restriction implementation

        Args:
            F: the fine level data (easier to access than via the fine attribute)
        """
        if isinstance(F, parallel_mesh):
            if self.spectral:
                G = self.coarse_prob.dtype_u(self.coarse_prob.init)
                if hasattr(self.fine_prob, 'ncomp'):
                    for i in range(self.fine_prob.ncomp):
                        tmpF = newDistArray(self.fine_prob.fft, False)
                        tmpF = self.fine_prob.fft.backward(F[..., i], tmpF)
                        tmpG = tmpF[::int(self.ratio[0]), ::int(self.ratio[1])]
                        G[..., i] = self.coarse_prob.fft.forward(tmpG, G[..., i])
                else:
                    tmpF = self.fine_prob.fft.backward(F)
                    tmpG = tmpF[::int(self.ratio[0]), ::int(self.ratio[1])]
                    G[:] = self.coarse_prob.fft.forward(tmpG, G)
            else:
                G = self.coarse_prob.dtype_u(self.coarse_prob.init)
                G[:] = F[::int(self.ratio[0]), ::int(self.ratio[1])]
        else:
            raise TransferError('Unknown data type, got %s' % type(F))

        return G

    def prolong(self, G):
        """
        Prolongation implementation

        Args:
            G: the coarse level data (easier to access than via the coarse attribute)
        """
        if isinstance(G, parallel_mesh):
            if self.spectral:
                F = self.fine_prob.dtype_u(self.fine_prob.init)
                if hasattr(self.fine_prob, 'ncomp'):
                    for i in range(self.fine_prob.ncomp):
                        tmpF = self.fft_pad.backward(G[..., i])
                        F[..., i] = self.fine_prob.fft.forward(tmpF, F[..., i])
                else:
                    tmpF = self.fft_pad.backward(G)
                    F[:] = self.fine_prob.fft.forward(tmpF, F)
            else:
                F = self.fine_prob.dtype_u(self.fine_prob.init)
                if hasattr(self.fine_prob, 'ncomp'):
                    for i in range(self.fine_prob.ncomp):
                        G_hat = self.coarse_prob.fft.forward(G[..., i])
                        F[..., i] = self.fft_pad.backward(G_hat, F[..., i])
                else:
                    G_hat = self.coarse_prob.fft.forward(G)
                    F[:] = self.fft_pad.backward(G_hat, F)
        elif isinstance(G, parallel_imex_mesh):
            if self.spectral:
                F = self.fine_prob.dtype_f(self.fine_prob.init)
                if hasattr(self.fine_prob, 'ncomp'):
                    for i in range(self.fine_prob.ncomp):
                        tmpF = self.fft_pad.backward(G.impl[..., i])
                        F.impl[..., i] = self.fine_prob.fft.forward(tmpF, F.impl[..., i])
                        tmpF = self.fft_pad.backward(G.expl[..., i])
                        F.expl[..., i] = self.fine_prob.fft.forward(tmpF, F.expl[..., i])
                else:
                    tmpF = self.fft_pad.backward(G.impl)
                    F.impl[:] = self.fine_prob.fft.forward(tmpF, F.impl)
                    tmpF = self.fft_pad.backward(G.expl)
                    F.expl[:] = self.fine_prob.fft.forward(tmpF, F.expl)
            else:
                F = self.fine_prob.dtype_f(self.fine_prob.init)
                if hasattr(self.fine_prob, 'ncomp'):
                    for i in range(self.fine_prob.ncomp):
                        G_hat = self.coarse_prob.fft.forward(G.impl[..., i])
                        F.impl[..., i] = self.fft_pad.backward(G_hat, F.impl[..., i])
                        G_hat = self.coarse_prob.fft.forward(G.expl[..., i])
                        F.expl[..., i] = self.fft_pad.backward(G_hat, F.expl[..., i])
                else:
                    G_hat = self.coarse_prob.fft.forward(G.impl)
                    F.impl[:] = self.fft_pad.backward(G_hat, F.impl)
                    G_hat = self.coarse_prob.fft.forward(G.expl)
                    F.expl[:] = self.fft_pad.backward(G_hat, F.expl)
        else:
            raise TransferError('Unknown data type, got %s' % type(G))

        return F
Exemplo n.º 15
0
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 1.0
        if 'init_type' not in problem_params:
            problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = None
        if 'dw' not in problem_params:
            problem_params['dw'] = 0.0

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'eps', 'L', 'radius', 'dw', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        ndim = len(problem_params['nvars'])
        axes = tuple(range(ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.float,
                        collapse=True)

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # invoke super init, passing the communicator and the local dimensions as init
        super(allencahn_imex,
              self).__init__(init=(tmp_u.shape, problem_params['comm'],
                                   tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        L = np.array([self.params.L] * ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = (X[i] * L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N[:-1]]
        k.append(np.fft.rfftfreq(N[-1], 1. / N[-1]).astype(int))
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / L
        for i in range(ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]
Exemplo n.º 16
0
@author: ogurcan
"""
from mpi4py import MPI
from mpi4py_fft import PFFT, newDistArray
import numpy as np
import matplotlib.pylab as plt

howmany = 6
Nx, Ny = 128, 128
padx, pady = 3 / 2, 3 / 2
Npx, Npy = int(128 * padx), int(128 * pady)
comm = MPI.COMM_WORLD

pf = PFFT(comm,
          shape=(howmany, Nx, Ny),
          axes=(1, 2),
          grid=[1, -1, 1],
          padding=[1, 1.5, 1.5],
          collapse=False)
u = newDistArray(pf, forward_output=False)
uk = newDistArray(pf, forward_output=True)

n, x, y = np.meshgrid(np.arange(0, howmany),
                      np.linspace(-1, 1, Npx),
                      np.linspace(-1, 1, Npy),
                      indexing='ij')
nl, xl, yl = n[u.local_slice()], x[u.local_slice()], y[u.local_slice()]
u[:] = np.sin(4 * np.pi * (xl + 2 * yl)) * np.exp(-xl**2 / 2 / 0.04 -
                                                  yl**2 / 2 / 0.08) * (nl - 3)
u0 = u.copy()

pf.forward(u, uk)
Exemplo n.º 17
0
def test_3D(backend, forward_output):
    if backend == 'netcdf4':
        assert forward_output is False
    T = PFFT(comm, (N[0], N[1], N[2]))
    d0 = ((0, np.pi), (0, 2 * np.pi), (0, 3 * np.pi))
    d1 = (np.arange(N[0], dtype=float) * 1 * np.pi / N[0],
          np.arange(N[1], dtype=float) * 2 * np.pi / N[1],
          np.arange(N[2], dtype=float) * 3 * np.pi / N[2])
    for i, domain in enumerate([None, d0, d1]):
        for rank in range(3):
            filename = ''.join(
                ('test_{}{}{}'.format(ex[i == 0], ex[forward_output],
                                      rank), ending[backend]))
            if backend == 'netcdf4':
                remove_if_exists('uv' + filename)
                remove_if_exists('v' + filename)

            u = newDistArray(T, forward_output=forward_output, rank=rank)
            v = newDistArray(T, forward_output=forward_output, rank=rank)
            h0file = writer[backend]('uv' + filename, domain=domain)
            h1file = writer[backend]('v' + filename, domain=domain)
            u[:] = np.random.random(u.shape)
            v[:] = 2
            for k in range(3):
                h0file.write(
                    k, {
                        'u': [
                            u, (u, [slice(None), slice(None), 4]),
                            (u, [5, 5, slice(None)])
                        ],
                        'v': [v,
                              (v, [slice(None), 6, slice(None)])]
                    })
                h1file.write(
                    k, {
                        'v': [
                            v,
                            (v, [slice(None), 6, slice(None)]),
                            (v, [6, 6, slice(None)])
                        ]
                    })
            # One more time with same k
            h0file.write(
                k, {
                    'u': [
                        u, (u, [slice(None), slice(None), 4]),
                        (u, [5, 5, slice(None)])
                    ],
                    'v': [v, (v, [slice(None), 6, slice(None)])]
                })
            h1file.write(
                k, {
                    'v': [
                        v, (v, [slice(None), 6, slice(None)]),
                        (v, [6, 6, slice(None)])
                    ]
                })

            if rank > 0:
                for k in range(3):
                    u.write('uv' + filename, 'u', k, as_scalar=True)
                    u.write('uv' + filename,
                            'u',
                            k, [slice(None), slice(None), 4],
                            as_scalar=True)
                    u.write('uv' + filename,
                            'u',
                            k, [5, 5, slice(None)],
                            as_scalar=True)
                    v.write('uv' + filename, 'v', k, as_scalar=True)
                    v.write('uv' + filename,
                            'v',
                            k, [slice(None), 6, slice(None)],
                            as_scalar=True)

            if not forward_output and backend == 'hdf5' and comm.Get_rank(
            ) == 0:
                generate_xdmf('uv' + filename)
                generate_xdmf('v' + filename, periodic=False)
                generate_xdmf('v' + filename, periodic=(True, True, True))
                generate_xdmf('v' + filename, order='visit')

            u0 = newDistArray(T, forward_output=forward_output, rank=rank)
            read = reader[backend]('uv' + filename)
            read.read(u0, 'u', step=0)
            assert np.allclose(u0, u)
            read.read(u0, 'v', step=0)
            assert np.allclose(u0, v)
    T.destroy()
Exemplo n.º 18
0
from mpi4py_fft import PFFT, newDistArray

# Set viscosity, end time and time step
nu = 0.000625
T = 0.1
dt = 0.01

# Set global size of the computational box
M = 6
N = [2**M, 2**M, 2**M]
L = np.array(
    [2 * np.pi, 4 * np.pi, 4 * np.pi], dtype=float
)  # Needs to be (2*int)*pi in all directions (periodic) because of initialization

# Create instance of PFFT to perform parallel FFT + an instance to do FFT with padding (3/2-rule)
FFT = PFFT(MPI.COMM_WORLD, N, collapse=False)
#FFT_pad = PFFT(MPI.COMM_WORLD, N, padding=[1.5, 1.5, 1.5])
FFT_pad = FFT

# Declare variables needed to solve Navier-Stokes
U = newDistArray(FFT, False, rank=1, view=True)  # Velocity
U_hat = newDistArray(FFT, rank=1, view=True)  # Velocity transformed
P = newDistArray(FFT, False, view=True)  # Pressure (scalar)
P_hat = newDistArray(FFT, view=True)  # Pressure transformed
U_hat0 = newDistArray(FFT, rank=1, view=True)  # Runge-Kutta work array
U_hat1 = newDistArray(FFT, rank=1, view=True)  # Runge-Kutta work array
a = [1. / 6., 1. / 3., 1. / 3., 1. / 6.]  # Runge-Kutta parameter
b = [0.5, 0.5, 1.]  # Runge-Kutta parameter
dU = newDistArray(FFT, rank=1, view=True)  # Right hand side of ODEs
curl = newDistArray(FFT, False, rank=1, view=True)
U_pad = newDistArray(FFT_pad, False, rank=1, view=True)
Exemplo n.º 19
0
class grayscott_imex_diffusion(ptype):
    """
    Example implementing the Gray-Scott equation in 2-3D using mpi4py-fft for solving linear parts,
    IMEX time-stepping (implicit diffusion, explicit reaction)

    mpi4py-fft: https://mpi4py-fft.readthedocs.io/en/latest/

    Attributes:
        fft: fft object
        X: grid coordinates in real space
        ndim: number of spatial dimensions
        Ku: Laplace operator in spectral space (u component)
        Kv: Laplace operator in spectral space (v component)
    """
    def __init__(self,
                 problem_params,
                 dtype_u=parallel_mesh,
                 dtype_f=parallel_imex_mesh):
        """
        Initialization routine

        Args:
            problem_params (dict): custom parameters for the example
            dtype_u: fft data type (will be passed to parent class)
            dtype_f: fft data type wuth implicit and explicit parts (will be passed to parent class)
        """

        if 'L' not in problem_params:
            problem_params['L'] = 2.0
        # if 'init_type' not in problem_params:
        #     problem_params['init_type'] = 'circle'
        if 'comm' not in problem_params:
            problem_params['comm'] = MPI.COMM_WORLD

        # these parameters will be used later, so assert their existence
        essential_keys = ['nvars', 'Du', 'Dv', 'A', 'B', 'spectral']
        for key in essential_keys:
            if key not in problem_params:
                msg = 'need %s to instantiate problem, only got %s' % (
                    key, str(problem_params.keys()))
                raise ParameterError(msg)

        if not (isinstance(problem_params['nvars'], tuple)
                and len(problem_params['nvars']) > 1):
            raise ProblemError('Need at least two dimensions')

        # Creating FFT structure
        self.ndim = len(problem_params['nvars'])
        axes = tuple(range(self.ndim))
        self.fft = PFFT(problem_params['comm'],
                        list(problem_params['nvars']),
                        axes=axes,
                        dtype=np.float64,
                        collapse=True,
                        backend='fftw')

        # get test data to figure out type and dimensions
        tmp_u = newDistArray(self.fft, problem_params['spectral'])

        # add two components to contain field and temperature
        self.ncomp = 2
        sizes = tmp_u.shape + (self.ncomp, )

        # invoke super init, passing the communicator and the local dimensions as init
        super(grayscott_imex_diffusion,
              self).__init__(init=(sizes, problem_params['comm'], tmp_u.dtype),
                             dtype_u=dtype_u,
                             dtype_f=dtype_f,
                             params=problem_params)

        L = np.array([self.params.L] * self.ndim, dtype=float)

        # get local mesh
        X = np.ogrid[self.fft.local_slice(False)]
        N = self.fft.global_shape()
        for i in range(len(N)):
            X[i] = -L[i] / 2 + (X[i] * L[i] / N[i])
        self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

        # get local wavenumbers and Laplace operator
        s = self.fft.local_slice()
        N = self.fft.global_shape()
        k = [np.fft.fftfreq(n, 1. / n).astype(int) for n in N[:-1]]
        k.append(np.fft.rfftfreq(N[-1], 1. / N[-1]).astype(int))
        K = [ki[si] for ki, si in zip(k, s)]
        Ks = np.meshgrid(*K, indexing='ij', sparse=True)
        Lp = 2 * np.pi / L
        for i in range(self.ndim):
            Ks[i] = (Ks[i] * Lp[i]).astype(float)
        K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
        K = np.array(K).astype(float)
        self.K2 = np.sum(K * K, 0, dtype=float)
        self.Ku = -self.K2 * self.params.Du
        self.Kv = -self.K2 * self.params.Dv

        # Need this for diagnostics
        self.dx = self.params.L / problem_params['nvars'][0]
        self.dy = self.params.L / problem_params['nvars'][1]

    def eval_f(self, u, t):
        """
        Routine to evaluate the RHS

        Args:
            u (dtype_u): current values
            t (float): current time

        Returns:
            dtype_f: the RHS
        """

        f = self.dtype_f(self.init)

        if self.params.spectral:

            f.impl[..., 0] = self.Ku * u[..., 0]
            f.impl[..., 1] = self.Kv * u[..., 1]
            tmpu = newDistArray(self.fft, False)
            tmpv = newDistArray(self.fft, False)
            tmpu[:] = self.fft.backward(u[..., 0], tmpu)
            tmpv[:] = self.fft.backward(u[..., 1], tmpv)
            tmpfu = -tmpu * tmpv**2 + self.params.A * (1 - tmpu)
            tmpfv = tmpu * tmpv**2 - self.params.B * tmpv
            f.expl[..., 0] = self.fft.forward(tmpfu)
            f.expl[..., 1] = self.fft.forward(tmpfv)

        else:

            u_hat = self.fft.forward(u[..., 0])
            lap_u_hat = self.Ku * u_hat
            f.impl[..., 0] = self.fft.backward(lap_u_hat, f.impl[..., 0])
            u_hat = self.fft.forward(u[..., 1])
            lap_u_hat = self.Kv * u_hat
            f.impl[..., 1] = self.fft.backward(lap_u_hat, f.impl[..., 1])
            f.expl[...,
                   0] = -u[..., 0] * u[...,
                                       1]**2 + self.params.A * (1 - u[..., 0])
            f.expl[...,
                   1] = u[..., 0] * u[..., 1]**2 - self.params.B * u[..., 1]

        return f

    def solve_system(self, rhs, factor, u0, t):
        """
        Simple FFT solver for the diffusion part

        Args:
            rhs (dtype_f): right-hand side for the linear system
            factor (float) : abbrev. for the node-to-node stepsize (or any other factor required)
            u0 (dtype_u): initial guess for the iterative solver (not used here so far)
            t (float): current time (e.g. for time-dependent BCs)

        Returns:
            dtype_u: solution as mesh
        """

        me = self.dtype_u(self.init)
        if self.params.spectral:

            me[..., 0] = rhs[..., 0] / (1.0 - factor * self.Ku)
            me[..., 1] = rhs[..., 1] / (1.0 - factor * self.Kv)

        else:

            rhs_hat = self.fft.forward(rhs[..., 0])
            rhs_hat /= (1.0 - factor * self.Ku)
            me[..., 0] = self.fft.backward(rhs_hat, me[..., 0])
            rhs_hat = self.fft.forward(rhs[..., 1])
            rhs_hat /= (1.0 - factor * self.Kv)
            me[..., 1] = self.fft.backward(rhs_hat, me[..., 1])

        return me

    def u_exact(self, t):
        """
        Routine to compute the exact solution at time t=0, see https://www.chebfun.org/examples/pde/GrayScott.html

        Args:
            t (float): current time

        Returns:
            dtype_u: exact solution
        """
        assert t == 0.0, 'Exact solution only valid as initial condition'
        assert self.ndim == 2, 'The initial conditions are 2D for now..'

        me = self.dtype_u(self.init, val=0.0)

        # This assumes that the box is [-L/2, L/2]^2
        if self.params.spectral:
            tmp = 1.0 - np.exp(-80.0 * ((self.X[0] + 0.05)**2 +
                                        (self.X[1] + 0.02)**2))
            me[..., 0] = self.fft.forward(tmp)
            tmp = np.exp(-80.0 * ((self.X[0] - 0.05)**2 +
                                  (self.X[1] - 0.02)**2))
            me[..., 1] = self.fft.forward(tmp)
        else:
            me[..., 0] = 1.0 - np.exp(-80.0 * ((self.X[0] + 0.05)**2 +
                                               (self.X[1] + 0.02)**2))
            me[..., 1] = np.exp(-80.0 * ((self.X[0] - 0.05)**2 +
                                         (self.X[1] - 0.02)**2))

        # tmpu = np.load('data/u_0001.npy')
        # tmpv = np.load('data/v_0001.npy')
        #
        # me[..., 0] = self.fft.forward(tmpu)
        # me[..., 1] = self.fft.forward(tmpv)

        return me