Пример #1
0
def slice_grid_array(
    grid: GridArray,
    constant: Union[str, int],
    value: float,
) -> GridArray:
    """
    Takes a slice of a `GridArray` at a constant value of a given axis.

    Arguments:
        constant:
            Name or index that defines the axis taken to be constant in the slice.
        value:
            Value of the axis at which the slice is taken.

    Returns:
        Slice of ``grid``.

    Examples:
        Obtain a slice of a two-dimensional array.

        ```pycon
        >>> from nata.containers import GridArray
        >>> from nata.containers import Axis
        >>> import numpy as np
        >>> x = np.arange(5)
        >>> data = np.arange(25).reshape((5, 5))
        >>> grid = GridArray.from_array(data, axes=[Axis(x), Axis(x)])
        >>> grid.slice(constant=0, value=1).to_numpy()
        array([5, 6, 7, 8, 9]) # the second column
        >>> grid.slice(constant=1, value=1).to_numpy()
        array([ 1,  6, 11, 16, 21]) # the second row
        ```
    """

    if grid.ndim < 1:
        raise ValueError("slice is not available for 0 dimensional GridArrays")

    # get slice axis
    slice_axis = get_slice_axis(grid, constant)

    axis = grid.axes[slice_axis]

    if value < np.min(axis.to_dask()) or value >= np.max(axis.to_dask()):
        raise ValueError(f"out of range value for axis '{constant}'")

    # get index of nearest neighbour
    slice_idx = (np.abs(axis.to_dask() - value)).argmin(axis=-1)

    # build data slice
    data_slice = [slice(None)] * len(grid.axes)
    data_slice[slice_axis] = slice_idx

    return GridArray.from_array(
        grid.to_dask()[tuple(data_slice)],
        name=grid.name,
        label=grid.label,
        unit=grid.unit,
        axes=[ax for key, ax in enumerate(grid.axes) if ax is not axis],
        time=grid.time,
    )
Пример #2
0
def transpose_grid_array(
    grid: GridArray,
    axes: Optional[list] = None,
) -> GridArray:
    """Reverses or permutes the axes of a `GridArray`.

    Parameters
    ----------
    axes: ``list``, optional
         List of integers and/or strings that identify the permutation of the
         axes. The i'th axis of the returned `GridArray` will correspond to the
         axis numbered/labeled axes[i] of the input. If not specified, the
         order of the axes is reversed.

    Returns
    ------
    :class:`nata.containers.GridArray`:
        Transpose of ``grid``.

    Examples
    --------
    Transpose a three-dimensional array.

    >>> from nata.containers import GridArray
    >>> import numpy as np
    >>> data = np.arange(96).reshape((8, 4, 3))
    >>> grid = GridArray.from_array(data)
    >>> grid.transpose().shape
    (3, 4, 8)
    >>> grid.transpose(axes=[0,2,1]).shape
    (8, 3, 4)

    """

    # get transpose axes
    tr_axes = get_transpose_axes(grid, axes)

    if len(set(tr_axes)) is not grid.ndim:
        raise ValueError("invalid transpose axes")

    return GridArray.from_array(
        da.transpose(grid.to_dask(), axes=tr_axes),
        name=grid.name,
        label=grid.label,
        unit=grid.unit,
        axes=[grid.axes[axis] for axis in tr_axes],
        time=grid.time,
    )
Пример #3
0
def fft_grid_array(
    grid: GridArray,
    axes: Optional[list] = None,
    comp: str = "abs",
) -> GridArray:
    """Computes the Fast Fourier Transform (FFT) of a
    :class:`nata.containers.GridArray` using `numpy's fft module`__.

     .. _fft: https://numpy.org/doc/stable/reference/routines.fft.html
     __ fft_

     The axes over which the FFT is computed are transformed such that
     the zero frequency bins are centered.

     Parameters
     ----------
     axes: ``list``, optional
         List of integers and/or strings that identify the axes over which
         to compute the FFT.
     comp: ``{'abs', 'real', 'imag', 'full'}``, optional
         Defines the component of the FFT selected for output. Available
         values are  ``'abs'`` (default), ``'real'``, ``'imag'`` and
         ``'full'``, which correspond to the absolute value, real component,
         imaginary component and full (complex) result of the FFT,
         respectively.

     Returns
     ------
     :class:`nata.containers.GridArray`:
         Selected FFT component of ``grid``.

     Examples
     --------
     Obtain the FFT of a one-dimensional array.

     >>> from nata.containers import GridArray
     >>> import numpy as np
     >>> x = np.arange(100)
     >>> grid = GridArray.from_array(np.exp(-(x-len(x)/2)**2))
     >>> fft_grid = grid.fft()

     # TODO: add an image here with a plot of the FFT

     Compute the FFT over the first axis of a two-dimensional array.

     >>> from nata.containers import GridArray
     >>> import numpy as np
     >>> x = np.linspace(0, 10*np.pi)
     >>> y = np.linspace(0, 10*np.pi)
     >>> X, Y = np.meshgrid(x, y, indexing="ij")
     >>> grid = GridArray.from_array(np.sin(X) + np.sin(2*Y))
     >>> fft_grid = grid.fft(axes=[0])

     # TODO: add an image here with a plot of the FFT

    """

    if grid.ndim < 1:
        raise ValueError("fft is not available for 0 dimensional GridArrays")

    # build fft axes
    fft_axes = get_fft_axes(grid, axes)

    # build new axes
    new_axes = []

    for idx, axis in enumerate(grid.axes):
        if idx in fft_axes:
            # axis is fft axis, determine its fourier counterpart
            delta = ((np.max(axis.to_dask()) - np.min(axis.to_dask())) /
                     axis.shape[-1] / (2.0 * np.pi))

            axis_data = fft.fftshift(fft.fftfreq(axis.shape[-1], delta))

            new_axes.append(
                Axis(
                    axis_data,
                    name=f"k_{axis.name}",
                    label=f"k_{{{axis.label}}}",
                    unit=f"({axis.unit})^{{-1}}" if axis.unit else "",
                ))
        else:
            # axis is not fft axis, stays the same
            new_axes.append(axis)

    # do the data fft
    fft_data = fft.fftn(grid.to_dask(), axes=fft_axes)
    fft_data = fft.fftshift(fft_data, axes=fft_axes)

    # get only selected component
    if comp == "abs":
        fft_data = np.abs(fft_data)
        label = f"|{grid.label}|"
    elif comp == "real":
        fft_data = np.real(fft_data)
        label = f"\\Re({grid.label})"
    elif comp == "imag":
        fft_data = np.imag(fft_data)
        label = f"\\Im({grid.label})"
    elif comp == "full":
        label = f"{grid.label}"
    else:
        raise ValueError(f"invalid fft component '{comp}'")

    # build return grid
    return GridArray.from_array(
        fft_data,
        name=grid.name,
        label=label,
        unit=grid.unit,
        axes=new_axes,
        time=grid.time,
    )