def test_nest_level(): assert (nest_level(0) == 0) assert (nest_level([]) == 1) assert (nest_level(['a']) == 1) assert (nest_level(['a', 'b']) == 1) assert (nest_level([['a'], 'b']) == 2) assert (nest_level(['a', ['b']]) == 2) assert (nest_level([['a'], ['b']]) == 2)
def make_2d_axes(params, **kwargs): """Create a set of axes for plotting 2D marginalised posteriors. Parameters ---------- params: lists of parameters Can be either: * list(str) if the x and y axes are the same * [list(str),list(str)] if the x and y axes are different Strings indicate the names of the parameters tex: dict(str:str), optional Dictionary mapping params to tex plot labels. Default: params upper, lower, diagonal: logical, optional Whether to create 2D marginalised plots above or below the diagonal, or to create a 1D marginalised plot on the diagonal. Default: True fig: matplotlib.figure.Figure, optional Figure to plot on. Default: matplotlib.pyplot.figure() subplot_spec: matplotlib.gridspec.GridSpec, optional gridspec to plot array as part of a subfigure. Default: None Returns ------- fig: matplotlib.figure.Figure New or original (if supplied) figure object axes: pandas.DataFrame(matplotlib.axes.Axes) Pandas array of axes objects """ if nest_level(params) == 2: xparams, yparams = params else: xparams = yparams = params upper = kwargs.pop('upper', True) lower = kwargs.pop('lower', True) diagonal = kwargs.pop('diagonal', True) axes = pandas.DataFrame(index=numpy.atleast_1d(yparams), columns=numpy.atleast_1d(xparams), dtype=object) axes[:][:] = None all_params = list(axes.columns) + list(axes.index) for j, y in enumerate(axes.index): for i, x in enumerate(axes.columns): if all_params.index(x) < all_params.index(y): if lower: axes[x][y] = -1 elif all_params.index(x) > all_params.index(y): if upper: axes[x][y] = +1 elif diagonal: axes[x][y] = 0 axes.dropna(axis=0, how='all', inplace=True) axes.dropna(axis=1, how='all', inplace=True) tex = kwargs.pop('tex', {}) tex = {p: tex[p] if p in tex else p for p in all_params} fig = kwargs.pop('fig') if 'fig' in kwargs else plt.figure() if 'subplot_spec' in kwargs: grid = SGS(*axes.shape, hspace=0, wspace=0, subplot_spec=kwargs.pop('subplot_spec')) else: grid = GS(*axes.shape, hspace=0, wspace=0) if kwargs: raise TypeError('Unexpected **kwargs: %r' % kwargs) if axes.size == 0: return fig, axes position = axes.copy() axes[:][:] = None for j, y in enumerate(axes.index): for i, x in enumerate(axes.columns): if position[x][y] is not None: sx = list(axes[x].dropna()) sx = sx[0] if sx else None sy = list(axes.T[y].dropna()) sy = sy[0] if sy else None axes[x][y] = fig.add_subplot(grid[j, i], sharex=sx, sharey=sy) if position[x][y] == 0: axes[x][y].twin = axes[x][y].twinx() axes[x][y].twin.set_yticks([]) axes[x][y].twin.set_ylim(0, 1.1) axes[x][y].set_zorder(axes[x][y].twin.get_zorder() + 1) axes[x][y].lines = axes[x][y].twin.lines axes[x][y].patches = axes[x][y].twin.patches axes[x][y].collections = axes[x][y].twin.collections axes[x][y].containers = axes[x][y].twin.containers make_diagonal(axes[x][y]) axes[x][y].position = 'diagonal' elif position[x][y] == 1: axes[x][y].position = 'upper' elif position[x][y] == -1: axes[x][y].position = 'lower' for y, ax in axes.bfill(axis=1).iloc[:, 0].dropna().iteritems(): ax.set_ylabel(tex[y]) for x, ax in axes.ffill(axis=0).iloc[-1, :].dropna().iteritems(): ax.set_xlabel(tex[x]) for y, ax in axes.iterrows(): ax_ = ax.dropna() if len(ax_): for a in ax_[1:]: a.tick_params('y', left=False, labelleft=False) for x, ax in axes.iteritems(): ax_ = ax.dropna() if len(ax_): for a in ax_[:-1]: a.tick_params('x', bottom=False, labelbottom=False) for y, ax in axes.bfill(axis=1).iloc[:, 0].dropna().iteritems(): ax.yaxis.set_major_locator(MaxNLocator(3, prune='both')) for x, ax in axes.ffill(axis=0).iloc[-1, :].dropna().iteritems(): ax.xaxis.set_major_locator(MaxNLocator(3, prune='both')) return fig, axes
def make_2d_axes(params, **kwargs): """Create a set of axes for plotting 2D marginalised posteriors. Parameters ---------- params: lists of parameters Can be either: * list(str) if the x and y axes are the same * [list(str),list(str)] if the x and y axes are different Strings indicate the names of the parameters tex: dict(str:str), optional Dictionary mapping params to tex plot labels. Default: params upper, lower, diagonal: logical, optional Whether to create 2D marginalised plots above or below the diagonal, or to create a 1D marginalised plot on the diagonal. Default: True fig: matplotlib.figure.Figure, optional Figure to plot on. Default: matplotlib.pyplot.figure() ticks: str If 'outer', plot ticks only on the very left and very bottom. If 'inner', plot ticks also in inner subplots. If None, plot no ticks at all. Default: 'outer' subplot_spec: matplotlib.gridspec.GridSpec, optional gridspec to plot array as part of a subfigure. Default: None Returns ------- fig: matplotlib.figure.Figure New or original (if supplied) figure object axes: pandas.DataFrame(matplotlib.axes.Axes) Pandas array of axes objects """ if nest_level(params) == 2: xparams, yparams = params else: xparams = yparams = params ticks = kwargs.pop('ticks', 'outer') upper = kwargs.pop('upper', True) lower = kwargs.pop('lower', True) diagonal = kwargs.pop('diagonal', True) axes = AxesDataFrame(index=np.atleast_1d(yparams), columns=np.atleast_1d(xparams), dtype=object) axes[:][:] = None all_params = list(axes.columns) + list(axes.index) for j, y in enumerate(axes.index): for i, x in enumerate(axes.columns): if all_params.index(x) < all_params.index(y): if lower: axes[x][y] = -1 elif all_params.index(x) > all_params.index(y): if upper: axes[x][y] = +1 elif diagonal: axes[x][y] = 0 axes.dropna(axis=0, how='all', inplace=True) axes.dropna(axis=1, how='all', inplace=True) tex = kwargs.pop('tex', {}) tex = {p: tex[p] if p in tex else p for p in all_params} fig = kwargs.pop('fig') if 'fig' in kwargs else plt.figure() spec = kwargs.pop('subplot_spec', None) if axes.shape[0] != 0 and axes.shape[1] != 0: if spec is not None: grid = SGS(*axes.shape, hspace=0, wspace=0, subplot_spec=spec) else: grid = GS(*axes.shape, hspace=0, wspace=0) if kwargs: raise TypeError('Unexpected **kwargs: %r' % kwargs) if axes.size == 0: return fig, axes position = axes.copy() axes[:][:] = None for j, y in enumerate(axes.index[::-1]): for i, x in enumerate(axes.columns): if position[x][y] is not None: sx = list(axes[x].dropna()) sx = sx[0] if sx else None sy = list(axes.T[y].dropna()) sy = sy[0] if sy else None axes[x][y] = fig.add_subplot(grid[axes.index.size-1-j, i], sharex=sx, sharey=sy) if position[x][y] == 0: axes[x][y].twin = axes[x][y].twinx() axes[x][y].twin.set_yticks([]) axes[x][y].twin.set_ylim(0, 1.1) axes[x][y].set_zorder(axes[x][y].twin.get_zorder() + 1) axes[x][y].lines = axes[x][y].twin.lines axes[x][y].patches = axes[x][y].twin.patches axes[x][y].collections = axes[x][y].twin.collections axes[x][y].containers = axes[x][y].twin.containers make_diagonal(axes[x][y]) axes[x][y].position = 'diagonal' axes[x][y].twin.xaxis.set_major_locator( MaxNLocator(3, prune='both')) else: if position[x][y] == 1: axes[x][y].position = 'upper' elif position[x][y] == -1: axes[x][y].position = 'lower' axes[x][y].yaxis.set_major_locator( MaxNLocator(3, prune='both')) axes[x][y].xaxis.set_major_locator( MaxNLocator(3, prune='both')) for y, ax in axes.bfill(axis=1).iloc[:, 0].dropna().iteritems(): ax.set_ylabel(tex[y]) for x, ax in axes.ffill(axis=0).iloc[-1, :].dropna().iteritems(): ax.set_xlabel(tex[x]) # left and right ticks and labels for y, ax in axes.iterrows(): ax_ = ax.dropna() if len(ax_) and ticks == 'inner': for i, a in enumerate(ax_): if i == 0: # first column if a.position == 'diagonal' and len(ax_) == 1: a.tick_params('y', left=False, labelleft=False) else: a.tick_params('y', left=True, labelleft=True) elif a.position == 'diagonal': # not first column tl = a.yaxis.majorTicks[0].tick1line.get_markersize() a.tick_params('y', direction='out', length=tl/2, left=True, labelleft=False) else: # not diagonal and not first column a.tick_params('y', direction='inout', left=True, labelleft=False) elif len(ax_) and ticks == 'outer': # no inner ticks for a in ax_[1:]: a.tick_params('y', left=False, labelleft=False) elif len(ax_) and ticks is None: # no ticks at all for a in ax_: a.tick_params('y', left=False, right=False, labelleft=False, labelright=False) else: raise ValueError( "ticks=%s was requested, but ticks can only be one of " "['outer', 'inner', None]." % ticks) # bottom and top ticks and labels for x, ax in axes.iteritems(): ax_ = ax.dropna() if len(ax_): if ticks == 'inner': for i, a in enumerate(ax_): if i == len(ax_) - 1: # bottom row a.tick_params('x', bottom=True, labelbottom=True) else: # not bottom row a.tick_params('x', direction='inout', bottom=True, labelbottom=False) if a.position == 'diagonal': a.twin.tick_params('x', direction='inout', bottom=True, labelbottom=False) elif ticks == 'outer': # no inner ticks for a in ax_[:-1]: a.tick_params('x', bottom=False, labelbottom=False) elif ticks is None: # no ticks at all for a in ax_: a.tick_params('x', bottom=False, top=False, labelbottom=False, labeltop=False) else: raise ValueError( "ticks=%s was requested, but ticks can only be one of " "['outer', 'inner', None]." % ticks) return fig, axes