示例#1
0
def fft_grid_dataset(
    dataset: GridDataset, type: Optional[str] = "abs",
) -> GridDataset:
    """Computes the Fast Fourier Transform (FFT) of a single/multiple\
       iteration :class:`nata.containers.GridDataset` along all grid axes\
       using `numpy's fft module`__.

        .. _fft: https://numpy.org/doc/stable/reference/routines.fft.html
        __ fft_

        Parameters
        ----------
        type: ``{'abs', 'real', 'imag', 'full'}``, optional
            Defines the component of the FFT selected for output. Available
            values are  ``'abs'`` (default), ``'real'``, ``'imag'`` and
            ``'full'``, which correspond to the absolute value, real component,
            imaginary component and full (complex) result of the FFT,
            respectively.

        Returns
        ------
        :class:`nata.containers.GridDataset`:
            Selected FFT component along all grid axes of ``dataset``.

        Examples
        --------
        To obtain the FFT of a :class:`nata.containers.GridDataset`, a simple
        call to the ``fft()`` method is enough. In the following example, we
        compute the FFT of a one-dimensional
        :class:`nata.containers.GridDataset`.

        >>> from nata.containers import GridDataset
        >>> import numpy as np
        >>> x = np.linspace(100)
        >>> arr = np.exp(-(x-len(x)/2)**2)
        >>> ds = GridDataset(arr[np.newaxis])
        >>> ds_fft = ds.fft()

    """

    fft_data = np.array(dataset)
    fft_axes = np.arange(len(dataset.grid_shape)) + 1

    fft_data = fft.fftn(
        fft_data if len(dataset) > 1 else fft_data[np.newaxis], axes=fft_axes
    )
    fft_data = fft.fftshift(fft_data, axes=fft_axes)

    if type == "real":
        fft_data = np.real(fft_data)
        label = f"Re(FFT({dataset.label}))"
        name = f"fftr_{dataset.name}"
    elif type == "imag":
        fft_data = np.imag(fft_data)
        label = f"Im(FFT({dataset.label}))"
        name = f"ffti_{dataset.name}"
    elif type == "abs":
        fft_data = np.abs(fft_data)
        label = f"|FFT({dataset.label})|"
        name = f"ffta_{dataset.name}"
    else:
        label = f"FFT({dataset.label})"
        name = f"fft_{dataset.name}"

    axes = []
    for a in dataset.axes["grid_axes"]:
        delta = [
            (np.max(a_ts) - np.min(a_ts)) / len(np.array(a_ts)) / (2.0 * np.pi)
            for a_ts in a
        ]
        axis_data = fft.fftshift(
            [
                fft.fftfreq(len(np.array(a_ts)), delta[idx])
                for idx, a_ts in enumerate(a)
            ],
            axes=-1,
        )
        axes.append(
            GridAxis(
                axis_data,
                name=f"k_{a.name}",
                label=f"k_{{{a.label}}}",
                unit=f"1/({a.unit})",
            )
        )

    return GridDataset(
        fft_data,
        name=name,
        label=label,
        unit=dataset.unit,
        grid_axes=axes,
        time=dataset.axes["time"],
        iteration=dataset.axes["iteration"],
    )
示例#2
0
文件: lineout.py 项目: tsung1029/nata
def lineout_grid_dataset(
    dataset: GridDataset,
    fixed: Union[str, int],
    value: float,
) -> GridDataset:
    """Takes a lineout across a two-dimensional, single/multiple iteration\
       :class:`nata.containers.GridDataset`:

        Parameters
        ----------

        fixed: :class:``str`` or :class:``int``
            Selection of the axes along which the taken lineout is constant.

            * if it is a string, then it must match the ``name`` property of an
              existing grid axis in ``dataset``.
            * if it is an integer, then it must match the index of a grid axis
              in ``dataset`` (i.e. `0` or `1`).

        value: scalar
            Value between the minimum and maximum of the axes selected through
            ``fixed`` over which the lineout is taken.

        Returns
        ------
        :class:`nata.containers.GridDataset`:
            One-dimensional :class:`nata.containers.GridDataset`.

        Examples
        --------
        The following example shows how to obtain a lineout from a
        two-dimensional :class:`nata.containers.GridDataset`. Since no axes are
        attributed to the dataset in this example, they are automatically
        generated with no names, and ``fixed`` must be an integer.

        >>> from nata.containers import GridDataset
        >>> import numpy as np
        >>> arr = np.arange(25).reshape((5,5))
        >>> ds = GridDataset(arr[np.newaxis])
        >>> lo = ds.lineout(fixed=0, value=2)
        >>> lo.data
        array([10, 11, 12, 13, 14])

    """
    if len(dataset.grid_shape) != 2:
        raise ValueError(
            "Grid lineouts are only supported for two-dimensional grid datasets"
        )

    # get handle for grid axes
    axes = dataset.axes["grid_axes"]

    if isinstance(fixed, str):
        ax_idx = -1
        # get index based on
        for key, ax in enumerate(axes):
            if ax.name == fixed:
                ax_idx = key
                break
        if ax_idx < 0:
            raise ValueError(
                f"Axis `{fixed}` could not be found in dataset `{dataset}`")
    else:
        ax_idx = fixed

    # build axis values
    axis = axes[ax_idx]

    if value < np.min(axis) or value > np.max(axis):
        raise ValueError(f"Out of range value for fixed `{fixed}`")

    values = np.array(axis)
    idx = (np.abs(values - value)).argmin()

    data = np.array(dataset)

    # get lineout
    if ax_idx == 0:
        lo_data = data[:, idx, :] if len(dataset) > 1 else data[idx, :]
        lo_axis = axes[1]

    elif ax_idx == 1:
        lo_data = data[:, :, idx] if len(dataset) > 1 else data[:, idx]
        lo_axis = axes[0]

    return GridDataset(
        lo_data if len(dataset) > 1 else lo_data[np.newaxis],
        name=f"{dataset.name}_lineout",
        label=dataset.label,
        unit=dataset.unit,
        grid_axes=[lo_axis],
        time=dataset.axes["time"],
        iteration=dataset.axes["iteration"],
    )