def slice_grid_array( grid: GridArray, constant: Union[str, int], value: float, ) -> GridArray: """ Takes a slice of a `GridArray` at a constant value of a given axis. Arguments: constant: Name or index that defines the axis taken to be constant in the slice. value: Value of the axis at which the slice is taken. Returns: Slice of ``grid``. Examples: Obtain a slice of a two-dimensional array. ```pycon >>> from nata.containers import GridArray >>> from nata.containers import Axis >>> import numpy as np >>> x = np.arange(5) >>> data = np.arange(25).reshape((5, 5)) >>> grid = GridArray.from_array(data, axes=[Axis(x), Axis(x)]) >>> grid.slice(constant=0, value=1).to_numpy() array([5, 6, 7, 8, 9]) # the second column >>> grid.slice(constant=1, value=1).to_numpy() array([ 1, 6, 11, 16, 21]) # the second row ``` """ if grid.ndim < 1: raise ValueError("slice is not available for 0 dimensional GridArrays") # get slice axis slice_axis = get_slice_axis(grid, constant) axis = grid.axes[slice_axis] if value < np.min(axis.to_dask()) or value >= np.max(axis.to_dask()): raise ValueError(f"out of range value for axis '{constant}'") # get index of nearest neighbour slice_idx = (np.abs(axis.to_dask() - value)).argmin(axis=-1) # build data slice data_slice = [slice(None)] * len(grid.axes) data_slice[slice_axis] = slice_idx return GridArray.from_array( grid.to_dask()[tuple(data_slice)], name=grid.name, label=grid.label, unit=grid.unit, axes=[ax for key, ax in enumerate(grid.axes) if ax is not axis], time=grid.time, )
def transpose_grid_array( grid: GridArray, axes: Optional[list] = None, ) -> GridArray: """Reverses or permutes the axes of a `GridArray`. Parameters ---------- axes: ``list``, optional List of integers and/or strings that identify the permutation of the axes. The i'th axis of the returned `GridArray` will correspond to the axis numbered/labeled axes[i] of the input. If not specified, the order of the axes is reversed. Returns ------ :class:`nata.containers.GridArray`: Transpose of ``grid``. Examples -------- Transpose a three-dimensional array. >>> from nata.containers import GridArray >>> import numpy as np >>> data = np.arange(96).reshape((8, 4, 3)) >>> grid = GridArray.from_array(data) >>> grid.transpose().shape (3, 4, 8) >>> grid.transpose(axes=[0,2,1]).shape (8, 3, 4) """ # get transpose axes tr_axes = get_transpose_axes(grid, axes) if len(set(tr_axes)) is not grid.ndim: raise ValueError("invalid transpose axes") return GridArray.from_array( da.transpose(grid.to_dask(), axes=tr_axes), name=grid.name, label=grid.label, unit=grid.unit, axes=[grid.axes[axis] for axis in tr_axes], time=grid.time, )
def fft_grid_array( grid: GridArray, axes: Optional[list] = None, comp: str = "abs", ) -> GridArray: """Computes the Fast Fourier Transform (FFT) of a :class:`nata.containers.GridArray` using `numpy's fft module`__. .. _fft: https://numpy.org/doc/stable/reference/routines.fft.html __ fft_ The axes over which the FFT is computed are transformed such that the zero frequency bins are centered. Parameters ---------- axes: ``list``, optional List of integers and/or strings that identify the axes over which to compute the FFT. comp: ``{'abs', 'real', 'imag', 'full'}``, optional Defines the component of the FFT selected for output. Available values are ``'abs'`` (default), ``'real'``, ``'imag'`` and ``'full'``, which correspond to the absolute value, real component, imaginary component and full (complex) result of the FFT, respectively. Returns ------ :class:`nata.containers.GridArray`: Selected FFT component of ``grid``. Examples -------- Obtain the FFT of a one-dimensional array. >>> from nata.containers import GridArray >>> import numpy as np >>> x = np.arange(100) >>> grid = GridArray.from_array(np.exp(-(x-len(x)/2)**2)) >>> fft_grid = grid.fft() # TODO: add an image here with a plot of the FFT Compute the FFT over the first axis of a two-dimensional array. >>> from nata.containers import GridArray >>> import numpy as np >>> x = np.linspace(0, 10*np.pi) >>> y = np.linspace(0, 10*np.pi) >>> X, Y = np.meshgrid(x, y, indexing="ij") >>> grid = GridArray.from_array(np.sin(X) + np.sin(2*Y)) >>> fft_grid = grid.fft(axes=[0]) # TODO: add an image here with a plot of the FFT """ if grid.ndim < 1: raise ValueError("fft is not available for 0 dimensional GridArrays") # build fft axes fft_axes = get_fft_axes(grid, axes) # build new axes new_axes = [] for idx, axis in enumerate(grid.axes): if idx in fft_axes: # axis is fft axis, determine its fourier counterpart delta = ((np.max(axis.to_dask()) - np.min(axis.to_dask())) / axis.shape[-1] / (2.0 * np.pi)) axis_data = fft.fftshift(fft.fftfreq(axis.shape[-1], delta)) new_axes.append( Axis( axis_data, name=f"k_{axis.name}", label=f"k_{{{axis.label}}}", unit=f"({axis.unit})^{{-1}}" if axis.unit else "", )) else: # axis is not fft axis, stays the same new_axes.append(axis) # do the data fft fft_data = fft.fftn(grid.to_dask(), axes=fft_axes) fft_data = fft.fftshift(fft_data, axes=fft_axes) # get only selected component if comp == "abs": fft_data = np.abs(fft_data) label = f"|{grid.label}|" elif comp == "real": fft_data = np.real(fft_data) label = f"\\Re({grid.label})" elif comp == "imag": fft_data = np.imag(fft_data) label = f"\\Im({grid.label})" elif comp == "full": label = f"{grid.label}" else: raise ValueError(f"invalid fft component '{comp}'") # build return grid return GridArray.from_array( fft_data, name=grid.name, label=label, unit=grid.unit, axes=new_axes, time=grid.time, )