def test_legend_uses_class_numbers_when_no_labels_are_passed(self, ax: Axes):
     """Expect the legend to default to numbers if no labels are passed"""
     expected = {
         "Class 0 $Lift = 3.17$ ",
         "Class 1 $Lift = 2.70$ ",
         "Class 2 $Lift = 2.70$ ",
         "Baseline",
     }
     result = {text.get_text() for text in ax.get_legend().get_texts()}
     assert result == expected
 def test_legend_is_labelled_correctly_when_labels_are_passed(self, ax_labels: Axes):
     """Expect the legend to include labels"""
     expected = {
         "Class Setosa $Lift = 3.17$ ",
         "Class Versicolor $Lift = 2.70$ ",
         "Class Virginica $Lift = 2.70$ ",
         "Baseline",
     }
     result = {text.get_text() for text in ax_labels.get_legend().get_texts()}
     assert result == expected
Exemple #3
0
def _upsample_others(ax: Axes, freq, kwargs):
    legend = ax.get_legend()
    lines, labels = _replot_ax(ax, freq, kwargs)
    _replot_ax(ax, freq, kwargs)

    other_ax = None
    if hasattr(ax, "left_ax"):
        other_ax = ax.left_ax
    if hasattr(ax, "right_ax"):
        other_ax = ax.right_ax

    if other_ax is not None:
        rlines, rlabels = _replot_ax(other_ax, freq, kwargs)
        lines.extend(rlines)
        labels.extend(rlabels)

    if legend is not None and kwargs.get("legend", True) and len(lines) > 0:
        title = legend.get_title().get_text()
        if title == "None":
            title = None
        ax.legend(lines, labels, loc="best", title=title)
Exemple #4
0
def plot_1d_to_axis(spectra: Union[Spectrum1D, Spectrum1DCollection],
                    ax: Axes,
                    labels: Optional[Sequence[str]] = None,
                    **mplargs) -> None:
    """Plot a (collection of) 1D spectrum lines to matplotlib axis

    Where there are two identical x-values in a row, plotting will restart
    to avoid creating a vertical line

    Parameters
    ----------
    spectra
        Spectrum1D or Spectrum1DCollection to plot
    ax
        Matplotlib axes to which spectra will be drawn
    labels
        A sequence of labels corresponding to the sequence of lines in
        spectra, used to label each line. If this is None, the
        label(s) contained in spectra.metadata['label'] (Spectrum1D) or
        spectra.metadata['line_data'][i]['label']
        (Spectrum1DCollection) will be used. To disable labelling for a
        specific line, pass an empty string.
    **mplargs
        Keyword arguments passed to Axes.plot
    """

    if isinstance(spectra, Spectrum1D):
        return plot_1d_to_axis(Spectrum1DCollection.from_spectra([spectra]),
                               ax=ax,
                               labels=labels,
                               **mplargs)

    try:
        assert isinstance(spectra, Spectrum1DCollection)
    except AssertionError:
        raise TypeError("spectra should be a Spectrum1D or "
                        "Spectrum1DCollection")

    if isinstance(labels, str): labels = [labels]
    if labels is not None and len(labels) != len(spectra):
        raise ValueError(
            f"The length of labels (got {len(labels)}) should be the "
            f"same as the number of lines to plot (got {len(spectra)})")

    # Find where there are two identical x_data points in a row
    breakpoints = (np.where(
        spectra.x_data.magnitude[:-1] == spectra.x_data.magnitude[1:])[0] +
                   1).tolist()
    breakpoints = [0] + breakpoints + [None]

    if labels is None:
        labels = [spec.metadata.get('label', None) for spec in spectra]

    for label, spectrum in zip(labels, spectra):
        # Plot each line in segments
        for x0, x1 in zip(breakpoints[:-1], breakpoints[1:]):
            # Keep colour consistent across segments
            if x0 == 0:
                color = None
            else:
                # Only add legend label to the first segment
                label = None
                color = p[-1].get_color()

            p = ax.plot(spectrum.get_bin_centres().magnitude[x0:x1],
                        spectrum.y_data.magnitude[x0:x1],
                        color=color,
                        label=label,
                        **mplargs)

    # Update legend if it exists, in case new labels have been added
    legend = ax.get_legend()
    if legend is not None:
        ax.legend()

    ax.set_xlim(left=min(spectra.x_data.magnitude),
                right=max(spectra.x_data.magnitude))

    _set_x_tick_labels(ax, spectra.x_tick_labels, spectra.x_data)