コード例 #1
0
ファイル: utils.py プロジェクト: s-weigand/pyglotaran-extras
def result_dataset_mapping(result: ResultLike) -> Mapping[str, xr.Dataset]:
    """Convert a ``ResultLike`` object to a per dataset mapping of result like data.

    Parameters
    ----------
    result : ResultLike
        Data structure which can be converted to a mapping.

    Returns
    -------
    Mapping[str, Dataset]
        Per dataset mapping of result like data.

    Raises
    ------
    TypeError
        If any value of a ``result`` isn't of :class:`DatasetConvertible`.
    TypeError
        If ``result`` isn't a :class:`ResultLike` object.
    """
    result_mapping = {}
    if isinstance(result, Result):
        return result.data
    if isinstance(result, (xr.Dataset, xr.DataArray, Path, str)):
        return {"dataset": load_data(result)}
    if isinstance(result, Sequence):
        for index, value in enumerate(result):
            result_mapping[f"dataset{index}"] = load_data(value)
        return result_mapping
    if isinstance(result, Mapping):
        for key, value in result.items():
            result_mapping[key] = load_data(value)
        return result_mapping
    raise TypeError(f"Result needs to be of type {ResultLike!r}, but was {result!r}.")
コード例 #2
0
def plot_overview(
    result: DatasetConvertible,
    center_λ: float | None = None,
    linlog: bool = True,
    linthresh: float = 1,
    linscale: float = 1,
    show_data: bool = False,
    main_irf_nr: int = 0,
    figsize: tuple[int, int] = (18, 16),
    cycler: Cycler | None = PlotStyle().cycler,
    figure_only: bool = True,
    nr_of_data_svd_vectors: int = 4,
    nr_of_residual_svd_vectors: int = 2,
    show_data_svd_legend: bool = True,
    show_residual_svd_legend: bool = True,
) -> Figure | tuple[Figure, Axes]:
    """Plot overview of the optimization result.

    Parameters
    ----------
    result: DatasetConvertible
        Result from a pyglotaran optimization as dataset, Path or Result object.
    center_λ: float | None
        Center wavelength (λ in nm)
    linlog: bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh: float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    linscale: float
        This allows the linear range (-linthresh to linthresh) to be stretched
        relative to the logarithmic range.
        Its value is the number of decades to use for each half of the linear range.
        For example, when linscale == 1.0 (the default), the space used for the
        positive and negative halves of the linear range will be equal to one
        decade in the logarithmic range. Defaults to 1.
    show_data: bool
        Whether to show the input data or residual. Defaults to False.
    main_irf_nr: int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (18, 16).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    figure_only: bool
        Whether or not to only return the figure.
        This is a deprecation helper argument to transition to a consistent return value
        consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.
    nr_of_data_svd_vectors: int
        Number of data SVD vector to plot. Defaults to 4.
    nr_of_residual_svd_vectors: int
        Number of residual SVD vector to plot. Defaults to 2.
    show_data_svd_legend: bool
        Whether or not to show the data SVD legend. Defaults to True.
    show_residual_svd_legend: bool
        Whether or not to show the residual SVD legend. Defaults to True.

    Returns
    -------
    Figure|tuple[Figure, Axes]
        If ``figure_only`` is True, Figure object which contains the plots (deprecated).
        If ``figure_only`` is False, Figure object which contains the plots and the Axes.
    """
    res = load_data(result)

    # Plot dimensions
    M = 4
    N = 3
    fig, axes = plt.subplots(M, N, figsize=figsize, constrained_layout=True)

    if center_λ is None:  # center wavelength (λ in nm)
        center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2))

    # First and second row: concentrations - SAS/EAS - DAS
    plot_concentrations(
        res,
        axes[0, 0],
        center_λ,
        linlog=linlog,
        linthresh=linthresh,
        linscale=linscale,
        main_irf_nr=main_irf_nr,
        cycler=cycler,
    )
    plot_spectra(res, axes[0:2, 1:3], cycler=cycler)
    plot_svd(
        res,
        axes[2:4, 0:3],
        linlog=linlog,
        linthresh=linthresh,
        cycler=cycler,
        nr_of_data_svd_vectors=nr_of_data_svd_vectors,
        nr_of_residual_svd_vectors=nr_of_residual_svd_vectors,
        show_data_svd_legend=show_data_svd_legend,
        show_residual_svd_legend=show_residual_svd_legend,
    )
    plot_residual(
        res, axes[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data, cycler=cycler
    )
    # plt.tight_layout(pad=3, w_pad=4.0, h_pad=4.0)
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2)
        return fig
    else:
        return fig, axes
コード例 #3
0
def plot_simple_overview(
    result: DatasetConvertible,
    title: str | None = None,
    figsize: tuple[int, int] = (12, 6),
    cycler: Cycler | None = PlotStyle().cycler,
    figure_only: bool = True,
) -> Figure | tuple[Figure, Axes]:
    """Plot simple overview.

    Parameters
    ----------
    result: DatasetConvertible
        Result from a pyglotaran optimization as dataset, Path or Result object.
    title: str | None
        Title of the figure. Defaults to None.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (18, 16).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    figure_only: bool
        Whether or not to only return the figure.
        This is a deprecation helper argument to transition to a consistent return value
        consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.

    Returns
    -------
    Figure|tuple[Figure, Axes]
        If ``figure_only`` is True, Figure object which contains the plots (deprecated).
        If ``figure_only`` is False, Figure object which contains the plots and the Axes.
    """
    res = load_data(result)

    fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True)
    for ax in axes.flatten():
        add_cycler_if_not_none(ax, cycler)
    if title:
        fig.suptitle(title, fontsize=16)
    sas = res.species_associated_spectra
    traces = res.species_concentration
    if "spectral" in traces.coords:
        traces.sel(spectral=res.spectral.values[0], method="nearest").plot.line(
            x="time", ax=axes[0, 0]
        )
    else:
        traces.plot.line(x="time", ax=axes[0, 0])
    sas.plot.line(x="spectral", ax=axes[0, 1])
    rLSV = res.residual_left_singular_vectors
    rLSV.isel(left_singular_value_index=range(min(2, len(rLSV)))).plot.line(
        x="time", ax=axes[1, 0]
    )

    axes[1, 0].set_title("res. LSV")
    rRSV = res.residual_right_singular_vectors
    rRSV.isel(right_singular_value_index=range(min(2, len(rRSV)))).plot.line(
        x="spectral", ax=axes[1, 1]
    )

    axes[1, 1].set_title("res. RSV")
    res.data.plot(x="time", ax=axes[0, 2])
    axes[0, 2].set_title("data")
    res.residual.plot(x="time", ax=axes[1, 2])
    axes[1, 2].set_title("residual")
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2)
        return fig
    else:
        return fig, axes
コード例 #4
0
    rRSV.isel(right_singular_value_index=range(min(2, len(rRSV)))).plot.line(
        x="spectral", ax=axes[1, 1]
    )

    axes[1, 1].set_title("res. RSV")
    res.data.plot(x="time", ax=axes[0, 2])
    axes[0, 2].set_title("data")
    res.residual.plot(x="time", ax=axes[1, 2])
    axes[1, 2].set_title("residual")
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2)
        return fig
    else:
        return fig, axes


if __name__ == "__main__":
    import sys

    result_path = Path(sys.argv[1])
    res = load_data(result_path)
    print(res)

    fig, plt.axes = plot_overview(res, figure_only=False)
    if len(sys.argv) > 2:
        fig.savefig(sys.argv[2], bbox_inches="tight")
        print(f"Saved figure to: {sys.argv[2]}")
    else:
        plt.show(block=False)
        input("press <ENTER> to continue")
コード例 #5
0
def plot_doas(
    result: DatasetConvertible,
    figsize: tuple[int, int] = (25, 25),
    cycler: Cycler | None = PlotStyle().cycler,
    figure_only: bool = True,
) -> Figure | tuple[Figure, Axes]:
    """Plot Damped oscillation associated spectra (DOAS).

    Parameters
    ----------
    result: DatasetConvertible
        Result from a pyglotaran optimization as dataset, Path or Result object.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (18, 16).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    figure_only: bool
        Whether or not to only return the figure.
        This is a deprecation helper argument to transition to a consistent return value
        consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.

    Returns
    -------
    Figure|tuple[Figure, Axes]
        If ``figure_only`` is True, Figure object which contains the plots (deprecated).
        If ``figure_only`` is False, Figure object which contains the plots and the Axes.
    """
    dataset = load_data(result)

    # Create M x N plotting grid
    M = 6
    N = 3

    fig, axes = plt.subplots(M, N, figsize=figsize)

    for ax in axes.flatten():
        add_cycler_if_not_none(ax, cycler)

    # Plot data
    dataset.species_associated_spectra.plot.line(x="spectral", ax=axes[0, 0])
    dataset.decay_associated_spectra.plot.line(x="spectral", ax=axes[0, 1])

    if "spectral" in dataset.species_concentration.coords:
        dataset.species_concentration.isel(spectral=0).plot.line(x="time",
                                                                 ax=axes[1, 0])
    else:
        dataset.species_concentration.plot.line(x="time", ax=axes[1, 0])
    axes[1, 0].set_xscale("symlog", linthreshx=1)

    if "dampened_oscillation_associated_spectra" in dataset:
        dataset.dampened_oscillation_cos.isel(spectral=0).sel(
            time=slice(-1, 10)).plot.line(x="time", ax=axes[1, 1])
        dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral",
                                                                  ax=axes[2,
                                                                          0])
        dataset.dampened_oscillation_phase.plot.line(x="spectral",
                                                     ax=axes[2, 1])

    dataset.residual_left_singular_vectors.isel(
        left_singular_value_index=0).plot(ax=axes[0, 2])
    dataset.residual_singular_values.plot.line("ro-",
                                               yscale="log",
                                               ax=axes[1, 2])
    dataset.residual_right_singular_vectors.isel(
        right_singular_value_index=0).plot(ax=axes[2, 2])

    interval = int(dataset.spectral.size / 11)
    for i in range(0):
        axi = axes[i % 3, int(i / 3) + 3]
        index = (i + 1) * interval
        dataset.data.isel(spectral=index).plot(ax=axi)
        dataset.residual.isel(spectral=index).plot(ax=axi)
        dataset.fitted_data.isel(spectral=index).plot(ax=axi)

    plt.tight_layout(pad=5, w_pad=2.0, h_pad=2.0)
    if figure_only is True:
        warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING),
             stacklevel=2)
        return fig
    else:
        return fig, axes
コード例 #6
0
def plot_data_overview(
    dataset: DatasetConvertible,
    title: str = "Data overview",
    linlog: bool = False,
    linthresh: float = 1,
    figsize: tuple[int, int] = (15, 10),
    nr_of_data_svd_vectors: int = 4,
    show_data_svd_legend: bool = True,
) -> tuple[Figure, Axes]:
    """Plot data as filled contour plot and SVD components.

    Parameters
    ----------
    dataset : DatasetConvertible
        Dataset containing data and SVD of the data.
    title : str
        Title to add to the figure. Defaults to "Data overview".
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    figsize : tuple[int, int]
        Size of the figure (N, M) in inches. Defaults to (15, 10).
    nr_of_data_svd_vectors: int
        Number of data SVD vector to plot. Defaults to 4.
    show_data_svd_legend: bool
        Whether or not to show the data SVD legend. Defaults to True.

    Returns
    -------
    tuple[Figure, Axes]
        Figure and axes which can then be refined by the user.
    """
    dataset = load_data(dataset)

    fig = plt.figure(figsize=figsize)
    data_ax = cast(
        Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig))
    lsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 0), fig=fig))
    sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig))
    rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig))

    if len(dataset.data.time) > 1:
        dataset.data.plot(x="time", ax=data_ax, center=False)
    else:
        dataset.data.plot(ax=data_ax)

    add_svd_to_dataset(dataset=dataset, name="data")
    plot_lsv_data(dataset,
                  lsv_ax,
                  indices=range(nr_of_data_svd_vectors),
                  show_legend=show_data_svd_legend)
    plot_sv_data(dataset, sv_ax)
    plot_rsv_data(dataset,
                  rsv_ax,
                  indices=range(nr_of_data_svd_vectors),
                  show_legend=show_data_svd_legend)
    fig.suptitle(title, fontsize=16)
    fig.tight_layout()

    if linlog:
        data_ax.set_xscale("symlog", linthresh=linthresh)
    return fig, (data_ax, lsv_ax, sv_ax, rsv_ax)