def fft(a, nfft=None, axis=-1):
    """Plan a discrete Fourier transform function.

    a : number, shape or array-like
        An input array, its shape or length.
    nfft : Optional[int]
        Number of FFT points. Default is input size along specified
    axis : Optional[int]
        Axis along which to perform the fft. Default is -1.

        Planned fft function.
    if nfft is None:
        shape = _utils.get_shape(a)
        nfft = shape[axis]

    # Define fft function
    def planned_fft(x):
        return sfft.fft(x, nfft, axis)

    return planned_fft
def irfft(a, nfft=None, axis=-1, fft_pair=False):
    """Returns a planned function that computes the real-valued 1-D inverse DFT
    of a sequence or array.

        a : number, shape or array-like
            An input array, its shape or length.
        nfft : Optional[int]
            Number of FFT points. Default is input size along specified axis
        axis : Optional[int]
            Axis along which to perform the fft. Default is -1.
        fft_pair : Optional[boolean]
            Indicates Whether or not to also return an ifft function. Default
            is False.

            Planned ifft function.
            Planned fft function. Returned if fft_pair is True.
    # Get shape
    shape = _utils.get_shape(a)
    n = 2 * (shape[axis] - 1)
    n_even = 1 - (n % 2)

    # Set up slices and pyfftw shapes
    n_dim = len(shape)
    slices = [slice(None)] * n_dim

    if nfft is None:
        nfft = n

    # Define ifft function
    def planned_irfft(x):
        x_sfft = _utils.complex_to_scipy_rfft(x, axis=axis, even=n_even)
        return sfft.irfft(x_sfft, nfft, axis)

    # Define fft function
    if fft_pair is True:
        if n == nfft:

            def planned_rfft(x):
                return _utils.scipy_rfft_to_complex(sfft.rfft(x, nfft, axis), axis=axis)

        elif n < nfft:
            slices[axis] = n

            def planned_rfft(x):
                return _utils.scipy_rfft_to_complex(sfft.rfft(x, nfft, axis), axis=axis)[slices]

            raise ValueError("NFFT must be at least equal to signal length " "when returning an FFT pair.")

        return planned_irfft, planned_rfft

        return planned_irfft
def rfft(a, nfft=None, axis=-1):
    """Returns a planned function that computes the 1-D DFT of a real-valued
    sequence or array.

    a : number, shape or array-like
        An input array, its shape or length.
    nfft : Optional[int]
        Number of FFT points. Default is input size along specified axis
    axis : Optional[int]
        Axis along which to perform the fft. Default is -1.

        Planned fft function.
    if nfft is None:
        shape = _utils.get_shape(a)
        nfft = shape[axis]

    # Define fft function
    def planned_rfft(x):
        return _utils.scipy_rfft_to_complex(sfft.rfft(x, nfft, axis), axis)

    return planned_rfft
def ifft_pair(a, nfft=None, axis=-1, crop_fft=False):
    """Returns a planned function that computes the 1-D inverse DFT of a
    sequence or array.

    a : number, shape or array-like
        An input array, its shape or length.
    nfft : Optional[int]
        Number of FFT points. Default is input size along specified axis
    axis : Optional[int]
        Axis along which to perform the fft. Default is -1.
    crop_fft : Optional[boolean]
        Indicates that the fft function should crop its output to match
        the shape of the input to the ifft function

        Planned ifft function.
        Planned fft function.
    # Get shape
    shape = _utils.get_shape(a)
    n = shape[axis]

    # Set up slices and pyfftw shapes
    n_dim = len(shape)
    slices = [slice(None)] * n_dim

    if nfft is None:
        nfft = shape[axis]
        shape = list(shape)
        shape[axis] = nfft

    # Define ifft function
    def planned_ifft(x):
        return sfft.ifft(x, nfft, axis)

    # Define fft function
    if n > nfft:
        raise ValueError("NFFT must be at least equal to signal length "
                         "when returning an FFT pair.")

    elif n < nfft and crop_fft:
        slices[axis] = slice(None, n)

        def planned_fft(x):
            return sfft.fft(x, nfft, axis)[tuple(slices)]


        def planned_fft(x):
            return sfft.fft(x, nfft, axis)

    return planned_ifft, planned_fft
def rfft_pair(a, nfft=None, axis=-1, crop_ifft=False):
    """Returns a planned function that computes the 1-D DFT of a real-valued
    sequence or array.

    a : number, shape or array-like
        An input array, its shape or length.
    nfft : Optional[int]
        Number of FFT points. Default is input size along specified axis
    axis : Optional[int]
        Axis along which to perform the fft. Default is -1.
    crop_ifft : Optional[boolean]
        Indicates whether the planned ifft function should crop its
        output to match input size. Default is False.

        Planned fft function
        Planned ifft function
    # Get shape
    shape = _utils.get_shape(a)
    n = shape[axis]
    n_even = 1 - (n % 2)

    if nfft is None:
        nfft = shape[axis]

    # Define fft function
    def planned_rfft(x):
        return _utils.scipy_rfft_to_complex(sfft.rfft(x, nfft, axis), axis)

    # Define ifft function
    if n > nfft:
        raise ValueError("NFFT must be at least equal to signal "
                         "length when returning an FFT pair.")

    elif n < nfft and crop_ifft:
        n_dim = len(shape)
        slices = [slice(None)] * n_dim
        slices[axis] = slice(None, n)
        slices = tuple(slices)

        def planned_irfft(x):
            x_sfft = _utils.complex_to_scipy_rfft(x, axis=axis, even=n_even)
            return sfft.irfft(x_sfft, nfft, axis)[slices]


        def planned_irfft(x):
            x_sfft = _utils.complex_to_scipy_rfft(x, axis=axis, even=n_even)
            return sfft.irfft(x_sfft, nfft, axis)

    return planned_rfft, planned_irfft
def irfft_pair(a, nfft=None, axis=-1):
    """Returns a planned function that computes the real-valued 1-D inverse DFT
    of a sequence or array.

    a : number, shape or array-like
        An input array, its shape or length.
    nfft : Optional[int]
        Number of FFT points. Default is input size along specified axis
    axis : Optional[int]
        Axis along which to perform the fft. Default is -1.

        Planned ifft function.
        Planned fft function.
    # Get shape
    shape = _utils.get_shape(a)
    n = 2 * (shape[axis] - 1)
    n_even = 1 - (n % 2)

    # Set up slices and pyfftw shapes
    n_dim = len(shape)
    slices = [slice(None)] * n_dim

    if nfft is None:
        nfft = n

    # Define ifft function
    def planned_irfft(x):
        x_sfft = _utils.complex_to_scipy_rfft(x, axis=axis, even=n_even)
        return sfft.irfft(x_sfft, nfft, axis)

    # Define fft function
    if n == nfft:

        def planned_rfft(x):
            return _utils.scipy_rfft_to_complex(sfft.rfft(x, nfft, axis),

    elif n < nfft:
        slices[axis] = n

        def planned_rfft(x):
            return _utils.scipy_rfft_to_complex(sfft.rfft(x, nfft, axis),

        raise ValueError("NFFT must be at least equal to signal length "
                         "when returning an FFT pair.")

    return planned_irfft, planned_rfft
def fft_pair(a, nfft=None, axis=-1, crop_ifft=False):
    """Plan a discrete Fourier transform function pair.

    a : number, shape or array-like
        An input array, its shape or length.
    nfft : Optional[int]
        Number of FFT points. Default is input size along specified
    axis : Optional[int]
        Axis along which to perform the fft. Default is -1.
    crop_ifft : Optional[boolean]
        Indicates whether the planned ifft function should crop its
        output to match input size. Default is False.

        Planned fft function.
        Planned ifft function.
    # Get shape
    shape = _utils.get_shape(a)
    n = shape[axis]

    # Set up slices and pyfftw shapes
    n_dim = len(shape)
    slices = [slice(None)] * n_dim

    if nfft is None:
        nfft = shape[axis]

    # Define fft function
    def planned_fft(x):
        return sfft.fft(x, nfft, axis)

    if n > nfft:
        raise ValueError("NFFT must be at least equal to signal "
                         "length when returning an FFT pair.")

    elif n < nfft and crop_ifft:
        slices[axis] = slice(None, n)

        def planned_ifft(x):
            return sfft.ifft(x, nfft, axis)[tuple(slices)]


        def planned_ifft(x):
            return sfft.ifft(x, nfft, axis)

    return planned_fft, planned_ifft
def irfftn(a, shape=None, axes=None):
    """Returns a planned function that computes the N-D DFT of a real-valued

    a : array-like or shape
        An input array or its shape.
    shape : Optional[sequence of ints]
        Number of FFT points. Default is input size along specified axes
    axes : Optional[sequence of ints]
        Axes along which to perform the fft. Default is all axes.

        Planned ifft function.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = list(a_shape)
            shape[-1] = 2 * (shape[-1] - 1)
            shape = [a_shape[axis] for axis in axes]
            shape[axes[-1]] = 2 * (shape[axes[-1]] - 1)

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes, )
    if np.asarray(shape).ndim == 0:
        shape = (shape, )

    # Define ifft function
    def planned_irfftn(x):
        # scipy.fftpack.fftn doesn't handle mixed sign axes well
        return np.fft.irfftn(x, shape, axes)

    return planned_irfftn
def fftn(a, shape=None, axes=None):
    """Returns a planned function that computes the N-D DFT of an array.

    a : array-like or shape
        An input array or its shape
    shape : Optional[List[int]]
        Number of FFT points. Default is input size along specified axes
    axes : Optional[sequence of ints]
        Axes along which to perform the fft. Default is all axes.

        Planned fft function.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = a_shape
            shape = [a_shape[axis] for axis in axes]

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes,)
    if np.asarray(shape).ndim == 0:
        shape = (shape,)

    # Define fft function
    def planned_fftn(x):
        return np.fft.fftn(x, shape, axes)

    return planned_fftn
def fft(a, nfft=None, axis=-1, fft_pair=False, crop_ifft=False):
    """Returns a planned function that computes the 1-D DFT of a sequence
    or array.

        a : number, shape or array-like
            An input array, its shape or length.
        nfft : Optional[int]
            Number of FFT points. Default is input size along specified
        axis : Optional[int]
            Axis along which to perform the fft. Default is -1.
        fft_pair : Optional[boolean]
            Indicates Whether or not to also return an ifft function.
            Default is False.
        crop_ifft : Optional[boolean]
            Indicates whether the planned ifft function should crop its
            output to match input size. Default is False.

            Planned fft function.
            Planned ifft function. Returned if fft_pair is True.
    # Get shape
    shape = _utils.get_shape(a)
    n = shape[axis]

    # Set up slices and pyfftw shapes
    n_dim = len(shape)
    slices = [slice(None)] * n_dim

    if nfft is None:
        nfft = shape[axis]

    # Define fft function
    def planned_fft(x):
        return np.fft.fft(x, nfft, axis)

    # Define ifft function
    if fft_pair is True:

        if n > nfft:
            raise ValueError("NFFT must be at least equal to signal "
                             "length when returning an FFT pair.")

        elif n < nfft and crop_ifft:
            slices[axis] = slice(None, n)

            def planned_ifft(x):
                return np.fft.ifft(x, nfft, axis)[slices]

            def planned_ifft(x):
                return np.fft.ifft(x, nfft, axis)

        return planned_fft, planned_ifft

        return planned_fft
def ifftn(a, shape=None, axes=None, fft_pair=False, crop_fft=False):
    """Returns a planned function that computes the N-D DFT of an array.

        a : array-like or shape
            An input array or its shape
        shape : Optional[sequence of ints]
            Number of FFT points. Default is input size along specified axes
        axes : Optional[sequence of ints]
            Axes along which to perform the fft. Default is all axes.
        fft_pair : Optional[boolean]
            Indicates Whether or not to also return an ifft function. Default
            is False.
        crop_fft : Optional[boolean]
            Indicates that output from the fft function should be cropped to
            match the input to the ifft function.

            Planned fft function.
            Planned ifft function. Returned if fft_pair is True.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = a_shape
            shape = [a_shape[axis] for axis in axes]

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes,)
    if np.asarray(shape).ndim == 0:
        shape = (shape,)

    # Compute FFTW shape
    fft_shape = list(a_shape)
    for n, axis in zip(range(len(axes)), axes):
        fft_shape[axis] = shape[n]

    # Set up slices
    slices = [slice(None)] * n_dim

    has_smaller_axis = any(s1 < s2 for s1, s2 in zip(a_shape, fft_shape))
    has_larger_axis = any(s1 > s2 for s1, s2 in zip(a_shape, fft_shape))

    # Define ifft function
    def planned_ifftn(x):
        return np.fft.ifftn(x, shape, axes)

    # Define fftn function
    if fft_pair is True:
        if has_larger_axis:
            raise ValueError("Number of FFT points must be equal to or greater"
                             "than the signal length for each axis when "
                             "returning an FFT pair.")

        elif has_smaller_axis and crop_fft:
            for axis in axes:
                slices[axis] = slice(0, a_shape[axis])

            def planned_fftn(x):
                return np.fft.fftn(x, shape, axes)[slices]

            def planned_fftn(x):
                return np.fft.fftn(x, shape, axes)

        return planned_ifftn, planned_fftn

        return planned_ifftn
def irfftn(a, shape=None, axes=None, fft_pair=False):
    """Returns a planned function that computes the N-D DFT of a real-valued

        a : array-like or shape
            An input array or its shape.
        shape : Optional[sequence of ints]
            Shape of the output (length of each transformed axis).
            Default is input size along specified axes, except for the last
            axis where length is n//2 + 1
        axes : Optional[sequence of ints]
            Axes along which to perform the fft. Default is all axes.
        fft_pair : Optional[boolean]
            Indicates Whether or not to also return an ifft function. Default
            is False.
        crop_fft : Optional[boolean]
            Indicates that output from the fft function should be cropped to
            match the input to the ifft function.

            Planned fft function.
            Planned ifft function. Returned if fft_pair is True.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = list(a_shape)
            shape[-1] = 2*(shape[-1] - 1)
            shape = [a_shape[axis] for axis in axes]
            shape[axes[-1]] = 2*(shape[axes[-1]] - 1)

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes,)
    if np.asarray(shape).ndim == 0:
        shape = (shape,)

    a_shape_out = list(a_shape)
    a_shape_out[axes[-1]] = 2*(a_shape_out[axes[-1]] - 1)

    has_larger_axis = any(s1 > s2 for s1, s2 in zip(a_shape_out, shape))

    # Define ifft function
    def planned_irfftn(x):
        return np.fft.irfftn(x, shape, axes)

    # Define fftn function
    if fft_pair is True:
        if has_larger_axis:
            raise ValueError("Number of FFT points must be equal to or greater"
                             "than the signal length for each axis when "
                             "returning an FFT pair.")

            def planned_rfftn(x):
                return np.fft.rfftn(x, shape, axes)

        return planned_irfftn, planned_rfftn

        return planned_irfftn
def irfftn(a, shape=None, axes=None, fft_pair=False):
    """Returns a planned function that computes the N-D DFT of a real-valued

        a : array-like or shape
            An input array or its shape.
        nfft : Optional[sequence of ints]
            Number of FFT points. Default is input size along specified axes
        axes : Optional[sequence of ints]
            Axes along which to perform the fft. Default is all axes.
        fft_pair : Optional[boolean]
            Indicates Whether or not to also return an ifft function. Default
            is False.

            Planned fft function.
            Planned ifft function. Returned if fft_pair is True.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = [n for n in a_shape]
            shape[-1] = 2*(shape[-1] - 1)
            shape = [a_shape[axis] for axis in axes]
            shape[axes[-1]] = 2*(shape[axes[-1]] - 1)
        shape = list(shape)

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes,)
    if np.asarray(shape).ndim == 0:
        shape = (shape,)

    # Compute FFTW shape
    u_shape = list(a_shape)
    v_shape = list(a_shape)

    for n, axis in zip(range(len(axes)), axes):
        u_shape[axis] = shape[n]
        v_shape[axis] = shape[n]

    v_shape[axes[-1]] = v_shape[axes[-1]]//2 + 1

    # Set up slices
    slices = [slice(None)] * n_dim
    for axis in axes:
        slices[axis] = slice(0, v_shape[axis])

    # Set data types
    if np.asarray(a).dtype.name in ('float32', 'int32', 'complex64'):
        dtype = 'float32'
        fft_dtype = 'complex64'
        dtype = 'float64'
        fft_dtype = 'complex128'

    u = pyfftw.n_byte_align_empty(u_shape, 16, dtype)
    v = pyfftw.n_byte_align_empty(v_shape, 16, fft_dtype)
    ifft_obj = pyfftw.FFTW(v, u, direction='FFTW_BACKWARD', axes=axes)

    has_smaller_axis = any(s1 < s2 for s1, s2 in zip(a_shape, v_shape))
    has_larger_axis = any(s1 > s2 for s1, s2 in zip(a_shape, v_shape))

    # Define ifft function
    if has_smaller_axis and has_larger_axis:
        def planned_irfftn(x):
            v[:] = _utils.pad_array(x[slices], v_shape)
            return ifft_obj().copy()

    elif has_larger_axis:
        def planned_irfftn(x):
            v[:] = x[slices]
            return ifft_obj().copy()

    elif has_smaller_axis:
        def planned_irfftn(x):
            v[:] = _utils.pad_array(x, v_shape)
            return ifft_obj().copy()

        def planned_irfftn(x):
            v[:] = x
            return ifft_obj().copy()

    # Define ifftn function
    if fft_pair is True:
        fft_obj = pyfftw.FFTW(u, v, direction="FFTW_FORWARD", axes=axes)

        if has_larger_axis:
            raise ValueError("Number of FFT points must be equal to or greater"
                             "than the signal length for each axis when "
                             "returning an FFT pair")

            def planned_rfftn(x):
                u[:] = x
                return fft_obj().copy()

        return planned_irfftn, planned_rfftn

        return planned_irfftn
def irfftn_pair(a, shape=None, axes=None):
    """Returns a planned function that computes the N-D DFT of a real-valued

    a : array-like or shape
        An input array or its shape.
    shape : Optional[sequence of ints]
        Number of FFT points. Default is input size along specified axes
    axes : Optional[sequence of ints]
        Axes along which to perform the fft. Default is all axes.

        Planned ifft function.
        Planned fft function.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = list(a_shape)
            shape[-1] = 2 * (shape[-1] - 1)
            shape = [a_shape[axis] for axis in axes]
            shape[axes[-1]] = 2 * (shape[axes[-1]] - 1)

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes, )
    if np.asarray(shape).ndim == 0:
        shape = (shape, )

    a_shape_out = list(a_shape)
    a_shape_out[axes[-1]] = 2 * (a_shape_out[axes[-1]] - 1)

    # Set up slices
    slices = [slice(None)] * n_dim

    has_smaller_axis = any(s1 < s2 for s1, s2 in zip(a_shape_out, shape))
    has_larger_axis = any(s1 > s2 for s1, s2 in zip(a_shape_out, shape))

    # Define ifft function
    def planned_irfftn(x):
        # scipy.fftpack.fftn doesn't handle mixed sign axes well
        return np.fft.irfftn(x, shape, axes)

    # Define fftn function
    if has_larger_axis:
        raise ValueError("Number of FFT points must be equal to or greater"
                         "than the signal length for each axis when "
                         "returning an FFT pair.")


        def planned_rfftn(x):
            # scipy.fftpack.fftn doesn't handle mixed sign axes well
            return np.fft.rfftn(x, shape, axes)

    return planned_irfftn, planned_rfftn
def fft(a, nfft=None, axis=-1, fft_pair=False, crop_ifft=False):
    """Returns a planned function that computes the 1-D DFT of a sequence
    or array.

        a : number, shape or array-like
            An input array, its shape or length.
        nfft : Optional[int]
            Number of FFT points. Default is input size along specified
        axis : Optional[int]
            Axis along which to perform the fft. Default is -1.
        fft_pair : Optional[boolean]
            Indicates Whether or not to also return an ifft function.
            Default is False.
        crop_ifft : Optional[boolean]
            Indicates whether the planned ifft function should crop its
            output to match input size. Default is False.

            Planned fft function.
            Planned ifft function. Returned if fft_pair is True.
    # Get shape
    shape = _utils.get_shape(a)
    n = shape[axis]

    # Set up slices and pyfftw shapes
    n_dim = len(shape)
    slices = [slice(None)] * n_dim

    if nfft is None:
        nfft = shape[axis]
        shape = list(shape)
        shape[axis] = nfft

    # Set data type
    if np.asarray(a).dtype.name in ('float32', 'int32'):
        dtype = 'complex64'
        dtype = 'complex128'

    # Set up input and output arrays and FFT object
    u = pyfftw.n_byte_align_empty(shape, 16, dtype)
    v = pyfftw.n_byte_align_empty(shape, 16, dtype)
    fft_obj = pyfftw.FFTW(u, v, direction='FFTW_FORWARD', axes=(axis,))

    # Define fft function
    if n == nfft:
        def planned_fft(x):
            u[:] = x
            return fft_obj().copy()

    elif n < nfft:
        def planned_fft(x):
            u[:] = _utils.pad_array(x, shape)
            return fft_obj().copy()

        slices[axis] = slice(0, nfft)

        def planned_fft(x):
            u[:] = x[slices]
            return fft_obj().copy()

    # Define ifft function
    if fft_pair is True:
        ifft_obj = pyfftw.FFTW(v, u, direction='FFTW_BACKWARD',

        if n > nfft:
            raise ValueError("NFFT must be at least equal to signal "
                             "length when returning an FFT pair.")

        elif n < nfft and crop_ifft:
            slices[axis] = slice(None, n)

            def planned_ifft(x):
                v[:] = x
                return ifft_obj().copy()[slices]

            def planned_ifft(x):
                v[:] = x
                return ifft_obj().copy()

        return planned_fft, planned_ifft

        return planned_fft
def ifftn_pair(a, shape=None, axes=None, crop_fft=False):
    """Returns a planned function that computes the N-D DFT of an array.

    a : array-like or shape
        An input array or its shape
    shape : Optional[sequence of ints]
        Number of FFT points. Default is input size along specified axes
    axes : Optional[sequence of ints]
        Axes along which to perform the fft. Default is all axes.
    crop_fft : Optional[boolean]
        Indicates that the fft function should crop its output to match
        the shape of the input to the ifft function

        Planned fft function.
        Planned ifft function.
    # Get shape of input
    a_shape = _utils.get_shape(a)
    n_dim = len(a_shape)

    if shape is None:
        if axes is None:
            shape = a_shape
            shape = [a_shape[axis] for axis in axes]

    if axes is None:
        n_dim_s = len(shape)
        dim_diff = n_dim - n_dim_s
        axes = [k + dim_diff for k in range(n_dim_s)]

    # Make sure axes and shape are iterable
    if np.asarray(axes).ndim == 0:
        axes = (axes, )
    if np.asarray(shape).ndim == 0:
        shape = (shape, )

    # Compute FFTW shape
    fft_shape = list(a_shape)
    for n, axis in zip(range(len(axes)), axes):
        fft_shape[axis] = shape[n]

    # Set up slices
    slices = [slice(None)] * n_dim

    has_smaller_axis = any(s1 < s2 for s1, s2 in zip(a_shape, fft_shape))
    has_larger_axis = any(s1 > s2 for s1, s2 in zip(a_shape, fft_shape))

    # Define ifft function
    def planned_ifftn(x):
        # scipy.fftpack.fftn doesn't handle mixed sign axes well
        return np.fft.ifftn(x, shape, axes)

    # Define fftn function
    if has_larger_axis:
        raise ValueError("Number of FFT points must be equal to or greater"
                         "than the signal length for each axis when "
                         "returning an FFT pair.")

    elif has_smaller_axis and crop_fft:
        for axis in axes:
            slices[axis] = slice(0, a_shape[axis])

        def planned_fftn(x):
            # scipy.fftpack.fftn doesn't handle mixed sign axes well
            return np.fft.fftn(x, shape, axes)[tuple(slices)]


        def planned_fftn(x):
            # scipy.fftpack.fftn doesn't handle mixed sign axes well
            return np.fft.fftn(x, shape, axes)

    return planned_ifftn, planned_fftn