示例#1
0
def test_Osiris_Hdf5_GridFile_check_is_valid_backend(os_hdf5_grid_444_file):
    """Check 'Osiris_Hdf5_GridFile' is a valid backend exclusively"""
    assert Osiris_Hdf5_GridFile.is_valid_backend(os_hdf5_grid_444_file) is True

    # backend are registered automatically for GridDatasets
    for (name, backend) in GridDataset.get_backends().items():
        if name == Osiris_Hdf5_GridFile.name:
            continue

        assert backend.is_valid_backend(os_hdf5_grid_444_file) is False
示例#2
0
def test_dataset_slice_dimensionality():
    data = np.arange(96).reshape((8, 4, 3))
    grid = GridDataset.from_array(data)

    sliced_grid = grid.slice(constant="axis0", value=0)

    assert sliced_grid.ndim == 2
    assert sliced_grid.axes[0].name == grid.axes[0].name
    assert sliced_grid.axes[1].name == grid.axes[2].name

    np.testing.assert_array_equal(sliced_grid, data[:, 0, :])
示例#3
0
def test_GridDataset_getitem(case):
    arr = np.array(case["arr"])
    indexing = case["indexing"]
    instance = case["instance_after_indexing"]
    expected_arr = case["expected_arr"] if "expected_arr" in case else arr[
        indexing]
    expected_axes = case["expected_axes"] if "expected_axes" in case else None

    grid = GridDataset.from_array(arr)
    subgrid = grid[indexing]
    expected_grid = instance.from_array(expected_arr, axes=expected_axes)

    assert isinstance(subgrid, instance)
    assert hash(subgrid) == hash(expected_grid)
    np.testing.assert_array_equal(subgrid, expected_grid)
示例#4
0
def test_dataset_transpose_shape():

    data = np.zeros((7, 6, 5))
    grid = GridDataset.from_array(data)

    for tr_axes in [
            None,
        [1, 2],
        [2, 1],
    ]:
        tr_grid = grid.transpose(axes=tr_axes)
        axes = (([
            0,
        ] + tr_axes) if tr_axes else [0, 2, 1])
        assert tr_grid.shape == np.transpose(data, axes=axes).shape
示例#5
0
def test_streak_axes_shape():

    time = np.arange(3)
    x = np.arange(10)

    grid = GridDataset.from_array(
        np.tile(x, (len(time), 1)),
        axes=[Axis.from_array(time),
              Axis.from_array([x for _ in time])],
    )

    stk_grid = grid.streak()

    assert stk_grid.shape == grid.shape
    assert stk_grid.axes[0].shape == time.shape
    assert stk_grid.axes[1].shape == x.shape
示例#6
0
def test_GridDataset_from_array_raise_invalid_axes():
    # invalid number of axes
    with pytest.raises(ValueError,
                       match="mismatches with dimensionality of data"):
        GridDataset.from_array([], axes=[0, 1])

    # axes which are not 1D dimensional
    with pytest.raises(ValueError, match="time axis has to be 1D"):
        GridDataset.from_array([0, 1], axes=[[[0, 1]]])

    # only 2D axes for GridDataset are supported
    with pytest.raises(ValueError, match="axis for GridDataset are supported"):
        GridDataset.from_array([[0, 1]], axes=[[0], [[[0, 1]]]])

    # axis mismatch with shape of data
    with pytest.raises(ValueError,
                       match="inconsistency between data and axis shape"):
        GridDataset.from_array([[0, 1]], axes=[[0], [[0, 1, 2, 3]]])
示例#7
0
def test_GridDataset_from_array_check_axes():
    grid_arr = GridDataset.from_array(
        [[0, 1], [1, 2]],
        axes=[[0, 1], [[10, 20], [30, 40]]],
    )

    np.testing.assert_array_equal(grid_arr.axes[0], [0, 1])
    assert grid_arr.axes[0].name == "time"
    assert grid_arr.axes[0].label == "time"
    assert grid_arr.axes[0].unit == ""
    assert grid_arr.axes[0].shape == (2, )

    np.testing.assert_array_equal(grid_arr.axes[1], [[10, 20], [30, 40]])
    assert grid_arr.axes[1].name == "axis0"
    assert grid_arr.axes[1].label == "unlabeled"
    assert grid_arr.axes[1].unit == ""
    assert grid_arr.axes[1].shape == (2, 2)
示例#8
0
def test_dataset_fft_peak_1d():
    time = np.arange(1, 3)
    x = np.linspace(0, 10 * np.pi, 101)

    k_modes = np.arange(len(time)) + 1

    grid = GridDataset.from_array(
        [np.sin(k_i * x) for k_i in k_modes],
        axes=[
            Axis.from_array(time),
            Axis.from_array(np.tile(x, (len(time), 1)))
        ],
    )
    fft_grid = grid.fft()

    for k_i, fft_grid_i in zip(k_modes, fft_grid):
        assert (fft_grid_i.to_dask().argmax() == (
            np.abs(fft_grid_i.axes[0].to_dask() + k_i)).argmin())
示例#9
0
def streak_grid_array(grid: GridDataset, ) -> GridArray:
    """Converts a `GridDataset` to a `GridArray`. Only `GridDataset` with axes
    that do not change over time are supported.

    Returns
    ------
    :class:`nata.containers.GridArray`:
        Streak of ``grid``.

    Examples
    --------
    Convert a one-dimensional dataset with time dependence to a two-dimensional
    array.

    >>> from nata.containers import GridDataset
    >>> import numpy as np
    >>> data = np.arange(5*7).reshape((5, 7))
    >>> grid = GridDataset.from_array(data)
    >>> stk_grid = grid.streak()
    >>> stk_grid.shape
    (5, 7)
    >>> [axis.shape for axis in stk_grid.axes]
    [(5,), (7,)]

    """

    if grid.ndim < 2:
        raise ValueError(
            "streak is not available for 0 dimensional GridDatasets")

    for axis in grid.axes[1:]:
        for i, axis_i in enumerate(axis):
            if np.any(axis_i.to_dask() != axis[0].to_dask()):
                raise ValueError("invalid axes for streak")

    return GridArray.from_array(
        grid.to_dask(),
        name=grid.name,
        label=grid.label,
        unit=grid.unit,
        axes=[grid.time] + [axis[0] for axis in grid.axes[1:]],
    )
示例#10
0
def test_GridDataset_from_path(grid_files: Path):
    grid = GridDataset.from_path(grid_files / "*")

    assert grid.name == "dummy_grid"
    assert grid.label == "dummy grid label"
    assert grid.unit == "dummy unit"

    assert grid.axes[0].name == "time"
    assert grid.axes[0].label == "time"
    assert grid.axes[0].unit == "dummy time unit"
    np.testing.assert_array_equal(grid.axes[0], [1.0, 1.0, 1.0])

    assert grid.axes[1].name == "dummy_axis0"
    assert grid.axes[2].name == "dummy_axis1"

    assert grid.axes[1].label == "dummy label axis0"
    assert grid.axes[2].label == "dummy label axis1"

    assert grid.axes[1].unit == "dummy unit axis0"
    assert grid.axes[2].unit == "dummy unit axis1"

    np.testing.assert_array_equal(
        grid, np.tile(np.arange(32).reshape((4, 8)), (3, 1, 1)))
示例#11
0
def test_GridDataset_from_array_default():
    grid_ds = GridDataset.from_array(da.arange(12, dtype=int).reshape((4, 3)))

    assert grid_ds.shape == (4, 3)
    assert grid_ds.ndim == 2
    assert grid_ds.dtype == int

    assert grid_ds.axes[0].name == "time"
    assert grid_ds.axes[0].label == "time"
    assert grid_ds.axes[0].unit == ""
    assert grid_ds.axes[0].shape == (4, )

    assert grid_ds.axes[1].name == "axis0"
    assert grid_ds.axes[1].label == "unlabeled"
    assert grid_ds.axes[1].unit == ""
    assert grid_ds.axes[1].shape == (4, 3)
    np.testing.assert_array_equal(grid_ds.axes[1],
                                  np.tile(np.arange(3), (4, 1)))

    assert grid_ds.time is grid_ds.axes[0]

    assert grid_ds.name == "unnamed"
    assert grid_ds.label == "unlabeled"
    assert grid_ds.unit == ""
示例#12
0
def test_dataset_slice_selection():
    grid = GridDataset.from_array(np.arange(12).reshape((4, 3)))

    with pytest.raises(ValueError, match="slice along the time axis is not supported"):
        grid.slice(constant=grid.time.name, value=0)
示例#13
0
def test_streak_shape():
    grid = GridDataset.from_array(np.arange(12).reshape((4, 3)))

    assert grid.streak().shape == grid.shape
示例#14
0
def fft_grid_dataset(
    dataset: GridDataset, type: Optional[str] = "abs",
) -> GridDataset:
    """Computes the Fast Fourier Transform (FFT) of a single/multiple\
       iteration :class:`nata.containers.GridDataset` along all grid axes\
       using `numpy's fft module`__.

        .. _fft: https://numpy.org/doc/stable/reference/routines.fft.html
        __ fft_

        Parameters
        ----------
        type: ``{'abs', 'real', 'imag', 'full'}``, optional
            Defines the component of the FFT selected for output. Available
            values are  ``'abs'`` (default), ``'real'``, ``'imag'`` and
            ``'full'``, which correspond to the absolute value, real component,
            imaginary component and full (complex) result of the FFT,
            respectively.

        Returns
        ------
        :class:`nata.containers.GridDataset`:
            Selected FFT component along all grid axes of ``dataset``.

        Examples
        --------
        To obtain the FFT of a :class:`nata.containers.GridDataset`, a simple
        call to the ``fft()`` method is enough. In the following example, we
        compute the FFT of a one-dimensional
        :class:`nata.containers.GridDataset`.

        >>> from nata.containers import GridDataset
        >>> import numpy as np
        >>> x = np.linspace(100)
        >>> arr = np.exp(-(x-len(x)/2)**2)
        >>> ds = GridDataset(arr[np.newaxis])
        >>> ds_fft = ds.fft()

    """

    fft_data = np.array(dataset)
    fft_axes = np.arange(len(dataset.grid_shape)) + 1

    fft_data = fft.fftn(
        fft_data if len(dataset) > 1 else fft_data[np.newaxis], axes=fft_axes
    )
    fft_data = fft.fftshift(fft_data, axes=fft_axes)

    if type == "real":
        fft_data = np.real(fft_data)
        label = f"Re(FFT({dataset.label}))"
        name = f"fftr_{dataset.name}"
    elif type == "imag":
        fft_data = np.imag(fft_data)
        label = f"Im(FFT({dataset.label}))"
        name = f"ffti_{dataset.name}"
    elif type == "abs":
        fft_data = np.abs(fft_data)
        label = f"|FFT({dataset.label})|"
        name = f"ffta_{dataset.name}"
    else:
        label = f"FFT({dataset.label})"
        name = f"fft_{dataset.name}"

    axes = []
    for a in dataset.axes["grid_axes"]:
        delta = [
            (np.max(a_ts) - np.min(a_ts)) / len(np.array(a_ts)) / (2.0 * np.pi)
            for a_ts in a
        ]
        axis_data = fft.fftshift(
            [
                fft.fftfreq(len(np.array(a_ts)), delta[idx])
                for idx, a_ts in enumerate(a)
            ],
            axes=-1,
        )
        axes.append(
            GridAxis(
                axis_data,
                name=f"k_{a.name}",
                label=f"k_{{{a.label}}}",
                unit=f"1/({a.unit})",
            )
        )

    return GridDataset(
        fft_data,
        name=name,
        label=label,
        unit=dataset.unit,
        grid_axes=axes,
        time=dataset.axes["time"],
        iteration=dataset.axes["iteration"],
    )
示例#15
0
def test_dataset_slice_invalid_ndim():

    with pytest.raises(ValueError, match="0 dimensional GridDatasets"):
        GridDataset.from_array(np.arange(5)).slice(constant="axis0", value=1)
示例#16
0
def test_dataset_fft_selection():
    grid = GridDataset.from_array(np.arange(12).reshape((4, 3)))

    with pytest.raises(ValueError,
                       match="fft along the time axis is not supported"):
        grid.fft(axes=[0])
示例#17
0
def test_GridDataset_raise_invalid_new_axis():
    grid = GridDataset.from_array(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
    with pytest.raises(IndexError):
        grid[np.newaxis]
示例#18
0
def test_GridDataset_from_array_check_name():
    grid_arr = GridDataset.from_array([], name="custom_name")
    assert grid_arr.name == "custom_name"
示例#19
0
def test_GridDatasets_backends_are_registered():
    backends = GridDataset.get_backends()

    assert backends[Osiris_Hdf5_GridFile.name] is Osiris_Hdf5_GridFile
    assert backends[Osiris_Dev_Hdf5_GridFile.name] is Osiris_Dev_Hdf5_GridFile
    assert backends[Osiris_zdf_GridFile.name] is Osiris_zdf_GridFile
示例#20
0
def test_GridDataset_from_array_raise_invalid_name():
    with pytest.raises(ValueError,
                       match="'name' has to be a valid identifier"):
        GridDataset.from_array([], name="invalid name")
示例#21
0
文件: lineout.py 项目: tsung1029/nata
def lineout_grid_dataset(
    dataset: GridDataset,
    fixed: Union[str, int],
    value: float,
) -> GridDataset:
    """Takes a lineout across a two-dimensional, single/multiple iteration\
       :class:`nata.containers.GridDataset`:

        Parameters
        ----------

        fixed: :class:``str`` or :class:``int``
            Selection of the axes along which the taken lineout is constant.

            * if it is a string, then it must match the ``name`` property of an
              existing grid axis in ``dataset``.
            * if it is an integer, then it must match the index of a grid axis
              in ``dataset`` (i.e. `0` or `1`).

        value: scalar
            Value between the minimum and maximum of the axes selected through
            ``fixed`` over which the lineout is taken.

        Returns
        ------
        :class:`nata.containers.GridDataset`:
            One-dimensional :class:`nata.containers.GridDataset`.

        Examples
        --------
        The following example shows how to obtain a lineout from a
        two-dimensional :class:`nata.containers.GridDataset`. Since no axes are
        attributed to the dataset in this example, they are automatically
        generated with no names, and ``fixed`` must be an integer.

        >>> from nata.containers import GridDataset
        >>> import numpy as np
        >>> arr = np.arange(25).reshape((5,5))
        >>> ds = GridDataset(arr[np.newaxis])
        >>> lo = ds.lineout(fixed=0, value=2)
        >>> lo.data
        array([10, 11, 12, 13, 14])

    """
    if len(dataset.grid_shape) != 2:
        raise ValueError(
            "Grid lineouts are only supported for two-dimensional grid datasets"
        )

    # get handle for grid axes
    axes = dataset.axes["grid_axes"]

    if isinstance(fixed, str):
        ax_idx = -1
        # get index based on
        for key, ax in enumerate(axes):
            if ax.name == fixed:
                ax_idx = key
                break
        if ax_idx < 0:
            raise ValueError(
                f"Axis `{fixed}` could not be found in dataset `{dataset}`")
    else:
        ax_idx = fixed

    # build axis values
    axis = axes[ax_idx]

    if value < np.min(axis) or value > np.max(axis):
        raise ValueError(f"Out of range value for fixed `{fixed}`")

    values = np.array(axis)
    idx = (np.abs(values - value)).argmin()

    data = np.array(dataset)

    # get lineout
    if ax_idx == 0:
        lo_data = data[:, idx, :] if len(dataset) > 1 else data[idx, :]
        lo_axis = axes[1]

    elif ax_idx == 1:
        lo_data = data[:, :, idx] if len(dataset) > 1 else data[:, idx]
        lo_axis = axes[0]

    return GridDataset(
        lo_data if len(dataset) > 1 else lo_data[np.newaxis],
        name=f"{dataset.name}_lineout",
        label=dataset.label,
        unit=dataset.unit,
        grid_axes=[lo_axis],
        time=dataset.axes["time"],
        iteration=dataset.axes["iteration"],
    )
示例#22
0
def test_GridDataset_from_array_check_unit():
    grid_arr = GridDataset.from_array([], unit="custom unit")
    assert grid_arr.unit == "custom unit"
示例#23
0
def test_GridDataset_from_array_check_label():
    grid_arr = GridDataset.from_array([], label="custom label")
    assert grid_arr.label == "custom label"
示例#24
0
def test_streak_invalid_ndim():

    with pytest.raises(ValueError, match="0 dimensional GridDatasets"):
        GridDataset.from_array(np.arange(5)).streak()
示例#25
0
def test_streak_type():
    grid = GridDataset.from_array(np.arange(12).reshape((4, 3)))

    assert isinstance(grid.streak(), GridArray)
示例#26
0
def plot_grid_dataset(
    dataset: GridDataset,
    fig: Optional[Figure] = None,
    axes: Optional[Axes] = None,
    style: Optional[dict] = {},
    interactive: Optional[bool] = True,
    n: Optional[int] = 0,
) -> Union[Figure, None]:
    """Plots a single/multiple iteration :class:`nata.containers.GridDataset`\
       using a :class:`nata.plots.types.LinePlot` or\
       :class:`nata.plots.types.ColorPlot` if the dataset is one- or\
       two-dimensional, respectively.

        Parameters
        ----------
        fig: :class:`nata.plots.Figure`, optional
            If provided, the plot is drawn on ``fig``. The plot is drawn on
            ``axes`` if it is a child axes of ``fig``, otherwise a new axes
            is created on ``fig``. If ``fig`` is not provided, a new
            :class:`nata.plots.Figure` is created.

        axes: :class:`nata.plots.Axes`, optional
            If provided, the plot is drawn on ``axes``, which must be an axes
            of ``fig``. If ``axes`` is not provided or is provided without a
            corresponding ``fig``, a new :class:`nata.plots.Axes` is created in
            a new :class:`nata.plots.Figure`.

        style: ``dict``, optional
            Dictionary that takes a mix of style properties of
            :class:`nata.plots.Figure`, :class:`nata.plots.Axes` and any plot
            type (see :class:`nata.plots.types.LinePlot` or
            :class:`nata.plots.types.ColorPlot`).

        interactive: ``bool``, optional
            Controls wether interactive widgets should be shown with the plot
            to allow for temporal navigation. Only applicable if ``dataset``
            has multiple iterations.

        n: ``int``, optional
            Selects the index of the iteration to be shown initially. Only
            applicable if ``dataset`` has multiple iterations, .

        Returns
        ------
        :class:`nata.plots.Figure` or ``None``:
            Figure with plot built based on ``dataset``. Interactive widgets
            are shown with the figure if ``dataset`` has multiple iterations,
            in which case this method returns  ``None``.

        Examples
        --------
        To get a plot with default style properties in a new figure, simply
        call the ``.plot()`` method of the dataset.

        >>> from nata.containers import GridDataset
        >>> import numpy as np
        >>> arr = np.arange(10)
        >>> ds = GridDataset.from_array(arr)
        >>> fig = ds.plot()

        In case a :class:`nata.plots.Figure` is returned by the method, it can
        be shown by calling the :func:`nata.plots.Figure.show` method.

        >>> fig.show()

        To draw a new plot on ``fig``, we can pass it as an argument to the
        ``.plot()`` method. If ``axes`` is provided, the new plot is drawn on
        the selected axes.

        >>> ds2 = GridDataset.from_array(arr**2)
        >>> fig = ds2.plot(fig=fig, axes=fig.axes[0])


    """

    p_plan = PlotPlan(dataset=dataset,
                      style=filter_style(dataset.plot_type(), style))

    a_plan = AxesPlan(axes=axes,
                      plots=[p_plan],
                      style=filter_style(Axes, style))

    f_plan = FigurePlan(fig=fig,
                        axes=[a_plan],
                        style=filter_style(Figure, style))

    if len(dataset) > 1:
        if inside_notebook():
            if interactive:
                f_plan.build_interactive(n)
            else:
                return f_plan[n].build()
        else:
            # TODO: remove last line from warn
            warn(f"Plotting only iteration with index n={str(n)}." +
                 " Interactive plots of multiple iteration datasets are not" +
                 " supported outside notebook environments.")
            return f_plan[n].build()

    else:
        return f_plan.build()