Esempio n. 1
0
def plot_interactions(locations: List[str],
                      latent_graph: ndarray,
                      map: Basemap,
                      ax: Axis,
                      skip_first: bool = False):
    """
    Given station ids and latent graph plot edges in different colors
    """

    # Transform lan/lot into region-specific values
    pixel_coords = [map(*coords) for coords in locations]

    # Draw contours and borders
    map.shadedrelief()
    map.drawcountries()
    # m.bluemarble()
    # m.etopo()


    # Plot Locations of weather stations
    for i, (x, y) in enumerate(pixel_coords):
        ax.plot(x, y, 'ok', markersize=10, color='yellow')
        ax.text(x + 10, y + 10, "Station " + str(i), fontsize=20, color='yellow');

    # Infer number of edge types and atoms from latent graph
    n_atoms = latent_graph.shape[-1]
    n_edge_types = latent_graph.shape[0]

    color_map = get_cmap('Set1')

    for i in range(n_atoms):
        for j in range(n_atoms):
            for edge_type in range(n_edge_types):
                if latent_graph[edge_type, i, j] > 0.5:

                    if skip_first and edge_type == 0:
                        continue

                    # Draw line between points
                    x = locations[i]
                    y = locations[j]
                    map.drawgreatcircle(x[0], x[1], y[0], y[1],
                                        color=color_map(edge_type - 1),
                                        label=str(edge_type))
    handles, labels = ax.get_legend_handles_labels()
    unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
    ax.legend(*zip(*unique))
    return ax
Esempio n. 2
0
def _plot_svd_vetors(
    vector_data: xr.DataArray,
    indices: Sequence[int],
    sv_index_dim: str,
    ax: Axis,
    show_legend: bool,
) -> None:
    """Plot SVD vectors with decreasing zorder on axis ``ax``.

    Parameters
    ----------
    vector_data: xr.DataArray
        DataArray containing the SVD vector data.
    indices: Sequence[int]
        Indices of the singular vector to plot.
    sv_index_dim: str
        Name of the singular value index dimension.
    ax: Axis
        Axis to plot on.
    show_legend: bool
        Whether or not to show the legend.

    See Also
    --------
    plot_lsv_data
    plot_rsv_data
    plot_lsv_residual
    plot_rsv_residual
    """
    max_index = len(getattr(vector_data, sv_index_dim))
    values = vector_data.isel(**{sv_index_dim: indices[:max_index]})
    x_dim = vector_data.dims[1]
    if x_dim == sv_index_dim:
        values = values.T
        x_dim = vector_data.dims[0]
    for index, (zorder, value) in enumerate(zip(range(100)[::-1], values)):
        value.plot.line(x=x_dim, ax=ax, zorder=zorder, label=index)
    if show_legend is True:
        ax.legend(title=sv_index_dim)
Esempio n. 3
0
def control_plot(data: (List[int], List[float], pd.Series, np.array),
                 upper_control_limit: (int, float),
                 lower_control_limit: (int, float),
                 highlight_beyond_limits: bool = True,
                 highlight_zone_a: bool = True,
                 highlight_zone_b: bool = True,
                 highlight_zone_c: bool = True,
                 highlight_trend: bool = True,
                 highlight_mixture: bool = True,
                 highlight_stratification: bool = True,
                 highlight_overcontrol: bool = True,
                 ax: Axis = None):
    """
    Create a control plot based on the input data.

    :param data: a list, pandas.Series, or numpy.array representing the data set
    :param upper_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param lower_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param highlight_beyond_limits: True if points beyond limits are to be highlighted
    :param highlight_zone_a: True if points that are zone A violations are to be highlighted
    :param highlight_zone_b: True if points that are zone B violations are to be highlighted
    :param highlight_zone_c: True if points that are zone C violations are to be highlighted
    :param highlight_trend: True if points that are trend violations are to be highlighted
    :param highlight_mixture: True if points that are mixture violations are to be highlighted
    :param highlight_stratification: True if points that are stratification violations are to be highlighted
    :param highlight_overcontrol: True if points that are overcontrol violations are to be hightlighted
    :param ax: an instance of matplotlib.axis.Axis
    :return: None
    """

    data = coerce(data)

    if ax is None:
        fig, ax = plt.subplots()

    ax.plot(data)
    ax.set_title('Zone Control Chart')

    spec_range = (upper_control_limit - lower_control_limit) / 2
    spec_center = lower_control_limit + spec_range
    zone_c_upper_limit = spec_center + spec_range / 3
    zone_c_lower_limit = spec_center - spec_range / 3
    zone_b_upper_limit = spec_center + 2 * spec_range / 3
    zone_b_lower_limit = spec_center - 2 * spec_range / 3
    zone_a_upper_limit = spec_center + spec_range
    zone_a_lower_limit = spec_center - spec_range

    ax.axhline(spec_center, linestyle='--', color='red', alpha=0.6)
    ax.axhline(zone_c_upper_limit, linestyle='--', color='red', alpha=0.5)
    ax.axhline(zone_c_lower_limit, linestyle='--', color='red', alpha=0.5)
    ax.axhline(zone_b_upper_limit, linestyle='--', color='red', alpha=0.3)
    ax.axhline(zone_b_lower_limit, linestyle='--', color='red', alpha=0.3)
    ax.axhline(zone_a_upper_limit, linestyle='--', color='red', alpha=0.2)
    ax.axhline(zone_a_lower_limit, linestyle='--', color='red', alpha=0.2)

    left, right = ax.get_xlim()
    right_plus = (right - left) * 0.01 + right

    ax.text(right_plus, upper_control_limit, s='UCL', va='center')
    ax.text(right_plus, lower_control_limit, s='LCL', va='center')

    ax.text(right_plus, (spec_center + zone_c_upper_limit) / 2,
            s='Zone C',
            va='center')
    ax.text(right_plus, (spec_center + zone_c_lower_limit) / 2,
            s='Zone C',
            va='center')
    ax.text(right_plus, (zone_b_upper_limit + zone_c_upper_limit) / 2,
            s='Zone B',
            va='center')
    ax.text(right_plus, (zone_b_lower_limit + zone_c_lower_limit) / 2,
            s='Zone B',
            va='center')
    ax.text(right_plus, (zone_a_upper_limit + zone_b_upper_limit) / 2,
            s='Zone A',
            va='center')
    ax.text(right_plus, (zone_a_lower_limit + zone_b_lower_limit) / 2,
            s='Zone A',
            va='center')

    plot_params = {'alpha': 0.3, 'zorder': -10, 'markersize': 14}

    if highlight_beyond_limits:
        beyond_limits_violations = control_beyond_limits(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(beyond_limits_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(beyond_limits_violations,
                    'o',
                    color='red',
                    label='beyond limits',
                    **plot_params)

    if highlight_zone_a:
        zone_a_violations = control_zone_a(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_a_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_a_violations,
                    'o',
                    color='orange',
                    label='zone a violations',
                    **plot_params)

    if highlight_zone_b:
        zone_b_violations = control_zone_b(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_b_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_b_violations,
                    'o',
                    color='blue',
                    label='zone b violations',
                    **plot_params)

    if highlight_zone_c:
        zone_c_violations = control_zone_c(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_c_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_c_violations,
                    'o',
                    color='green',
                    label='zone c violations',
                    **plot_params)

    if highlight_trend:
        zone_trend_violations = control_zone_trend(data=data)
        if len(zone_trend_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_trend_violations,
                    'o',
                    color='purple',
                    label='trend violations',
                    **plot_params)

    if highlight_mixture:
        zone_mixture_violations = control_zone_mixture(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_mixture_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_mixture_violations,
                    'o',
                    color='brown',
                    label='mixture violations',
                    **plot_params)

    if highlight_stratification:
        zone_stratification_violations = control_zone_stratification(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_stratification_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_stratification_violations,
                    'o',
                    color='orange',
                    label='stratification violations',
                    **plot_params)

    if highlight_overcontrol:
        zone_overcontrol_violations = control_zone_overcontrol(
            data=data,
            upper_control_limit=upper_control_limit,
            lower_control_limit=lower_control_limit)
        if len(zone_overcontrol_violations):
            plot_params['zorder'] -= 1
            plot_params['markersize'] -= 1
            ax.plot(zone_overcontrol_violations,
                    'o',
                    color='blue',
                    label='overcontrol violations',
                    **plot_params)

    ax.legend()
Esempio n. 4
0
def ppk_plot(data: (List[int], List[float], pd.Series, np.array),
             upper_control_limit: (int, float),
             lower_control_limit: (int, float),
             threshold_percent: float = 0.001,
             ax: Axis = None):
    """
    Shows the statistical distribution of the data along with CPK and limits.

    :param data: a list, pandas.Series, or numpy.array representing the data set
    :param upper_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param lower_control_limit: an integer or float which represents the upper control limit, commonly called the UCL
    :param threshold_percent: the threshold at which % of units above/below the number will display on the plot
    :param ax: an instance of matplotlig.axis.Axis
    :return: None
    """

    data = coerce(data)
    mean = data.mean()
    std = data.std()

    if ax is None:
        fig, ax = plt.subplots()

    ax.hist(data, density=True, label='data', alpha=0.3)
    x = np.linspace(mean - 4 * std, mean + 4 * std, 100)
    pdf = stats.norm.pdf(x, mean, std)
    ax.plot(x, pdf, label='normal fit', alpha=0.7)

    bottom, top = ax.get_ylim()

    ax.axvline(mean, linestyle='--')
    ax.text(mean, top * 1.01, s='$\mu$', ha='center')

    ax.axvline(mean + std, alpha=0.6, linestyle='--')
    ax.text(mean + std, top * 1.01, s='$\sigma$', ha='center')

    ax.axvline(mean - std, alpha=0.6, linestyle='--')
    ax.text(mean - std, top * 1.01, s='$-\sigma$', ha='center')

    ax.axvline(mean + 2 * std, alpha=0.4, linestyle='--')
    ax.text(mean + 2 * std, top * 1.01, s='$2\sigma$', ha='center')

    ax.axvline(mean - 2 * std, alpha=0.4, linestyle='--')
    ax.text(mean - 2 * std, top * 1.01, s='-$2\sigma$', ha='center')

    ax.axvline(mean + 3 * std, alpha=0.2, linestyle='--')
    ax.text(mean + 3 * std, top * 1.01, s='$3\sigma$', ha='center')

    ax.axvline(mean - 3 * std, alpha=0.2, linestyle='--')
    ax.text(mean - 3 * std, top * 1.01, s='-$3\sigma$', ha='center')

    ax.fill_between(x,
                    pdf,
                    where=x < lower_control_limit,
                    facecolor='red',
                    alpha=0.5)
    ax.fill_between(x,
                    pdf,
                    where=x > upper_control_limit,
                    facecolor='red',
                    alpha=0.5)

    lower_percent = 100.0 * stats.norm.cdf(lower_control_limit, mean, std)
    lower_percent_text = f'{lower_percent:.02f}% < LCL' if lower_percent > threshold_percent else None

    higher_percent = 100.0 - 100.0 * stats.norm.cdf(upper_control_limit, mean,
                                                    std)
    higher_percent_text = f'{higher_percent:.02f}% > UCL' if higher_percent > threshold_percent else None

    left, right = ax.get_xlim()
    bottom, top = ax.get_ylim()
    cpk = calc_ppk(data,
                   upper_control_limit=upper_control_limit,
                   lower_control_limit=lower_control_limit)

    lower_sigma_level = (mean - lower_control_limit) / std
    if lower_sigma_level < 6.0:
        ax.axvline(lower_control_limit,
                   color='red',
                   alpha=0.25,
                   label='limits')
        ax.text(lower_control_limit,
                top * 0.95,
                s=f'$-{lower_sigma_level:.01f}\sigma$',
                ha='center')
    else:
        ax.text(left, top * 0.95, s=f'limit > $-6\sigma$', ha='left')

    upper_sigma_level = (upper_control_limit - mean) / std
    if upper_sigma_level < 6.0:
        ax.axvline(upper_control_limit, color='red', alpha=0.25)
        ax.text(upper_control_limit,
                top * 0.95,
                s=f'${upper_sigma_level:.01f}\sigma$',
                ha='center')
    else:
        ax.text(right, top * 0.95, s=f'limit > $6\sigma$', ha='right')

    strings = [f'Ppk = {cpk:.02f}']

    strings.append(f'$\mu = {mean:.3g}$')
    strings.append(f'$\sigma = {std:.3g}$')

    if lower_percent_text:
        strings.append(lower_percent_text)
    if higher_percent_text:
        strings.append(higher_percent_text)

    props = dict(boxstyle='round',
                 facecolor='white',
                 alpha=0.75,
                 edgecolor='grey')
    ax.text(right - (right - left) * 0.05,
            0.85 * top,
            '\n'.join(strings),
            bbox=props,
            ha='right',
            va='top')

    ax.legend(loc='lower right')
Esempio n. 5
0
def plot_data_and_fits(
    result: ResultLike,
    wavelength: float,
    axis: Axis,
    center_λ: float | None = None,
    main_irf_nr: int = 0,
    linlog: bool = False,
    linthresh: float = 1,
    divide_by_scale: bool = True,
    per_axis_legend: bool = False,
    y_label: str = "a.u.",
    cycler: Cycler | None = PlotStyle().data_cycler_solid,
) -> None:
    """Plot data and fits for a given ``wavelength`` on a given ``axis``.

    If the wavelength isn't part of a dataset, that dataset will be skipped.

    Parameters
    ----------
    result : ResultLike
        Data structure which can be converted to a mapping.
    wavelength : float
        Wavelength to plot data and fits for.
    axis: Axis
        Axis to plot the data and fits on.
    center_λ: float | None
        Center wavelength (λ in nm)
    main_irf_nr : int
        Index of the main ``irf`` component when using an ``irf``
        parametrized with multiple peaks. Defaults to 0.
    linlog : bool
        Whether to use 'symlog' scale or not. Defaults to False.
    linthresh : float
        A single float which defines the range (-x, x), within which the plot is linear.
        This avoids having the plot go to infinity around zero. Defaults to 1.
    divide_by_scale : bool
        Whether or not to divide the data by the dataset scale used for optimization.
        Defaults to True.
    per_axis_legend: bool
        Whether to use a legend per plot or for the whole figure. Defaults to False.
    y_label: str
        Label used for the y-axis of each subplot.
    cycler : Cycler | None
        Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid.

    See Also
    --------
    plot_fit_overview
    """
    result_map = result_dataset_mapping(result)
    add_cycler_if_not_none(axis, cycler)
    for dataset_name in result_map.keys():
        spectral_coords = result_map[dataset_name].coords["spectral"].values
        if spectral_coords.min() <= wavelength <= spectral_coords.max():
            result_data = result_map[dataset_name].sel(spectral=[wavelength],
                                                       method="nearest")
            scale = extract_dataset_scale(result_data, divide_by_scale)
            irf_loc = extract_irf_location(result_data, center_λ, main_irf_nr)
            result_data = result_data.assign_coords(
                time=result_data.coords["time"] - irf_loc)
            (result_data.data / scale).plot(x="time",
                                            ax=axis,
                                            label=f"{dataset_name}_data")
            (result_data.fitted_data / scale).plot(x="time",
                                                   ax=axis,
                                                   label=f"{dataset_name}_fit")
        else:
            [next(axis._get_lines.prop_cycler) for _ in range(2)]
    if linlog:
        axis.set_xscale("symlog", linthresh=linthresh)
    axis.set_ylabel(y_label)
    if per_axis_legend is True:
        axis.legend()