Beispiel #1
0
def test_nest_level():
    assert (nest_level(0) == 0)
    assert (nest_level([]) == 1)
    assert (nest_level(['a']) == 1)
    assert (nest_level(['a', 'b']) == 1)
    assert (nest_level([['a'], 'b']) == 2)
    assert (nest_level(['a', ['b']]) == 2)
    assert (nest_level([['a'], ['b']]) == 2)
Beispiel #2
0
def make_2d_axes(params, **kwargs):
    """Create a set of axes for plotting 2D marginalised posteriors.

    Parameters
    ----------
        params: lists of parameters
            Can be either:
            * list(str) if the x and y axes are the same
            * [list(str),list(str)] if the x and y axes are different
            Strings indicate the names of the parameters

        tex: dict(str:str), optional
            Dictionary mapping params to tex plot labels.
            Default: params

        upper, lower, diagonal: logical, optional
            Whether to create 2D marginalised plots above or below the
            diagonal, or to create a 1D marginalised plot on the diagonal.
            Default: True

        fig: matplotlib.figure.Figure, optional
            Figure to plot on.
            Default: matplotlib.pyplot.figure()

        subplot_spec: matplotlib.gridspec.GridSpec, optional
            gridspec to plot array as part of a subfigure.
            Default: None

    Returns
    -------
    fig: matplotlib.figure.Figure
        New or original (if supplied) figure object

    axes: pandas.DataFrame(matplotlib.axes.Axes)
        Pandas array of axes objects

    """
    if nest_level(params) == 2:
        xparams, yparams = params
    else:
        xparams = yparams = params

    upper = kwargs.pop('upper', True)
    lower = kwargs.pop('lower', True)
    diagonal = kwargs.pop('diagonal', True)

    axes = pandas.DataFrame(index=numpy.atleast_1d(yparams),
                            columns=numpy.atleast_1d(xparams),
                            dtype=object)
    axes[:][:] = None
    all_params = list(axes.columns) + list(axes.index)

    for j, y in enumerate(axes.index):
        for i, x in enumerate(axes.columns):
            if all_params.index(x) < all_params.index(y):
                if lower:
                    axes[x][y] = -1
            elif all_params.index(x) > all_params.index(y):
                if upper:
                    axes[x][y] = +1
            elif diagonal:
                axes[x][y] = 0

    axes.dropna(axis=0, how='all', inplace=True)
    axes.dropna(axis=1, how='all', inplace=True)

    tex = kwargs.pop('tex', {})
    tex = {p: tex[p] if p in tex else p for p in all_params}
    fig = kwargs.pop('fig') if 'fig' in kwargs else plt.figure()
    if 'subplot_spec' in kwargs:
        grid = SGS(*axes.shape,
                   hspace=0,
                   wspace=0,
                   subplot_spec=kwargs.pop('subplot_spec'))
    else:
        grid = GS(*axes.shape, hspace=0, wspace=0)

    if kwargs:
        raise TypeError('Unexpected **kwargs: %r' % kwargs)

    if axes.size == 0:
        return fig, axes
    position = axes.copy()
    axes[:][:] = None
    for j, y in enumerate(axes.index):
        for i, x in enumerate(axes.columns):
            if position[x][y] is not None:
                sx = list(axes[x].dropna())
                sx = sx[0] if sx else None
                sy = list(axes.T[y].dropna())
                sy = sy[0] if sy else None
                axes[x][y] = fig.add_subplot(grid[j, i], sharex=sx, sharey=sy)

                if position[x][y] == 0:
                    axes[x][y].twin = axes[x][y].twinx()
                    axes[x][y].twin.set_yticks([])
                    axes[x][y].twin.set_ylim(0, 1.1)
                    axes[x][y].set_zorder(axes[x][y].twin.get_zorder() + 1)
                    axes[x][y].lines = axes[x][y].twin.lines
                    axes[x][y].patches = axes[x][y].twin.patches
                    axes[x][y].collections = axes[x][y].twin.collections
                    axes[x][y].containers = axes[x][y].twin.containers
                    make_diagonal(axes[x][y])
                    axes[x][y].position = 'diagonal'
                elif position[x][y] == 1:
                    axes[x][y].position = 'upper'
                elif position[x][y] == -1:
                    axes[x][y].position = 'lower'

    for y, ax in axes.bfill(axis=1).iloc[:, 0].dropna().iteritems():
        ax.set_ylabel(tex[y])

    for x, ax in axes.ffill(axis=0).iloc[-1, :].dropna().iteritems():
        ax.set_xlabel(tex[x])

    for y, ax in axes.iterrows():
        ax_ = ax.dropna()
        if len(ax_):
            for a in ax_[1:]:
                a.tick_params('y', left=False, labelleft=False)

    for x, ax in axes.iteritems():
        ax_ = ax.dropna()
        if len(ax_):
            for a in ax_[:-1]:
                a.tick_params('x', bottom=False, labelbottom=False)

    for y, ax in axes.bfill(axis=1).iloc[:, 0].dropna().iteritems():
        ax.yaxis.set_major_locator(MaxNLocator(3, prune='both'))

    for x, ax in axes.ffill(axis=0).iloc[-1, :].dropna().iteritems():
        ax.xaxis.set_major_locator(MaxNLocator(3, prune='both'))

    return fig, axes
Beispiel #3
0
def make_2d_axes(params, **kwargs):
    """Create a set of axes for plotting 2D marginalised posteriors.

    Parameters
    ----------
        params: lists of parameters
            Can be either:
            * list(str) if the x and y axes are the same
            * [list(str),list(str)] if the x and y axes are different
            Strings indicate the names of the parameters

        tex: dict(str:str), optional
            Dictionary mapping params to tex plot labels.
            Default: params

        upper, lower, diagonal: logical, optional
            Whether to create 2D marginalised plots above or below the
            diagonal, or to create a 1D marginalised plot on the diagonal.
            Default: True

        fig: matplotlib.figure.Figure, optional
            Figure to plot on.
            Default: matplotlib.pyplot.figure()

        ticks: str
            If 'outer', plot ticks only on the very left and very bottom.
            If 'inner', plot ticks also in inner subplots.
            If None, plot no ticks at all.
            Default: 'outer'

        subplot_spec: matplotlib.gridspec.GridSpec, optional
            gridspec to plot array as part of a subfigure.
            Default: None

    Returns
    -------
    fig: matplotlib.figure.Figure
        New or original (if supplied) figure object

    axes: pandas.DataFrame(matplotlib.axes.Axes)
        Pandas array of axes objects

    """
    if nest_level(params) == 2:
        xparams, yparams = params
    else:
        xparams = yparams = params

    ticks = kwargs.pop('ticks', 'outer')
    upper = kwargs.pop('upper', True)
    lower = kwargs.pop('lower', True)
    diagonal = kwargs.pop('diagonal', True)

    axes = AxesDataFrame(index=np.atleast_1d(yparams),
                         columns=np.atleast_1d(xparams),
                         dtype=object)
    axes[:][:] = None
    all_params = list(axes.columns) + list(axes.index)

    for j, y in enumerate(axes.index):
        for i, x in enumerate(axes.columns):
            if all_params.index(x) < all_params.index(y):
                if lower:
                    axes[x][y] = -1
            elif all_params.index(x) > all_params.index(y):
                if upper:
                    axes[x][y] = +1
            elif diagonal:
                axes[x][y] = 0

    axes.dropna(axis=0, how='all', inplace=True)
    axes.dropna(axis=1, how='all', inplace=True)

    tex = kwargs.pop('tex', {})
    tex = {p: tex[p] if p in tex else p for p in all_params}
    fig = kwargs.pop('fig') if 'fig' in kwargs else plt.figure()
    spec = kwargs.pop('subplot_spec', None)
    if axes.shape[0] != 0 and axes.shape[1] != 0:
        if spec is not None:
            grid = SGS(*axes.shape, hspace=0, wspace=0, subplot_spec=spec)
        else:
            grid = GS(*axes.shape, hspace=0, wspace=0)

    if kwargs:
        raise TypeError('Unexpected **kwargs: %r' % kwargs)

    if axes.size == 0:
        return fig, axes
    position = axes.copy()
    axes[:][:] = None
    for j, y in enumerate(axes.index[::-1]):
        for i, x in enumerate(axes.columns):
            if position[x][y] is not None:
                sx = list(axes[x].dropna())
                sx = sx[0] if sx else None
                sy = list(axes.T[y].dropna())
                sy = sy[0] if sy else None
                axes[x][y] = fig.add_subplot(grid[axes.index.size-1-j, i],
                                             sharex=sx, sharey=sy)

                if position[x][y] == 0:
                    axes[x][y].twin = axes[x][y].twinx()
                    axes[x][y].twin.set_yticks([])
                    axes[x][y].twin.set_ylim(0, 1.1)
                    axes[x][y].set_zorder(axes[x][y].twin.get_zorder() + 1)
                    axes[x][y].lines = axes[x][y].twin.lines
                    axes[x][y].patches = axes[x][y].twin.patches
                    axes[x][y].collections = axes[x][y].twin.collections
                    axes[x][y].containers = axes[x][y].twin.containers
                    make_diagonal(axes[x][y])
                    axes[x][y].position = 'diagonal'
                    axes[x][y].twin.xaxis.set_major_locator(
                        MaxNLocator(3, prune='both'))
                else:
                    if position[x][y] == 1:
                        axes[x][y].position = 'upper'
                    elif position[x][y] == -1:
                        axes[x][y].position = 'lower'
                    axes[x][y].yaxis.set_major_locator(
                        MaxNLocator(3, prune='both'))
                axes[x][y].xaxis.set_major_locator(
                    MaxNLocator(3, prune='both'))

    for y, ax in axes.bfill(axis=1).iloc[:, 0].dropna().iteritems():
        ax.set_ylabel(tex[y])

    for x, ax in axes.ffill(axis=0).iloc[-1, :].dropna().iteritems():
        ax.set_xlabel(tex[x])

    # left and right ticks and labels
    for y, ax in axes.iterrows():
        ax_ = ax.dropna()
        if len(ax_) and ticks == 'inner':
            for i, a in enumerate(ax_):
                if i == 0:  # first column
                    if a.position == 'diagonal' and len(ax_) == 1:
                        a.tick_params('y', left=False, labelleft=False)
                    else:
                        a.tick_params('y', left=True, labelleft=True)
                elif a.position == 'diagonal':  # not first column
                    tl = a.yaxis.majorTicks[0].tick1line.get_markersize()
                    a.tick_params('y', direction='out', length=tl/2,
                                  left=True, labelleft=False)
                else:  # not diagonal and not first column
                    a.tick_params('y', direction='inout',
                                  left=True, labelleft=False)
        elif len(ax_) and ticks == 'outer':  # no inner ticks
            for a in ax_[1:]:
                a.tick_params('y', left=False, labelleft=False)
        elif len(ax_) and ticks is None:  # no ticks at all
            for a in ax_:
                a.tick_params('y', left=False, right=False,
                              labelleft=False, labelright=False)
        else:
            raise ValueError(
                "ticks=%s was requested, but ticks can only be one of "
                "['outer', 'inner', None]." % ticks)

    # bottom and top ticks and labels
    for x, ax in axes.iteritems():
        ax_ = ax.dropna()
        if len(ax_):
            if ticks == 'inner':
                for i, a in enumerate(ax_):
                    if i == len(ax_) - 1:  # bottom row
                        a.tick_params('x', bottom=True, labelbottom=True)
                    else:  # not bottom row
                        a.tick_params('x', direction='inout',
                                      bottom=True, labelbottom=False)
                        if a.position == 'diagonal':
                            a.twin.tick_params('x', direction='inout',
                                               bottom=True, labelbottom=False)
            elif ticks == 'outer':  # no inner ticks
                for a in ax_[:-1]:
                    a.tick_params('x', bottom=False, labelbottom=False)
            elif ticks is None:  # no ticks at all
                for a in ax_:
                    a.tick_params('x', bottom=False, top=False,
                                  labelbottom=False, labeltop=False)
            else:
                raise ValueError(
                    "ticks=%s was requested, but ticks can only be one of "
                    "['outer', 'inner', None]." % ticks)

    return fig, axes