def outside_legend( fig: plt.Figure, ax: plt.Axes, legend_padding: float = 0.04, legend_height: float = 0.3, **kwargs, ) -> None: """Plots a legend immediately above the figure axes. Args: fig: The figure to plot the legend on. ax: The axes to plot the legend above. legend_padding: Padding between top of axes and bottom of legend, in inches. legend_height: Height of legend, in inches. **kwargs: Passed through to `fig.legend`. """ _width, height = fig.get_size_inches() pos = ax.get_position() legend_left = pos.x0 legend_right = pos.x0 + pos.width legend_width = legend_right - legend_left legend_bottom = pos.y0 + pos.height + legend_padding / height legend_height = legend_height / height bbox = (legend_left, legend_bottom, legend_width, legend_height) fig.legend( loc="lower left", bbox_to_anchor=bbox, bbox_transform=fig.transFigure, mode="expand", **kwargs, )
def plot_2H(ax: plt.Axes, pids=(3, 10)): pos = ax.get_position() ax.set_position((pos.x0, pos.y0 + 0.08, pos.width, pos.height)) ax.set_frame_on(False) ax.set_xticks([]) ax.set_xlabel('Human choice/model prediction', labelpad=10) ax.set_yticks([]) ax.set_ylabel('True Structure', labelpad=16) fig = plt.gcf() gs = gridspec.GridSpecFromSubplotSpec(len(pids), 2, subplot_spec=ax, wspace=0.1, hspace=0.1) mpl.rcParams['font.size'] -= 2 mpl.rcParams['xtick.labelsize'] -= 2 mpl.rcParams['ytick.labelsize'] -= 2 for j in range(len(pids)): ax = plt.Subplot(fig, gs[0, j]) data = DataExp1(DataExp1.pids[pids[j] - 1]) data.plot_confusion_matrix(ax) ax.set_title(f'\uf007$\#${pids[j]}', size=7, fontproperties=fp) ax.set_xticks([]) ax.set_xlabel('') if j == 0: ax.set_ylabel('Human') else: ax.set_yticks([]) ax.set_ylabel('') # ax.yaxis.set_label_coords(-0.4, 0.6 - i * 0.2) fig.add_subplot(ax) ax = plt.Subplot(fig, gs[1, j]) model = data.build_model(models.ChoiceModel4Param) model.plot_confusion_matrix( data.cross_validate(models.ChoiceModel4Param), ax) ax.set_xlabel('') if j == 0: ax.set_ylabel('Model') else: ax.set_yticks([]) ax.set_ylabel('') fig.add_subplot(ax) mpl.rcParams['font.size'] += 2 mpl.rcParams['xtick.labelsize'] += 2 mpl.rcParams['ytick.labelsize'] += 2
def break_axis( amin, amax=None, xy='x', ax: plt.Axes = None, fun_draw: Callable = None, margin=0.05, ) -> (plt.Axes, plt.Axes): """ :param amin: data coordinate to start breaking from :param amax: data coordinate to end breaking at :param xy: 'x' or 'y' :param fun_draw: if not None, fun_draw(ax1) and fun_draw(ax2) will be run to recreate ax. Use the same function as that was called for with ax. Use, e.g., fun_draw=lambda ax: ax.plot(x, y) :return: axs: a list of axes created """ if amax is None: amax = amin if ax is None: ax = plt.gca() if xy == 'x': rect = ax.get_position().bounds lim = ax.get_xlim() prop_min = (amin - lim[0]) / (lim[1] - lim[0]) prop_max = (amax - lim[0]) / (lim[1] - lim[0]) rect1 = np.array([rect[0], rect[1], rect[2] * prop_min, rect[3]]) rect2 = [ rect[0] + rect[2] * prop_max, rect[1], rect[2] * (1 - prop_max), rect[3] ] fig = ax.figure # type: plt.Figure ax1 = fig.add_axes(plt.Axes(fig=fig, rect=rect1)) ax1.update_from(ax) if fun_draw is not None: fun_draw(ax1) ax1.set_xticks(ax.get_xticks()) ax1.set_xlim(lim[0], amin) ax1.spines['right'].set_visible(False) ax2 = fig.add_axes(plt.Axes(fig=fig, rect=rect2)) ax2.update_from(ax) if fun_draw is not None: fun_draw(ax2) ax2.set_xticks(ax.get_xticks()) ax2.set_xlim(amax, lim[1]) ax2.spines['left'].set_visible(False) ax2.set_yticks([]) ax.set_visible(False) # plt.show() # CHECKED axs = [ax1, ax2] elif xy == 'y': rect = ax.get_position().bounds lim = ax.get_ylim() prop_all = ((amin - lim[0]) + (lim[1] - amax)) / (1 - margin) prop_min = (amin - lim[0]) / prop_all prop_max = (lim[1] - amax) / prop_all rect1 = np.array([rect[0], rect[1], rect[2], rect[3] * prop_min]) rect2 = [ rect[0], rect[1] + rect[3] * (1 - prop_max), rect[2], rect[3] * (1 - prop_max) ] fig = ax.figure # type: plt.Figure ax1 = fig.add_axes(plt.Axes(fig=fig, rect=rect1)) ax1.update_from(ax) if fun_draw is not None: fun_draw(ax1) ax1.set_yticks(ax.get_yticks()) ax1.set_ylim(lim[0], amin) ax1.spines['top'].set_visible(False) ax2 = fig.add_axes(plt.Axes(fig=fig, rect=rect2)) ax2.update_from(ax) if fun_draw is not None: fun_draw(ax2) ax2.set_yticks(ax.get_yticks()) ax2.set_ylim(amax, lim[1]) ax2.spines['bottom'].set_visible(False) ax2.set_xticks([]) ax.set_visible(False) # plt.show() # CHECKED axs = [ax1, ax2] else: raise ValueError() return axs
def animation_model_states(models: Union[Dict[str, MALA], MALA], function_array: np.ndarray, function_coords: tuple, n_start: int = 0, n_end: int = 100, fig: plt.Figure = None, ax: plt.Axes = None, interval: float = 200, colors=('b', 'k', 'c', 'y', 'g', 'r'), plot_covariance=True, **kwargs): """ Make an animation, displaying one after the other the states of a model. Parameters ---------- models: Union[Dict[str, HastingMetropolis], HastingMetropolis] The dictionary of models or a single model. function_array The array displayed in background. Should represent the target pdf we're trying to sample from. function_coords This array displayed in background is defined on a grid, function_coords: (x_start, x_end, y_start, y_end). n_start When to start displaying the steps n_end When to stop displaying the steps fig Matplotlib Figure or None. ax Matplotlib Axes or None. interval The interval between each image, in milliseconds. colors The colors of each models. plot_covariance Whether to plot the covariance matrices, for models with drifts. kwargs Other kwargs passed to Animation. Returns ------- A Animation object. """ models = _check_models_dims(models, dims=2, n_start=n_start, n_end=n_end, colors=colors, function_coords=function_coords) if fig is None: fig = plt.figure(figsize=(7, 7)) if ax is None: ax = plt.gca() ax.imshow(function_array, extent=function_coords, cmap='coolwarm', origin='lower') # Just turning the history of each state in an array models_states = {k: np.array(models[k].history['state']) for k in models} lines = {} for c, k in zip(colors, models): lines[k], = ax.plot([], [], "*%s" % c) legend_elements = [ Line2D([0], [0], color=c, markerfacecolor=c, marker='*', label=k) for c, k in zip(colors, models) ] # put legend outside the plot chartBox = ax.get_position() ax.set_position([ chartBox.x0, chartBox.y0, chartBox.width * 0.8, chartBox.height * 0.8 ]) ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(1.2, 0.8), shadow=True, ncol=1) n_covariance_models = len([model for model in models.values()]) def update_frame(iteration): covariance_patches = [] # Remove previous ellipse patches if present if iteration > 1 and plot_covariance: for _ in range(n_covariance_models): ax.patches.pop(0) for i, k in enumerate(models): model = models[k] model_state = models_states[k] x_data, y_data = model_state[n_start:n_start + iteration, 0], model_state[n_start:n_start + iteration, 1] lines[k].set_data(x_data, y_data) if plot_covariance: if hasattr(model, 'gamma'): # Annoying little case for samplers which do not update the gamma param len_gamma = len(model.params_history['gamma']) if len_gamma - 1 < iteration: cov = model.params_history['gamma'][-1] else: cov = model.params_history['gamma'][iteration] else: cov = model.gamma cov = cov * model.params_history['sigma'][n_start + iteration]**2 new_patch = _plot_ellipse_covariance(model_state[n_start + iteration], cov=cov, ax=ax, edgecolor=colors[i], lw=1.2) covariance_patches += new_patch return list(lines.values()) + covariance_patches + [ ax.set_title("Step {}".format(n_start + iteration)) ] animation = FuncAnimation(fig, update_frame, frames=n_end - n_start, blit=True, interval=interval, **kwargs) return animation