def test_add_cycler_if_not_none(cycler: Cycler | None, expected_cycler: cycle): """Default cycler inf None and cycler otherwise""" ax = plt.subplot() add_cycler_if_not_none(ax, cycler) for _ in range(10): assert next(ax._get_lines.prop_cycler) == next(expected_cycler)
def plot_das(res: xr.Dataset, ax: Axis, title: str = "DAS", cycler: Cycler | None = PlotStyle().cycler) -> None: """Plot DAS (Decay Associated Spectra) on ``ax``. Parameters ---------- res : xr.Dataset Result dataset ax : Axis Axis to plot on. title : str Title of the plot. Defaults to "DAS". cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. """ add_cycler_if_not_none(ax, cycler) keys = [ v for v in res.data_vars if v.startswith(("decay_associated_spectra", "species_spectra")) ] for key in keys: das = res[key] das.plot.line(x="spectral", ax=ax) ax.set_title(title) ax.get_legend().remove()
def plot_norm_sas(res: xr.Dataset, ax: Axis, title: str = "norm SAS", cycler: Cycler | None = PlotStyle().cycler) -> None: """Plot normalized SAS (Species Associated Spectra) on ``ax``. Parameters ---------- res : xr.Dataset Result dataset ax : Axis Axis to plot on. title : str Title of the plot. Defaults to "norm SAS". cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. """ add_cycler_if_not_none(ax, cycler) keys = [ v for v in res.data_vars if v.startswith(("species_associated_spectra", "species_spectra")) ] for key in keys: sas = res[key] # sas = res.species_associated_spectra (sas / np.abs(sas).max(dim="spectral")).plot.line(x="spectral", ax=ax) ax.set_title(title) ax.get_legend().remove()
def plot_sv_residual( res: xr.Dataset, ax: Axis, indices: Sequence[int] = range(10), cycler: Cycler | None = PlotStyle().cycler, ) -> None: """Plot singular values of the residual matrix. Parameters ---------- res : xr.Dataset Result dataset ax : Axis Axis to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to range(4). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. """ add_cycler_if_not_none(ax, cycler) if "weighted_residual_singular_values" in res: rSV = res.weighted_residual_singular_values else: rSV = res.residual_singular_values rSV.sel(singular_value_index=indices[:len(rSV.singular_value_index)] ).plot.line("ro-", yscale="log", ax=ax) ax.set_title("res. log(SV)")
def plot_rsv_residual( res: xr.Dataset, ax: Axis, indices: Sequence[int] = range(2), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, ) -> None: """Plot right singular vectors (spectra) of the residual matrix. Parameters ---------- res : xr.Dataset Result dataset ax : Axis Axis to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to range(4). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. show_legend: bool Whether or not to show the legend. Defaults to True. """ add_cycler_if_not_none(ax, cycler) if "weighted_residual_right_singular_vectors" in res: rRSV = res.weighted_residual_right_singular_vectors else: rRSV = res.residual_right_singular_vectors _plot_svd_vetors(rRSV, indices, "right_singular_value_index", ax, show_legend) ax.set_title("res. RSV")
def plot_concentrations( res: xr.Dataset, ax: Axis, center_λ: float | None, linlog: bool = False, linthresh: float = 1, linscale: float = 1, main_irf_nr: int = 0, cycler: Cycler | None = PlotStyle().cycler, title: str = "Concentrations", ) -> None: """Plot traces on the given axis ``ax``. Parameters ---------- res: xr.Dataset Result dataset from a pyglotaran optimization. ax: Axis Axis to plot the traces on center_λ: float | None Center wavelength (λ in nm) linlog: bool Whether to use 'symlog' scale or not. Defaults to False. linthresh: float A single float which defines the range (-x, x), within which the plot is linear. This avoids having the plot go to infinity around zero. Defaults to 1. linscale: float This allows the linear range (-linthresh to linthresh) to be stretched relative to the logarithmic range. Its value is the number of decades to use for each half of the linear range. For example, when linscale == 1.0 (the default), the space used for the positive and negative halves of the linear range will be equal to one decade in the logarithmic range. Defaults to 1. main_irf_nr: int Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks. Defaults to 0. cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid. title: str Title used for the plot axis. Defaults to "Concentrations". See Also -------- get_shifted_traces """ add_cycler_if_not_none(ax, cycler) traces = get_shifted_traces(res, center_λ, main_irf_nr) if "spectral" in traces.coords: traces.sel(spectral=center_λ, method="nearest").plot.line(x="time", ax=ax) else: traces.plot.line(x="time", ax=ax) ax.set_title(title) if linlog: ax.set_xscale("symlog", linthresh=linthresh, linscale=linscale)
def plot_residual( res: xr.Dataset, ax: Axis, linlog: bool = False, linthresh: float = 1, show_data: bool = False, cycler: Cycler | None = PlotStyle().cycler, ) -> None: """Plot data or residual on a 2D contour plot. Parameters ---------- res : xr.Dataset Result dataset ax : Axis Axis to plot on. linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float A single float which defines the range (-x, x), within which the plot is linear. This avoids having the plot go to infinity around zero. Defaults to 1. show_data : bool Whether to show the data or the residual. Defaults to False. cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. """ add_cycler_if_not_none(ax, cycler) data = res.data if show_data else res.residual title = "dataset" if show_data else "residual" shape = np.array(data.shape) dims = data.coords.dims # Handle different dimensionality of data if min(shape) == 1: data.plot.line(x=dims[shape.argmax()], ax=ax) elif min(shape) < 5: data.plot(x="time", ax=ax) else: data.plot(x="time", ax=ax, add_colorbar=False) if linlog: ax.set_xscale("symlog", linthresh=linthresh) ax.set_title(title)
def plot_lsv_residual( res: xr.Dataset, ax: Axis, indices: Sequence[int] = range(2), linlog: bool = False, linthresh: float = 1, cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, ) -> None: """Plot left singular vectors (time) of the residual matrix. Parameters ---------- res : xr.Dataset Result dataset ax : Axis Axis to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to range(4). linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float A single float which defines the range (-x, x), within which the plot is linear. This avoids having the plot go to infinity around zero. Defaults to 1. cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. show_legend: bool Whether or not to show the legend. Defaults to True. """ add_cycler_if_not_none(ax, cycler) if "weighted_residual_left_singular_vectors" in res: rLSV = res.weighted_residual_left_singular_vectors else: rLSV = res.residual_left_singular_vectors _plot_svd_vetors(rLSV, indices, "left_singular_value_index", ax, show_legend) ax.set_title("res. LSV") if linlog: ax.set_xscale("symlog", linthresh=linthresh)
def plot_simple_overview( result: DatasetConvertible, title: str | None = None, figsize: tuple[int, int] = (12, 6), cycler: Cycler | None = PlotStyle().cycler, figure_only: bool = True, ) -> Figure | tuple[Figure, Axes]: """Plot simple overview. Parameters ---------- result: DatasetConvertible Result from a pyglotaran optimization as dataset, Path or Result object. title: str | None Title of the figure. Defaults to None. figsize : tuple[int, int] Size of the figure (N, M) in inches. Defaults to (18, 16). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. figure_only: bool Whether or not to only return the figure. This is a deprecation helper argument to transition to a consistent return value consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True. Returns ------- Figure|tuple[Figure, Axes] If ``figure_only`` is True, Figure object which contains the plots (deprecated). If ``figure_only`` is False, Figure object which contains the plots and the Axes. """ res = load_data(result) fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True) for ax in axes.flatten(): add_cycler_if_not_none(ax, cycler) if title: fig.suptitle(title, fontsize=16) sas = res.species_associated_spectra traces = res.species_concentration if "spectral" in traces.coords: traces.sel(spectral=res.spectral.values[0], method="nearest").plot.line( x="time", ax=axes[0, 0] ) else: traces.plot.line(x="time", ax=axes[0, 0]) sas.plot.line(x="spectral", ax=axes[0, 1]) rLSV = res.residual_left_singular_vectors rLSV.isel(left_singular_value_index=range(min(2, len(rLSV)))).plot.line( x="time", ax=axes[1, 0] ) axes[1, 0].set_title("res. LSV") rRSV = res.residual_right_singular_vectors rRSV.isel(right_singular_value_index=range(min(2, len(rRSV)))).plot.line( x="spectral", ax=axes[1, 1] ) axes[1, 1].set_title("res. RSV") res.data.plot(x="time", ax=axes[0, 2]) axes[0, 2].set_title("data") res.residual.plot(x="time", ax=axes[1, 2]) axes[1, 2].set_title("residual") if figure_only is True: warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2) return fig else: return fig, axes
def plot_doas( result: DatasetConvertible, figsize: tuple[int, int] = (25, 25), cycler: Cycler | None = PlotStyle().cycler, figure_only: bool = True, ) -> Figure | tuple[Figure, Axes]: """Plot Damped oscillation associated spectra (DOAS). Parameters ---------- result: DatasetConvertible Result from a pyglotaran optimization as dataset, Path or Result object. figsize : tuple[int, int] Size of the figure (N, M) in inches. Defaults to (18, 16). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. figure_only: bool Whether or not to only return the figure. This is a deprecation helper argument to transition to a consistent return value consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True. Returns ------- Figure|tuple[Figure, Axes] If ``figure_only`` is True, Figure object which contains the plots (deprecated). If ``figure_only`` is False, Figure object which contains the plots and the Axes. """ dataset = load_data(result) # Create M x N plotting grid M = 6 N = 3 fig, axes = plt.subplots(M, N, figsize=figsize) for ax in axes.flatten(): add_cycler_if_not_none(ax, cycler) # Plot data dataset.species_associated_spectra.plot.line(x="spectral", ax=axes[0, 0]) dataset.decay_associated_spectra.plot.line(x="spectral", ax=axes[0, 1]) if "spectral" in dataset.species_concentration.coords: dataset.species_concentration.isel(spectral=0).plot.line(x="time", ax=axes[1, 0]) else: dataset.species_concentration.plot.line(x="time", ax=axes[1, 0]) axes[1, 0].set_xscale("symlog", linthreshx=1) if "dampened_oscillation_associated_spectra" in dataset: dataset.dampened_oscillation_cos.isel(spectral=0).sel( time=slice(-1, 10)).plot.line(x="time", ax=axes[1, 1]) dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral", ax=axes[2, 0]) dataset.dampened_oscillation_phase.plot.line(x="spectral", ax=axes[2, 1]) dataset.residual_left_singular_vectors.isel( left_singular_value_index=0).plot(ax=axes[0, 2]) dataset.residual_singular_values.plot.line("ro-", yscale="log", ax=axes[1, 2]) dataset.residual_right_singular_vectors.isel( right_singular_value_index=0).plot(ax=axes[2, 2]) interval = int(dataset.spectral.size / 11) for i in range(0): axi = axes[i % 3, int(i / 3) + 3] index = (i + 1) * interval dataset.data.isel(spectral=index).plot(ax=axi) dataset.residual.isel(spectral=index).plot(ax=axi) dataset.fitted_data.isel(spectral=index).plot(ax=axi) plt.tight_layout(pad=5, w_pad=2.0, h_pad=2.0) if figure_only is True: warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2) return fig else: return fig, axes
def plot_data_and_fits( result: ResultLike, wavelength: float, axis: Axis, center_λ: float | None = None, main_irf_nr: int = 0, linlog: bool = False, linthresh: float = 1, divide_by_scale: bool = True, per_axis_legend: bool = False, y_label: str = "a.u.", cycler: Cycler | None = PlotStyle().data_cycler_solid, ) -> None: """Plot data and fits for a given ``wavelength`` on a given ``axis``. If the wavelength isn't part of a dataset, that dataset will be skipped. Parameters ---------- result : ResultLike Data structure which can be converted to a mapping. wavelength : float Wavelength to plot data and fits for. axis: Axis Axis to plot the data and fits on. center_λ: float | None Center wavelength (λ in nm) main_irf_nr : int Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks. Defaults to 0. linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float A single float which defines the range (-x, x), within which the plot is linear. This avoids having the plot go to infinity around zero. Defaults to 1. divide_by_scale : bool Whether or not to divide the data by the dataset scale used for optimization. Defaults to True. per_axis_legend: bool Whether to use a legend per plot or for the whole figure. Defaults to False. y_label: str Label used for the y-axis of each subplot. cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid. See Also -------- plot_fit_overview """ result_map = result_dataset_mapping(result) add_cycler_if_not_none(axis, cycler) for dataset_name in result_map.keys(): spectral_coords = result_map[dataset_name].coords["spectral"].values if spectral_coords.min() <= wavelength <= spectral_coords.max(): result_data = result_map[dataset_name].sel(spectral=[wavelength], method="nearest") scale = extract_dataset_scale(result_data, divide_by_scale) irf_loc = extract_irf_location(result_data, center_λ, main_irf_nr) result_data = result_data.assign_coords( time=result_data.coords["time"] - irf_loc) (result_data.data / scale).plot(x="time", ax=axis, label=f"{dataset_name}_data") (result_data.fitted_data / scale).plot(x="time", ax=axis, label=f"{dataset_name}_fit") else: [next(axis._get_lines.prop_cycler) for _ in range(2)] if linlog: axis.set_xscale("symlog", linthresh=linthresh) axis.set_ylabel(y_label) if per_axis_legend is True: axis.legend()