예제 #1
0
def score_plot(data, pipelines=None):
    """
    In:
        data: output of Results.to_dataframe()
        pipelines: list of string|None, pipelines to include in this plot
    Out:
        ax: pyplot Axes reference
    """
    data = collapse_session_scores(data)
    data["dataset"] = data["dataset"].apply(_simplify_names)
    if pipelines is not None:
        data = data[data.pipeline.isin(pipelines)]
    fig = plt.figure(figsize=(8.5, 11))
    ax = fig.add_subplot(111)
    # markers = ['o', '8', 's', 'p', '+', 'x', 'D', 'd', '>', '<', '^']
    sea.stripplot(
        data=data,
        y="dataset",
        x="score",
        jitter=0.15,
        palette=PIPELINE_PALETTE,
        hue="pipeline",
        dodge=True,
        ax=ax,
        alpha=0.7,
    )
    ax.set_xlim([0, 1])
    ax.axvline(0.5, linestyle="--", color="k", linewidth=2)
    ax.set_title("Scores per dataset and algorithm")
    handles, labels = ax.get_legend_handles_labels()
    color_dict = {lb: h.get_facecolor()[0] for lb, h in zip(labels, handles)}
    plt.tight_layout()
    return fig, color_dict
예제 #2
0
def score_plot(data, pipelines=None):
    '''
    In:
        data: output of Results.to_dataframe()
        pipelines: list of string|None, pipelines to include in this plot
    Out:
        ax: pyplot Axes reference
    '''
    data = collapse_session_scores(data)
    data['dataset'] = data['dataset'].apply(_simplify_names)
    if pipelines is not None:
        data = data[data.pipeline.isin(pipelines)]
    fig = plt.figure(figsize=(8.5, 11))
    ax = fig.add_subplot(111)
    sea.stripplot(data=data,
                  y="dataset",
                  x="score",
                  jitter=True,
                  dodge=True,
                  hue="pipeline",
                  zorder=1,
                  alpha=0.7,
                  ax=ax)
    # sea.pointplot(data=data, y="score", x="dataset",
    #               hue="pipeline", zorder=1, ax=ax)
    # sometimes the score is lower than 0.5 (...not sure how to deal with that)
    ax.set_xlim([0, 1])
    ax.set_title('Scores per dataset and algorithm')
    handles, labels = ax.get_legend_handles_labels()
    color_dict = {l: h.get_facecolor()[0] for l, h in zip(labels, handles)}
    return fig, color_dict
예제 #3
0
def paired_plot(data, alg1, alg2):
    """Generate a figure with a paired plot

    Parameters
    ----------
    data: DataFrame
        dataframe obtained from evaluation
    alg1: str
        Name of a member of column data.pipeline
    alg2: str
        Name of a member of column data.pipeline

    Returns
    -------
    fig: Figure
        Pyplot handle
    """
    data = collapse_session_scores(data)
    data = data[data.pipeline.isin([alg1, alg2])]
    data = data.pivot_table(values="score",
                            columns="pipeline",
                            index=["subject", "dataset"])
    data = data.reset_index()
    fig = plt.figure(figsize=(11, 8.5))
    ax = fig.add_subplot(111)
    data.plot.scatter(alg1, alg2, ax=ax)
    ax.plot([0, 1], [0, 1], ls="--", c="k")
    ax.set_xlim([0.5, 1])
    ax.set_ylim([0.5, 1])
    return fig
예제 #4
0
def paired_plot(data, alg1, alg2):
    """
    returns figure with an axis that has a paired plot on it
    Data: dataframe from Results
    alg1: name of a member of column data.pipeline
    alg2: name of a member of column data.pipeline

    """
    data = collapse_session_scores(data)
    data = data[data.pipeline.isin([alg1, alg2])]
    data = data.pivot_table(values="score",
                            columns="pipeline",
                            index=["subject", "dataset"])
    data = data.reset_index()
    fig = plt.figure(figsize=(11, 8.5))
    ax = fig.add_subplot(111)
    data.plot.scatter(alg1, alg2, ax=ax)
    ax.plot([0, 1], [0, 1], ls="--", c="k")
    ax.set_xlim([0.5, 1])
    ax.set_ylim([0.5, 1])
    return fig
예제 #5
0
파일: plotting.py 프로젝트: TateXu/moabb
def paired_plot(data, alg1, alg2, fig_ax=None, P=None, T=None, task=None):
    '''
    returns figure with an axis that has a paired plot on it
    Data: dataframe from Results
    alg1: name of a member of column data.pipeline
    alg2: name of a member of column data.pipeline

    '''
    import pdb
    if P is not None:
        p_values = P.T['RE'].values[::-1][1:]
        t_values = T.T['RE'].values[::-1][1:]
    data = collapse_session_scores(data)
    nr_feat = len(alg1)
    if len(alg1) == 1:
        data = data[data.pipeline.isin([alg1, alg2])]
    data = data.pivot_table(values='score', columns='pipeline',
                            index=['subject', 'dataset'])
    data = data.reset_index(level=['dataset'])
    if fig_ax is None:
        fig = plt.figure(figsize=(11, 8.5))
        ax = fig.add_subplot(111)
    else:
        ax = fig_ax

    # data.plot.scatter(alg1, alg2, ax=ax)
    # sea.set(rc={'figure.figsize': (nr_feat * 4, 0.1)}, font="Times New Roman")
    fig_ = sea.pairplot(data, x_vars=alg1, y_vars=alg2, hue='dataset', size=6,
                        palette=sea.color_palette("deep", 8),
                        plot_kws={"s": 180},
                        markers=["v", "^", "<", ">", "o", "+", "s", "D"])

    for ind in range(nr_feat):
        ax = fig_.axes[0, ind]
        a_xl = ax.get_xlabel()
        a_yl = ax.get_ylabel()

        ax.set_xlabel(a_xl, fontsize=35)
        ax.set_ylabel(a_yl, fontsize=35)
        ax.set_xlim([0.5, 1])
        ax.set_ylim([0.5, 1])
        ax.set_xticklabels(labels=[0.5, 0.6, 0.7, 0.8, 0.9, 1], fontsize=30)
        ax.set_yticklabels(labels=[0.5, 0.6, 0.7, 0.8, 0.9, 1], fontsize=30)
        ax.plot([0, 1], [0, 1], ls='--', c='k')
        if ind == 3:
            ax.set_facecolor((0.9176470588235294, 0.9176470588235294,
                              0.9490196078431372, 1.0))
        txt = 't={:.2f}\np={:1.0e}'.format(t_values[ind], p_values[ind])

        if p_values[ind] < 0.05 and t_values[ind] >= 0:
            ax.text(0.51, 0.9, txt, fontsize=35, fontweight='black', color='green')
        elif p_values[ind] < 0.05 and t_values[ind] < 0:
            ax.text(0.51, 0.9, txt, fontsize=35, fontweight='black', color='red')
        else:
            ax.text(0.51, 0.9, txt, fontsize=35)
    mpl.rc('font', family='serif', serif='Times New Roman')
    handles = fig_._legend_data.values()
    labels = fig_._legend_data.keys()
    fig_.fig.legend(handles=handles, labels=labels, loc='upper center', ncol=4,
                    fontsize=25)
    fig_.fig.subplots_adjust(top=0.8, bottom=0.08)
    mpl.rcParams['pdf.fonttype'] = 42
    mpl.rcParams['ps.fonttype'] = 42
    # fig_.savefig('Figure/color_pair_' + task + '.png')
    fig_.savefig('Figure/color_pair_' + task + '.pdf')

    if fig_ax is None:
        return fig
    else:
        return ax