def test_legend_uses_class_numbers_when_no_labels_are_passed(self, ax: Axes): """Expect the legend to default to numbers if no labels are passed""" expected = { "Class 0 $Lift = 3.17$ ", "Class 1 $Lift = 2.70$ ", "Class 2 $Lift = 2.70$ ", "Baseline", } result = {text.get_text() for text in ax.get_legend().get_texts()} assert result == expected
def test_legend_is_labelled_correctly_when_labels_are_passed(self, ax_labels: Axes): """Expect the legend to include labels""" expected = { "Class Setosa $Lift = 3.17$ ", "Class Versicolor $Lift = 2.70$ ", "Class Virginica $Lift = 2.70$ ", "Baseline", } result = {text.get_text() for text in ax_labels.get_legend().get_texts()} assert result == expected
def _upsample_others(ax: Axes, freq, kwargs): legend = ax.get_legend() lines, labels = _replot_ax(ax, freq, kwargs) _replot_ax(ax, freq, kwargs) other_ax = None if hasattr(ax, "left_ax"): other_ax = ax.left_ax if hasattr(ax, "right_ax"): other_ax = ax.right_ax if other_ax is not None: rlines, rlabels = _replot_ax(other_ax, freq, kwargs) lines.extend(rlines) labels.extend(rlabels) if legend is not None and kwargs.get("legend", True) and len(lines) > 0: title = legend.get_title().get_text() if title == "None": title = None ax.legend(lines, labels, loc="best", title=title)
def plot_1d_to_axis(spectra: Union[Spectrum1D, Spectrum1DCollection], ax: Axes, labels: Optional[Sequence[str]] = None, **mplargs) -> None: """Plot a (collection of) 1D spectrum lines to matplotlib axis Where there are two identical x-values in a row, plotting will restart to avoid creating a vertical line Parameters ---------- spectra Spectrum1D or Spectrum1DCollection to plot ax Matplotlib axes to which spectra will be drawn labels A sequence of labels corresponding to the sequence of lines in spectra, used to label each line. If this is None, the label(s) contained in spectra.metadata['label'] (Spectrum1D) or spectra.metadata['line_data'][i]['label'] (Spectrum1DCollection) will be used. To disable labelling for a specific line, pass an empty string. **mplargs Keyword arguments passed to Axes.plot """ if isinstance(spectra, Spectrum1D): return plot_1d_to_axis(Spectrum1DCollection.from_spectra([spectra]), ax=ax, labels=labels, **mplargs) try: assert isinstance(spectra, Spectrum1DCollection) except AssertionError: raise TypeError("spectra should be a Spectrum1D or " "Spectrum1DCollection") if isinstance(labels, str): labels = [labels] if labels is not None and len(labels) != len(spectra): raise ValueError( f"The length of labels (got {len(labels)}) should be the " f"same as the number of lines to plot (got {len(spectra)})") # Find where there are two identical x_data points in a row breakpoints = (np.where( spectra.x_data.magnitude[:-1] == spectra.x_data.magnitude[1:])[0] + 1).tolist() breakpoints = [0] + breakpoints + [None] if labels is None: labels = [spec.metadata.get('label', None) for spec in spectra] for label, spectrum in zip(labels, spectra): # Plot each line in segments for x0, x1 in zip(breakpoints[:-1], breakpoints[1:]): # Keep colour consistent across segments if x0 == 0: color = None else: # Only add legend label to the first segment label = None color = p[-1].get_color() p = ax.plot(spectrum.get_bin_centres().magnitude[x0:x1], spectrum.y_data.magnitude[x0:x1], color=color, label=label, **mplargs) # Update legend if it exists, in case new labels have been added legend = ax.get_legend() if legend is not None: ax.legend() ax.set_xlim(left=min(spectra.x_data.magnitude), right=max(spectra.x_data.magnitude)) _set_x_tick_labels(ax, spectra.x_tick_labels, spectra.x_data)