Esempio n. 1
0
def outside_legend(
    fig: plt.Figure,
    ax: plt.Axes,
    legend_padding: float = 0.04,
    legend_height: float = 0.3,
    **kwargs,
) -> None:
    """Plots a legend immediately above the figure axes.

    Args:
        fig: The figure to plot the legend on.
        ax: The axes to plot the legend above.
        legend_padding: Padding between top of axes and bottom of legend, in inches.
        legend_height: Height of legend, in inches.
        **kwargs: Passed through to `fig.legend`.
    """
    _width, height = fig.get_size_inches()
    pos = ax.get_position()
    legend_left = pos.x0
    legend_right = pos.x0 + pos.width
    legend_width = legend_right - legend_left
    legend_bottom = pos.y0 + pos.height + legend_padding / height
    legend_height = legend_height / height
    bbox = (legend_left, legend_bottom, legend_width, legend_height)
    fig.legend(
        loc="lower left",
        bbox_to_anchor=bbox,
        bbox_transform=fig.transFigure,
        mode="expand",
        **kwargs,
    )
Esempio n. 2
0
def plot_2H(ax: plt.Axes, pids=(3, 10)):
    pos = ax.get_position()
    ax.set_position((pos.x0, pos.y0 + 0.08, pos.width, pos.height))
    ax.set_frame_on(False)
    ax.set_xticks([])
    ax.set_xlabel('Human choice/model prediction', labelpad=10)
    ax.set_yticks([])
    ax.set_ylabel('True Structure', labelpad=16)
    fig = plt.gcf()
    gs = gridspec.GridSpecFromSubplotSpec(len(pids),
                                          2,
                                          subplot_spec=ax,
                                          wspace=0.1,
                                          hspace=0.1)
    mpl.rcParams['font.size'] -= 2
    mpl.rcParams['xtick.labelsize'] -= 2
    mpl.rcParams['ytick.labelsize'] -= 2
    for j in range(len(pids)):
        ax = plt.Subplot(fig, gs[0, j])
        data = DataExp1(DataExp1.pids[pids[j] - 1])
        data.plot_confusion_matrix(ax)
        ax.set_title(f'\uf007$\#${pids[j]}', size=7, fontproperties=fp)
        ax.set_xticks([])
        ax.set_xlabel('')
        if j == 0:
            ax.set_ylabel('Human')
        else:
            ax.set_yticks([])
            ax.set_ylabel('')
            # ax.yaxis.set_label_coords(-0.4, 0.6 - i * 0.2)
        fig.add_subplot(ax)

        ax = plt.Subplot(fig, gs[1, j])
        model = data.build_model(models.ChoiceModel4Param)
        model.plot_confusion_matrix(
            data.cross_validate(models.ChoiceModel4Param), ax)
        ax.set_xlabel('')
        if j == 0:
            ax.set_ylabel('Model')
        else:
            ax.set_yticks([])
            ax.set_ylabel('')
        fig.add_subplot(ax)
    mpl.rcParams['font.size'] += 2
    mpl.rcParams['xtick.labelsize'] += 2
    mpl.rcParams['ytick.labelsize'] += 2
Esempio n. 3
0
def break_axis(
    amin,
    amax=None,
    xy='x',
    ax: plt.Axes = None,
    fun_draw: Callable = None,
    margin=0.05,
) -> (plt.Axes, plt.Axes):
    """
    :param amin: data coordinate to start breaking from
    :param amax: data coordinate to end breaking at
    :param xy: 'x' or 'y'
    :param fun_draw: if not None, fun_draw(ax1) and fun_draw(ax2) will
    be run to recreate ax. Use the same function as that was called for
    with ax. Use, e.g., fun_draw=lambda ax: ax.plot(x, y)
    :return: axs: a list of axes created
    """

    if amax is None:
        amax = amin

    if ax is None:
        ax = plt.gca()

    if xy == 'x':
        rect = ax.get_position().bounds
        lim = ax.get_xlim()
        prop_min = (amin - lim[0]) / (lim[1] - lim[0])
        prop_max = (amax - lim[0]) / (lim[1] - lim[0])
        rect1 = np.array([rect[0], rect[1], rect[2] * prop_min, rect[3]])
        rect2 = [
            rect[0] + rect[2] * prop_max, rect[1], rect[2] * (1 - prop_max),
            rect[3]
        ]

        fig = ax.figure  # type: plt.Figure
        ax1 = fig.add_axes(plt.Axes(fig=fig, rect=rect1))
        ax1.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax1)
        ax1.set_xticks(ax.get_xticks())
        ax1.set_xlim(lim[0], amin)
        ax1.spines['right'].set_visible(False)

        ax2 = fig.add_axes(plt.Axes(fig=fig, rect=rect2))
        ax2.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax2)
        ax2.set_xticks(ax.get_xticks())
        ax2.set_xlim(amax, lim[1])
        ax2.spines['left'].set_visible(False)
        ax2.set_yticks([])

        ax.set_visible(False)
        # plt.show()  # CHECKED
        axs = [ax1, ax2]

    elif xy == 'y':
        rect = ax.get_position().bounds
        lim = ax.get_ylim()
        prop_all = ((amin - lim[0]) + (lim[1] - amax)) / (1 - margin)
        prop_min = (amin - lim[0]) / prop_all
        prop_max = (lim[1] - amax) / prop_all
        rect1 = np.array([rect[0], rect[1], rect[2], rect[3] * prop_min])
        rect2 = [
            rect[0], rect[1] + rect[3] * (1 - prop_max), rect[2],
            rect[3] * (1 - prop_max)
        ]

        fig = ax.figure  # type: plt.Figure
        ax1 = fig.add_axes(plt.Axes(fig=fig, rect=rect1))
        ax1.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax1)
        ax1.set_yticks(ax.get_yticks())
        ax1.set_ylim(lim[0], amin)
        ax1.spines['top'].set_visible(False)

        ax2 = fig.add_axes(plt.Axes(fig=fig, rect=rect2))
        ax2.update_from(ax)
        if fun_draw is not None:
            fun_draw(ax2)
        ax2.set_yticks(ax.get_yticks())
        ax2.set_ylim(amax, lim[1])
        ax2.spines['bottom'].set_visible(False)
        ax2.set_xticks([])

        ax.set_visible(False)
        # plt.show()  # CHECKED
        axs = [ax1, ax2]

    else:
        raise ValueError()

    return axs
Esempio n. 4
0
def animation_model_states(models: Union[Dict[str, MALA], MALA],
                           function_array: np.ndarray,
                           function_coords: tuple,
                           n_start: int = 0,
                           n_end: int = 100,
                           fig: plt.Figure = None,
                           ax: plt.Axes = None,
                           interval: float = 200,
                           colors=('b', 'k', 'c', 'y', 'g', 'r'),
                           plot_covariance=True,
                           **kwargs):
    """
    Make an animation, displaying one after the other the states of a model.

    Parameters
    ----------
    models: Union[Dict[str, HastingMetropolis], HastingMetropolis]
        The dictionary of models or a single model.
    function_array
        The array displayed in background. Should represent the target pdf we're trying to sample from.
    function_coords
        This array displayed in background is defined on a grid, function_coords: (x_start, x_end, y_start, y_end).
    n_start
        When to start displaying the steps
    n_end
        When to stop displaying the steps
    fig
        Matplotlib Figure or None.
    ax
        Matplotlib Axes or None.
    interval
        The interval between each image, in milliseconds.
    colors
        The colors of each models.
    plot_covariance
        Whether to plot the covariance matrices, for models with drifts.
    kwargs
        Other kwargs passed to Animation.

    Returns
    -------
    A Animation object.
    """
    models = _check_models_dims(models,
                                dims=2,
                                n_start=n_start,
                                n_end=n_end,
                                colors=colors,
                                function_coords=function_coords)
    if fig is None:
        fig = plt.figure(figsize=(7, 7))
    if ax is None:
        ax = plt.gca()

    ax.imshow(function_array,
              extent=function_coords,
              cmap='coolwarm',
              origin='lower')

    # Just turning the history of each state in an array
    models_states = {k: np.array(models[k].history['state']) for k in models}
    lines = {}
    for c, k in zip(colors, models):
        lines[k], = ax.plot([], [], "*%s" % c)

    legend_elements = [
        Line2D([0], [0], color=c, markerfacecolor=c, marker='*', label=k)
        for c, k in zip(colors, models)
    ]

    # put legend outside the plot
    chartBox = ax.get_position()
    ax.set_position([
        chartBox.x0, chartBox.y0, chartBox.width * 0.8, chartBox.height * 0.8
    ])
    ax.legend(handles=legend_elements,
              loc='upper center',
              bbox_to_anchor=(1.2, 0.8),
              shadow=True,
              ncol=1)
    n_covariance_models = len([model for model in models.values()])

    def update_frame(iteration):
        covariance_patches = []
        # Remove previous ellipse patches if present
        if iteration > 1 and plot_covariance:
            for _ in range(n_covariance_models):
                ax.patches.pop(0)

        for i, k in enumerate(models):
            model = models[k]
            model_state = models_states[k]
            x_data, y_data = model_state[n_start:n_start + iteration,
                                         0], model_state[n_start:n_start +
                                                         iteration, 1]
            lines[k].set_data(x_data, y_data)

            if plot_covariance:
                if hasattr(model, 'gamma'):
                    # Annoying little case for samplers which do not update the gamma param
                    len_gamma = len(model.params_history['gamma'])
                    if len_gamma - 1 < iteration:
                        cov = model.params_history['gamma'][-1]
                    else:
                        cov = model.params_history['gamma'][iteration]
                else:
                    cov = model.gamma
                cov = cov * model.params_history['sigma'][n_start +
                                                          iteration]**2
                new_patch = _plot_ellipse_covariance(model_state[n_start +
                                                                 iteration],
                                                     cov=cov,
                                                     ax=ax,
                                                     edgecolor=colors[i],
                                                     lw=1.2)
                covariance_patches += new_patch
        return list(lines.values()) + covariance_patches + [
            ax.set_title("Step {}".format(n_start + iteration))
        ]

    animation = FuncAnimation(fig,
                              update_frame,
                              frames=n_end - n_start,
                              blit=True,
                              interval=interval,
                              **kwargs)

    return animation