Example #1
0
def plot_das(res: xr.Dataset,
             ax: Axis,
             title: str = "DAS",
             cycler: Cycler | None = PlotStyle().cycler) -> None:
    """Plot DAS (Decay Associated Spectra) on ``ax``.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    title : str
        Title of the plot. Defaults to "DAS".
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    keys = [
        v for v in res.data_vars
        if v.startswith(("decay_associated_spectra", "species_spectra"))
    ]
    for key in keys:
        das = res[key]
        das.plot.line(x="spectral", ax=ax)
        ax.set_title(title)
        ax.get_legend().remove()
Example #2
0
def plot_rsv_residual(
    res: xr.Dataset,
    ax: Axis,
    indices: Sequence[int] = range(2),
    cycler: Cycler | None = PlotStyle().cycler,
    show_legend: bool = True,
) -> None:
    """Plot right singular vectors (spectra) of the residual matrix.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    indices : Sequence[int]
        Indices of the singular vector to plot. Defaults to range(4).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    show_legend: bool
        Whether or not to show the legend. Defaults to True.
    """
    add_cycler_if_not_none(ax, cycler)
    if "weighted_residual_right_singular_vectors" in res:
        rRSV = res.weighted_residual_right_singular_vectors
    else:
        rRSV = res.residual_right_singular_vectors
    _plot_svd_vetors(rRSV, indices, "right_singular_value_index", ax,
                     show_legend)
    ax.set_title("res. RSV")
Example #3
0
def plot_norm_sas(res: xr.Dataset,
                  ax: Axis,
                  title: str = "norm SAS",
                  cycler: Cycler | None = PlotStyle().cycler) -> None:
    """Plot normalized SAS (Species Associated Spectra) on ``ax``.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    title : str
        Title of the plot. Defaults to "norm SAS".
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    keys = [
        v for v in res.data_vars
        if v.startswith(("species_associated_spectra", "species_spectra"))
    ]
    for key in keys:
        sas = res[key]
        # sas = res.species_associated_spectra
        (sas / np.abs(sas).max(dim="spectral")).plot.line(x="spectral", ax=ax)
        ax.set_title(title)
        ax.get_legend().remove()
Example #4
0
def plot_sv_residual(
    res: xr.Dataset,
    ax: Axis,
    indices: Sequence[int] = range(10),
    cycler: Cycler | None = PlotStyle().cycler,
) -> None:
    """Plot singular values of the residual matrix.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    indices : Sequence[int]
        Indices of the singular vector to plot. Defaults to range(4).
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    if "weighted_residual_singular_values" in res:
        rSV = res.weighted_residual_singular_values
    else:
        rSV = res.residual_singular_values
    rSV.sel(singular_value_index=indices[:len(rSV.singular_value_index)]
            ).plot.line("ro-", yscale="log", ax=ax)
    ax.set_title("res. log(SV)")
def plot_concentrations(
    res: xr.Dataset,
    ax: Axis,
    center_λ: float | None,
    linlog: bool = False,
    linthresh: float = 1,
    linscale: float = 1,
    main_irf_nr: int = 0,
    cycler: Cycler | None = PlotStyle().cycler,
    title: str = "Concentrations",
) -> None:
    """Plot traces on the given axis ``ax``.

    Parameters
    ----------
    res: xr.Dataset
        Result dataset from a pyglotaran optimization.
    ax: Axis
        Axis to plot the traces on
    center_λ: float | None
        Center wavelength (λ in nm)
    linlog: bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh: float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    linscale: float
        This allows the linear range (-linthresh to linthresh) to be stretched
        relative to the logarithmic range.
        Its value is the number of decades to use for each half of the linear range.
        For example, when linscale == 1.0 (the default), the space used for the
        positive and negative halves of the linear range will be equal to one
        decade in the logarithmic range. Defaults to 1.
    main_irf_nr: int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid.
    title: str
        Title used for the plot axis. Defaults to "Concentrations".

    See Also
    --------
    get_shifted_traces
    """
    add_cycler_if_not_none(ax, cycler)
    traces = get_shifted_traces(res, center_λ, main_irf_nr)

    if "spectral" in traces.coords:
        traces.sel(spectral=center_λ, method="nearest").plot.line(x="time",
                                                                  ax=ax)
    else:
        traces.plot.line(x="time", ax=ax)
    ax.set_title(title)

    if linlog:
        ax.set_xscale("symlog", linthresh=linthresh, linscale=linscale)
Example #6
0
 def _adjust(self, ax: Axis, to: Union[bool, float], default: float,
             side: int) -> float:
     if to is False:
         return default
     try:
         return float(to)
     except ValueError:
         pass
     vals = ax.get_majorticklocs() if self.major else ax.get_minorticklocs()
     return vals[side]
Example #7
0
def set_axis_tick_label_rotation(ax: Axis, rotation: int):
    """
    Set the rotation of axis tick labels.

    :param ax: The axis whose tick label rotation to set.
    :param rotation: The rotation value to set.
    """
    if ax.get_majorticklabels():
        plt.setp(ax.get_majorticklabels(), rotation=rotation)
    if ax.get_minorticklabels():
        plt.setp(ax.get_minorticklabels(), rotation=rotation)
Example #8
0
def transform_axis_tick_labels(ax: Axis, transformation: FunctionType):
    """
    Transforms the labels of each label along the axis by a transformation function.

    :param ax: The axis whose tick labels to transform.
    :param transformation: The transformation function e.g. `lambda t: t.split('T')[0]`.
    """
    ax.figure.canvas.draw()  # make sure the figure has been drawn so the labels are available to be got
    labels = ax.get_ticklabels()
    for label in labels:
        new_label = transformation(label.get_text())
        label.set_text(new_label)
    ax.set_ticklabels(labels)
Example #9
0
def _remove_labels_from_axis(axis: Axis):
    for t in axis.get_majorticklabels():
        t.set_visible(False)

    # set_visible will not be effective if
    # minor axis has NullLocator and NullFormatter (default)
    if isinstance(axis.get_minor_locator(), ticker.NullLocator):
        axis.set_minor_locator(ticker.AutoLocator())
    if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
        axis.set_minor_formatter(ticker.FormatStrFormatter(""))
    for t in axis.get_minorticklabels():
        t.set_visible(False)

    axis.get_label().set_visible(False)
Example #10
0
    def add_to_plot(self,
                    ax: Axis,
                    N: int = 201,
                    xrange: Tuple[float] = None,
                    xscale: float = 1,
                    yscale: float = 1,
                    **kwargs):
        """Add fit to existing plot axis

        Args:
            ax: Axis to add plot to
            N: number of points to use as x values (to smoothe fit curve)
            xrange: Optional range for x values (min, max)
            xscale: value to multiple x values by to rescale axis
            yscale: value to multiple y values by to rescale axis
            kwargs: Additional plot kwargs. By default Fit.plot_kwargs are used

        Returns:
            plot_handle of fit curve
        """
        if xrange is None:
            x_vals = self.xvals
            x_vals_full = np.linspace(min(x_vals), max(x_vals), N)
        else:
            x_vals_full = np.linspace(*xrange, N)

        y_vals_full = self.fit_result.eval(
            **{self.sweep_parameter: x_vals_full})
        x_vals_full *= xscale
        y_vals_full *= yscale
        plot_kwargs = {**self.plot_kwargs, **kwargs}
        self.plot_handle, = ax.plot(x_vals_full, y_vals_full, **plot_kwargs)
        return self.plot_handle
Example #11
0
def add_cycler_if_not_none(axis: Axis, cycler: Cycler | None) -> None:
    """Add cycler to and axis if it is not None.

    This is a convenience function that allow to opt out of using
    a cycler, which is needed to run a plotting function in a loop
    where the cycler is controlled from the outside.


    Parameters
    ----------
    axis: Axis
        Axis to plot the data and fits on.
    cycler: Cycler | None
        Plot style cycler to use.
    """
    if cycler is not None:
        axis.set_prop_cycle(cycler)
def plot_residual(
    res: xr.Dataset,
    ax: Axis,
    linlog: bool = False,
    linthresh: float = 1,
    show_data: bool = False,
    cycler: Cycler | None = PlotStyle().cycler,
) -> None:
    """Plot data or residual on a 2D contour plot.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    show_data : bool
        Whether to show the data or the residual. Defaults to False.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    """
    add_cycler_if_not_none(ax, cycler)
    data = res.data if show_data else res.residual
    title = "dataset" if show_data else "residual"
    shape = np.array(data.shape)
    dims = data.coords.dims
    # Handle different dimensionality of data
    if min(shape) == 1:
        data.plot.line(x=dims[shape.argmax()], ax=ax)
    elif min(shape) < 5:
        data.plot(x="time", ax=ax)
    else:
        data.plot(x="time", ax=ax, add_colorbar=False)
    if linlog:
        ax.set_xscale("symlog", linthresh=linthresh)
    ax.set_title(title)
Example #13
0
 def plot_track(
     self,
     axis: Axis = None,
     show: bool = False,
     color: str = 'k',
     coastline: bool = True,
     **kwargs,
 ):
     kwargs.update({'color': color})
     if axis is None:
         fig = pyplot.figure()
         axis = fig.add_subplot(111)
     data = self.data
     for i, (_, row) in enumerate(data.iterrows()):
         # when dealing with nautical degrees, U is sine and V is cosine.
         U = row['speed'] * numpy.sin(numpy.deg2rad(row['direction']))
         V = row['speed'] * numpy.cos(numpy.deg2rad(row['direction']))
         axis.quiver(row['longitude'], row['latitude'], U, V, **kwargs)
         if i % 6 == 0:
             axis.annotate(
                 row['datetime'],
                 (row['longitude'], row['latitude']),
             )
     if show:
         axis.axis('scaled')
     if bool(coastline) is True:
         plot_coastline(axis, show)
Example #14
0
def _plot_svd_vetors(
    vector_data: xr.DataArray,
    indices: Sequence[int],
    sv_index_dim: str,
    ax: Axis,
    show_legend: bool,
) -> None:
    """Plot SVD vectors with decreasing zorder on axis ``ax``.

    Parameters
    ----------
    vector_data: xr.DataArray
        DataArray containing the SVD vector data.
    indices: Sequence[int]
        Indices of the singular vector to plot.
    sv_index_dim: str
        Name of the singular value index dimension.
    ax: Axis
        Axis to plot on.
    show_legend: bool
        Whether or not to show the legend.

    See Also
    --------
    plot_lsv_data
    plot_rsv_data
    plot_lsv_residual
    plot_rsv_residual
    """
    max_index = len(getattr(vector_data, sv_index_dim))
    values = vector_data.isel(**{sv_index_dim: indices[:max_index]})
    x_dim = vector_data.dims[1]
    if x_dim == sv_index_dim:
        values = values.T
        x_dim = vector_data.dims[0]
    for index, (zorder, value) in enumerate(zip(range(100)[::-1], values)):
        value.plot.line(x=x_dim, ax=ax, zorder=zorder, label=index)
    if show_legend is True:
        ax.legend(title=sv_index_dim)
Example #15
0
def plot_lsv_residual(
    res: xr.Dataset,
    ax: Axis,
    indices: Sequence[int] = range(2),
    linlog: bool = False,
    linthresh: float = 1,
    cycler: Cycler | None = PlotStyle().cycler,
    show_legend: bool = True,
) -> None:
    """Plot left singular vectors (time) of the residual matrix.

    Parameters
    ----------
    res : xr.Dataset
        Result dataset
    ax : Axis
        Axis to plot on.
    indices : Sequence[int]
        Indices of the singular vector to plot. Defaults to range(4).
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().cycler.
    show_legend: bool
        Whether or not to show the legend. Defaults to True.
    """
    add_cycler_if_not_none(ax, cycler)
    if "weighted_residual_left_singular_vectors" in res:
        rLSV = res.weighted_residual_left_singular_vectors
    else:
        rLSV = res.residual_left_singular_vectors
    _plot_svd_vetors(rLSV, indices, "left_singular_value_index", ax,
                     show_legend)
    ax.set_title("res. LSV")
    if linlog:
        ax.set_xscale("symlog", linthresh=linthresh)
Example #16
0
def plot_interactions(locations: List[str],
                      latent_graph: ndarray,
                      map: Basemap,
                      ax: Axis,
                      skip_first: bool = False):
    """
    Given station ids and latent graph plot edges in different colors
    """

    # Transform lan/lot into region-specific values
    pixel_coords = [map(*coords) for coords in locations]

    # Draw contours and borders
    map.shadedrelief()
    map.drawcountries()
    # m.bluemarble()
    # m.etopo()


    # Plot Locations of weather stations
    for i, (x, y) in enumerate(pixel_coords):
        ax.plot(x, y, 'ok', markersize=10, color='yellow')
        ax.text(x + 10, y + 10, "Station " + str(i), fontsize=20, color='yellow');

    # Infer number of edge types and atoms from latent graph
    n_atoms = latent_graph.shape[-1]
    n_edge_types = latent_graph.shape[0]

    color_map = get_cmap('Set1')

    for i in range(n_atoms):
        for j in range(n_atoms):
            for edge_type in range(n_edge_types):
                if latent_graph[edge_type, i, j] > 0.5:

                    if skip_first and edge_type == 0:
                        continue

                    # Draw line between points
                    x = locations[i]
                    y = locations[j]
                    map.drawgreatcircle(x[0], x[1], y[0], y[1],
                                        color=color_map(edge_type - 1),
                                        label=str(edge_type))
    handles, labels = ax.get_legend_handles_labels()
    unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
    ax.legend(*zip(*unique))
    return ax
Example #17
0
    def add_to_plot(self,
                    ax: Axis,
                    N: int = 201,
                    xrange: Tuple[float] = None,
                    xscale: float = 1,
                    yscale: float = 1,
                    **kwargs):
        """Add fit to existing plot axis

        Args:
            ax: Axis to add plot to
            N: number of points to use as x values (to smoothe fit curve)
            xrange: Optional range for x values (min, max)
            xscale: value to multiple x values by to rescale axis
            yscale: value to multiple y values by to rescale axis
            kwargs: Additional plot kwargs. By default Fit.plot_kwargs are used

        Returns:
            plot_handle of fit curve
        """
        if xrange is None:
            x_vals = self.xvals
            x_vals_full = np.linspace(min(x_vals), max(x_vals), N)
        else:
            x_vals_full = np.linspace(*xrange, N)

        y_vals_full = self.fit_result.eval(
            **{self.sweep_parameter: x_vals_full})
        x_vals_full *= xscale
        y_vals_full *= yscale

        # Set default plot kwargs while de-aliasing (e.g. 'lw' -> 'linewidth')
        # kwargs to prevent duplicate keys
        kwargs = {
            **self.default_plot_kwargs,
            **cbook.normalize_kwargs(kwargs, mlines.Line2D)
        }
        self.plot_handle, = ax.plot(x_vals_full, y_vals_full, **kwargs)
        return self.plot_handle
Example #18
0
 def restore_axis_state(axis: Axis, state: dict):
     if state['grid']:
         axis.grid(True, **state['grid'])
     else:
         axis.grid(False)
Example #19
0
def plot_data_and_fits(
    result: ResultLike,
    wavelength: float,
    axis: Axis,
    center_λ: float | None = None,
    main_irf_nr: int = 0,
    linlog: bool = False,
    linthresh: float = 1,
    divide_by_scale: bool = True,
    per_axis_legend: bool = False,
    y_label: str = "a.u.",
    cycler: Cycler | None = PlotStyle().data_cycler_solid,
) -> None:
    """Plot data and fits for a given ``wavelength`` on a given ``axis``.

    If the wavelength isn't part of a dataset, that dataset will be skipped.

    Parameters
    ----------
    result : ResultLike
        Data structure which can be converted to a mapping.
    wavelength : float
        Wavelength to plot data and fits for.
    axis: Axis
        Axis to plot the data and fits on.
    center_λ: float | None
        Center wavelength (λ in nm)
    main_irf_nr : int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    divide_by_scale : bool
        Whether or not to divide the data by the dataset scale used for optimization.
        Defaults to True.
    per_axis_legend: bool
        Whether to use a legend per plot or for the whole figure. Defaults to False.
    y_label: str
        Label used for the y-axis of each subplot.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid.

    See Also
    --------
    plot_fit_overview
    """
    result_map = result_dataset_mapping(result)
    add_cycler_if_not_none(axis, cycler)
    for dataset_name in result_map.keys():
        spectral_coords = result_map[dataset_name].coords["spectral"].values
        if spectral_coords.min() <= wavelength <= spectral_coords.max():
            result_data = result_map[dataset_name].sel(spectral=[wavelength],
                                                       method="nearest")
            scale = extract_dataset_scale(result_data, divide_by_scale)
            irf_loc = extract_irf_location(result_data, center_λ, main_irf_nr)
            result_data = result_data.assign_coords(
                time=result_data.coords["time"] - irf_loc)
            (result_data.data / scale).plot(x="time",
                                            ax=axis,
                                            label=f"{dataset_name}_data")
            (result_data.fitted_data / scale).plot(x="time",
                                                   ax=axis,
                                                   label=f"{dataset_name}_fit")
        else:
            [next(axis._get_lines.prop_cycler) for _ in range(2)]
    if linlog:
        axis.set_xscale("symlog", linthresh=linthresh)
    axis.set_ylabel(y_label)
    if per_axis_legend is True:
        axis.legend()
Example #20
0
def file_vs_value_plot(
    a_x: Axis, field_name: str, row: pd.DataFrame, range_columns: List[str], fontsize: float, pad: float
) -> None:
    """Create a dot plot with one point per file"""
    assert field_name in ["rt_peak", "peak_height"]
    a_x.tick_params(direction="in", length=1, pad=pad, width=0.1, labelsize=fontsize)
    num_files = len(range_columns)
    a_x.scatter(range(num_files), row[:num_files], s=0.2)
    if field_name == "rt_peak":
        a_x.axhline(y=row["atlas RT peak"], color="r", linestyle="-", linewidth=0.2)
        range_columns += ["atlas RT peak"]
        a_x.set_ylim(np.nanmin(row.loc[range_columns]) - 0.12, np.nanmax(row.loc[range_columns]) + 0.12)
    else:
        a_x.set_yscale("log")
        a_x.set_ylim(bottom=1e4, top=1e10)
    a_x.set_xlim(-0.5, num_files + 0.5)
    a_x.xaxis.set_major_locator(mticker.FixedLocator(np.arange(0, num_files, 1.0)))
    _ = [s.set_linewidth(0.1) for s in a_x.spines.values()]
    # truncate name so it fits above a single subplot
    a_x.set_title(row.name[:33], pad=pad, fontsize=fontsize)
    a_x.set_xlabel("Files", labelpad=pad, fontsize=fontsize)
    ylabel = "Actual RTs" if field_name == "rt_peak" else "Peak Height"
    a_x.set_ylabel(ylabel, labelpad=pad, fontsize=fontsize)
Example #21
0
# xtick_pos = np.arange(0, bars+1)
# # print(bars)
# fig, ax = plt.subplots()

# print(var1.size)
delta = dt.timedelta(days=1)
dates = pd.date_range(begin_date1, end_date1, freq='1D')
print(dates)

fig, ax = plt.subplots()
ax.bar(dates, var1)

ax.set_xlim(dates[0], dates[-1] + delta)

ax.xaxis.set_major_locator(DayLocator(interval=7))
Axis.set_minor_locator(ax.xaxis, DayLocator())
ax.xaxis.set_major_formatter(DateFormatter('%d.%m'))

ax.fmt_xdata = DateFormatter('%Y-%m-%d %H:%M:%S')
fig.autofmt_xdate()

fig.suptitle(
    "Диаграмма изменчивости суточного гидротермического коэффициента Селянинова\n"
    + r"%s год" % (begin_date1.strftime("%Y")) + "\n" +
    r"вегетационный период с %s по %s" %
    (begin_date1.strftime("%d.%m"), end_date1.strftime("%d.%m")),
    fontsize=12,
    fontweight='bold')

plt.show()
Example #22
0
def plot_heated_stacked_area(df: pd.DataFrame,
                             lines: str,
                             heat: str,
                             backtest: str = None,
                             reset_y_lim: bool = False,
                             figsize: Tuple[int, int] = (16, 9),
                             color_map: str = 'afmhot',
                             ax: Axis = None,
                             upper_lower_missing_scale: float = 0.05) -> Axis:
    color_function = plt.get_cmap(color_map)
    x = df.index
    y = df[lines].values
    c = df[heat].values
    b = df[backtest].values if backtest is not None else None

    # assert enough data
    assert len(y.shape) > 1 and len(
        c.shape) > 1, "lines and heat need to be 2 dimensions!"

    # make sure we have one more line as heats
    if c.shape[1] == y.shape[1] + 1:
        lower = np.full((c.shape[0], 1),
                        y.min() * (1 - upper_lower_missing_scale))
        upper = np.full((c.shape[0], 1),
                        y.max() * (1 + upper_lower_missing_scale))
        y = np.hstack([lower, y, upper])

    # check for matching columns
    assert y.shape[1] - 1 == c.shape[
        1], f'unexpeced shapes: {y.shape[1] - 1} != {c.shape[1]}'

    _, ax = plt.subplots(figsize=figsize) if ax is None else (None, ax)

    ax.plot(x, y, color='k', alpha=0.0)

    for ci in range(c.shape[1]):
        for xi in range(len(x)):
            ax.fill_between(x[xi - 1:xi + 1],
                            y[xi - 1:xi + 1, ci],
                            y[xi - 1:xi + 1, ci + 1],
                            facecolors=color_function(c[xi - 1:xi + 1, ci]))

        if ci > 0:
            # todo annotate all first last and only convert date if it is actually a date
            ax.annotate(f'{y[-1, ci]:.2f}',
                        xy=(mdates.date2num(x[-1]), y[-1, ci]),
                        xytext=(4, -4),
                        textcoords='offset pixels')

    # reset limits
    ax.autoscale(tight=True)
    if reset_y_lim:
        ax.set_ylim(bottom=y[:, 1].min(), top=y[:, -1].max())

    # backtest
    if backtest:
        ax.plot(x, b)

    return ax
Example #23
0
def control_plot(data: (List[int], List[float], pd.Series, np.array),
                 upper_control_limit: (int, float),
                 lower_control_limit: (int, float),
                 highlight_beyond_limits: bool = True,
                 highlight_zone_a: bool = True,
                 highlight_zone_b: bool = True,
                 highlight_zone_c: bool = True,
                 highlight_trend: bool = True,
                 highlight_mixture: bool = True,
                 highlight_stratification: bool = True,
                 highlight_overcontrol: bool = True,
                 ax: Axis = None):
    """
    Create a control plot based on the input data.

    :param data: a list, pandas.Series, or numpy.array representing the data set
    :param upper_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param lower_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param highlight_beyond_limits: True if points beyond limits are to be highlighted
    :param highlight_zone_a: True if points that are zone A violations are to be highlighted
    :param highlight_zone_b: True if points that are zone B violations are to be highlighted
    :param highlight_zone_c: True if points that are zone C violations are to be highlighted
    :param highlight_trend: True if points that are trend violations are to be highlighted
    :param highlight_mixture: True if points that are mixture violations are to be highlighted
    :param highlight_stratification: True if points that are stratification violations are to be highlighted
    :param highlight_overcontrol: True if points that are overcontrol violations are to be hightlighted
    :param ax: an instance of matplotlib.axis.Axis
    :return: None
    """

    data = coerce(data)

    if ax is None:
        fig, ax = plt.subplots()

    ax.plot(data)
    ax.set_title('Zone Control Chart')

    spec_range = (upper_control_limit - lower_control_limit) / 2
    spec_center = lower_control_limit + spec_range
    zone_c_upper_limit = spec_center + spec_range / 3
    zone_c_lower_limit = spec_center - spec_range / 3
    zone_b_upper_limit = spec_center + 2 * spec_range / 3
    zone_b_lower_limit = spec_center - 2 * spec_range / 3
    zone_a_upper_limit = spec_center + spec_range
    zone_a_lower_limit = spec_center - spec_range

    ax.axhline(spec_center, linestyle='--', color='red', alpha=0.6)
    ax.axhline(zone_c_upper_limit, linestyle='--', color='red', alpha=0.5)
    ax.axhline(zone_c_lower_limit, linestyle='--', color='red', alpha=0.5)
    ax.axhline(zone_b_upper_limit, linestyle='--', color='red', alpha=0.3)
    ax.axhline(zone_b_lower_limit, linestyle='--', color='red', alpha=0.3)
    ax.axhline(zone_a_upper_limit, linestyle='--', color='red', alpha=0.2)
    ax.axhline(zone_a_lower_limit, linestyle='--', color='red', alpha=0.2)

    left, right = ax.get_xlim()
    right_plus = (right - left) * 0.01 + right

    ax.text(right_plus, upper_control_limit, s='UCL', va='center')
    ax.text(right_plus, lower_control_limit, s='LCL', va='center')

    ax.text(right_plus, (spec_center + zone_c_upper_limit) / 2,
            s='Zone C',
            va='center')
    ax.text(right_plus, (spec_center + zone_c_lower_limit) / 2,
            s='Zone C',
            va='center')
    ax.text(right_plus, (zone_b_upper_limit + zone_c_upper_limit) / 2,
            s='Zone B',
            va='center')
    ax.text(right_plus, (zone_b_lower_limit + zone_c_lower_limit) / 2,
            s='Zone B',
            va='center')
    ax.text(right_plus, (zone_a_upper_limit + zone_b_upper_limit) / 2,
            s='Zone A',
            va='center')
    ax.text(right_plus, (zone_a_lower_limit + zone_b_lower_limit) / 2,
            s='Zone A',
            va='center')

    plot_params = {'alpha': 0.3, 'zorder': -10, 'markersize': 14}

    if highlight_beyond_limits:
        beyond_limits_violations = control_beyond_limits(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(beyond_limits_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(beyond_limits_violations,
                    'o',
                    color='red',
                    label='beyond limits',
                    **plot_params)

    if highlight_zone_a:
        zone_a_violations = control_zone_a(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_a_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_a_violations,
                    'o',
                    color='orange',
                    label='zone a violations',
                    **plot_params)

    if highlight_zone_b:
        zone_b_violations = control_zone_b(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_b_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_b_violations,
                    'o',
                    color='blue',
                    label='zone b violations',
                    **plot_params)

    if highlight_zone_c:
        zone_c_violations = control_zone_c(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_c_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_c_violations,
                    'o',
                    color='green',
                    label='zone c violations',
                    **plot_params)

    if highlight_trend:
        zone_trend_violations = control_zone_trend(data=data)
        if len(zone_trend_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_trend_violations,
                    'o',
                    color='purple',
                    label='trend violations',
                    **plot_params)

    if highlight_mixture:
        zone_mixture_violations = control_zone_mixture(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_mixture_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_mixture_violations,
                    'o',
                    color='brown',
                    label='mixture violations',
                    **plot_params)

    if highlight_stratification:
        zone_stratification_violations = control_zone_stratification(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_stratification_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_stratification_violations,
                    'o',
                    color='orange',
                    label='stratification violations',
                    **plot_params)

    if highlight_overcontrol:
        zone_overcontrol_violations = control_zone_overcontrol(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_overcontrol_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_overcontrol_violations,
                    'o',
                    color='blue',
                    label='overcontrol violations',
                    **plot_params)

    ax.legend()
Example #24
0
def ppk_plot(data: (List[int], List[float], pd.Series, np.array),
             upper_control_limit: (int, float),
             lower_control_limit: (int, float),
             threshold_percent: float = 0.001,
             ax: Axis = None):
    """
    Shows the statistical distribution of the data along with CPK and limits.

    :param data: a list, pandas.Series, or numpy.array representing the data set
    :param upper_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param lower_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param threshold_percent: the threshold at which % of units above/below the number will display on the plot
    :param ax: an instance of matplotlig.axis.Axis
    :return: None
    """

    data = coerce(data)
    mean = data.mean()
    std = data.std()

    if ax is None:
        fig, ax = plt.subplots()

    ax.hist(data, density=True, label='data', alpha=0.3)
    x = np.linspace(mean - 4 * std, mean + 4 * std, 100)
    pdf = stats.norm.pdf(x, mean, std)
    ax.plot(x, pdf, label='normal fit', alpha=0.7)

    bottom, top = ax.get_ylim()

    ax.axvline(mean, linestyle='--')
    ax.text(mean, top * 1.01, s='$\mu$', ha='center')

    ax.axvline(mean + std, alpha=0.6, linestyle='--')
    ax.text(mean + std, top * 1.01, s='$\sigma$', ha='center')

    ax.axvline(mean - std, alpha=0.6, linestyle='--')
    ax.text(mean - std, top * 1.01, s='$-\sigma$', ha='center')

    ax.axvline(mean + 2 * std, alpha=0.4, linestyle='--')
    ax.text(mean + 2 * std, top * 1.01, s='$2\sigma$', ha='center')

    ax.axvline(mean - 2 * std, alpha=0.4, linestyle='--')
    ax.text(mean - 2 * std, top * 1.01, s='-$2\sigma$', ha='center')

    ax.axvline(mean + 3 * std, alpha=0.2, linestyle='--')
    ax.text(mean + 3 * std, top * 1.01, s='$3\sigma$', ha='center')

    ax.axvline(mean - 3 * std, alpha=0.2, linestyle='--')
    ax.text(mean - 3 * std, top * 1.01, s='-$3\sigma$', ha='center')

    ax.fill_between(x,
                    pdf,
                    where=x < lower_control_limit,
                    facecolor='red',
                    alpha=0.5)
    ax.fill_between(x,
                    pdf,
                    where=x > upper_control_limit,
                    facecolor='red',
                    alpha=0.5)

    lower_percent = 100.0 * stats.norm.cdf(lower_control_limit, mean, std)
    lower_percent_text = f'{lower_percent:.02f}% < LCL' if lower_percent > threshold_percent else None

    higher_percent = 100.0 - 100.0 * stats.norm.cdf(upper_control_limit, mean,
                                                    std)
    higher_percent_text = f'{higher_percent:.02f}% > UCL' if higher_percent > threshold_percent else None

    left, right = ax.get_xlim()
    bottom, top = ax.get_ylim()
    cpk = calc_ppk(data,
                   upper_control_limit=upper_control_limit,
                   lower_control_limit=lower_control_limit)

    lower_sigma_level = (mean - lower_control_limit) / std
    if lower_sigma_level < 6.0:
        ax.axvline(lower_control_limit,
                   color='red',
                   alpha=0.25,
                   label='limits')
        ax.text(lower_control_limit,
                top * 0.95,
                s=f'$-{lower_sigma_level:.01f}\sigma$',
                ha='center')
    else:
        ax.text(left, top * 0.95, s=f'limit > $-6\sigma$', ha='left')

    upper_sigma_level = (upper_control_limit - mean) / std
    if upper_sigma_level < 6.0:
        ax.axvline(upper_control_limit, color='red', alpha=0.25)
        ax.text(upper_control_limit,
                top * 0.95,
                s=f'${upper_sigma_level:.01f}\sigma$',
                ha='center')
    else:
        ax.text(right, top * 0.95, s=f'limit > $6\sigma$', ha='right')

    strings = [f'Ppk = {cpk:.02f}']

    strings.append(f'$\mu = {mean:.3g}$')
    strings.append(f'$\sigma = {std:.3g}$')

    if lower_percent_text:
        strings.append(lower_percent_text)
    if higher_percent_text:
        strings.append(higher_percent_text)

    props = dict(boxstyle='round',
                 facecolor='white',
                 alpha=0.75,
                 edgecolor='grey')
    ax.text(right - (right - left) * 0.05,
            0.85 * top,
            '\n'.join(strings),
            bbox=props,
            ha='right',
            va='top')

    ax.legend(loc='lower right')