def test_Osiris_Hdf5_GridFile_check_is_valid_backend(os_hdf5_grid_444_file): """Check 'Osiris_Hdf5_GridFile' is a valid backend exclusively""" assert Osiris_Hdf5_GridFile.is_valid_backend(os_hdf5_grid_444_file) is True # backend are registered automatically for GridDatasets for (name, backend) in GridDataset.get_backends().items(): if name == Osiris_Hdf5_GridFile.name: continue assert backend.is_valid_backend(os_hdf5_grid_444_file) is False
def test_dataset_slice_dimensionality(): data = np.arange(96).reshape((8, 4, 3)) grid = GridDataset.from_array(data) sliced_grid = grid.slice(constant="axis0", value=0) assert sliced_grid.ndim == 2 assert sliced_grid.axes[0].name == grid.axes[0].name assert sliced_grid.axes[1].name == grid.axes[2].name np.testing.assert_array_equal(sliced_grid, data[:, 0, :])
def test_GridDataset_getitem(case): arr = np.array(case["arr"]) indexing = case["indexing"] instance = case["instance_after_indexing"] expected_arr = case["expected_arr"] if "expected_arr" in case else arr[ indexing] expected_axes = case["expected_axes"] if "expected_axes" in case else None grid = GridDataset.from_array(arr) subgrid = grid[indexing] expected_grid = instance.from_array(expected_arr, axes=expected_axes) assert isinstance(subgrid, instance) assert hash(subgrid) == hash(expected_grid) np.testing.assert_array_equal(subgrid, expected_grid)
def test_dataset_transpose_shape(): data = np.zeros((7, 6, 5)) grid = GridDataset.from_array(data) for tr_axes in [ None, [1, 2], [2, 1], ]: tr_grid = grid.transpose(axes=tr_axes) axes = (([ 0, ] + tr_axes) if tr_axes else [0, 2, 1]) assert tr_grid.shape == np.transpose(data, axes=axes).shape
def test_streak_axes_shape(): time = np.arange(3) x = np.arange(10) grid = GridDataset.from_array( np.tile(x, (len(time), 1)), axes=[Axis.from_array(time), Axis.from_array([x for _ in time])], ) stk_grid = grid.streak() assert stk_grid.shape == grid.shape assert stk_grid.axes[0].shape == time.shape assert stk_grid.axes[1].shape == x.shape
def test_GridDataset_from_array_raise_invalid_axes(): # invalid number of axes with pytest.raises(ValueError, match="mismatches with dimensionality of data"): GridDataset.from_array([], axes=[0, 1]) # axes which are not 1D dimensional with pytest.raises(ValueError, match="time axis has to be 1D"): GridDataset.from_array([0, 1], axes=[[[0, 1]]]) # only 2D axes for GridDataset are supported with pytest.raises(ValueError, match="axis for GridDataset are supported"): GridDataset.from_array([[0, 1]], axes=[[0], [[[0, 1]]]]) # axis mismatch with shape of data with pytest.raises(ValueError, match="inconsistency between data and axis shape"): GridDataset.from_array([[0, 1]], axes=[[0], [[0, 1, 2, 3]]])
def test_GridDataset_from_array_check_axes(): grid_arr = GridDataset.from_array( [[0, 1], [1, 2]], axes=[[0, 1], [[10, 20], [30, 40]]], ) np.testing.assert_array_equal(grid_arr.axes[0], [0, 1]) assert grid_arr.axes[0].name == "time" assert grid_arr.axes[0].label == "time" assert grid_arr.axes[0].unit == "" assert grid_arr.axes[0].shape == (2, ) np.testing.assert_array_equal(grid_arr.axes[1], [[10, 20], [30, 40]]) assert grid_arr.axes[1].name == "axis0" assert grid_arr.axes[1].label == "unlabeled" assert grid_arr.axes[1].unit == "" assert grid_arr.axes[1].shape == (2, 2)
def test_dataset_fft_peak_1d(): time = np.arange(1, 3) x = np.linspace(0, 10 * np.pi, 101) k_modes = np.arange(len(time)) + 1 grid = GridDataset.from_array( [np.sin(k_i * x) for k_i in k_modes], axes=[ Axis.from_array(time), Axis.from_array(np.tile(x, (len(time), 1))) ], ) fft_grid = grid.fft() for k_i, fft_grid_i in zip(k_modes, fft_grid): assert (fft_grid_i.to_dask().argmax() == ( np.abs(fft_grid_i.axes[0].to_dask() + k_i)).argmin())
def streak_grid_array(grid: GridDataset, ) -> GridArray: """Converts a `GridDataset` to a `GridArray`. Only `GridDataset` with axes that do not change over time are supported. Returns ------ :class:`nata.containers.GridArray`: Streak of ``grid``. Examples -------- Convert a one-dimensional dataset with time dependence to a two-dimensional array. >>> from nata.containers import GridDataset >>> import numpy as np >>> data = np.arange(5*7).reshape((5, 7)) >>> grid = GridDataset.from_array(data) >>> stk_grid = grid.streak() >>> stk_grid.shape (5, 7) >>> [axis.shape for axis in stk_grid.axes] [(5,), (7,)] """ if grid.ndim < 2: raise ValueError( "streak is not available for 0 dimensional GridDatasets") for axis in grid.axes[1:]: for i, axis_i in enumerate(axis): if np.any(axis_i.to_dask() != axis[0].to_dask()): raise ValueError("invalid axes for streak") return GridArray.from_array( grid.to_dask(), name=grid.name, label=grid.label, unit=grid.unit, axes=[grid.time] + [axis[0] for axis in grid.axes[1:]], )
def test_GridDataset_from_path(grid_files: Path): grid = GridDataset.from_path(grid_files / "*") assert grid.name == "dummy_grid" assert grid.label == "dummy grid label" assert grid.unit == "dummy unit" assert grid.axes[0].name == "time" assert grid.axes[0].label == "time" assert grid.axes[0].unit == "dummy time unit" np.testing.assert_array_equal(grid.axes[0], [1.0, 1.0, 1.0]) assert grid.axes[1].name == "dummy_axis0" assert grid.axes[2].name == "dummy_axis1" assert grid.axes[1].label == "dummy label axis0" assert grid.axes[2].label == "dummy label axis1" assert grid.axes[1].unit == "dummy unit axis0" assert grid.axes[2].unit == "dummy unit axis1" np.testing.assert_array_equal( grid, np.tile(np.arange(32).reshape((4, 8)), (3, 1, 1)))
def test_GridDataset_from_array_default(): grid_ds = GridDataset.from_array(da.arange(12, dtype=int).reshape((4, 3))) assert grid_ds.shape == (4, 3) assert grid_ds.ndim == 2 assert grid_ds.dtype == int assert grid_ds.axes[0].name == "time" assert grid_ds.axes[0].label == "time" assert grid_ds.axes[0].unit == "" assert grid_ds.axes[0].shape == (4, ) assert grid_ds.axes[1].name == "axis0" assert grid_ds.axes[1].label == "unlabeled" assert grid_ds.axes[1].unit == "" assert grid_ds.axes[1].shape == (4, 3) np.testing.assert_array_equal(grid_ds.axes[1], np.tile(np.arange(3), (4, 1))) assert grid_ds.time is grid_ds.axes[0] assert grid_ds.name == "unnamed" assert grid_ds.label == "unlabeled" assert grid_ds.unit == ""
def test_dataset_slice_selection(): grid = GridDataset.from_array(np.arange(12).reshape((4, 3))) with pytest.raises(ValueError, match="slice along the time axis is not supported"): grid.slice(constant=grid.time.name, value=0)
def test_streak_shape(): grid = GridDataset.from_array(np.arange(12).reshape((4, 3))) assert grid.streak().shape == grid.shape
def fft_grid_dataset( dataset: GridDataset, type: Optional[str] = "abs", ) -> GridDataset: """Computes the Fast Fourier Transform (FFT) of a single/multiple\ iteration :class:`nata.containers.GridDataset` along all grid axes\ using `numpy's fft module`__. .. _fft: https://numpy.org/doc/stable/reference/routines.fft.html __ fft_ Parameters ---------- type: ``{'abs', 'real', 'imag', 'full'}``, optional Defines the component of the FFT selected for output. Available values are ``'abs'`` (default), ``'real'``, ``'imag'`` and ``'full'``, which correspond to the absolute value, real component, imaginary component and full (complex) result of the FFT, respectively. Returns ------ :class:`nata.containers.GridDataset`: Selected FFT component along all grid axes of ``dataset``. Examples -------- To obtain the FFT of a :class:`nata.containers.GridDataset`, a simple call to the ``fft()`` method is enough. In the following example, we compute the FFT of a one-dimensional :class:`nata.containers.GridDataset`. >>> from nata.containers import GridDataset >>> import numpy as np >>> x = np.linspace(100) >>> arr = np.exp(-(x-len(x)/2)**2) >>> ds = GridDataset(arr[np.newaxis]) >>> ds_fft = ds.fft() """ fft_data = np.array(dataset) fft_axes = np.arange(len(dataset.grid_shape)) + 1 fft_data = fft.fftn( fft_data if len(dataset) > 1 else fft_data[np.newaxis], axes=fft_axes ) fft_data = fft.fftshift(fft_data, axes=fft_axes) if type == "real": fft_data = np.real(fft_data) label = f"Re(FFT({dataset.label}))" name = f"fftr_{dataset.name}" elif type == "imag": fft_data = np.imag(fft_data) label = f"Im(FFT({dataset.label}))" name = f"ffti_{dataset.name}" elif type == "abs": fft_data = np.abs(fft_data) label = f"|FFT({dataset.label})|" name = f"ffta_{dataset.name}" else: label = f"FFT({dataset.label})" name = f"fft_{dataset.name}" axes = [] for a in dataset.axes["grid_axes"]: delta = [ (np.max(a_ts) - np.min(a_ts)) / len(np.array(a_ts)) / (2.0 * np.pi) for a_ts in a ] axis_data = fft.fftshift( [ fft.fftfreq(len(np.array(a_ts)), delta[idx]) for idx, a_ts in enumerate(a) ], axes=-1, ) axes.append( GridAxis( axis_data, name=f"k_{a.name}", label=f"k_{{{a.label}}}", unit=f"1/({a.unit})", ) ) return GridDataset( fft_data, name=name, label=label, unit=dataset.unit, grid_axes=axes, time=dataset.axes["time"], iteration=dataset.axes["iteration"], )
def test_dataset_slice_invalid_ndim(): with pytest.raises(ValueError, match="0 dimensional GridDatasets"): GridDataset.from_array(np.arange(5)).slice(constant="axis0", value=1)
def test_dataset_fft_selection(): grid = GridDataset.from_array(np.arange(12).reshape((4, 3))) with pytest.raises(ValueError, match="fft along the time axis is not supported"): grid.fft(axes=[0])
def test_GridDataset_raise_invalid_new_axis(): grid = GridDataset.from_array(np.arange(3 * 4 * 5).reshape((3, 4, 5))) with pytest.raises(IndexError): grid[np.newaxis]
def test_GridDataset_from_array_check_name(): grid_arr = GridDataset.from_array([], name="custom_name") assert grid_arr.name == "custom_name"
def test_GridDatasets_backends_are_registered(): backends = GridDataset.get_backends() assert backends[Osiris_Hdf5_GridFile.name] is Osiris_Hdf5_GridFile assert backends[Osiris_Dev_Hdf5_GridFile.name] is Osiris_Dev_Hdf5_GridFile assert backends[Osiris_zdf_GridFile.name] is Osiris_zdf_GridFile
def test_GridDataset_from_array_raise_invalid_name(): with pytest.raises(ValueError, match="'name' has to be a valid identifier"): GridDataset.from_array([], name="invalid name")
def lineout_grid_dataset( dataset: GridDataset, fixed: Union[str, int], value: float, ) -> GridDataset: """Takes a lineout across a two-dimensional, single/multiple iteration\ :class:`nata.containers.GridDataset`: Parameters ---------- fixed: :class:``str`` or :class:``int`` Selection of the axes along which the taken lineout is constant. * if it is a string, then it must match the ``name`` property of an existing grid axis in ``dataset``. * if it is an integer, then it must match the index of a grid axis in ``dataset`` (i.e. `0` or `1`). value: scalar Value between the minimum and maximum of the axes selected through ``fixed`` over which the lineout is taken. Returns ------ :class:`nata.containers.GridDataset`: One-dimensional :class:`nata.containers.GridDataset`. Examples -------- The following example shows how to obtain a lineout from a two-dimensional :class:`nata.containers.GridDataset`. Since no axes are attributed to the dataset in this example, they are automatically generated with no names, and ``fixed`` must be an integer. >>> from nata.containers import GridDataset >>> import numpy as np >>> arr = np.arange(25).reshape((5,5)) >>> ds = GridDataset(arr[np.newaxis]) >>> lo = ds.lineout(fixed=0, value=2) >>> lo.data array([10, 11, 12, 13, 14]) """ if len(dataset.grid_shape) != 2: raise ValueError( "Grid lineouts are only supported for two-dimensional grid datasets" ) # get handle for grid axes axes = dataset.axes["grid_axes"] if isinstance(fixed, str): ax_idx = -1 # get index based on for key, ax in enumerate(axes): if ax.name == fixed: ax_idx = key break if ax_idx < 0: raise ValueError( f"Axis `{fixed}` could not be found in dataset `{dataset}`") else: ax_idx = fixed # build axis values axis = axes[ax_idx] if value < np.min(axis) or value > np.max(axis): raise ValueError(f"Out of range value for fixed `{fixed}`") values = np.array(axis) idx = (np.abs(values - value)).argmin() data = np.array(dataset) # get lineout if ax_idx == 0: lo_data = data[:, idx, :] if len(dataset) > 1 else data[idx, :] lo_axis = axes[1] elif ax_idx == 1: lo_data = data[:, :, idx] if len(dataset) > 1 else data[:, idx] lo_axis = axes[0] return GridDataset( lo_data if len(dataset) > 1 else lo_data[np.newaxis], name=f"{dataset.name}_lineout", label=dataset.label, unit=dataset.unit, grid_axes=[lo_axis], time=dataset.axes["time"], iteration=dataset.axes["iteration"], )
def test_GridDataset_from_array_check_unit(): grid_arr = GridDataset.from_array([], unit="custom unit") assert grid_arr.unit == "custom unit"
def test_GridDataset_from_array_check_label(): grid_arr = GridDataset.from_array([], label="custom label") assert grid_arr.label == "custom label"
def test_streak_invalid_ndim(): with pytest.raises(ValueError, match="0 dimensional GridDatasets"): GridDataset.from_array(np.arange(5)).streak()
def test_streak_type(): grid = GridDataset.from_array(np.arange(12).reshape((4, 3))) assert isinstance(grid.streak(), GridArray)
def plot_grid_dataset( dataset: GridDataset, fig: Optional[Figure] = None, axes: Optional[Axes] = None, style: Optional[dict] = {}, interactive: Optional[bool] = True, n: Optional[int] = 0, ) -> Union[Figure, None]: """Plots a single/multiple iteration :class:`nata.containers.GridDataset`\ using a :class:`nata.plots.types.LinePlot` or\ :class:`nata.plots.types.ColorPlot` if the dataset is one- or\ two-dimensional, respectively. Parameters ---------- fig: :class:`nata.plots.Figure`, optional If provided, the plot is drawn on ``fig``. The plot is drawn on ``axes`` if it is a child axes of ``fig``, otherwise a new axes is created on ``fig``. If ``fig`` is not provided, a new :class:`nata.plots.Figure` is created. axes: :class:`nata.plots.Axes`, optional If provided, the plot is drawn on ``axes``, which must be an axes of ``fig``. If ``axes`` is not provided or is provided without a corresponding ``fig``, a new :class:`nata.plots.Axes` is created in a new :class:`nata.plots.Figure`. style: ``dict``, optional Dictionary that takes a mix of style properties of :class:`nata.plots.Figure`, :class:`nata.plots.Axes` and any plot type (see :class:`nata.plots.types.LinePlot` or :class:`nata.plots.types.ColorPlot`). interactive: ``bool``, optional Controls wether interactive widgets should be shown with the plot to allow for temporal navigation. Only applicable if ``dataset`` has multiple iterations. n: ``int``, optional Selects the index of the iteration to be shown initially. Only applicable if ``dataset`` has multiple iterations, . Returns ------ :class:`nata.plots.Figure` or ``None``: Figure with plot built based on ``dataset``. Interactive widgets are shown with the figure if ``dataset`` has multiple iterations, in which case this method returns ``None``. Examples -------- To get a plot with default style properties in a new figure, simply call the ``.plot()`` method of the dataset. >>> from nata.containers import GridDataset >>> import numpy as np >>> arr = np.arange(10) >>> ds = GridDataset.from_array(arr) >>> fig = ds.plot() In case a :class:`nata.plots.Figure` is returned by the method, it can be shown by calling the :func:`nata.plots.Figure.show` method. >>> fig.show() To draw a new plot on ``fig``, we can pass it as an argument to the ``.plot()`` method. If ``axes`` is provided, the new plot is drawn on the selected axes. >>> ds2 = GridDataset.from_array(arr**2) >>> fig = ds2.plot(fig=fig, axes=fig.axes[0]) """ p_plan = PlotPlan(dataset=dataset, style=filter_style(dataset.plot_type(), style)) a_plan = AxesPlan(axes=axes, plots=[p_plan], style=filter_style(Axes, style)) f_plan = FigurePlan(fig=fig, axes=[a_plan], style=filter_style(Figure, style)) if len(dataset) > 1: if inside_notebook(): if interactive: f_plan.build_interactive(n) else: return f_plan[n].build() else: # TODO: remove last line from warn warn(f"Plotting only iteration with index n={str(n)}." + " Interactive plots of multiple iteration datasets are not" + " supported outside notebook environments.") return f_plan[n].build() else: return f_plan.build()