示例#1
0
    def plot_histograms(
        self,
        keys: Sequence[str],
        ax: Optional[plt.Axes] = None,
        *,
        labels: Optional[Sequence[str]] = None,
    ) -> plt.Axes:
        """Plots integrated histograms of metric values corresponding to keys

        Args:
            keys: List of metric keys for which an integrated histogram should be plot
            ax: The axis to plot on. If None, we generate one.

        Returns:
            The axis that was plotted on.

        Raises:
            ValueError if the metric values are not single floats.
        """
        show_plot = not ax
        if not ax:
            fig, ax = plt.subplots(1, 1)

        if isinstance(keys, str):
            keys = [keys]
        if not labels:
            labels = keys
        colors = ['b', 'r', 'k', 'g', 'c', 'm']
        for key, label, color in zip(keys, labels, cycle(colors)):
            metrics = self[key]
            if not all(len(k) == 1 for k in metrics.values()):
                raise ValueError(
                    'Histograms are only supported if all values in a metric '
                    + 'are single metric values.'
                    + f'{key} has metric values {metrics.values()}'
                )
            vis.integrated_histogram(
                [self.value_to_float(v) for v in metrics.values()],
                ax,
                label=label,
                color=color,
                title=key.replace('_', ' ').title(),
            )
        if show_plot:
            fig.show()

        return ax
示例#2
0
def test_multiple_plots():
    _, ax = plt.subplots(1, 1)
    n = 53
    data = np.random.random_sample((2, n))
    integrated_histogram(
        data[0],
        ax,
        color='r',
        label='data_1',
        median_line=False,
        mean_line=True,
        mean_label='mean_1',
    )
    integrated_histogram(data[1],
                         ax,
                         color='k',
                         label='data_2',
                         median_label='median_2')
    assert ax.get_title() == 'N=53'
    for line in ax.get_lines():
        assert line.get_color() in ['r', 'k']
        assert line.get_label() in ['data_1', 'data_2', 'mean_1', 'median_2']
示例#3
0
def test_integrated_histogram(data):
    ax = integrated_histogram(
        data,
        title='Test Plot',
        axis_label='Y Axis Label',
        color='r',
        label='line label',
        cdf_on_x=True,
        show_zero=True,
    )
    assert ax.get_title() == 'Test Plot'
    assert ax.get_ylabel() == 'Y Axis Label'
    assert len(ax.get_lines()) == 2
    for line in ax.get_lines():
        assert line.get_color() == 'r'