def test_make_2d_axes_behaviour(): np.random.seed(0) def calc_n(axes): """Compute the number of upper, lower and diagonal plots.""" n = {'upper': 0, 'lower': 0, 'diagonal': 0} for y, row in axes.iterrows(): for x, ax in row.iteritems(): if ax is not None: n[ax.position] += 1 return n # Check for only paramnames_x paramnames_x = ['A', 'B', 'C', 'D'] nx = len(paramnames_x) for upper in [True, False]: for lower in [True, False]: for diagonal in [True, False]: fig, axes = make_2d_axes(paramnames_x, upper=upper, lower=lower, diagonal=diagonal) ns = calc_n(axes) assert (ns['upper'] == upper * nx * (nx - 1) // 2) assert (ns['lower'] == lower * nx * (nx - 1) // 2) assert (ns['diagonal'] == diagonal * nx) plt.close("all") for paramnames_y in [['A', 'B', 'C', 'D'], ['A', 'C', 'B', 'D'], ['D', 'C', 'B', 'A'], ['C', 'B', 'A'], ['E', 'F', 'G', 'H'], ['A', 'B', 'E', 'F'], ['B', 'E', 'A', 'F'], ['B', 'F', 'A', 'H', 'G'], ['B', 'A', 'H', 'G']]: params = [paramnames_x, paramnames_y] all_params = paramnames_x + paramnames_y nu, nl, nd = 0, 0, 0 for x in paramnames_x: for y in paramnames_y: if x == y: nd += 1 elif all_params.index(x) < all_params.index(y): nl += 1 elif all_params.index(x) > all_params.index(y): nu += 1 for upper in [True, False]: for lower in [True, False]: for diagonal in [True, False]: fig, axes = make_2d_axes(params, upper=upper, lower=lower, diagonal=diagonal) ns = calc_n(axes) assert (ns['upper'] == upper * nu) assert (ns['lower'] == lower * nl) assert (ns['diagonal'] == diagonal * nd) plt.close("all")
def draw(self, labels, tex={}): """Draw a new triangular grid for list of parameters labels. Parameters ---------- labels: list(str) labels for the triangular grid. """ # Remove any existing axes for y, row in self.ax.iterrows(): for x, ax in row.iteritems(): if ax is not None: if x == y: self.fig.delaxes(ax.twin) self.fig.delaxes(ax) # Set up the axes _, self.ax = make_2d_axes(labels, upper=False, tex=tex, fig=self.fig, subplot_spec=self.gridspec) # Plot no points points. for y, row in self.ax.iterrows(): for x, ax in row.iteritems(): if ax is not None: if x == y: ax.twin.plot([None], [None], 'k-') else: ax.plot([None], [None], 'k.')
def test_2d_axlines_axspans(axesparams, params, values, upper): values = np.array(values) line_kwargs = dict(c='k', ls='--', lw=0.5) span_kwargs = dict(c='k', alpha=0.5) fig, axes = make_2d_axes(axesparams, upper=upper) axes.axlines(params, values, **line_kwargs) axes.axspans(params, values - 0.1, values + 0.1, **span_kwargs) plt.close("all")
def test_2d_axlines_axspans_error(params, values): values = np.array(values) axesparams = ['A', 'B', 'C', 'D'] fig, axes = make_2d_axes(axesparams) with pytest.raises(ValueError): axes.axlines(params, values) with pytest.raises(ValueError): axes.axspans(params, values - 0.1, values + 0.1) plt.close("all")
def test_make_2d_axes_ticks(upper, ticks): xticks = [0.1, 0.4, 0.7] yticks = [0.2, 0.5, 0.8] paramnames = ["x0", "x1", "x2", "x3"] for k in paramnames: fig, axes = make_2d_axes(paramnames, upper=upper, ticks=ticks) axes[k][k].set_xticks(xticks) axes[k][k].set_yticks(yticks) for i, row in axes.iterrows(): for j, ax in row.iteritems(): if ax is None: break if i == k: assert np.array_equal(yticks, ax.get_yticks()) else: assert not np.array_equal(yticks, ax.get_yticks()) if j == k: assert np.array_equal(xticks, ax.get_xticks()) else: assert not np.array_equal(xticks, ax.get_xticks()) plt.close("all") with pytest.raises(ValueError): make_2d_axes(paramnames, upper=upper, ticks='spam')
def test_2d_axes_limits(): numpy.random.seed(0) paramnames = ['A', 'B', 'C', 'D'] fig, axes = make_2d_axes(paramnames) for x in paramnames: for y in paramnames: a, b, c, d = numpy.random.rand(4) axes[x][y].set_xlim(a, b) for z in paramnames: assert(axes[x][z].get_xlim() == (a, b)) assert(axes[z][x].get_ylim() == (a, b)) axes[x][y].set_ylim(c, d) for z in paramnames: assert(axes[y][z].get_xlim() == (c, d)) assert(axes[z][y].get_ylim() == (c, d))
def test_make_2d_axes_inputs_outputs(): paramnames_x = ['A', 'B', 'C', 'D'] paramnames_y = ['B', 'A', 'D', 'E'] # 2D axes fig, axes = make_2d_axes([paramnames_x, paramnames_y]) assert(isinstance(fig, Figure)) assert(isinstance(axes, DataFrame)) assert_array_equal(axes.index, paramnames_y) assert_array_equal(axes.columns, paramnames_x) # Axes labels for p, ax in axes.iloc[:, 0].iteritems(): assert(ax.get_ylabel() == p) for p, ax in axes.iloc[-1].iteritems(): assert(ax.get_xlabel() == p) for ax in axes.iloc[:-1, 1:].values.flatten(): assert(ax.get_xlabel() == '') assert(ax.get_ylabel() == '') # Check fig argument fig0 = plt.figure() fig, axes = make_2d_axes(paramnames_x) assert(fig is not fig0) fig, axes = make_2d_axes(paramnames_x, fig=fig0) assert(fig is fig0) plt.close("all") # Check gridspec argument grid = gs.GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[3, 1]) g00 = grid[0, 0] fig, axes = make_2d_axes(paramnames_x, subplot_spec=g00) assert(g00 is axes.iloc[0, 0].get_subplotspec().get_topmost_subplotspec()) # Check unexpected kwargs with pytest.raises(TypeError): make_2d_axes(paramnames_x, foo='bar')
def plot_2d(self, axes, *args, **kwargs): """Create an array of 2D plots. To avoid intefering with y-axis sharing, one-dimensional plots are created on a separate axis, which is monkey-patched onto the argument ax as the attribute ax.twin. Parameters ---------- axes: plotting axes Can be: - list(str) if the x and y axes are the same - [list(str),list(str)] if the x and y axes are different - pandas.DataFrame(matplotlib.axes.Axes) If a pandas.DataFrame is provided as an existing set of axes, then this is used for creating the plot. Otherwise a new set of axes are created using the list or lists of strings. types: dict, optional What type (or types) of plots to produce. Takes the keys 'diagonal' for the 1D plots and 'lower' and 'upper' for the 2D plots. The options for 'diagonal are: - 'kde' - 'hist' - 'astropyhist' The options for 'lower' and 'upper' are: - 'kde' - 'scatter' - 'hist' - 'fastkde' Default: {'diagonal': 'kde', 'lower': 'kde', 'upper':'scatter'} diagonal_kwargs, lower_kwargs, upper_kwargs: dict, optional kwargs for the diagonal (1D)/lower or upper (2D) plots. This is useful when there is a conflict of kwargs for different types of plots. Note that any kwargs directly passed to plot_2d will overwrite any kwarg with the same key passed to <sub>_kwargs. Default: {} Returns ------- fig: matplotlib.figure.Figure New or original (if supplied) figure object axes: pandas.DataFrame of matplotlib.axes.Axes Pandas array of axes objects """ default_types = {'diagonal': 'kde', 'lower': 'kde', 'upper': 'scatter'} types = kwargs.pop('types', default_types) local_kwargs = { pos: kwargs.pop('%s_kwargs' % pos, {}) for pos in default_types } for pos in local_kwargs: local_kwargs[pos].update(kwargs) if not isinstance(axes, pandas.DataFrame): fig, axes = make_2d_axes(axes, tex=self.tex, upper=('upper' in types), lower=('lower' in types), diagonal=('diagonal' in types)) else: fig = axes.bfill().to_numpy().flatten()[0].figure for y, row in axes.iterrows(): for x, ax in row.iteritems(): if ax is not None: pos = ax.position ax_ = ax.twin if x == y else ax plot_type = types.get(pos, None) lkwargs = local_kwargs.get(pos, {}) self.plot(ax_, x, y, plot_type=plot_type, *args, **lkwargs) return fig, axes
def __init__(self, fig, gridspec): super().__init__(fig, gridspec) self.fig.delaxes(self.ax) _, self.ax = make_2d_axes([], fig=self.fig, subplot_spec=self.gridspec)
def plot_2d(self, axes, *args, **kwargs): """Create an array of 2D plots. To avoid intefering with y-axis sharing, one-dimensional plots are created on a separate axis, which is monkey-patched onto the argument ax as the attribute ax.twin. Parameters ---------- axes: plotting axes Can be: - list(str) if the x and y axes are the same - [list(str),list(str)] if the x and y axes are different - pandas.DataFrame(matplotlib.axes.Axes) If a pandas.DataFrame is provided as an existing set of axes, then this is used for creating the plot. Otherwise a new set of axes are created using the list or lists of strings. types: dict, optional What type (or types) of plots to produce. Takes the keys 'diagonal' for the 1D plots and 'lower' and 'upper' for the 2D plots. The options for 'diagonal are: - 'kde' - 'hist' - 'astropyhist' The options for 'lower' and 'upper' are: - 'kde' - 'scatter' - 'hist' - 'fastkde' Default: {'diagonal': 'kde', 'lower': 'kde', 'upper':'scatter'} diagonal_kwargs, lower_kwargs, upper_kwargs: dict, optional kwargs for the diagonal (1D)/lower or upper (2D) plots. This is useful when there is a conflict of kwargs for different types of plots. Note that any kwargs directly passed to plot_2d will overwrite any kwarg with the same key passed to <sub>_kwargs. Default: {} Returns ------- fig: matplotlib.figure.Figure New or original (if supplied) figure object axes: pandas.DataFrame of matplotlib.axes.Axes Pandas array of axes objects """ default_types = {'diagonal': 'kde', 'lower': 'kde', 'upper': 'scatter'} types = kwargs.pop('types', default_types) diagonal = kwargs.pop('diagonal', True) if isinstance(types, list) or isinstance(types, str): from warnings import warn warn( "MCMCSamples.plot_2d's argument 'types' might stop accepting " "str or list(str) as input in the future. It takes a " "dictionary as input, now, with keys 'diagonal' for the 1D " "plots and 'lower' and 'upper' for the 2D plots. 'diagonal' " "accepts the values 'kde' or 'hist' and both 'lower' and " "'upper' accept the values 'kde' or 'scatter'. " "Default: {'diagonal': 'kde', 'lower': 'kde'}.", FutureWarning) if isinstance(types, str): types = {'lower': types} if diagonal: types['diagonal'] = types['lower'] elif isinstance(types, list): types = {'lower': types[0], 'upper': types[-1]} if diagonal: types['diagonal'] = types['lower'] local_kwargs = { pos: kwargs.pop('%s_kwargs' % pos, {}) for pos in default_types } for pos in local_kwargs: local_kwargs[pos].update(kwargs) if not isinstance(axes, pandas.DataFrame): fig, axes = make_2d_axes(axes, tex=self.tex, upper=('upper' in types), lower=('lower' in types), diagonal=('diagonal' in types)) else: fig = axes.values[~axes.isna()][0].figure for y, row in axes.iterrows(): for x, ax in row.iteritems(): if ax is not None: pos = ax.position ax_ = ax.twin if x == y else ax plot_type = types.get(pos, None) lkwargs = local_kwargs.get(pos, {}) self.plot(ax_, x, y, plot_type=plot_type, *args, **lkwargs) return fig, axes