def test_make_2d_axes_behaviour():
    np.random.seed(0)

    def calc_n(axes):
        """Compute the number of upper, lower and diagonal plots."""
        n = {'upper': 0, 'lower': 0, 'diagonal': 0}
        for y, row in axes.iterrows():
            for x, ax in row.iteritems():
                if ax is not None:
                    n[ax.position] += 1
        return n

    # Check for only paramnames_x
    paramnames_x = ['A', 'B', 'C', 'D']
    nx = len(paramnames_x)
    for upper in [True, False]:
        for lower in [True, False]:
            for diagonal in [True, False]:
                fig, axes = make_2d_axes(paramnames_x,
                                         upper=upper,
                                         lower=lower,
                                         diagonal=diagonal)
                ns = calc_n(axes)
                assert (ns['upper'] == upper * nx * (nx - 1) // 2)
                assert (ns['lower'] == lower * nx * (nx - 1) // 2)
                assert (ns['diagonal'] == diagonal * nx)

    plt.close("all")

    for paramnames_y in [['A', 'B', 'C', 'D'], ['A', 'C', 'B', 'D'],
                         ['D', 'C', 'B', 'A'], ['C', 'B', 'A'],
                         ['E', 'F', 'G', 'H'], ['A', 'B', 'E', 'F'],
                         ['B', 'E', 'A', 'F'], ['B', 'F', 'A', 'H', 'G'],
                         ['B', 'A', 'H', 'G']]:
        params = [paramnames_x, paramnames_y]
        all_params = paramnames_x + paramnames_y

        nu, nl, nd = 0, 0, 0
        for x in paramnames_x:
            for y in paramnames_y:
                if x == y:
                    nd += 1
                elif all_params.index(x) < all_params.index(y):
                    nl += 1
                elif all_params.index(x) > all_params.index(y):
                    nu += 1

        for upper in [True, False]:
            for lower in [True, False]:
                for diagonal in [True, False]:
                    fig, axes = make_2d_axes(params,
                                             upper=upper,
                                             lower=lower,
                                             diagonal=diagonal)
                    ns = calc_n(axes)
                    assert (ns['upper'] == upper * nu)
                    assert (ns['lower'] == lower * nl)
                    assert (ns['diagonal'] == diagonal * nd)
        plt.close("all")
Esempio n. 2
0
    def draw(self, labels, tex={}):
        """Draw a new triangular grid for list of parameters labels.

        Parameters
        ----------
            labels: list(str)
                labels for the triangular grid.

        """
        # Remove any existing axes
        for y, row in self.ax.iterrows():
            for x, ax in row.iteritems():
                if ax is not None:
                    if x == y:
                        self.fig.delaxes(ax.twin)
                    self.fig.delaxes(ax)

        # Set up the axes
        _, self.ax = make_2d_axes(labels,
                                  upper=False,
                                  tex=tex,
                                  fig=self.fig,
                                  subplot_spec=self.gridspec)

        # Plot no points  points.
        for y, row in self.ax.iterrows():
            for x, ax in row.iteritems():
                if ax is not None:
                    if x == y:
                        ax.twin.plot([None], [None], 'k-')
                    else:
                        ax.plot([None], [None], 'k.')
def test_2d_axlines_axspans(axesparams, params, values, upper):
    values = np.array(values)
    line_kwargs = dict(c='k', ls='--', lw=0.5)
    span_kwargs = dict(c='k', alpha=0.5)
    fig, axes = make_2d_axes(axesparams, upper=upper)
    axes.axlines(params, values, **line_kwargs)
    axes.axspans(params, values - 0.1, values + 0.1, **span_kwargs)
    plt.close("all")
def test_2d_axlines_axspans_error(params, values):
    values = np.array(values)
    axesparams = ['A', 'B', 'C', 'D']
    fig, axes = make_2d_axes(axesparams)
    with pytest.raises(ValueError):
        axes.axlines(params, values)
    with pytest.raises(ValueError):
        axes.axspans(params, values - 0.1, values + 0.1)
    plt.close("all")
def test_make_2d_axes_ticks(upper, ticks):
    xticks = [0.1, 0.4, 0.7]
    yticks = [0.2, 0.5, 0.8]
    paramnames = ["x0", "x1", "x2", "x3"]
    for k in paramnames:
        fig, axes = make_2d_axes(paramnames, upper=upper, ticks=ticks)
        axes[k][k].set_xticks(xticks)
        axes[k][k].set_yticks(yticks)
        for i, row in axes.iterrows():
            for j, ax in row.iteritems():
                if ax is None:
                    break
                if i == k:
                    assert np.array_equal(yticks, ax.get_yticks())
                else:
                    assert not np.array_equal(yticks, ax.get_yticks())
                if j == k:
                    assert np.array_equal(xticks, ax.get_xticks())
                else:
                    assert not np.array_equal(xticks, ax.get_xticks())
        plt.close("all")
    with pytest.raises(ValueError):
        make_2d_axes(paramnames, upper=upper, ticks='spam')
Esempio n. 6
0
def test_2d_axes_limits():
    numpy.random.seed(0)
    paramnames = ['A', 'B', 'C', 'D']
    fig, axes = make_2d_axes(paramnames)
    for x in paramnames:
        for y in paramnames:
            a, b, c, d = numpy.random.rand(4)
            axes[x][y].set_xlim(a, b)
            for z in paramnames:
                assert(axes[x][z].get_xlim() == (a, b))
                assert(axes[z][x].get_ylim() == (a, b))

            axes[x][y].set_ylim(c, d)
            for z in paramnames:
                assert(axes[y][z].get_xlim() == (c, d))
                assert(axes[z][y].get_ylim() == (c, d))
Esempio n. 7
0
def test_make_2d_axes_inputs_outputs():
    paramnames_x = ['A', 'B', 'C', 'D']
    paramnames_y = ['B', 'A', 'D', 'E']

    # 2D axes
    fig, axes = make_2d_axes([paramnames_x, paramnames_y])
    assert(isinstance(fig, Figure))
    assert(isinstance(axes, DataFrame))
    assert_array_equal(axes.index, paramnames_y)
    assert_array_equal(axes.columns, paramnames_x)

    # Axes labels
    for p, ax in axes.iloc[:, 0].iteritems():
        assert(ax.get_ylabel() == p)

    for p, ax in axes.iloc[-1].iteritems():
        assert(ax.get_xlabel() == p)

    for ax in axes.iloc[:-1, 1:].values.flatten():
        assert(ax.get_xlabel() == '')
        assert(ax.get_ylabel() == '')

    # Check fig argument
    fig0 = plt.figure()
    fig, axes = make_2d_axes(paramnames_x)
    assert(fig is not fig0)
    fig, axes = make_2d_axes(paramnames_x, fig=fig0)
    assert(fig is fig0)
    plt.close("all")

    # Check gridspec argument
    grid = gs.GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[3, 1])
    g00 = grid[0, 0]
    fig, axes = make_2d_axes(paramnames_x, subplot_spec=g00)
    assert(g00 is axes.iloc[0, 0].get_subplotspec().get_topmost_subplotspec())

    # Check unexpected kwargs
    with pytest.raises(TypeError):
        make_2d_axes(paramnames_x, foo='bar')
Esempio n. 8
0
    def plot_2d(self, axes, *args, **kwargs):
        """Create an array of 2D plots.

        To avoid intefering with y-axis sharing, one-dimensional plots are
        created on a separate axis, which is monkey-patched onto the argument
        ax as the attribute ax.twin.

        Parameters
        ----------
        axes: plotting axes
            Can be:
                - list(str) if the x and y axes are the same
                - [list(str),list(str)] if the x and y axes are different
                - pandas.DataFrame(matplotlib.axes.Axes)
            If a pandas.DataFrame is provided as an existing set of axes, then
            this is used for creating the plot. Otherwise a new set of axes are
            created using the list or lists of strings.

        types: dict, optional
            What type (or types) of plots to produce. Takes the keys 'diagonal'
            for the 1D plots and 'lower' and 'upper' for the 2D plots.
            The options for 'diagonal are:
                - 'kde'
                - 'hist'
                - 'astropyhist'
            The options for 'lower' and 'upper' are:
                - 'kde'
                - 'scatter'
                - 'hist'
                - 'fastkde'
            Default: {'diagonal': 'kde', 'lower': 'kde', 'upper':'scatter'}

        diagonal_kwargs, lower_kwargs, upper_kwargs: dict, optional
            kwargs for the diagonal (1D)/lower or upper (2D) plots. This is
            useful when there is a conflict of kwargs for different types of
            plots.  Note that any kwargs directly passed to plot_2d will
            overwrite any kwarg with the same key passed to <sub>_kwargs.
            Default: {}

        Returns
        -------
        fig: matplotlib.figure.Figure
            New or original (if supplied) figure object

        axes: pandas.DataFrame of matplotlib.axes.Axes
            Pandas array of axes objects

        """
        default_types = {'diagonal': 'kde', 'lower': 'kde', 'upper': 'scatter'}
        types = kwargs.pop('types', default_types)
        local_kwargs = {
            pos: kwargs.pop('%s_kwargs' % pos, {})
            for pos in default_types
        }

        for pos in local_kwargs:
            local_kwargs[pos].update(kwargs)

        if not isinstance(axes, pandas.DataFrame):
            fig, axes = make_2d_axes(axes,
                                     tex=self.tex,
                                     upper=('upper' in types),
                                     lower=('lower' in types),
                                     diagonal=('diagonal' in types))
        else:
            fig = axes.bfill().to_numpy().flatten()[0].figure

        for y, row in axes.iterrows():
            for x, ax in row.iteritems():
                if ax is not None:
                    pos = ax.position
                    ax_ = ax.twin if x == y else ax
                    plot_type = types.get(pos, None)
                    lkwargs = local_kwargs.get(pos, {})
                    self.plot(ax_, x, y, plot_type=plot_type, *args, **lkwargs)

        return fig, axes
Esempio n. 9
0
 def __init__(self, fig, gridspec):
     super().__init__(fig, gridspec)
     self.fig.delaxes(self.ax)
     _, self.ax = make_2d_axes([], fig=self.fig, subplot_spec=self.gridspec)
Esempio n. 10
0
    def plot_2d(self, axes, *args, **kwargs):
        """Create an array of 2D plots.

        To avoid intefering with y-axis sharing, one-dimensional plots are
        created on a separate axis, which is monkey-patched onto the argument
        ax as the attribute ax.twin.

        Parameters
        ----------
        axes: plotting axes
            Can be:
                - list(str) if the x and y axes are the same
                - [list(str),list(str)] if the x and y axes are different
                - pandas.DataFrame(matplotlib.axes.Axes)
            If a pandas.DataFrame is provided as an existing set of axes, then
            this is used for creating the plot. Otherwise a new set of axes are
            created using the list or lists of strings.

        types: dict, optional
            What type (or types) of plots to produce. Takes the keys 'diagonal'
            for the 1D plots and 'lower' and 'upper' for the 2D plots.
            The options for 'diagonal are:
                - 'kde'
                - 'hist'
                - 'astropyhist'
            The options for 'lower' and 'upper' are:
                - 'kde'
                - 'scatter'
                - 'hist'
                - 'fastkde'
            Default: {'diagonal': 'kde', 'lower': 'kde', 'upper':'scatter'}

        diagonal_kwargs, lower_kwargs, upper_kwargs: dict, optional
            kwargs for the diagonal (1D)/lower or upper (2D) plots. This is
            useful when there is a conflict of kwargs for different types of
            plots.  Note that any kwargs directly passed to plot_2d will
            overwrite any kwarg with the same key passed to <sub>_kwargs.
            Default: {}

        Returns
        -------
        fig: matplotlib.figure.Figure
            New or original (if supplied) figure object

        axes: pandas.DataFrame of matplotlib.axes.Axes
            Pandas array of axes objects

        """
        default_types = {'diagonal': 'kde', 'lower': 'kde', 'upper': 'scatter'}
        types = kwargs.pop('types', default_types)
        diagonal = kwargs.pop('diagonal', True)
        if isinstance(types, list) or isinstance(types, str):
            from warnings import warn
            warn(
                "MCMCSamples.plot_2d's argument 'types' might stop accepting "
                "str or list(str) as input in the future. It takes a "
                "dictionary as input, now, with keys 'diagonal' for the 1D "
                "plots and 'lower' and 'upper' for the 2D plots. 'diagonal' "
                "accepts the values 'kde' or 'hist' and both 'lower' and "
                "'upper' accept the values 'kde' or 'scatter'. "
                "Default: {'diagonal': 'kde', 'lower': 'kde'}.", FutureWarning)

            if isinstance(types, str):
                types = {'lower': types}
                if diagonal:
                    types['diagonal'] = types['lower']
            elif isinstance(types, list):
                types = {'lower': types[0], 'upper': types[-1]}
                if diagonal:
                    types['diagonal'] = types['lower']

        local_kwargs = {
            pos: kwargs.pop('%s_kwargs' % pos, {})
            for pos in default_types
        }

        for pos in local_kwargs:
            local_kwargs[pos].update(kwargs)

        if not isinstance(axes, pandas.DataFrame):
            fig, axes = make_2d_axes(axes,
                                     tex=self.tex,
                                     upper=('upper' in types),
                                     lower=('lower' in types),
                                     diagonal=('diagonal' in types))
        else:
            fig = axes.values[~axes.isna()][0].figure

        for y, row in axes.iterrows():
            for x, ax in row.iteritems():
                if ax is not None:
                    pos = ax.position
                    ax_ = ax.twin if x == y else ax
                    plot_type = types.get(pos, None)
                    lkwargs = local_kwargs.get(pos, {})
                    self.plot(ax_, x, y, plot_type=plot_type, *args, **lkwargs)

        return fig, axes