Esempio n. 1
0
def plot_norm_sas(res: xr.Dataset,
                  ax: Axis,
                  title: str = "norm SAS",
                  cycler: Cycler | None = PlotStyle().cycler) -> None:
    """Plot normalized SAS (Species Associated Spectra) on ``ax``.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    title : str
        Title of the plot. Defaults to "norm SAS".
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    keys = [
        v for v in res.data_vars
        if v.startswith(("species_associated_spectra", "species_spectra"))
    ]
    for key in keys:
        sas = res[key]
        # sas = res.species_associated_spectra
        (sas / np.abs(sas).max(dim="spectral")).plot.line(x="spectral", ax=ax)
        ax.set_title(title)
        ax.get_legend().remove()
Esempio n. 2
0
def plot_das(res: xr.Dataset,
             ax: Axis,
             title: str = "DAS",
             cycler: Cycler | None = PlotStyle().cycler) -> None:
    """Plot DAS (Decay Associated Spectra) on ``ax``.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    title : str
        Title of the plot. Defaults to "DAS".
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    keys = [
        v for v in res.data_vars
        if v.startswith(("decay_associated_spectra", "species_spectra"))
    ]
    for key in keys:
        das = res[key]
        das.plot.line(x="spectral", ax=ax)
        ax.set_title(title)
        ax.get_legend().remove()
Esempio n. 3
0
def plot_sv_residual(
    res: xr.Dataset,
    ax: Axis,
    indices: Sequence[int] = range(10),
    cycler: Cycler | None = PlotStyle().cycler,
) -> None:
    """Plot singular values of the residual matrix.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    indices : Sequence[int]
        Indices of the singular vector to plot. Defaults to range(4).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    if "weighted_residual_singular_values" in res:
        rSV = res.weighted_residual_singular_values
    else:
        rSV = res.residual_singular_values
    rSV.sel(singular_value_index=indices[:len(rSV.singular_value_index)]
            ).plot.line("ro-", yscale="log", ax=ax)
    ax.set_title("res. log(SV)")
Esempio n. 4
0
def plot_rsv_residual(
    res: xr.Dataset,
    ax: Axis,
    indices: Sequence[int] = range(2),
    cycler: Cycler | None = PlotStyle().cycler,
    show_legend: bool = True,
) -> None:
    """Plot right singular vectors (spectra) of the residual matrix.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    indices : Sequence[int]
        Indices of the singular vector to plot. Defaults to range(4).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    show_legend: bool
        Whether or not to show the legend. Defaults to True.
    """
    add_cycler_if_not_none(ax, cycler)
    if "weighted_residual_right_singular_vectors" in res:
        rRSV = res.weighted_residual_right_singular_vectors
    else:
        rRSV = res.residual_right_singular_vectors
    _plot_svd_vetors(rRSV, indices, "right_singular_value_index", ax,
                     show_legend)
    ax.set_title("res. RSV")
def plot_concentrations(
    res: xr.Dataset,
    ax: Axis,
    center_λ: float | None,
    linlog: bool = False,
    linthresh: float = 1,
    linscale: float = 1,
    main_irf_nr: int = 0,
    cycler: Cycler | None = PlotStyle().cycler,
    title: str = "Concentrations",
) -> None:
    """Plot traces on the given axis ``ax``.

    Parameters
    ----------
    res: xr.Dataset
        Result dataset from a pyglotaran optimization.
    ax: Axis
        Axis to plot the traces on
    center_λ: float | None
        Center wavelength (λ in nm)
    linlog: bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh: float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    linscale: float
        This allows the linear range (-linthresh to linthresh) to be stretched
        relative to the logarithmic range.
        Its value is the number of decades to use for each half of the linear range.
        For example, when linscale == 1.0 (the default), the space used for the
        positive and negative halves of the linear range will be equal to one
        decade in the logarithmic range. Defaults to 1.
    main_irf_nr: int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid.
    title: str
        Title used for the plot axis. Defaults to "Concentrations".

    See Also
    --------
    get_shifted_traces
    """
    add_cycler_if_not_none(ax, cycler)
    traces = get_shifted_traces(res, center_λ, main_irf_nr)

    if "spectral" in traces.coords:
        traces.sel(spectral=center_λ, method="nearest").plot.line(x="time",
                                                                  ax=ax)
    else:
        traces.plot.line(x="time", ax=ax)
    ax.set_title(title)

    if linlog:
        ax.set_xscale("symlog", linthresh=linthresh, linscale=linscale)
Esempio n. 6
0
def plot_spectra(res: xr.Dataset,
                 axes: Axes,
                 cycler: Cycler | None = PlotStyle().cycler) -> None:
    """Plot spectra such as SAS and DAS as well as their normalize version on ``axes``.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    axes : Axes
        Axes to plot the spectra on (needs to be at least 2x2).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    plot_sas(res, axes[0, 0], cycler=cycler)
    plot_das(res, axes[0, 1], cycler=cycler)
    plot_norm_sas(res, axes[1, 0], cycler=cycler)
    plot_norm_das(res, axes[1, 1], cycler=cycler)
def plot_residual(
    res: xr.Dataset,
    ax: Axis,
    linlog: bool = False,
    linthresh: float = 1,
    show_data: bool = False,
    cycler: Cycler | None = PlotStyle().cycler,
) -> None:
    """Plot data or residual on a 2D contour plot.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    show_data : bool
        Whether to show the data or the residual. Defaults to False.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    data = res.data if show_data else res.residual
    title = "dataset" if show_data else "residual"
    shape = np.array(data.shape)
    dims = data.coords.dims
    # Handle different dimensionality of data
    if min(shape) == 1:
        data.plot.line(x=dims[shape.argmax()], ax=ax)
    elif min(shape) < 5:
        data.plot(x="time", ax=ax)
    else:
        data.plot(x="time", ax=ax, add_colorbar=False)
    if linlog:
        ax.set_xscale("symlog", linthresh=linthresh)
    ax.set_title(title)
def load_and_plot_results():
    # %% Plot and save as PDF
    # This set subsequent plots to the glotaran style
    plot_style = PlotStyle()
    plt.rc("axes", prop_cycle=plot_style.cycler)

    parameter_file = output_folder.joinpath("optimized_parameters.csv")
    parameters = read_parameters_from_csv_file(str(parameter_file))
    print(f"Optimized parameters loaded:\n {parameters}")

    result1 = output_folder.joinpath("dataset1.nc")
    fig1 = plot_overview(result1, linlog=True, show_data=True)
    timestamp = datetime.today().strftime("%y%m%d_%H%M")
    fig1.savefig(output_folder.joinpath(f"plot_overview_1of2_{timestamp}.pdf"),
                 bbox_inches="tight")

    result2 = output_folder.joinpath("dataset2.nc")
    fig2 = plot_overview(result2, linlog=True)
    timestamp = datetime.today().strftime("%y%m%d_%H%M")
    fig2.savefig(output_folder.joinpath(f"plot_overview_2of2_{timestamp}.pdf"),
                 bbox_inches="tight")
    plt.show()
Esempio n. 9
0
def plot_lsv_residual(
    res: xr.Dataset,
    ax: Axis,
    indices: Sequence[int] = range(2),
    linlog: bool = False,
    linthresh: float = 1,
    cycler: Cycler | None = PlotStyle().cycler,
    show_legend: bool = True,
) -> None:
    """Plot left singular vectors (time) of the residual matrix.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    indices : Sequence[int]
        Indices of the singular vector to plot. Defaults to range(4).
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    show_legend: bool
        Whether or not to show the legend. Defaults to True.
    """
    add_cycler_if_not_none(ax, cycler)
    if "weighted_residual_left_singular_vectors" in res:
        rLSV = res.weighted_residual_left_singular_vectors
    else:
        rLSV = res.residual_left_singular_vectors
    _plot_svd_vetors(rLSV, indices, "left_singular_value_index", ax,
                     show_legend)
    ax.set_title("res. LSV")
    if linlog:
        ax.set_xscale("symlog", linthresh=linthresh)
Esempio n. 10
0
def plot_overview(
    result: DatasetConvertible,
    center_λ: float | None = None,
    linlog: bool = True,
    linthresh: float = 1,
    linscale: float = 1,
    show_data: bool = False,
    main_irf_nr: int = 0,
    figsize: tuple[int, int] = (18, 16),
    cycler: Cycler | None = PlotStyle().cycler,
    figure_only: bool = True,
    nr_of_data_svd_vectors: int = 4,
    nr_of_residual_svd_vectors: int = 2,
    show_data_svd_legend: bool = True,
    show_residual_svd_legend: bool = True,
) -> Figure | tuple[Figure, Axes]:
    """Plot overview of the optimization result.

    Parameters
    ----------
    result: DatasetConvertible
        Result from a pyglotaran optimization as dataset, Path or Result object.
    center_λ: float | None
        Center wavelength (λ in nm)
    linlog: bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh: float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    linscale: float
        This allows the linear range (-linthresh to linthresh) to be stretched
        relative to the logarithmic range.
        Its value is the number of decades to use for each half of the linear range.
        For example, when linscale == 1.0 (the default), the space used for the
        positive and negative halves of the linear range will be equal to one
        decade in the logarithmic range. Defaults to 1.
    show_data: bool
        Whether to show the input data or residual. Defaults to False.
    main_irf_nr: int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (18, 16).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    figure_only: bool
        Whether or not to only return the figure.
        This is a deprecation helper argument to transition to a consistent return value
        consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.
    nr_of_data_svd_vectors: int
        Number of data SVD vector to plot. Defaults to 4.
    nr_of_residual_svd_vectors: int
        Number of residual SVD vector to plot. Defaults to 2.
    show_data_svd_legend: bool
        Whether or not to show the data SVD legend. Defaults to True.
    show_residual_svd_legend: bool
        Whether or not to show the residual SVD legend. Defaults to True.

    Returns
    -------
    Figure|tuple[Figure, Axes]
        If ``figure_only`` is True, Figure object which contains the plots (deprecated).
        If ``figure_only`` is False, Figure object which contains the plots and the Axes.
    """
    res = load_data(result)

    # Plot dimensions
    M = 4
    N = 3
    fig, axes = plt.subplots(M, N, figsize=figsize, constrained_layout=True)

    if center_λ is None:  # center wavelength (λ in nm)
        center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2))

    # First and second row: concentrations - SAS/EAS - DAS
    plot_concentrations(
        res,
        axes[0, 0],
        center_λ,
        linlog=linlog,
        linthresh=linthresh,
        linscale=linscale,
        main_irf_nr=main_irf_nr,
        cycler=cycler,
    )
    plot_spectra(res, axes[0:2, 1:3], cycler=cycler)
    plot_svd(
        res,
        axes[2:4, 0:3],
        linlog=linlog,
        linthresh=linthresh,
        cycler=cycler,
        nr_of_data_svd_vectors=nr_of_data_svd_vectors,
        nr_of_residual_svd_vectors=nr_of_residual_svd_vectors,
        show_data_svd_legend=show_data_svd_legend,
        show_residual_svd_legend=show_residual_svd_legend,
    )
    plot_residual(
        res, axes[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data, cycler=cycler
    )
    # plt.tight_layout(pad=3, w_pad=4.0, h_pad=4.0)
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2)
        return fig
    else:
        return fig, axes
Esempio n. 11
0
def plot_simple_overview(
    result: DatasetConvertible,
    title: str | None = None,
    figsize: tuple[int, int] = (12, 6),
    cycler: Cycler | None = PlotStyle().cycler,
    figure_only: bool = True,
) -> Figure | tuple[Figure, Axes]:
    """Plot simple overview.

    Parameters
    ----------
    result: DatasetConvertible
        Result from a pyglotaran optimization as dataset, Path or Result object.
    title: str | None
        Title of the figure. Defaults to None.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (18, 16).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    figure_only: bool
        Whether or not to only return the figure.
        This is a deprecation helper argument to transition to a consistent return value
        consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.

    Returns
    -------
    Figure|tuple[Figure, Axes]
        If ``figure_only`` is True, Figure object which contains the plots (deprecated).
        If ``figure_only`` is False, Figure object which contains the plots and the Axes.
    """
    res = load_data(result)

    fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True)
    for ax in axes.flatten():
        add_cycler_if_not_none(ax, cycler)
    if title:
        fig.suptitle(title, fontsize=16)
    sas = res.species_associated_spectra
    traces = res.species_concentration
    if "spectral" in traces.coords:
        traces.sel(spectral=res.spectral.values[0], method="nearest").plot.line(
            x="time", ax=axes[0, 0]
        )
    else:
        traces.plot.line(x="time", ax=axes[0, 0])
    sas.plot.line(x="spectral", ax=axes[0, 1])
    rLSV = res.residual_left_singular_vectors
    rLSV.isel(left_singular_value_index=range(min(2, len(rLSV)))).plot.line(
        x="time", ax=axes[1, 0]
    )

    axes[1, 0].set_title("res. LSV")
    rRSV = res.residual_right_singular_vectors
    rRSV.isel(right_singular_value_index=range(min(2, len(rRSV)))).plot.line(
        x="spectral", ax=axes[1, 1]
    )

    axes[1, 1].set_title("res. RSV")
    res.data.plot(x="time", ax=axes[0, 2])
    axes[0, 2].set_title("data")
    res.residual.plot(x="time", ax=axes[1, 2])
    axes[1, 2].set_title("residual")
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2)
        return fig
    else:
        return fig, axes
Esempio n. 12
0
def plot_doas(
    result: DatasetConvertible,
    figsize: tuple[int, int] = (25, 25),
    cycler: Cycler | None = PlotStyle().cycler,
    figure_only: bool = True,
) -> Figure | tuple[Figure, Axes]:
    """Plot Damped oscillation associated spectra (DOAS).

    Parameters
    ----------
    result: DatasetConvertible
        Result from a pyglotaran optimization as dataset, Path or Result object.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (18, 16).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    figure_only: bool
        Whether or not to only return the figure.
        This is a deprecation helper argument to transition to a consistent return value
        consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.

    Returns
    -------
    Figure|tuple[Figure, Axes]
        If ``figure_only`` is True, Figure object which contains the plots (deprecated).
        If ``figure_only`` is False, Figure object which contains the plots and the Axes.
    """
    dataset = load_data(result)

    # Create M x N plotting grid
    M = 6
    N = 3

    fig, axes = plt.subplots(M, N, figsize=figsize)

    for ax in axes.flatten():
        add_cycler_if_not_none(ax, cycler)

    # Plot data
    dataset.species_associated_spectra.plot.line(x="spectral", ax=axes[0, 0])
    dataset.decay_associated_spectra.plot.line(x="spectral", ax=axes[0, 1])

    if "spectral" in dataset.species_concentration.coords:
        dataset.species_concentration.isel(spectral=0).plot.line(x="time",
                                                                 ax=axes[1, 0])
    else:
        dataset.species_concentration.plot.line(x="time", ax=axes[1, 0])
    axes[1, 0].set_xscale("symlog", linthreshx=1)

    if "dampened_oscillation_associated_spectra" in dataset:
        dataset.dampened_oscillation_cos.isel(spectral=0).sel(
            time=slice(-1, 10)).plot.line(x="time", ax=axes[1, 1])
        dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral",
                                                                  ax=axes[2,
                                                                          0])
        dataset.dampened_oscillation_phase.plot.line(x="spectral",
                                                     ax=axes[2, 1])

    dataset.residual_left_singular_vectors.isel(
        left_singular_value_index=0).plot(ax=axes[0, 2])
    dataset.residual_singular_values.plot.line("ro-",
                                               yscale="log",
                                               ax=axes[1, 2])
    dataset.residual_right_singular_vectors.isel(
        right_singular_value_index=0).plot(ax=axes[2, 2])

    interval = int(dataset.spectral.size / 11)
    for i in range(0):
        axi = axes[i % 3, int(i / 3) + 3]
        index = (i + 1) * interval
        dataset.data.isel(spectral=index).plot(ax=axi)
        dataset.residual.isel(spectral=index).plot(ax=axi)
        dataset.fitted_data.isel(spectral=index).plot(ax=axi)

    plt.tight_layout(pad=5, w_pad=2.0, h_pad=2.0)
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING),
             stacklevel=2)
        return fig
    else:
        return fig, axes
Esempio n. 13
0
"""Tests for pyglotaran_extras.plotting.utils"""
from __future__ import annotations

import matplotlib
import matplotlib.pyplot as plt
import pytest
from cycler import Cycler
from cycler import cycle

from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none

matplotlib.use("Agg")
DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"]


@pytest.mark.parametrize(
    "cycler,expected_cycler",
    ((None, DEFAULT_CYCLER()), (PlotStyle().cycler, PlotStyle().cycler())),
)
def test_add_cycler_if_not_none(cycler: Cycler | None, expected_cycler: cycle):
    """Default cycler inf None and cycler otherwise"""
    ax = plt.subplot()
    add_cycler_if_not_none(ax, cycler)

    for _ in range(10):
        assert next(ax._get_lines.prop_cycler) == next(expected_cycler)
Esempio n. 14
0
def plot_svd(
    res: xr.Dataset,
    axes: Axes,
    linlog: bool = False,
    linthresh: float = 1,
    cycler: Cycler | None = PlotStyle().cycler,
    nr_of_data_svd_vectors: int = 4,
    nr_of_residual_svd_vectors: int = 2,
    show_data_svd_legend: bool = True,
    show_residual_svd_legend: bool = True,
) -> None:
    """Plot SVD (Singular Value Decomposition) of data and residual.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    axes : Axes
        Axes to plot the SVDs on (needs to be at least 2x3).
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    nr_of_data_svd_vectors: int
        Number of data SVD vector to plot. Defaults to 4.
    nr_of_residual_svd_vectors: int
        Number of residual SVD vector to plot. Defaults to 2.
    show_data_svd_legend: bool
        Whether or not to show the data SVD legend. Defaults to True.
    show_residual_svd_legend: bool
        Whether or not to show the residual SVD legend. Defaults to True.
    """
    if "weighted_residual" in res:
        add_svd_to_dataset(dataset=res, name="weighted_residual")
    else:
        add_svd_to_dataset(dataset=res, name="residual")
    plot_lsv_residual(
        res,
        axes[0, 0],
        linlog=linlog,
        linthresh=linthresh,
        cycler=cycler,
        indices=range(nr_of_residual_svd_vectors),
        show_legend=show_residual_svd_legend,
    )
    plot_rsv_residual(
        res,
        axes[0, 1],
        cycler=cycler,
        indices=range(nr_of_residual_svd_vectors),
        show_legend=show_residual_svd_legend,
    )
    plot_sv_residual(res, axes[0, 2], cycler=cycler)
    add_svd_to_dataset(dataset=res, name="data")
    plot_lsv_data(
        res,
        axes[1, 0],
        linlog=linlog,
        linthresh=linthresh,
        cycler=cycler,
        indices=range(nr_of_data_svd_vectors),
        show_legend=show_data_svd_legend,
    )
    plot_rsv_data(
        res,
        axes[1, 1],
        cycler=cycler,
        indices=range(nr_of_data_svd_vectors),
        show_legend=show_data_svd_legend,
    )
    plot_sv_data(res, axes[1, 2], cycler=cycler)
Esempio n. 15
0
def plot_fitted_traces(
    result: ResultLike,
    wavelengths: Iterable[float],
    axes_shape: tuple[int, int] = (4, 4),
    center_λ: float | None = None,
    main_irf_nr: int = 0,
    linlog: bool = False,
    linthresh: float = 1,
    divide_by_scale: bool = True,
    per_axis_legend: bool = False,
    figsize: tuple[int, int] = (30, 15),
    title: str = "Fit overview",
    y_label: str = "a.u.",
    cycler: Cycler | None = PlotStyle().data_cycler_solid,
) -> tuple[Figure, Axes]:
    """Plot data and their fit in per wavelength plot grid.

    Parameters
    ----------
    result : ResultLike
        Data structure which can be converted to a mapping of datasets.
    axes_shape : tuple[int, int]
        Shape of the plot grid (N, M). Defaults to (4, 4).
    wavelengths: Iterable[float]
        Wavelength which should be used for each subplot, should to be of length N*M
        with ``axes_shape`` being of shape (N, M), else it will result in missing plots.
    center_λ: float | None
        Center wavelength of the IRF (λ in nm).
    main_irf_nr : int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    divide_by_scale : bool
        Whether or not to divide the data by the dataset scale used for optimization.
       . Defaults to True.
    per_axis_legend : bool
        Whether to use a legend per plot or for the whole figure. Defaults to False.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (30, 15).
    title : str
        Title to add to the figure. Defaults to "Fit overview".
    y_label: str
        Label used for the y-axis of each subplot.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid.

    Returns
    -------
    tuple[Figure, Axes]
        Figure and axes which can then be refined by the user.

    See Also
    --------
    maximum_coordinate_range
    add_unique_figure_legend
    plot_data_and_fits
    calculate_wavelengths
    """
    result_map = result_dataset_mapping(result)
    fig, axes = plt.subplots(*axes_shape, figsize=figsize)
    nr_of_plots = len(axes.flatten())
    max_spectral_values = max(
        len(result_map[dataset_name].coords["spectral"])
        for dataset_name in result_map.keys())
    if nr_of_plots > max_spectral_values:
        warn(
            PlotDuplicationWarning(
                f"The number of plots ({nr_of_plots}) exceeds the maximum number of "
                f"spectral data points ({max_spectral_values}), "
                "which will lead in duplicated plots."),
            stacklevel=2,
        )
    for wavelength, axis in zip(wavelengths, axes.flatten()):
        plot_data_and_fits(
            result=result_map,
            wavelength=wavelength,
            axis=axis,
            center_λ=center_λ,
            main_irf_nr=main_irf_nr,
            linlog=linlog,
            linthresh=linthresh,
            divide_by_scale=divide_by_scale,
            per_axis_legend=per_axis_legend,
            y_label=y_label,
            cycler=cycler,
        )
    if per_axis_legend is False:
        add_unique_figure_legend(fig, axes)
    fig.suptitle(title, fontsize=28)
    fig.tight_layout()
    return fig, axes
Esempio n. 16
0
        "dataset1": dataset1,
        "dataset2": dataset2,
        "dataset3": dataset3
    },
    maximum_number_function_evaluations=99,
    non_negative_least_squares=True,
    # optimization_method="Levenberg-Marquardt",
)
# optimize
result = optimize(scheme)
# %% Save results
result.save(str(output_folder))

# %% Plot results
# Set subsequent plots to the glotaran style
plot_style = PlotStyle()
plt.rc("axes", prop_cycle=plot_style.cycler)

# TODO: enhance plot_overview to handle multiple datasets
result_datafile1 = output_folder.joinpath("dataset1.nc")
result_datafile2 = output_folder.joinpath("dataset2.nc")
result_datafile3 = output_folder.joinpath("dataset3.nc")
fig1 = plot_overview(result_datafile1, linlog=True, linthresh=1)
fig1.savefig(
    output_folder.joinpath("plot_overview_sim3d_d1.pdf"),
    bbox_inches="tight",
)

fig2 = plot_overview(result_datafile2, linlog=True, linthresh=1)
fig2.savefig(
    output_folder.joinpath("plot_overview_sim3d_d2.pdf"),
Esempio n. 17
0
def plot_data_and_fits(
    result: ResultLike,
    wavelength: float,
    axis: Axis,
    center_λ: float | None = None,
    main_irf_nr: int = 0,
    linlog: bool = False,
    linthresh: float = 1,
    divide_by_scale: bool = True,
    per_axis_legend: bool = False,
    y_label: str = "a.u.",
    cycler: Cycler | None = PlotStyle().data_cycler_solid,
) -> None:
    """Plot data and fits for a given ``wavelength`` on a given ``axis``.

    If the wavelength isn't part of a dataset, that dataset will be skipped.

    Parameters
    ----------
    result : ResultLike
        Data structure which can be converted to a mapping.
    wavelength : float
        Wavelength to plot data and fits for.
    axis: Axis
        Axis to plot the data and fits on.
    center_λ: float | None
        Center wavelength (λ in nm)
    main_irf_nr : int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    divide_by_scale : bool
        Whether or not to divide the data by the dataset scale used for optimization.
        Defaults to True.
    per_axis_legend: bool
        Whether to use a legend per plot or for the whole figure. Defaults to False.
    y_label: str
        Label used for the y-axis of each subplot.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid.

    See Also
    --------
    plot_fit_overview
    """
    result_map = result_dataset_mapping(result)
    add_cycler_if_not_none(axis, cycler)
    for dataset_name in result_map.keys():
        spectral_coords = result_map[dataset_name].coords["spectral"].values
        if spectral_coords.min() <= wavelength <= spectral_coords.max():
            result_data = result_map[dataset_name].sel(spectral=[wavelength],
                                                       method="nearest")
            scale = extract_dataset_scale(result_data, divide_by_scale)
            irf_loc = extract_irf_location(result_data, center_λ, main_irf_nr)
            result_data = result_data.assign_coords(
                time=result_data.coords["time"] - irf_loc)
            (result_data.data / scale).plot(x="time",
                                            ax=axis,
                                            label=f"{dataset_name}_data")
            (result_data.fitted_data / scale).plot(x="time",
                                                   ax=axis,
                                                   label=f"{dataset_name}_fit")
        else:
            [next(axis._get_lines.prop_cycler) for _ in range(2)]
    if linlog:
        axis.set_xscale("symlog", linthresh=linthresh)
    axis.set_ylabel(y_label)
    if per_axis_legend is True:
        axis.legend()