def fft_grid_dataset( dataset: GridDataset, type: Optional[str] = "abs", ) -> GridDataset: """Computes the Fast Fourier Transform (FFT) of a single/multiple\ iteration :class:`nata.containers.GridDataset` along all grid axes\ using `numpy's fft module`__. .. _fft: https://numpy.org/doc/stable/reference/routines.fft.html __ fft_ Parameters ---------- type: ``{'abs', 'real', 'imag', 'full'}``, optional Defines the component of the FFT selected for output. Available values are ``'abs'`` (default), ``'real'``, ``'imag'`` and ``'full'``, which correspond to the absolute value, real component, imaginary component and full (complex) result of the FFT, respectively. Returns ------ :class:`nata.containers.GridDataset`: Selected FFT component along all grid axes of ``dataset``. Examples -------- To obtain the FFT of a :class:`nata.containers.GridDataset`, a simple call to the ``fft()`` method is enough. In the following example, we compute the FFT of a one-dimensional :class:`nata.containers.GridDataset`. >>> from nata.containers import GridDataset >>> import numpy as np >>> x = np.linspace(100) >>> arr = np.exp(-(x-len(x)/2)**2) >>> ds = GridDataset(arr[np.newaxis]) >>> ds_fft = ds.fft() """ fft_data = np.array(dataset) fft_axes = np.arange(len(dataset.grid_shape)) + 1 fft_data = fft.fftn( fft_data if len(dataset) > 1 else fft_data[np.newaxis], axes=fft_axes ) fft_data = fft.fftshift(fft_data, axes=fft_axes) if type == "real": fft_data = np.real(fft_data) label = f"Re(FFT({dataset.label}))" name = f"fftr_{dataset.name}" elif type == "imag": fft_data = np.imag(fft_data) label = f"Im(FFT({dataset.label}))" name = f"ffti_{dataset.name}" elif type == "abs": fft_data = np.abs(fft_data) label = f"|FFT({dataset.label})|" name = f"ffta_{dataset.name}" else: label = f"FFT({dataset.label})" name = f"fft_{dataset.name}" axes = [] for a in dataset.axes["grid_axes"]: delta = [ (np.max(a_ts) - np.min(a_ts)) / len(np.array(a_ts)) / (2.0 * np.pi) for a_ts in a ] axis_data = fft.fftshift( [ fft.fftfreq(len(np.array(a_ts)), delta[idx]) for idx, a_ts in enumerate(a) ], axes=-1, ) axes.append( GridAxis( axis_data, name=f"k_{a.name}", label=f"k_{{{a.label}}}", unit=f"1/({a.unit})", ) ) return GridDataset( fft_data, name=name, label=label, unit=dataset.unit, grid_axes=axes, time=dataset.axes["time"], iteration=dataset.axes["iteration"], )
def lineout_grid_dataset( dataset: GridDataset, fixed: Union[str, int], value: float, ) -> GridDataset: """Takes a lineout across a two-dimensional, single/multiple iteration\ :class:`nata.containers.GridDataset`: Parameters ---------- fixed: :class:``str`` or :class:``int`` Selection of the axes along which the taken lineout is constant. * if it is a string, then it must match the ``name`` property of an existing grid axis in ``dataset``. * if it is an integer, then it must match the index of a grid axis in ``dataset`` (i.e. `0` or `1`). value: scalar Value between the minimum and maximum of the axes selected through ``fixed`` over which the lineout is taken. Returns ------ :class:`nata.containers.GridDataset`: One-dimensional :class:`nata.containers.GridDataset`. Examples -------- The following example shows how to obtain a lineout from a two-dimensional :class:`nata.containers.GridDataset`. Since no axes are attributed to the dataset in this example, they are automatically generated with no names, and ``fixed`` must be an integer. >>> from nata.containers import GridDataset >>> import numpy as np >>> arr = np.arange(25).reshape((5,5)) >>> ds = GridDataset(arr[np.newaxis]) >>> lo = ds.lineout(fixed=0, value=2) >>> lo.data array([10, 11, 12, 13, 14]) """ if len(dataset.grid_shape) != 2: raise ValueError( "Grid lineouts are only supported for two-dimensional grid datasets" ) # get handle for grid axes axes = dataset.axes["grid_axes"] if isinstance(fixed, str): ax_idx = -1 # get index based on for key, ax in enumerate(axes): if ax.name == fixed: ax_idx = key break if ax_idx < 0: raise ValueError( f"Axis `{fixed}` could not be found in dataset `{dataset}`") else: ax_idx = fixed # build axis values axis = axes[ax_idx] if value < np.min(axis) or value > np.max(axis): raise ValueError(f"Out of range value for fixed `{fixed}`") values = np.array(axis) idx = (np.abs(values - value)).argmin() data = np.array(dataset) # get lineout if ax_idx == 0: lo_data = data[:, idx, :] if len(dataset) > 1 else data[idx, :] lo_axis = axes[1] elif ax_idx == 1: lo_data = data[:, :, idx] if len(dataset) > 1 else data[:, idx] lo_axis = axes[0] return GridDataset( lo_data if len(dataset) > 1 else lo_data[np.newaxis], name=f"{dataset.name}_lineout", label=dataset.label, unit=dataset.unit, grid_axes=[lo_axis], time=dataset.axes["time"], iteration=dataset.axes["iteration"], )