コード例 #1
0
def plot_confusion_matrix(cm,
                          labels,
                          cmap=plt.cm.Blues,
                          ax=None,
                          colorbar_label=None,
                          discrete=False):
    if ax is None:
        ax = plt.gca()
    if discrete:
        ax, cbar = discrete_matshow(cm, cmap=cmap, ax=ax)
    else:
        h = ax.pcolormesh(cm, cmap=cmap)
        cbar = plt.colorbar(h)
    if colorbar_label:
        cbar.set_label(colorbar_label)
    tick_marks = np.arange(len(labels)) + .5
    plt.xticks(tick_marks, labels, rotation=45)
    plt.yticks(tick_marks, labels)
    plt.tight_layout()
    plt.xlabel('True label')
    plt.ylabel('Predicted label')

    sns.despine(ax=ax)

    return ax
コード例 #2
0
 def plot_lung_on_subplot(self, 
                          ax, 
                          x_index, 
                          y_index, 
                          lung,
                          skip_n = 1, # skip n - 1 points when plotting lung dots
                          lung_color = blue,
                          lung_alpha = 0.0058,
                          lung_markersize = 2,
                          rasterize_lung = True, # allows pdfs to not save every single point 
                          rasterize_order = 0,# no idea what this does
                          hide_axes = [True, True, False, False]): # [top, right, left, bottom] 
     ax.plot(lung[::skip_n, x_index], 
             lung[::skip_n, y_index], 
             '.', 
             color = lung_color, 
             alpha = lung_alpha, 
             markersize = lung_markersize)
     ax.set_rasterized(rasterize_lung)
     ax.set_rasterization_zorder(rasterize_order)
     ax.tick_params(labelsize=icra.label_fontsize)
     ax.set_xlim(min(lung[:,x_index]), max(lung[:,x_index]))
     ax.set_ylim(min(lung[:,y_index]), max(lung[:,y_index])) 
     ax.set_aspect('equal')
     ax.invert_yaxis()
     top, right, left, bottom = hide_axes
     sns.despine(ax = ax, top = top, right = right, left = left, bottom = bottom)
コード例 #3
0
def plot_results(name, results, fractions, in_figures=False):
    from matplotlib import pyplot as plt
    import seaborn.apionly as sns
    fig, ax = plt.subplots(
        figsize=[plotinfo.TEXTWIDTH_1_2_IN, plotinfo.TEXTWIDTH_1_2_IN * .75])

    ax.plot([0, 1], [0, 1], c=plotinfo.color_blue, zorder=0)
    ax.scatter(results,
               fractions,
               s=16,
               lw=1,
               edgecolor='k',
               c=plotinfo.color_red,
               zorder=1)
    ax.set_ylim(-.04, 1.04)
    ax.set_xlim(-.04, 1.04)
    ax.grid(False)
    ax.set_xlabel('Human labeled')
    ax.set_ylabel('Automatic prediction')
    sns.despine(ax=ax, trim=True)
    fig.tight_layout()

    directory = ('figures/' if in_figures else 'outputs/')
    mkdir_p(directory)
    fig.savefig('{}/{}.eps'.format(directory, name))
    fig.savefig('{}/{}.pdf'.format(directory, name))
    plt.close(fig)
コード例 #4
0
def plot_enrichment(ax, enrichment, color, title='', rad=True):
    ax.plot(range(9), np.mean(enrichment, axis=0), color=color)
    ax.plot(range(9),
            np.percentile(enrichment, 5, axis=0),
            ls='--',
            color=color)
    ax.plot(range(9),
            np.percentile(enrichment, 95, axis=0),
            ls='--',
            color=color)
    ax.fill_between(range(9),
                    np.percentile(enrichment, 5, axis=0),
                    np.percentile(enrichment, 95, axis=0),
                    facecolor=color,
                    alpha=0.5)

    sns.despine(ax=ax)
    ax.tick_params(length=3, pad=2, direction='out')
    ax.set_xlim(-0.5, 8.5)
    if rad:
        ax.set_ylim(-0.15, 0.5)
        ax.set_ylabel('Enrichment (rad)')
    else:
        ax.set_ylim(-0.15, 0.10 * 2 * np.pi)
        y_ticks = np.array(['0', '0.05', '0.10'])
        ax.set_yticks(y_ticks.astype('float') * 2 * np.pi)
        ax.set_yticklabels(y_ticks)
        ax.set_ylabel('Enrichment (fraction of belt)')
    ax.set_xlabel("Iteration ('session' #)")
    ax.set_title(title)
コード例 #5
0
def plot(rs):
    from matplotlib import pyplot as plt
    import seaborn.apionly as sns
    import plotinfo
    rh2 = np.corrcoef(rs.T[0],rs.T[1])[0,1]**2
    real,alt,_ = rs.T
    q2 = 1.-np.dot(real-alt,real-alt)/np.dot(real-real.mean(),real-real.mean())
    w = plotinfo.TEXTWIDTH_1_2_IN
    fig,ax = plt.subplots(figsize=[w, w* 0.75])
    ax.text(.2, .84, r'$R^2: {}\%$'.format(int(np.round(rh2*100))), fontsize=14)
    ax.text(.2, .70, r'$Q^2: {}\%$'.format(int(np.round(q2*100))), fontsize=14)
    ax.set_xlabel(r'${}$Human labeler 1 (fraction assigned to NET)${}$')
    ax.set_ylabel(r'${}$Human labeler 2 (fraction assigned to NET)${}$')
    ax.set_ylim(-.04,1.04)
    ax.set_xlim(-.04,1.04)


    ax.scatter(rs.T[0], rs.T[1], s=16, lw=1, edgecolor='k', c=plotinfo.color_red, zorder=1)
    ax.plot([rs.T[:2].min(), rs.T[:2].max()], [rs.T[:2].min(), rs.T[:2].max()], c=plotinfo.color_blue, zorder=0)
    fig.tight_layout()
    sns.despine(ax=ax, offset=True, trim=True)

    fig.savefig('figures/human-comparison.pdf')
    fig.savefig('figures/human-comparison.eps')
    fig.savefig('figures/human-comparison.png', dpi=1200)
コード例 #6
0
ファイル: startup.py プロジェクト: frtennis1/am221-project
def ft_ax(ax=None,
          y=1.03,
          yy=1.1,
          title=None,
          subtitle=None,
          source=None,
          add_box=False,
          left_axis=False):
    """
    Format either a desired axis or the current axis as an FT plot.
    """

    if ax is None:
        ax = plt.gca()

    ax.set_axisbelow(True)
    
    if title is not None:
        title = plt.title(title, y=y, loc='left')
    if subtitle is not None:
        plt.annotate(subtitle, xy=title.get_position(),
                xycoords='axes fraction', xytext=(0,-11), 
                 textcoords='offset points', size='large') 
    
    if source is not None:
        src = plt.annotate(source, xy=(0,0), 
             xycoords='axes fraction', xytext=(0,-35), 
             textcoords='offset points', ha='left', va='top', size='small')
    
    # axes and grid-lines
    plt.grid(axis='y', linewidth=.5)
    sns.despine(left=True)
    if not left_axis:
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position('right')
        ax.yaxis.set_label_coords(1,yy)
        ax.yaxis.get_label().set_rotation(0)
    ax.tick_params('y', length=0)
    
    plt.tight_layout()
    
    if add_box:
        ax2 = plt.axes(ax.get_position().bounds, facecolor=(1,1,1,0))
        ax2.xaxis.set_visible(False)
        ax2.yaxis.set_visible(False)
        x,y = np.array([[.01, 0.15], [y+.12, y+.12]])
        line = matplotlib.lines.Line2D(x, y, lw=6., color='k')
        ax2.add_line(line)
        line.set_clip_on(False)
    
    if add_box and source is not None:
        return (line, src)
    elif not add_box and source is not None:
        return (src,)
    elif add_box and source is None:
        return (line,)
    else:
        return []
コード例 #7
0
def plot_final_results(results):
    import numpy as np
    from matplotlib import pyplot as plt
    import seaborn.apionly as sns

    results = {k: (100 * v.q2) for k, v in results.iteritems()}

    methods = [m for m, _ in FEATURES]
    methods.sort(key=lambda m: results[m + '.origins'])
    methods.append('average.loo')

    color0 = plotinfo.color_red
    color1 = plotinfo.color_blue

    fig, ax = plt.subplots(
        figsize=[plotinfo.TEXTWIDTH_1_2_IN, plotinfo.TEXTWIDTH_1_2_IN / 1.6])

    ax.bar(np.arange(len(methods)) + .4,
           [results[m + '.origins'] for m in methods],
           width=.4,
           color=color0,
           label='Corrected')
    ax.bar(np.arange(len(methods) - 1),
           [results[m + '.origins.raw'] for m in methods[:-1]],
           width=.4,
           color=color1,
           label='Raw')
    ax.set_ylim(40, 100)
    ax.set_xlabel('Method')
    ax.set_ylabel(r'$Q^2 (\%)$')
    ax.set_xticks(np.arange(len(methods)) + .4)
    ax.set_xticklabels(methods[:-1] + ['avg'], fontsize=7)
    ax.grid(False)

    for i, m in enumerate(methods[:-1]):
        v = results[m + '.origins.raw']
        plt.text(i + .2,
                 v + 0.5,
                 '{}'.format(int(np.round(v))),
                 fontsize=8,
                 horizontalalignment='center')
    for i, m in enumerate(methods):
        v = results[m + '.origins']
        ax.text(i + .6,
                v + 0.5,
                '{}'.format(int(np.round(v))),
                fontsize=8,
                horizontalalignment='center')
    sns.despine(ax=ax)

    ax.legend(loc='upper left', fontsize=7)

    fig.tight_layout()
    fig.savefig('figures/final-plot.pdf')
    fig.savefig('figures/final-plot.svg')
    fig.savefig('figures/final-plot.eps')
コード例 #8
0
ファイル: figutils.py プロジェクト: SeanMcKnight/wqio
def parallel_coordinates(dataframe, hue, cols=None, palette=None, **subplot_kws):
    """ Produce a parallel coordinates plot from a dataframe.

    Parameters
    ----------
    dataframe : pandas.DataFrame
        The data to be plotted.
    hue : string
        The column used to the determine assign the lines' colors.
    cols : list of strings, optional
        The non-hue columns to include. If None, all other columns are
        used.
    palette : string, optional
        Name of the seaborn color palette to use.
    **subplot_kws : keyword arguments
        Options passed directly to plt.subplots()

    Returns
    -------
    fig : matplotlib Figure

    """

    # get the columsn to plot
    if cols is None:
        cols = dataframe.select(lambda c: c != hue, axis=1).columns.tolist()

    # subset the data
    final_cols = copy.copy(cols)
    final_cols.append(hue)
    data = dataframe[final_cols]

    # these plots look ridiculous in anything other than 'ticks'
    with seaborn.axes_style('ticks'):
        fig, axes = plt.subplots(ncols=len(cols), **subplot_kws)
        hue_vals = dataframe[hue].unique()
        colors = seaborn.color_palette(name=palette, n_colors=len(hue_vals))
        color_dict = dict(zip(hue_vals, colors))

        for col, ax in zip(cols, axes):
            data_limits =[(0, dataframe[col].min()), (0, dataframe[col].max())]
            ax.set_xticks([0])
            ax.update_datalim(data_limits)
            ax.set_xticklabels([col])
            ax.autoscale(axis='y')
            ax.tick_params(axis='y', direction='inout')
            ax.tick_params(axis='x', direction='in')

        for row in data.values:
            for n, (ax1, ax2) in enumerate(zip(axes[:-1], axes[1:])):
                line = _connect_spines(ax1, ax2, row[n], row[n+1], color=color_dict[row[-1]])


    fig.subplots_adjust(wspace=0)
    seaborn.despine(fig=fig, bottom=True, trim=True)
    return fig
コード例 #9
0
ファイル: resultsPlotter.py プロジェクト: YesperP/Master
def _plotTrackingPercentage(plotData):
    figure = plt.figure(figsize=(figureWidth, halfPageHeight), dpi=600)
    sns.set_style(style='white')
    ax = figure.gca()

    minTracking = 100.
    lambdaPhiSet = set()

    trackingPercentageList = []
    for j, (P_d, d1) in enumerate(plotData.items()):
        for i, (N, d2) in enumerate(d1.items()):
            x = []
            y = []
            for lambda_phi, trackingPercentage in d2.items():
                x.append(lambda_phi)
                y.append(trackingPercentage)
                minTracking = min(minTracking, trackingPercentage)
                lambdaPhiSet.add(lambda_phi)
            x = np.array(x)
            y = np.array(y)
            x, y = (list(t) for t in zip(*sorted(zip(x, y))))
            trackingPercentageList.append((P_d, N, x, y))

    trackingPercentageList.sort(key=lambda tup: float(tup[1]), reverse=True)
    trackingPercentageList.sort(key=lambda tup: float(tup[0]), reverse=True)

    pdSet = set()
    nSet = set()
    for P_d, N, x, y in trackingPercentageList:
        if P_d not in pdSet:
            nSet.clear()
        pdSet.add(P_d)
        nSet.add(N)
        ax.plot(x, y,
                label="$P_D$={0:}, N={1:.0f}".format(P_d, N),
                c=colors[len(nSet) - 1],
                linestyle=linestyleList[len(pdSet) - 1],
                linewidth=linewidth,
                marker='*' if len(x)==1 else None)

    lambdaPhiList = list(lambdaPhiSet)
    lambdaPhiList.sort()

    ax.legend(loc=0, ncol=len(pdSet), fontsize=legendFontsize)
    ax.set_xlabel("$\lambda_{\phi}$", fontsize=labelFontsize)
    ax.set_ylabel("\nAverage tracking percentage", fontsize=labelFontsize)
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.1e'))
    ax.set_ylim(0.0, 100.01)
    ax.tick_params(labelsize=labelFontsize)
    sns.despine(ax=ax, offset=0)
    ax.xaxis.set_ticks(lambdaPhiList)
    figure.tight_layout(pad=0.8, h_pad=0.8, w_pad=0.8)
    return figure
コード例 #10
0
def multi_colorbar(cmaps, vmins, vmaxs, ax=None, orientation='vertical'):
    """
    plots multiple colorbars with different vmins and vmaxs on the same axis
    elect_colors = ('blue', 'red', 'green')
    vmins=(.5, -.5, .5)
    vmaxs=(3., 3., 2.)

    cmaps = [sns.light_palette(ecolor, as_cmap=True) for ecolor in elect_colors]

    multi_colorbar(cmaps, vmins, vmaxs, orientation='horizontal')
    """

    if orientation not in ('vertical', 'horizontal'):
        raise ValueError('orientation must be either vertical or horizontal')

    if not (len(cmaps) == len(vmins) and (len(cmaps) == len(vmaxs))):
        raise ValueError(
            'cmaps, vmins, and vmaxs must all be iterables of the same length')

    if ax is None:
        if orientation == 'horizontal':
            figsize = (3, 1)
        else:
            figsize = (1, 3)
        fig, ax = plt.subplots(figsize=figsize)

    maxmax = max(vmaxs)
    minmin = min(vmins)

    if orientation == 'vertical':
        extent = [-.5, len(cmaps) - .5, maxmax, minmin]
    else:
        extent = [minmin, maxmax, -.5, len(cmaps) - .5]

    im = []
    for cmap, vmin, vmax in zip(cmaps, vmins, vmaxs):
        yy = np.linspace(minmin, maxmax, 100)
        colors = cmap((yy - vmin) / (vmax - vmin))
        im.append(colors)
    im = np.array(im)
    if orientation == 'vertical':
        im = np.array(im).swapaxes(0, 1)
    ax.imshow(im, interpolation='nearest', extent=extent)
    ax.invert_yaxis()
    plt.axis('tight')
    if orientation == 'vertical':
        ax.set_xticks([])
    else:
        ax.set_yticks([])
    sns.despine(ax=ax)

    return ax
コード例 #11
0
def plot_angle(data,
               N=50,
               title=None,
               ax1=None,
               ax2=None,
               color=None,
               wrap=True):
    if ax1 is None or ax2 is None:
        gs = gridspec.GridSpec(2, 6)
        ax1 = plt.subplot(gs[:1, :2], polar=True)
        ax2 = plt.subplot(gs[:1, 2:])

    if wrap:
        vf = np.vectorize(wrapAngle)
    else:
        vf = np.vectorize(constrainAngle)
    x = vf(data)

    sns.distplot(x, bins=N, ax=ax2, color=color, kde=True)
    radii, theta = np.histogram(x, bins=N, normed=True)
    ax1.set_yticklabels([])

    if wrap:
        ax1ticks = [0, 45, 90, 135, 180, -135, -90, -45]
        ax2ticks = list(range(-180, 180 + 45, 45))
        ax1.set_xticklabels(['{}°'.format(x) for x in ax1ticks])
        ax2.set_xlim(-180, 180)
        ax2.set_xticks(ax2ticks)
        ax2.set_xticklabels(['{}°'.format(x) for x in ax2ticks])

    else:
        ax2ticks = list(range(0, 360 + 45, 45))
        ax2.set_xlim(0, 360)
        ax2.set_xticks(ax2ticks)
        ax2.set_xticklabels(['{}°'.format(x) for x in ax2ticks])

    ax2.set_yticks([])
    ax2.set(xlabel='Angle', ylabel='Density')

    sns.despine(ax=ax2)
    width = (2 * np.pi) / N

    ax1.bar(np.deg2rad(theta[1:]), radii, width=width, color=color, alpha=.5)

    if title is not None:
        plt.suptitle(title)

    plt.tight_layout()

    f = plt.gcf()
    return f, (ax1, ax2)
コード例 #12
0
def make_map(state2poly, states, label, figsize=(12, 9)):
    """
    Draw a cloropleth map, that maps data onto the United States
    
    Inputs
    -------
    state2poly: state geometry dictionary
    states : Column of a DataFrame
        The value for each state, to display on a map
    label : str
        Label of the color bar

    Returns
    --------
    The map
    """
    fig = plt.figure(figsize=figsize)  # create a figure
    ax = plt.gca()  # get axes from the figure

    if states.max() < 2:  # colormap for election probabilities
        cmap = cm.RdBu
        vmin, vmax = 0, 1
    else:  # colormap for electoral votes, or other values
        cmap = cm.binary
        vmin, vmax = 0, states.min(), states.max()
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

    skip = set([
        'National', 'District of Columbia', 'Guam', 'Puerto Rico',
        'Virgin Islands', 'American Samoa', 'Northern Mariana Islands'
    ])
    for state in states_abbrev.values():
        if state in skip:
            continue
        color = cmap(norm(states.loc[state]))
        draw_state(ax, state2poly[state], color=color)

    #add an inset colorbar
    ax1 = fig.add_axes([0.45, 0.70, 0.4, 0.02])
    cb1 = mpl.colorbar.ColorbarBase(ax1,
                                    cmap=cmap,
                                    norm=norm,
                                    orientation='horizontal')
    ax1.set_title(label)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(-180, -60)
    ax.set_ylim(15, 75)
    sns.despine(left=True, bottom=True)
    return ax
コード例 #13
0
def plot_waveform(wf_header, wf_data,\
                  fig=None,savename=None,\
                  use_mv_and_ns=True,\
                  color=None):
    """
    Make a plot of a single acquisition

    Args:
        wf_header (dict): custom waveform header
        wf_data (np.ndarray): waveform data

    Keyword Args:
        fig (pylab.figure): A figure instance
        savename (str): where to save the figure (full path)
        use_mv_and_ns (bool): use mV and ns instead of V and s
    Returns:
        pylab.fig
    """
    if color is None:
        color = sb.color_palette("dark")[0]

    if fig is None:
        fig = p.figure()
    ax = fig.gca()

    # if remove_empty_bins:
    #    bmin = min(bincenters[bincontent > 0])
    #    bmax = max(bincenters[bincontent > 0])
    #    bincenters = bincenters[np.logical_and(bincenters >= bmin, bincenters <= bmax)]
    #    bincontent = bincontent[np.logical_and(bincenters >= bmin, bincenters <= bmax)]

    xlabel = wf_header["xunit"]
    ylabel = wf_header["yunit"]
    xs = copy(wf_header["xs"])
    ys = copy(wf_data)

    if xlabel == "s" and ylabel == "V" and use_mv_and_ns:
        xs *= 1e9
        ys *= 1e3
        xlabel = "ns"
        ylabel = "mV"
    ax.plot(xs, ys, color=color)
    ax.grid()
    sb.despine(fig)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    p.tight_layout()
    if savename is not None:
        fig.savefig(savename)
    return fig
def main():
    expts = lab.ExperimentSet(
        os.path.join(df.metadata_path, 'expt_metadata.xml'),
        behaviorDataPath=os.path.join(df.data_path, 'behavior'),
        dataPath=os.path.join(df.data_path, 'imaging'))

    sal_grp = lab.classes.HiddenRewardExperimentGroup.from_json(
        sal_json, expts, label='saline to muscimol')
    mus_grp = lab.classes.HiddenRewardExperimentGroup.from_json(
        mus_json, expts, label='muscimol to saline')

    fig = plt.figure(figsize=(8.5, 11))
    gs = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.1, right=0.4)
    ax = fig.add_subplot(gs[0, 0])

    for expt in mus_grp:
        if 'saline' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'reversal'
        elif 'muscimol' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'learning'
    for expt in sal_grp:
        if 'saline' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'learning'
        elif 'muscimol' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'reversal'

    plotting.plot_metric(
        ax, [sal_grp, mus_grp], metric_fn=ra.fraction_licks_in_reward_zone,
        label_groupby=False, plotby=['X_drug_condition'],
        plot_method='swarm', rotate_labels=False,
        activity_label='Fraction of licks in reward zone',
        colors=sns.color_palette('deep'), plot_bar=True)
    ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4])
    ax.set_ylim(top=0.4)
    ax.set_xticklabels(['Days 1-3', 'Day 4'])

    sns.despine(fig)
    ax.set_title('')
    ax.set_xlabel('')

    misc.save_figure(
        fig, filename, save_dir=save_dir)

    plt.close('all')
コード例 #15
0
ファイル: samples.py プロジェクト: Geosyntec/pycvc
    def make_samplefig(self, **figkwargs):
        """ Generate a matplotlib figure showing the hyetograph,
        hydrograph, and timing of water quality samples

        Parameters
        ----------
        figkwargs : keyward arguments
            Plotting options passed directly to Storm.summaryPlot

        Writes
        ------
        Saves a .png and .pdf of the figure

        Returns
        -------
        fig : matplotlib.figure
            The instance of the figure.

        """
        serieslabels = {
            self.storm.outflowcol: 'Effluent (L/s)',
            self.storm.precipcol: '10-min Precip Depth (mm)'
        }

        fig, artists, labels = self.storm.summaryPlot(inflow=False, showLegend=False,
                                                      figopts=figkwargs,
                                                      serieslabels=serieslabels)
        rug = self.plot_ts(ax=fig.axes[1], isFocus=True, asrug=False)
        fig.axes[0].set_ylabel('Precip (mm)')
        fig.axes[1].set_ylabel('BMP Effluent (L/s)')
        seaborn.despine(ax=fig.axes[1])
        seaborn.despine(ax=fig.axes[0], bottom=True, top=False)

        artists.extend([rug])
        labels.extend(['Samples'])

        leg = fig.axes[0].legend(artists, labels, fontsize=7, ncol=1,
                                 markerscale=0.75, frameon=False,
                                 loc='lower right')
        leg.get_frame().set_zorder(25)

        viz.savefig(fig, self.storm_figure, extra='Storm', asPDF=True)
        return fig
コード例 #16
0
def plot_final_results(results):
    import numpy as np
    from matplotlib import pyplot as plt
    import seaborn.apionly as sns

    results = {k:(100*v.q2) for k,v in results.iteritems()}

    methods = [m for m,_ in FEATURES]
    methods.sort(key=lambda m: results[m+'.origins'])
    methods.append('average.loo')


    color0 = plotinfo.color_red
    color1 = plotinfo.color_blue

    fig,ax = plt.subplots(figsize=[plotinfo.TEXTWIDTH_1_2_IN, plotinfo.TEXTWIDTH_1_2_IN/1.6])


    ax.bar(np.arange(len(methods))+.4, [results[m + '.origins'] for m in methods], width=.4, color=color0, label='Corrected')
    ax.bar(np.arange(len(methods)-1), [results[m + '.origins.raw'] for m in methods[:-1]], width=.4, color=color1, label='Raw')
    ax.set_ylim(40, 100)
    ax.set_xlabel('Method')
    ax.set_ylabel(r'$Q^2 (\%)$')
    ax.set_xticks(np.arange(len(methods))+.4)
    ax.set_xticklabels(methods[:-1] + ['avg'], fontsize=7)
    ax.grid(False)

    for i, m in enumerate(methods[:-1]):
        v = results[m + '.origins.raw']
        plt.text(i +.2, v + 0.5, '{}'.format(int(np.round(v))), fontsize=8, horizontalalignment='center')
    for i, m in enumerate(methods):
        v = results[m + '.origins']
        ax.text(i + .6, v + 0.5, '{}'.format(int(np.round(v))), fontsize=8, horizontalalignment='center')
    sns.despine(ax=ax)

    ax.legend(loc='upper left', fontsize=7)

    fig.tight_layout()
    fig.savefig('figures/final-plot.pdf')
    fig.savefig('figures/final-plot.svg')
    fig.savefig('figures/final-plot.eps')
コード例 #17
0
def plot_results(name, results, fractions, in_figures=False):
    from matplotlib import pyplot as plt
    import seaborn.apionly as sns
    fig,ax = plt.subplots(figsize=[plotinfo.TEXTWIDTH_1_2_IN, plotinfo.TEXTWIDTH_1_2_IN*.75])

    ax.plot([0,1], [0,1], c=plotinfo.color_blue, zorder=0)
    ax.scatter(results, fractions, s=16, lw=1, edgecolor='k', c=plotinfo.color_red, zorder=1)
    ax.set_ylim(-.04,1.04)
    ax.set_xlim(-.04,1.04)
    ax.grid(False)
    ax.set_xlabel('Human labeled')
    ax.set_ylabel('Automatic prediction')
    sns.despine(ax=ax, trim=True)
    fig.tight_layout()


    directory = ('figures/' if in_figures else 'outputs/')
    mkdir_p(directory)
    fig.savefig('{}/{}.eps'.format(directory, name))
    fig.savefig('{}/{}.pdf'.format(directory, name))
    plt.close(fig)
コード例 #18
0
def plot(rs):
    from matplotlib import pyplot as plt
    import seaborn.apionly as sns
    import plotinfo
    rh2 = np.corrcoef(rs.T[0], rs.T[1])[0, 1]**2
    real, alt, _ = rs.T
    q2 = 1. - np.dot(real - alt, real - alt) / np.dot(real - real.mean(),
                                                      real - real.mean())
    w = plotinfo.TEXTWIDTH_1_2_IN
    fig, ax = plt.subplots(figsize=[w, w * 0.75])
    ax.text(.2,
            .84,
            r'$R^2: {}\%$'.format(int(np.round(rh2 * 100))),
            fontsize=14)
    ax.text(.2,
            .70,
            r'$Q^2: {}\%$'.format(int(np.round(q2 * 100))),
            fontsize=14)
    ax.set_xlabel(r'${}$Human labeler 1 (fraction assigned to NET)${}$')
    ax.set_ylabel(r'${}$Human labeler 2 (fraction assigned to NET)${}$')
    ax.set_ylim(-.04, 1.04)
    ax.set_xlim(-.04, 1.04)

    ax.scatter(rs.T[0],
               rs.T[1],
               s=16,
               lw=1,
               edgecolor='k',
               c=plotinfo.color_red,
               zorder=1)
    ax.plot([rs.T[:2].min(), rs.T[:2].max()], [rs.T[:2].min(), rs.T[:2].max()],
            c=plotinfo.color_blue,
            zorder=0)
    fig.tight_layout()
    sns.despine(ax=ax, offset=True, trim=True)

    fig.savefig('figures/human-comparison.pdf')
    fig.savefig('figures/human-comparison.eps')
    fig.savefig('figures/human-comparison.png', dpi=1200)
コード例 #19
0
    def plot_lung_cloud(self, 
                        x_index,
                        y_index,
                        ax = plt.subplot(111),
                        color = blue,
                        alpha = 0.002,
                        slice_index = 2,
                        slice_max = 1e3,
                        slice_min = -1e3,
                        do_slice = False,# pass in reduced lung if wanted
                        skip_n = 1): 
        if do_slice:
            lung = self.lung[(self.lung[:,slice_index] < slice_max) & (self.lung[:,slice_index] > slice_min)]
        else:
            lung = self.lung

        ax.plot(lung[::skip_n, x_index], lung[::skip_n, y_index], '.', color = color, alpha = alpha, markersize = 3)            
        ax.set_aspect('equal')
        ax.invert_yaxis()
        ax.invert_xaxis()
        sns.despine(ax = ax,  top=True, right=True, left=False, bottom=False)
        return ax
コード例 #20
0
def plot_histogram(bincenters,bincontent,\
                   fig=None,savename="test.png",\
                   remove_empty_bins=True):
    """
    Plot a histogram returned by TektronixDPO4104B.get_histogram
    Use pylab.plot

    Args:
        bincenters (np.ndarray); bincenters (x)
        bincontent (np.ndarray): bincontent (y)

    Keyword Args:
        fig (pylab.figure): A figure instance
        savename (str): where to save the figure (full path)
        remove_empty_bins (bool): Cut away preceeding and trailing zero bins


    """

    if fig is None:
        fig = p.figure()
    ax = fig.gca()
    if remove_empty_bins:
        bmin = min(bincenters[bincontent > 0])
        bmax = max(bincenters[bincontent > 0])
        bincenters = bincenters[np.logical_and(bincenters >= bmin,
                                               bincenters <= bmax)]
        bincontent = bincontent[np.logical_and(bincenters >= bmin,
                                               bincenters <= bmax)]

    ax.plot(bincenters, bincontent, color=sb.color_palette("dark")[0])
    ax.grid()
    sb.despine(fig)
    ax.set_xlabel("amplitude")
    ax.set_ylabel("log nevents ")
    p.tight_layout()
    fig.savefig(savename)
    return fig
コード例 #21
0
def calculate_peak_to_valley_ratio(bestfitmodel,
                                   mu_ped,
                                   mu_spe,
                                   control_plot=False):
    """
    Calculate the peak to valley ratio
    Args:
        bestfitmodel (fit.Model): A fitted model to charge response data
        mu_ped (float): The x value of the fitted pedestal
        mu_spe (flota): The x value of the fitted spe peak
   
    Keyword Args:
        control_plot (bool): Show control plot to see if correct values are found
   
    """

    tmpdata = bestfitmodel.prediction(bestfitmodel.xs)
    valley = min(tmpdata[np.logical_and(bestfitmodel.xs > mu_ped,\
                                    bestfitmodel.xs < mu_spe)])
    valley_x = bestfitmodel.xs[tmpdata == valley]

    peak = max(tmpdata[bestfitmodel.xs > valley_x])
    peak_x = bestfitmodel.xs[tmpdata == peak]
    peak_v_ratio = (peak / valley)

    if control_plot:
        fig = p.figure()
        ax = fig.gca()
        ax.plot(bestfitmodel.xs, tmpdata)

        ax.scatter(valley_x, valley, marker="o")
        ax.scatter(peak_x, peak, marker="o")
        ax.set_ylim(ymin=1e-4)
        ax.set_yscale("log")
        ax.grid(1)
        sb.despine()

    return peak_v_ratio
コード例 #22
0
ファイル: resultsPlotter.py プロジェクト: YesperP/Master
def _plotTimeLog(plotData):
    figure = plt.figure(figsize=(figureWidth, halfPageHeight), dpi=600)
    ax = figure.gca()
    sns.set_style(style='white')

    nSet = set()
    for j, P_d in enumerate(sorted(plotData, reverse=True)):
        d1 = plotData[P_d]
        for i, lambda_phi in enumerate(sorted(d1)):
            d2 = d1[lambda_phi]
            x = []
            y = []
            for N, (meanRuntime, percentiles) in d2.items():
                x.append(N)
                y.append(meanRuntime)
                nSet.add(N)
            x = np.array(x)
            y = np.array(y)
            x, y = (list(t) for t in zip(*sorted(zip(x, y))))
            ax.plot(x,y, linestyle=linestyleList[i], color=colors[j],
                     label="$P_D$={0:}, $\lambda_\phi$={1:}".format(P_d, lambda_phi),
                    linewidth=linewidth)

    ax.set_xlim(ax.get_xlim()[0]-0.5, ax.get_xlim()[1]+0.5)
    ax.set_ylim(0,1)
    ax.set_title("Tracking iteration runtime", fontsize=titleFontsize)
    ax.set_xlabel("N", fontsize=labelFontsize, labelpad=0)
    ax.set_ylabel("Average iteration time [s]", fontsize=labelFontsize)
    ax.xaxis.set_ticks(sorted(list(nSet)))
    ax.tick_params(labelsize=labelFontsize)

    ax.legend(loc=0, fontsize=legendFontsize)
    ax.grid(False)

    sns.despine(ax=ax)
    figure.tight_layout(pad=0.5, h_pad=0.5, w_pad=0.5)
    return figure
コード例 #23
0
ファイル: calc_rmsf.py プロジェクト: JHP4911/script
def plot_rmsf(data, resid_data):

    col = ['#8c96c6', '#8856a7', '#810f7c']


    plt.style.use('ggplot')
    sns.set_style('ticks')

    ax = plt.subplot(111)
    ax.plot(resid_data.resids, data, '-', linewidth=1, color=col[-1])
    ax.fill_between(resid_data.resids, data, alpha=0.1, color=col[-1])

    sns.despine(ax=ax, offset=2.5)
    ax.set_xlabel("Residue number")
    ax.set_ylabel(r"RMSF ($\AA$)")
    #ax.set_ylim(top=9.5)

    #ax.axvspan(4, 9, alpha=0.1, color='red') # mark the hydophobic patch region
    # ax.axvspan(127, 132, alpha=0.5, color='red')
    # ax.axvspan(250, 255, alpha=0.5, color='red')

    plt.legend()
    plt.savefig('rmsf.svg', format='svg')
    plt.show()
コード例 #24
0
def cornerplot(datas,
               labels=None,
               interact_plot=None,
               lims=(-.1, .2),
               highlight=None):
    if interact_plot is None:

        def interact_plot(data2, data1):
            plt.plot(data2, data1, '.')
            if highlight is not None:
                plt.plot(data2[highlight], data1[highlight], 'r.')
            plt.plot(plt.xlim(), plt.xlim(), '--', color='grey')
            sns.despine(ax=plt.gca())

    fig, axs = plt.subplots(len(datas), len(datas), figsize=(8, 8))
    plt.subplots_adjust(hspace=.3, wspace=.3)

    for i, data1 in enumerate(datas):
        if labels:
            axs[i, 0].set_ylabel(labels[i])
        for j, data2 in enumerate(datas):
            if labels and i == (len(datas) - 1):
                axs[len(datas) - 1, j].set_xlabel(labels[j])
            if i == j:
                axs[i, j].hist(data1)
                if lims is not None:
                    axs[i, j].set_xlim(lims)
                sns.despine(ax=axs[i, j])
            elif i > j:
                plt.sca(axs[i, j])
                interact_plot(data2, data1)
                if lims is not None:
                    axs[i, j].set_xlim(lims)
                    axs[i, j].set_ylim(lims)
            elif j > i:
                axs[i, j].axis('off')
def show_parameters(axs, model, enrich, color='b'):

    positions = np.linspace(-np.pi, np.pi, 1000)

    bs, ks = model.shift_mean_var(positions)
    recur = model.recur_by_position(positions)

    axs[0].plot(positions, recur, color=color)
    axs[0].axvline(ls='--', color='0.4', lw=0.5)
    axs[0].set_xlim(-np.pi, np.pi)
    axs[0].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0].set_ylim(-0.3, 1.3)
    axs[0].set_yticks([0, 0.5, 1])
    axs[0].tick_params(length=3, pad=1, top=False)
    axs[0].set_xlabel('Distance from reward (fraction of belt)')
    axs[0].set_ylabel('Recurrence probability')

    axs[1].plot(positions, bs, color=color)
    axs[1].axvline(ls='--', color='0.4', lw=0.5)
    axs[1].axhline(ls='--', color='0.4', lw=0.5)
    axs[1].tick_params(length=3, pad=1, top=False)
    axs[1].set_xlim(-np.pi, np.pi)
    axs[1].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[1].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[1].set_ylim(-0.10 * 2 * np.pi, 0.10 * 2 * np.pi)
    y_ticks = np.array(['-0.10', '-0.05', '0', '0.05', '0.10'])
    axs[1].set_yticks(y_ticks.astype('float') * 2 * np.pi)
    axs[1].set_yticklabels(y_ticks)
    axs[1].set_xlabel('Initial distance from reward (fraction of belt)')
    axs[1].set_ylabel(r'$\Delta$ position (fraction of belt)')

    axs[2].plot(positions, 1 / ks, color=color)
    axs[2].axvline(ls='--', color='0.4', lw=0.5)
    axs[2].tick_params(length=3, pad=1, top=False)
    axs[2].set_xlim(-np.pi, np.pi)
    axs[2].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[2].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[2].set_ylim(0, 1)
    y_ticks = np.array(['0', '0.005', '0.010', '0.015', '0.020', '0.025'])
    axs[2].set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    axs[2].set_yticklabels(y_ticks)
    axs[2].set_xlabel('Initial distance from reward (fraction of belt)')
    axs[2].set_ylabel(r'$\Delta$ position variance')

    axs[3].plot(range(9), np.mean(enrich, axis=0), color=color)
    axs[3].plot(range(9),
                np.percentile(enrich, 5, axis=0),
                ls='--',
                color=color)
    axs[3].plot(range(9),
                np.percentile(enrich, 95, axis=0),
                ls='--',
                color=color)
    axs[3].fill_between(range(9),
                        np.percentile(enrich, 5, axis=0),
                        np.percentile(enrich, 95, axis=0),
                        facecolor=color,
                        alpha=0.5)
    axs[3].axhline(0, ls='--', color='0.4', lw=0.5)

    sns.despine(ax=axs[3])
    axs[3].tick_params(length=3, pad=2, direction='out')
    axs[3].set_xlabel("Iteration ('session' #)")
    axs[3].set_ylabel('Enrichment (fraction of belt)')
    axs[3].set_xlim(-0.5, 8.5)
    axs[3].set_xticks([0, 2, 4, 6, 8])
    axs[3].set_ylim(-0.15, 0.10 * 2 * np.pi)
    y_ticks = np.array(['0', '0.05', '0.10'])
    axs[3].set_yticks(y_ticks.astype('float') * 2 * np.pi)
    axs[3].set_yticklabels(y_ticks)
コード例 #26
0
ファイル: recsys_compare.py プロジェクト: BigR-Lab/modl
def plot_benchs():
    output_dir = join(trace_dir, 'benches')

    fig = plt.figure()

    fig.subplots_adjust(right=.9)
    fig.subplots_adjust(top=.905)
    fig.subplots_adjust(bottom=.12)
    fig.subplots_adjust(left=.06)
    fig.set_figheight(fig.get_figheight() * 0.66)
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1.5])

    ylims = {'100k': [.90, .96], '1m': [.864, .915], '10m': [.80, .868],
             'netflix': [.93, .99]}
    xlims = {'100k': [0.0001, 10], '1m': [0.1, 20], '10m': [1, 400],
             'netflix': [30, 3000]}

    names = {'dl_partial': 'Proposed \n(partial projection)',
             'dl': 'Proposed \n(full projection)',
             'cd': 'Coordinate descent'}
    zorder = {'cd': 10,
              'dl': 1,
              'dl_partial': 5}
    for i, version in enumerate(['1m', '10m', 'netflix']):
        try:
            with open(join(output_dir, 'results_%s.json' % version), 'r') as f:
                results = json.load(f)
        except IOError:
            continue

        ax_time = fig.add_subplot(gs[0, i])
        ax_time.grid()
        sns.despine(fig, ax_time)

        ax_time.spines['left'].set_color((.6, .6, .6))
        ax_time.spines['bottom'].set_color((.6, .6, .6))
        ax_time.xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax_time.yaxis.set_tick_params(color=(.6, .6, .6), which='both')

        for tick in ax_time.xaxis.get_major_ticks():
            tick.label.set_fontsize(7)
            tick.label.set_color('black')
        for tick in ax_time.yaxis.get_major_ticks():
            tick.label.set_fontsize(7)
            tick.label.set_color('black')

        if i == 0:
            ax_time.set_ylabel('RMSE on test set')
        if i == 2:
            ax_time.set_xlabel('CPU time')
            ax_time.xaxis.set_label_coords(1.14, -0.06)

        ax_time.grid()
        palette = sns.cubehelix_palette(3, start=0, rot=.5, hue=1, dark=.3,
                                        light=.7,
                                        reverse=False)
        color = {'dl_partial': palette[2], 'dl': palette[1], 'cd': palette[0]}
        for idx in sorted(OrderedDict(results).keys()):
            this_result = results[idx]
            ax_time.plot(this_result['timings'], this_result['rmse'],
                         label=names[idx], color=color[idx],
                         linewidth=2,
                         linestyle='-' if idx != 'cd' else '--',
                         zorder=zorder[idx])
        if version == 'netflix':
            ax_time.legend(loc='upper left', bbox_to_anchor=(.65, 1.1),
                           numpoints=1,
                           frameon=False)
        ax_time.set_xscale('log')
        ax_time.set_ylim(ylims[version])
        ax_time.set_xlim(xlims[version])
        if version == '1m':
            ax_time.set_xticks([.1, 1, 10])
            ax_time.set_xticklabels(['0.1 s', '1 s', '10 s'])
        elif version == '10m':
            ax_time.set_xticks([1, 10, 100])
            ax_time.set_xticklabels(['1 s', '10 s', '100 s'])
        else:
            ax_time.set_xticks([100, 1000])
            ax_time.set_xticklabels(['100 s', '1000 s'])
        ax_time.annotate(
            'MovieLens %s' % version.upper() if version != 'netflix' else 'Netflix (140M)',
            xy=(.5 if version != 'netflix' else .4, 1),
            xycoords='axes fraction', ha='center', va='bottom')
    plt.savefig(join(trace_dir, 'bench.pdf'))
コード例 #27
0
ファイル: recsys_compare.py プロジェクト: BigR-Lab/modl
def plot_learning_rate():
    output_dir = join(trace_dir, 'learning_rate')
    fig = plt.figure()
    fig.subplots_adjust(bottom=0.33)
    fig.subplots_adjust(top=0.99)
    fig.subplots_adjust(right=0.98)

    fig.set_figwidth(3.25653379549)
    fig.set_figheight(1.25)
    ax = {}
    gs = gridspec.GridSpec(1, 2)
    palette = sns.cubehelix_palette(10, start=0, rot=3, hue=1, dark=.3,
                                    light=.7,
                                    reverse=False)

    for j, version in enumerate(['10m', 'netflix']):
        with open(join(output_dir, 'results_%s.json' % version), 'r') as f:
            data = json.load(f)
        ax[j] = fig.add_subplot(gs[j])
        learning_rates = sorted(data, key=lambda t: float(t))
        for i, learning_rate in enumerate(learning_rates):
            this_data = data[str(learning_rate)]
            n_epochs = _get_hyperparams()['dl_partial'][version]['n_epochs']
            ax[j].plot(np.linspace(0, n_epochs, len(this_data['rmse'])),
                       this_data['rmse'],
                       label='%.2f' % float(learning_rate),
                       color=palette[i],
                       zorder=int(100 * float(learning_rate)))
            ax[j].set_xscale('log')
        sns.despine(fig, ax)

        ax[j].spines['left'].set_color((.6, .6, .6))
        ax[j].spines['bottom'].set_color((.6, .6, .6))
        ax[j].xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[j].yaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[j].tick_params(axis='y', labelsize=6)

    ax[0].set_ylabel('RMSE on test set')
    ax[0].set_xlabel('Epoch', ha='left', va='top')
    ax[0].xaxis.set_label_coords(-.18, -0.055)

    ax[0].set_xlim([.1, 40])
    ax[0].set_xticks([1, 10, 40])
    ax[0].set_xticklabels(['1', '10', '40'])
    ax[1].set_xlim([.1, 25])
    ax[1].set_xticks([.1, 1, 10, 20])
    ax[1].set_xticklabels(['.1', '1', '10', '20'])

    ax[0].annotate('MovieLens 10M', xy=(.95, .9), ha='right',
                   xycoords='axes fraction', zorder=100)
    ax[1].annotate('Netflix', xy=(.95, .9), ha='right',
                   xycoords='axes fraction', zorder=100)

    ax[0].set_ylim([0.795, 0.877])
    ax[1].set_ylim([0.93, .999])
    ax[0].legend(ncol=4, loc='upper left', bbox_to_anchor=(-0.09, -.13),
                 fontsize=7, numpoints=1, columnspacing=.3, frameon=False)
    ax[0].annotate('Learning rate $\\beta$', xy=(1.6, -.38),
                   xycoords='axes fraction')
    ltext = ax[0].get_legend().get_texts()
    plt.setp(ltext, fontsize=7)

    plt.savefig(join(trace_dir, 'learning_rate.pdf'))
コード例 #28
0
def main():
    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_place_set']
    Df_expt_grp = all_grps['Df_place_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]
    if MALES_ONLY:
        for expt_grp in expt_grps:
            expt_grp.filter(lambda expt: expt.parent.get('sex') == 'M')

    WT_label = WT_expt_grp.label()
    Df_label = Df_expt_grp.label()

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(2,
                       5,
                       left=0.1,
                       right=0.3,
                       top=0.90,
                       bottom=0.67,
                       hspace=0.2)
    gs1_2 = plt.GridSpec(2,
                         5,
                         left=0.3,
                         right=0.5,
                         top=0.90,
                         bottom=0.67,
                         hspace=0.2)
    WT_1_heatmap_ax = fig.add_subplot(gs1[0, :-1])
    WT_3_heatmap_ax = fig.add_subplot(gs1_2[0, :-1])
    Df_1_heatmap_ax = fig.add_subplot(gs1[1, :-1])
    Df_3_heatmap_ax = fig.add_subplot(gs1_2[1, :-1])

    gs_cbar = plt.GridSpec(2,
                           10,
                           left=0.3,
                           right=0.5,
                           top=0.90,
                           bottom=0.67,
                           hspace=0.2)
    WT_colorbar_ax = fig.add_subplot(gs_cbar[0, -1])
    Df_colorbar_ax = fig.add_subplot(gs_cbar[1, -1])

    gs2 = plt.GridSpec(1, 10, left=0.1, right=0.5, top=0.6, bottom=0.45)
    pf_close_fraction_ax = fig.add_subplot(gs2[0, :4])
    pf_close_behav_corr_ax = fig.add_subplot(gs2[0, 5:])

    frac_near_range_2 = (-0.051, 0.551)
    behav_range_2 = (-0.051, 0.551)

    #
    # Heatmaps
    #

    WT_cmap = sns.light_palette(WT_color, as_cmap=True)
    WT_dataframe = lab.ExperimentGroup.dataframe(
        WT_expt_grp, include_columns=['X_condition', 'X_day', 'X_session'])

    WT_1_expt_grp = WT_expt_grp.subGroup(
        list(WT_dataframe[(WT_dataframe['X_condition'] == 'C')
                          & (WT_dataframe['X_day'] == '0') &
                          (WT_dataframe['X_session'] == '0')]['expt']))
    place.plotPositionHeatmap(WT_1_expt_grp,
                              roi_filter=WT_filter,
                              ax=WT_1_heatmap_ax,
                              norm='individual',
                              cbar_visible=False,
                              cmap=WT_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(WT_1_heatmap_ax, WT_1_expt_grp)
    WT_1_heatmap_ax.set_title(r'Condition $\mathrm{III}$: Day 1')
    WT_1_heatmap_ax.set_ylabel(WT_label)
    WT_1_heatmap_ax.set_xlabel('')

    WT_3_expt_grp = WT_expt_grp.subGroup(
        list(WT_dataframe[(WT_dataframe['X_condition'] == 'C')
                          & (WT_dataframe['X_day'] == '2') &
                          (WT_dataframe['X_session'] == '0')]['expt']))
    place.plotPositionHeatmap(WT_3_expt_grp,
                              roi_filter=WT_filter,
                              ax=WT_3_heatmap_ax,
                              norm='individual',
                              cbar_visible=True,
                              cax=WT_colorbar_ax,
                              cmap=WT_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(WT_3_heatmap_ax, WT_3_expt_grp)
    WT_3_heatmap_ax.set_title(r'Condition $\mathrm{III}$: Day 3')
    WT_3_heatmap_ax.set_ylabel('')
    WT_3_heatmap_ax.set_xlabel('')
    WT_colorbar_ax.set_yticklabels(['Min', 'Max'])

    Df_cmap = sns.light_palette(Df_color, as_cmap=True)
    Df_dataframe = lab.ExperimentGroup.dataframe(
        Df_expt_grp, include_columns=['X_condition', 'X_day', 'X_session'])

    Df_1_expt_grp = Df_expt_grp.subGroup(
        list(Df_dataframe[(Df_dataframe['X_condition'] == 'C')
                          & (Df_dataframe['X_day'] == '0') &
                          (Df_dataframe['X_session'] == '2')]['expt']))
    place.plotPositionHeatmap(Df_1_expt_grp,
                              roi_filter=Df_filter,
                              ax=Df_1_heatmap_ax,
                              norm='individual',
                              cbar_visible=False,
                              cmap=Df_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(Df_1_heatmap_ax, Df_1_expt_grp)
    Df_1_heatmap_ax.set_ylabel(Df_label)

    Df_3_expt_grp = Df_expt_grp.subGroup(
        list(Df_dataframe[(Df_dataframe['X_condition'] == 'C')
                          & (Df_dataframe['X_day'] == '2') &
                          (Df_dataframe['X_session'] == '0')]['expt']))
    place.plotPositionHeatmap(Df_3_expt_grp,
                              roi_filter=Df_filter,
                              ax=Df_3_heatmap_ax,
                              norm='individual',
                              cbar_visible=True,
                              cax=Df_colorbar_ax,
                              cmap=Df_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(Df_3_heatmap_ax, Df_3_expt_grp)
    Df_3_heatmap_ax.set_ylabel('')
    Df_colorbar_ax.set_yticklabels(['Min', 'Max'])

    #
    # Fraction of PCs near reward
    #

    activity_metric = place.centroid_to_position_threshold
    activity_kwargs = {
        'method': 'resultant_vector',
        'positions': 'reward',
        'pcs_only': True,
        'threshold': np.pi / 8
    }
    behavior_fn = ra.fraction_licks_in_reward_zone
    behavior_kwargs = {}
    behavior_label = 'Fraction of licks in reward zone'

    plotting.plot_metric(pf_close_fraction_ax,
                         expt_grps,
                         metric_fn=activity_metric,
                         roi_filters=roi_filters,
                         groupby=[['expt', 'X_condition', 'X_day']],
                         plotby=['X_condition', 'X_day'],
                         plot_abs=False,
                         plot_method='line',
                         activity_kwargs=activity_kwargs,
                         rotate_labels=False,
                         activity_label='Fraction of place cells near reward',
                         label_every_n=1,
                         colors=colors,
                         markers=markers,
                         markersize=5,
                         return_full_dataframes=False,
                         linestyles=linestyles)
    pf_close_fraction_ax.axhline(1 / 8., linestyle='--', color='k')
    pf_close_fraction_ax.set_title('')
    sns.despine(ax=pf_close_fraction_ax)
    pf_close_fraction_ax.set_xlabel('Day in Condition')
    day_number_only_label(pf_close_fraction_ax)
    label_conditions(pf_close_fraction_ax)
    pf_close_fraction_ax.legend(loc='upper left', fontsize=6)
    pf_close_fraction_ax.set_ylim(0, 0.40)
    pf_close_fraction_ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4])

    scatter_kws = {'s': 5}
    colorby_list = [(expt_grp.label(), 'C') for expt_grp in expt_grps]
    pf_close_behav_corr_ax.set_xlim(frac_near_range_2)
    pf_close_behav_corr_ax.set_ylim(behav_range_2)
    plotting.plot_paired_metrics(
        expt_grps,
        first_metric_fn=place.centroid_to_position_threshold,
        second_metric_fn=behavior_fn,
        roi_filters=roi_filters,
        groupby=(('expt', ), ),
        colorby=('expt_grp', 'X_condition'),
        filter_fn=lambda df: df['X_condition'] == 'C',
        filter_columns=['X_condition'],
        first_metric_kwargs=activity_kwargs,
        second_metric_kwargs=behavior_kwargs,
        first_metric_label='Fraction of place cells near reward',
        second_metric_label=behavior_label,
        shuffle_colors=False,
        fit_reg=True,
        plot_method='regplot',
        colorby_list=colorby_list,
        colors=colors,
        markers=markers,
        ax=pf_close_behav_corr_ax,
        scatter_kws=scatter_kws,
        truncate=False,
        linestyles=linestyles)
    pf_close_behav_corr_ax.set_xlim(frac_near_range_2)
    pf_close_behav_corr_ax.set_ylim(behav_range_2)
    pf_close_behav_corr_ax.tick_params(direction='in')
    pf_close_behav_corr_ax.get_legend().set_visible(False)
    pf_close_behav_corr_ax.legend(loc='upper left', fontsize=6)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
コード例 #29
0
ファイル: hydro.py プロジェクト: SeanMcKnight/wqio
    def summaryPlot(self, axratio=2, filename=None, showLegend=True,
                    precip=True, inflow=True, outflow=True, figopts={},
                    serieslabels={}):
        '''
        Creates a figure showing the hydrlogic record (flow and
            precipitation) of the storm

        Input:
            axratio : optional float or int (default = 2)
                Relative height of the flow axis compared to the
                precipiation axis.

            filename : optional string (default = None)
                Filename to which the figure will be saved.

            **figwargs will be passed on to `plt.Figure`

        Writes:
            Figure of flow and precipitation for a storm

        Returns:
            None
        '''
        fig = plt.figure(**figopts)
        gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[1, axratio],
                               hspace=0.12)
        rainax = fig.add_subplot(gs[0])
        rainax.yaxis.set_major_locator(plt.MaxNLocator(5))
        flowax = fig.add_subplot(gs[1], sharex=rainax)

        # create the legend proxy artists
        artists = []
        labels = []

        legcols = 0
        # in the label assignment: `serieslabels.pop(item, item)` might
        # seem odd. What it does is looks for a label (value) in the
        # dictionary with the key equal to `item`. If there is no valur
        # for that key in the dictionary the `item` itself is returned.
        # so if there's nothing called "test" in mydict,
        # `mydict.pop("test", "test")` returns `"test"`.
        if inflow:
            fig, labels, artists = self.plot_hydroquantity(
                self.inflowcol,
                ax=flowax,
                label=serieslabels.pop(self.inflowcol, self.inflowcol),
                otherlabels=labels,
                artists=artists,
            )

        if outflow:
            fig, labels, arti = self.plot_hydroquantity(
                self.outflowcol,
                ax=flowax,
                label=serieslabels.pop(self.outflowcol, self.outflowcol),
                otherlabels=labels,
                artists=artists
            )

        if precip:
            fig, labels, arti = self.plot_hydroquantity(
                self.precipcol,
                ax=rainax,
                label=serieslabels.pop(self.precipcol, self.precipcol),
                otherlabels=labels,
                artists=artists
            )
            rainax.invert_yaxis()

        if showLegend:
            leg = rainax.legend(artists, labels, fontsize=7, ncol=1,
                                markerscale=0.75, frameon=False,
                                loc='lower right')
            leg.get_frame().set_zorder(25)
            _leg = [leg]
        else:
            _leg = None


        seaborn.despine(ax=rainax, bottom=True, top=False)
        seaborn.despine(ax=flowax)
        flowax.set_xlabel('')
        rainax.set_xlabel('')
        # grid lines and axis background color and layout
        #fig.tight_layout()

        if filename is not None:
            fig.savefig(filename, dpi=300, transparent=True,
                        bbox_inches='tight', bbox_extra_artists=_leg)

        return fig, artists, labels
コード例 #30
0
def pairedcontrast(data, x, y, idcol, reps = 3000,
statfunction = None, idx = None, figsize = None,
beforeAfterSpacer = 0.01, 
violinWidth = 0.005, 
floatOffset = 0.05, 
showRawData = False,
showAllYAxes = False,
floatContrast = True,
smoothboot = False,
floatViolinOffset = None, 
showConnections = True,
summaryBar = False,
contrastYlim = None,
swarmYlim = None,
barWidth = 0.005,
rawMarkerSize = 8,
rawMarkerType = 'o',
summaryMarkerSize = 10,
summaryMarkerType = 'o',
summaryBarColor = 'grey',
meansSummaryLineStyle = 'solid', 
contrastZeroLineStyle = 'solid', contrastEffectSizeLineStyle = 'solid',
contrastZeroLineColor = 'black', contrastEffectSizeLineColor = 'black',
pal = None,
legendLoc = 2, legendFontSize = 12, legendMarkerScale = 1,
axis_title_size = None,
yticksize = None,
xticksize = None,
tickAngle=45,
tickAlignment='right',
**kwargs):

    # Preliminaries.
    data = data.dropna()

    # plot params
    if axis_title_size is None:
        axis_title_size = 15
    if yticksize is None:
        yticksize = 12
    if xticksize is None:
        xticksize = 12

    axisTitleParams = {'labelsize' : axis_title_size}
    xtickParams = {'labelsize' : xticksize}
    ytickParams = {'labelsize' : yticksize}

    rc('axes', **axisTitleParams)
    rc('xtick', **xtickParams)
    rc('ytick', **ytickParams)

    ## If `idx` is not specified, just take the FIRST TWO levels alphabetically.
    if idx is None:
        idx = tuple(np.unique(data[x])[0:2],)
    else:
        # check if multi-plot or not
        if all(isinstance(element, str) for element in idx):
            # if idx is supplied but not a multiplot (ie single list or tuple)
            if len(idx) != 2:
                print(idx, "does not have length 2.")
                sys.exit(0)
            else:
                idx = (tuple(idx, ),)
        elif all(isinstance(element, tuple) for element in idx):
            # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
            if ( any(len(element) != 2 for element in idx) ):
                # If any of the tuples contain more than 2 elements.
                print(element, "does not have length 2.")
                sys.exit(0)
    if floatViolinOffset is None:
        floatViolinOffset = beforeAfterSpacer/2
    if contrastYlim is not None:
        contrastYlim = np.array([contrastYlim[0],contrastYlim[1]])
    if swarmYlim is not None:
        swarmYlim = np.array([swarmYlim[0],swarmYlim[1]])

    ## Here we define the palette on all the levels of the 'x' column.
    ## Thus, if the same pandas dataframe is re-used across different plots,
    ## the color identity of each group will be maintained.
    ## Set palette based on total number of categories in data['x'] or data['hue_column']
    if 'hue' in kwargs:
        u = kwargs['hue']
    else:
        u = x
    if ('color' not in kwargs and 'hue' not in kwargs):
        kwargs['color'] = 'k'

    if pal is None:
        pal = dict( zip( data[u].unique(), sns.color_palette(n_colors = len(data[u].unique())) ) 
                      )
    else:
        pal = pal

    # Initialise figure.
    if figsize is None:
        if len(idx) > 2:
            figsize = (12,(12/np.sqrt(2)))
        else:
            figsize = (6,6)
    fig = plt.figure(figsize = figsize)

    # Initialise GridSpec based on `levs_tuple` shape.
    gsMain = gridspec.GridSpec( 1, np.shape(idx)[0]) # 1 row; columns based on number of tuples in tuple.
    # Set default statfunction
    if statfunction is None:
        statfunction = np.mean
    # Create list to collect all the contrast DataFrames generated.
    contrastList = list()
    contrastListNames = list()

    for gsIdx, xlevs in enumerate(idx):
        ## Pivot tempdat to get before and after lines.
        data_pivot = data.pivot_table(index = idcol, columns = x, values = y)

        # Start plotting!!
        if floatContrast is True:
            ax_raw = fig.add_subplot(gsMain[gsIdx], frame_on = False)
            ax_contrast = ax_raw.twinx()
        else:
            gsSubGridSpec = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = gsMain[gsIdx])
            ax_raw = plt.Subplot(fig, gsSubGridSpec[0, 0], frame_on = False)
            ax_contrast = plt.Subplot(fig, gsSubGridSpec[1, 0], sharex = ax_raw, frame_on = False)

        ## Plot raw data as swarmplot or stripplot.
        if showRawData is True:
            swarm_raw = sns.swarmplot(data = data, 
                                     x = x, y = y, 
                                     order = xlevs,
                                     ax = ax_raw,
                                     palette = pal,
                                     size = rawMarkerSize,
                                     marker = rawMarkerType,
                                     **kwargs)
        else:
            swarm_raw = sns.stripplot(data = data, 
                                     x = x, y = y, 
                                     order = xlevs,
                                     ax = ax_raw,
                                     palette = pal,
                                     **kwargs)
        swarm_raw.set_ylim(swarmYlim)
           
        ## Get some details about the raw data.
        maxXBefore = max(swarm_raw.collections[0].get_offsets().T[0])
        minXAfter = min(swarm_raw.collections[1].get_offsets().T[0])
        if showRawData is True:
            #beforeAfterSpacer = (getSwarmSpan(swarm_raw, 0) + getSwarmSpan(swarm_raw, 1))/2
            beforeAfterSpacer = 1
        xposAfter = maxXBefore + beforeAfterSpacer
        xAfterShift = minXAfter - xposAfter

        ## shift the after swarmpoints closer for aesthetic purposes.
        offsetSwarmX(swarm_raw.collections[1], -xAfterShift)

        ## pandas DataFrame of 'before' group
        x1 = pd.DataFrame({str(xlevs[0] + '_x') : pd.Series(swarm_raw.collections[0].get_offsets().T[0]),
                       xlevs[0] : pd.Series(swarm_raw.collections[0].get_offsets().T[1]),
                       '_R_' : pd.Series(swarm_raw.collections[0].get_facecolors().T[0]),
                       '_G_' : pd.Series(swarm_raw.collections[0].get_facecolors().T[1]),
                       '_B_' : pd.Series(swarm_raw.collections[0].get_facecolors().T[2]),
                      })
        ## join the RGB columns into a tuple, then assign to a column.
        x1['_hue_'] = x1[['_R_', '_G_', '_B_']].apply(tuple, axis=1) 
        x1 = x1.sort_values(by = xlevs[0])
        x1.index = data_pivot.sort_values(by = xlevs[0]).index

        ## pandas DataFrame of 'after' group
        ### create convenient signifiers for column names.
        befX = str(xlevs[0] + '_x')
        aftX = str(xlevs[1] + '_x')

        x2 = pd.DataFrame( {aftX : pd.Series(swarm_raw.collections[1].get_offsets().T[0]),
            xlevs[1] : pd.Series(swarm_raw.collections[1].get_offsets().T[1])} )
        x2 = x2.sort_values(by = xlevs[1])
        x2.index = data_pivot.sort_values(by = xlevs[1]).index

        ## Join x1 and x2, on both their indexes.
        plotPoints = x1.merge(x2, left_index = True, right_index = True, how='outer')

        ## Add the hue column if hue argument was passed.
        if 'hue' in kwargs:
            h = kwargs['hue']
            plotPoints[h] = data.pivot(index = idcol, columns = x, values = h)[xlevs[0]]
            swarm_raw.legend(loc = legendLoc, 
                fontsize = legendFontSize, 
                markerscale = legendMarkerScale)

        ## Plot the lines to join the 'before' points to their respective 'after' points.
        if showConnections is True:
            for i in plotPoints.index:
                ax_raw.plot([ plotPoints.ix[i, befX],
                    plotPoints.ix[i, aftX] ],
                    [ plotPoints.ix[i, xlevs[0]], 
                    plotPoints.ix[i, xlevs[1]] ],
                    linestyle = 'solid',
                    color = plotPoints.ix[i, '_hue_'],
                    linewidth = 0.75,
                    alpha = 0.75
                    )

        ## Hide the raw swarmplot data if so desired.
        if showRawData is False:
            swarm_raw.collections[0].set_visible(False)
            swarm_raw.collections[1].set_visible(False)

        if showRawData is True:
            #maxSwarmSpan = max(np.array([getSwarmSpan(swarm_raw, 0), getSwarmSpan(swarm_raw, 1)]))/2
            maxSwarmSpan = 0.5
        else:
            maxSwarmSpan = barWidth            

        ## Plot Summary Bar.
        if summaryBar is True:
            # Calculate means
            means = data.groupby([x], sort = True).mean()[y]
            # # Calculate medians
            # medians = data.groupby([x], sort = True).median()[y]

            ## Draw summary bar.
            bar_raw = sns.barplot(x = means.index, 
                        y = means.values, 
                        order = xlevs,
                        ax = ax_raw,
                        ci = 0,
                        facecolor = summaryBarColor, 
                        alpha = 0.25)
            ## Draw zero reference line.
            ax_raw.add_artist(Line2D(
                (ax_raw.xaxis.get_view_interval()[0], 
                    ax_raw.xaxis.get_view_interval()[1]), 
                (0,0),
                color='black', linewidth=0.75
                )
            )       

            ## get swarm with largest span, set as max width of each barplot.
            for i, bar in enumerate(bar_raw.patches):
                x_width = bar.get_x()
                width = bar.get_width()
                centre = x_width + width/2.
                if i == 0:
                    bar.set_x(centre - maxSwarmSpan/2.)
                else:
                    bar.set_x(centre - xAfterShift - maxSwarmSpan/2.)
                bar.set_width(maxSwarmSpan)

        # Get y-limits of the treatment swarm points.
        beforeRaw = pd.DataFrame( swarm_raw.collections[0].get_offsets() )
        afterRaw = pd.DataFrame( swarm_raw.collections[1].get_offsets() )
        before_leftx = min(beforeRaw[0])
        after_leftx = min(afterRaw[0])
        after_rightx = max(afterRaw[0])
        after_stat_summary = statfunction(beforeRaw[1])

        # Calculate the summary difference and CI.
        plotPoints['delta_y'] = plotPoints[xlevs[1]] - plotPoints[xlevs[0]]
        plotPoints['delta_x'] = [0] * np.shape(plotPoints)[0]

        tempseries = plotPoints['delta_y'].tolist()
        test = tempseries.count(tempseries[0]) != len(tempseries)

        bootsDelta = bootstrap(plotPoints['delta_y'],
            statfunction = statfunction, 
            smoothboot = smoothboot,
            reps = reps)
        summDelta = bootsDelta['summary']
        lowDelta = bootsDelta['bca_ci_low']
        highDelta = bootsDelta['bca_ci_high']

        # set new xpos for delta violin.
        if floatContrast is True:
            if showRawData is False:
                xposPlusViolin = deltaSwarmX = after_rightx + floatViolinOffset
            else:
                xposPlusViolin = deltaSwarmX = after_rightx + maxSwarmSpan
        else:
            xposPlusViolin = xposAfter
        if showRawData is True:
            # If showRawData is True and floatContrast is True, 
            # set violinwidth to the barwidth.
            violinWidth = maxSwarmSpan

        xmaxPlot = xposPlusViolin + violinWidth

        # Plot the summary measure.
        ax_contrast.plot(xposPlusViolin, summDelta,
            marker = 'o',
            markerfacecolor = 'k', 
            markersize = summaryMarkerSize,
            alpha = 0.75
            )

        # Plot the CI.
        ax_contrast.plot([xposPlusViolin, xposPlusViolin],
            [lowDelta, highDelta],
            color = 'k', 
            alpha = 0.75,
            linestyle = 'solid'
            )

        # Plot the violin-plot.
        v = ax_contrast.violinplot(bootsDelta['stat_array'], [xposPlusViolin], 
                                 widths = violinWidth, 
                                 showextrema = False, 
                                 showmeans = False)
        halfviolin(v, half = 'right', color = 'k')

        # Remove left axes x-axis title.
        ax_raw.set_xlabel("")
        # Remove floating axes y-axis title.
        ax_contrast.set_ylabel("")

        # Set proper x-limits
        ax_raw.set_xlim(before_leftx - beforeAfterSpacer/2, xmaxPlot)
        ax_raw.get_xaxis().set_view_interval(before_leftx - beforeAfterSpacer/2, 
            after_rightx + beforeAfterSpacer/2)
        ax_contrast.set_xlim(ax_raw.get_xlim())

        if floatContrast is True:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))

            # Make sure they have the same y-limits.
            ax_contrast.set_ylim(ax_raw.get_ylim())
            
            # Drawing in the x-axis for ax_raw.
            ## Set the tick labels!
            ax_raw.set_xticklabels(xlevs, rotation = tickAngle, horizontalalignment = tickAlignment)
            ## Get lowest y-value for ax_raw.
            y = ax_raw.get_yaxis().get_view_interval()[0] 

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, statfunction(plotPoints[xlevs[0]]),
                           ax_contrast, 0)

            # Add label to floating axes. But on ax_raw!
            ax_raw.text(x = deltaSwarmX,
                          y = ax_raw.get_yaxis().get_view_interval()[0],
                          horizontalalignment = 'left',
                          s = 'Difference',
                          fontsize = 15)        

            # Set reference lines
            ## zero line
            ax_contrast.hlines(0,                                           # y-coordinate
                            ax_contrast.xaxis.get_majorticklocs()[0],       # x-coordinates, start and end.
                            ax_raw.xaxis.get_view_interval()[1],   
                            linestyle = 'solid',
                            linewidth = 0.75,
                            color = 'black')

            ## effect size line
            ax_contrast.hlines(summDelta, 
                            ax_contrast.xaxis.get_majorticklocs()[1],
                            ax_raw.xaxis.get_view_interval()[1],
                            linestyle = 'solid',
                            linewidth = 0.75,
                            color = 'black')

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, after_stat_summary, ax_contrast, 0.)
        else:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))
            
            fig.add_subplot(ax_raw)
            fig.add_subplot(ax_contrast)
        ax_contrast.set_ylim(contrastYlim)
        # Calculate p-values.
        # 1-sample t-test to see if the mean of the difference is different from 0.
        ttestresult = ttest_1samp(plotPoints['delta_y'], popmean = 0)[1]
        bootsDelta['ttest_pval'] = ttestresult
        contrastList.append(bootsDelta)
        contrastListNames.append( str(xlevs[1])+' v.s. '+str(xlevs[0]) )

    # Turn contrastList into a pandas DataFrame,
    contrastList = pd.DataFrame(contrastList).T
    contrastList.columns = contrastListNames

    # Now we iterate thru the contrast axes to normalize all the ylims.
    for j,i in enumerate(range(1, len(fig.get_axes()), 2)):
        axx=fig.get_axes()[i]
        ## Get max and min of the dataset.
        lower = np.min(contrastList.ix['stat_array',j])
        upper = np.max(contrastList.ix['stat_array',j])
        meandiff = contrastList.ix['summary', j]

        ## Make sure we have zero in the limits.
        if lower > 0:
            lower = 0.
        if upper < 0:
            upper = 0.

        ## Get tick distance on raw axes.
        ## This will be the tick distance for the contrast axes.
        rawAxesTicks = fig.get_axes()[i-1].yaxis.get_majorticklocs()
        rawAxesTickDist = rawAxesTicks[1] - rawAxesTicks[0]

        ## First re-draw of axis with new tick interval
        axx.yaxis.set_major_locator(MultipleLocator(rawAxesTickDist))
        newticks1 = fig.get_axes()[i].get_yticks()

        if floatContrast is False:
            if (showAllYAxes is False and i in range( 2, len(fig.get_axes())) ):
                axx.get_yaxis().set_visible(showAllYAxes)
            else:
                ## Obtain major ticks that comfortably encompass lower and upper.
                newticks2 = list()
                for a,b in enumerate(newticks1):
                    if (b >= lower and b <= upper):
                        # if the tick lies within upper and lower, take it.
                        newticks2.append(b)
                # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
                if np.max(newticks2) < meandiff:
                    ind = np.where(newticks1 == np.max(newticks2))[0][0] # find out the max tick index in newticks1.
                    newticks2.append( newticks1[ind+1] )
                elif meandiff < np.min(newticks2):
                    ind = np.where(newticks1 == np.min(newticks2))[0][0] # find out the min tick index in newticks1.
                    newticks2.append( newticks1[ind-1] )
                newticks2 = np.array(newticks2)
                newticks2.sort()
                axx.yaxis.set_major_locator(FixedLocator(locs = newticks2))

                ## Draw zero reference line.
                axx.hlines(y = 0,
                    xmin = fig.get_axes()[i].get_xaxis().get_view_interval()[0], 
                    xmax = fig.get_axes()[i].get_xaxis().get_view_interval()[1],
                    linestyle = contrastZeroLineStyle,
                    linewidth = 0.75,
                    color = contrastZeroLineColor)

                sns.despine(ax = fig.get_axes()[i], trim = True, 
                    bottom = False, right = True,
                    left = False, top = True)

                ## Draw back the lines for the relevant y-axes.
                drawback_y(axx)

                ## Draw back the lines for the relevant x-axes.
                drawback_x(axx)

        elif floatContrast is True:
            ## Get the original ticks on the floating y-axis.
            newticks1 = fig.get_axes()[i].get_yticks()

            ## Obtain major ticks that comfortably encompass lower and upper.
            newticks2 = list()
            for a,b in enumerate(newticks1):
                if (b >= lower and b <= upper):
                    # if the tick lies within upper and lower, take it.
                    newticks2.append(b)
            # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
            if np.max(newticks2) < meandiff:
                ind = np.where(newticks1 == np.max(newticks2))[0][0] # find out the max tick index in newticks1.
                newticks2.append( newticks1[ind+1] )
            elif meandiff < np.min(newticks2):
                ind = np.where(newticks1 == np.min(newticks2))[0][0] # find out the min tick index in newticks1.
                newticks2.append( newticks1[ind-1] )
            newticks2 = np.array(newticks2)
            newticks2.sort()

            ## Re-draw the axis.
            axx.yaxis.set_major_locator(FixedLocator(locs = newticks2)) 

            ## Despine and trim the axes.
            sns.despine(ax = axx, trim = True, 
                bottom = False, right = False,
                left = True, top = True)

    for i in range(0, len(fig.get_axes()), 2):
        # Loop through the raw data swarmplots and despine them appropriately.
        if floatContrast is True:
            sns.despine(ax = fig.get_axes()[i], trim = True, right = True)

        else:
            sns.despine(ax = fig.get_axes()[i], trim = True, bottom = True, right = True)
            fig.get_axes()[i].get_xaxis().set_visible(False)

        # Draw back the lines for the relevant y-axes.
        ymin = fig.get_axes()[i].get_yaxis().get_majorticklocs()[0]
        ymax = fig.get_axes()[i].get_yaxis().get_majorticklocs()[-1]
        x, _ = fig.get_axes()[i].get_xaxis().get_view_interval()
        fig.get_axes()[i].add_artist(Line2D((x, x), (ymin, ymax), color='black', linewidth=1.5))    

    # Zero gaps between plots on the same row, if floatContrast is False
    if (floatContrast is False and showAllYAxes is False):
        gsMain.update(wspace = 0)
    else:    
        # Tight Layout!
        gsMain.tight_layout(fig)

    # And we're done.
    rcdefaults() # restore matplotlib defaults.
    sns.set() # restore seaborn defaults.
    return fig, contrastList
コード例 #31
0
ファイル: hcp_plot.py プロジェクト: BigR-Lab/modl
def display_explained_variance_density(output_dir):
    dir_list = [join(output_dir, f) for f in os.listdir(output_dir) if
                os.path.isdir(join(output_dir, f))]

    fig = plt.figure(figsize=(fig_width * 0.73, fig_height))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])
    fig.subplots_adjust(bottom=0.29)
    fig.subplots_adjust(left=0.075)
    fig.subplots_adjust(right=.92)

    results = []
    analyses = []
    ref_time = 1000000
    for dir_name in dir_list:
        try:
            analyses.append(
                json.load(open(join(dir_name, 'analysis.json'), 'r')))
            results.append(
                json.load(open(join(dir_name, 'results.json'), 'r')))
            if results[-1]['reduction'] == 12:
                timings = np.array(results[-1]['timings'])
                diff = timings[1:] - timings[:1]
                ref_time = min(ref_time, np.min(diff))
        except IOError:
            pass
    print(ref_time)
    h_reductions = []
    ax = {}
    ylim = {1e-2: [2.455e8, 2.525e8], 1e-3: [2.3e8, 2.47e8],
            1e-4: [2.16e8, 2.42e8]}
    for i, alpha in enumerate([1e-3, 1e-4]):
        ax[alpha] = fig.add_subplot(gs[:, i])
        if i == 0:
            ax[alpha].set_ylabel('Objective value on test set')
        ax[alpha].annotate('$\\lambda  = 10^{%.0f}$' % log(alpha, 10),
                           xy=(.65, .85),
                           fontsize=8,
                           xycoords='axes fraction')
        ax[alpha].set_xlim([.05, 200])
        ax[alpha].set_ylim(ylim[alpha])

        for tick in ax[alpha].xaxis.get_major_ticks():
            tick.label.set_fontsize(7)
        ax[alpha].set_xscale('log')

        ax[alpha].set_xticks([.1, 1, 10, 100])
        ax[alpha].set_xticklabels(['.1 h', '1 h', '10 h', '100 h'])

        sns.despine(fig=fig, ax=ax[alpha])

        ax[alpha].spines['left'].set_color((.6, .6, .6))
        ax[alpha].spines['bottom'].set_color((.6, .6, .6))
        ax[alpha].xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[alpha].yaxis.set_tick_params(color=(.6, .6, .6), which='both')
        for tick in ax[alpha].xaxis.get_major_ticks():
            tick.label.set_color('black')
        for tick in ax[alpha].yaxis.get_major_ticks():
            tick.label.set_fontsize(6)

            tick.label.set_color('black')
        t = ax[alpha].yaxis.get_offset_text()
        t.set_size(5)
    ax[1e-4].set_xlabel('CPU\ntime', ha='right')
    ax[1e-4].xaxis.set_label_coords(1.15, -0.05)

    colormap = sns.cubehelix_palette(4, start=0, rot=0., hue=1, dark=.3,
                                     light=.7,
                                     reverse=False)
    other_colormap = sns.cubehelix_palette(4, start=0, rot=.5, hue=1, dark=.3,
                                           light=.7,
                                           reverse=False)
    colormap[0] = other_colormap[0]
    colormap_dict = {reduction: color for reduction, color in
                     zip([1, 4, 8, 12],
                         colormap)}

    x_bar = []
    y_bar_objective = []
    y_bar_density = []
    hue_bar = []

    for result, analysis in zip(results, analyses):
        if result['alpha'] != 1e-2 and result['reduction'] != 2:
            print("%s %s" % (result['alpha'], result['reduction']))
            timings = (np.array(analysis['records']) + 1) / int(
                result['reduction']) * 12 * ref_time / 3600
            # timings = np.array(result['timings'])[np.array(analysis['records']) + 1] / 3600
            s, = ax[result[
                'alpha']].plot(
                timings,
                np.array(analysis['objectives']) / 4,
                color=colormap_dict[int(result['reduction'])],
                linewidth=2,
                linestyle='--' if result[
                                      'reduction'] == 1 else '-',
                zorder=result['reduction'] if result[
                                                  'reduction'] != 1 else 100)
            if result['alpha'] == 1e-3:
                h_reductions.append(
                    (s, '%.0f' % result['reduction']))

    handles, labels = list(zip(*h_reductions[::-1]))
    argsort = sorted(range(len(labels)), key=lambda t: int(labels[t]))
    handles = [handles[i] for i in argsort]
    labels = [labels[i] for i in argsort]

    offset = .3
    yoffset = -.05
    legend_vanilla = mlegend.Legend(ax[1e-3], handles[:1], ['No reduction'],
                                    loc='lower left',
                                    ncol=5,
                                    numpoints=1,
                                    handlelength=2,
                                    markerscale=1.4,
                                    bbox_to_anchor=(
                                        0.3 + offset, -.39 + yoffset),
                                    fontsize=8,
                                    frameon=False
                                    )

    legend_ratio = mlegend.Legend(ax[1e-3], handles[1:], labels[1:],
                                  loc='lower left',
                                  ncol=5,
                                  markerscale=1.4,
                                  handlelength=2,
                                  fontsize=8,
                                  bbox_to_anchor=(
                                      0.3 + offset, -.54 + yoffset),
                                  frameon=False
                                  )
    ax[1e-3].annotate('Original online algorithm',
                      xy=(0.28 + offset, -.27 + yoffset),
                      xycoords='axes fraction',
                      horizontalalignment='right', verticalalignment='bottom',
                      fontsize=8)
    ax[1e-3].annotate('Proposed reduction factor $r$',
                      xy=(0.28 + offset, -.42 + yoffset),
                      xycoords='axes fraction',
                      horizontalalignment='right', verticalalignment='bottom',
                      fontsize=8)
    ax[1e-3].add_artist(legend_ratio)
    ax[1e-3].add_artist(legend_vanilla)

    ax[1e-3].annotate('(a) Convergence speed', xy=(0.7, 1.02), ha='center',
                      fontsize=9, va='bottom', xycoords='axes fraction')

    fig.savefig(join(output_dir, 'hcp_bench.pdf'))

    for result, analysis in zip(results, analyses):
        if result['alpha'] != 1e-2 and result['reduction'] != 2:
            x_bar.append(result['alpha'])
            y_bar_objective.append(analysis['objectives'][-1])
            y_bar_density.append(analysis['densities'][-1])
            hue_bar.append(result['reduction'])
    ref_objective = {}
    for objective, alpha, reduction in zip(y_bar_objective, x_bar, hue_bar):
        if reduction == 1:
            ref_objective[alpha] = objective

    for i, (objective, alpha) in enumerate(zip(y_bar_objective, x_bar)):
        y_bar_objective[i] /= ref_objective[alpha]
        y_bar_objective[i] -= 1

    ####################### Final objective
    fig = plt.figure(figsize=(fig_width * 0.27, fig_height))
    fig.subplots_adjust(bottom=0.29)
    fig.subplots_adjust(left=0.05)
    fig.subplots_adjust(right=1.2)
    fig.subplots_adjust(top=0.85)
    gs = gridspec.GridSpec(2, 1, width_ratios=[1, 1], height_ratios=[1.2, 0.8])
    ax_bar_objective = fig.add_subplot(gs[0])
    ax_bar_objective.set_ylim(-0.007, 0.007)
    ax_bar_objective.set_yticks([-0.005, 0, 0.005])
    ax_bar_objective.set_yticklabels(['-0.5\%', '0\%', '0.5\%'])
    ax_bar_objective.tick_params(axis='y', labelsize=6)

    sns.despine(fig=fig, ax=ax_bar_objective, left=True, right=False)

    sns.barplot(x=x_bar, y=y_bar_objective, hue=hue_bar, ax=ax_bar_objective,
                order=[1e-3, 1e-4],
                palette=colormap)
    plt.setp(ax_bar_objective.patches, linewidth=0.1)
    ax_bar_objective.legend_ = None
    ax_bar_objective.get_xaxis().set_visible(False)
    ax_bar_objective.set_xlim([-.5, 1.6])
    ax_bar_objective.annotate('Final\nobjective\ndeviation\n(relative)',
                              xy=(1.28, 0.45), fontsize=7, va='center',
                              xycoords='axes fraction')
    ax_bar_objective.annotate('(Less is better)', xy=(.06, 0.1), fontsize=7,
                              va='center', xycoords='axes fraction')
    ax_bar_objective.yaxis.set_label_position('right')

    ################################## Density
    x_bar = []
    y_bar_density = []
    hue_bar = []
    for result, analysis in zip(results, analyses):
        if result['alpha'] != 1e-2 and result['reduction'] != 2:
            x_bar.append(result['alpha'])
            y_bar_density.append(analysis['densities'][-1])
            hue_bar.append(result['reduction'])

    ax_bar_density = fig.add_subplot(gs[1])
    ax_bar_density.set_yscale('log')
    ax_bar_density.set_ylim(100, 1000)
    ax_bar_density.set_yticks([100, 1000])
    ax_bar_density.set_yticklabels(['100', '1000'])
    ax_bar_density.tick_params(axis='y', labelsize=6)

    sns.barplot(x=x_bar, y=y_bar_density, hue=hue_bar, ax=ax_bar_density,
                order=[1e-3, 1e-4],
                palette=colormap)
    ax_bar_density.set_xticklabels(['$10^{-2}$', '$10^{-3}$', '$10^{-4}$'])
    sns.despine(fig=fig, ax=ax_bar_density, left=True, right=False)
    # ax_bar_density.get_xaxis().set_ticks([])
    ax_bar_density.set_xlim([-.5, 1.6])
    ax_bar_density.set_xlabel('Regularization $\\lambda$')
    ax_bar_density.annotate('$\\frac{\\ell_1}{\\ell_2}(\\mathbf D)$',
                            xy=(1.26, 0.45),
                            fontsize=7, va='center', xycoords='axes fraction')
    ax_bar_density.yaxis.set_label_position('right')

    plt.setp(ax_bar_density.patches, linewidth=0.1)
    ax_bar_density.legend_ = None

    for ax in [ax_bar_density, ax_bar_objective]:
        ax.spines['right'].set_color((.6, .6, .6))
        ax.spines['bottom'].set_color((.6, .6, .6))
        ax.xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax.yaxis.set_tick_params(color=(.6, .6, .6), which='both')

    for tic in ax_bar_density.xaxis.get_major_ticks():
        tic.tick1On = tic.tick2On = False
    ax_bar_objective.spines['bottom'].set_position(('data', 0))
    ax_bar_objective.spines['bottom'].set_linewidth(.3)
    ax_bar_objective.annotate('(b) Decomposition quality', xy=(0.7, 1.21),
                              ha='center', va='bottom', fontsize=9,
                              xycoords='axes fraction')

    fig.savefig(expanduser(join(output_dir, 'bar_plot.pdf')))
コード例 #32
0
ファイル: hcp_plot.py プロジェクト: BigR-Lab/modl
def display_explained_variance_epoch(output_dir):
    dir_list = [join(output_dir, f) for f in os.listdir(output_dir) if
                os.path.isdir(join(output_dir, f))]

    fig = plt.figure()
    gs = gridspec.GridSpec(1, 1, width_ratios=[1])
    fig.set_figwidth(3.25653379549)
    fig.set_figheight(1.3)
    fig.subplots_adjust(bottom=0.105)
    fig.subplots_adjust(top=0.9)
    fig.subplots_adjust(left=0.12)
    fig.subplots_adjust(right=.95)

    results = []
    analyses = []
    ref_time = 1000000
    for dir_name in dir_list:
        try:
            analyses.append(
                json.load(open(join(dir_name, 'analysis.json'), 'r')))
            results.append(
                json.load(open(join(dir_name, 'results.json'), 'r')))
            if results[-1]['reduction'] == 12:
                timings = np.array(results[-1]['timings'])
                diff = timings[1:] - timings[:1]
                ref_time = min(ref_time, np.min(diff))
        except IOError:
            pass
    h_reductions = []
    ax = {}
    ylim = {1e-2: [2.475e8, 2.522e8], 1e-3: [2.3e8, 2.335e8],
            1e-4: [2.16e8, 2.24e8]}
    for i, alpha in enumerate([1e-4]):
        ax[alpha] = fig.add_subplot(gs[:, i])
        if i == 0:
            ax[alpha].set_ylabel('Objective value on test set')
        ax[alpha].set_xlim([50, 4000])

        for tick in ax[alpha].xaxis.get_major_ticks():
            tick.label.set_fontsize(7)
        ax[alpha].set_xscale('log')

        ax[alpha].set_xticks([100, 1000, 1947, 4000])
        ax[alpha].set_xticklabels(['100', '1000', 'Epoch', '4000'])


        ax[alpha].set_ylim(ylim[alpha])
        sns.despine(fig=fig, ax=ax[alpha])

        ax[alpha].spines['left'].set_color((.6, .6, .6))
        ax[alpha].spines['bottom'].set_color((.6, .6, .6))
        ax[alpha].xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[alpha].yaxis.set_tick_params(color=(.6, .6, .6), which='both')
        for tick in ax[alpha].xaxis.get_major_ticks():
            tick.label.set_color('black')
        for tick in ax[alpha].yaxis.get_major_ticks():
            tick.label.set_fontsize(7)

            tick.label.set_color('black')
        t = ax[alpha].yaxis.get_offset_text()
        t.set_size(6)
    ax[1e-4].set_xlabel('Records')
    ax[1e-4].xaxis.set_label_coords(-0.04, -0.047)

    colormap = sns.cubehelix_palette(4, start=0, rot=0., hue=1, dark=.3,
                                     light=.7,
                                     reverse=False)

    other_colormap = sns.cubehelix_palette(4, start=0, rot=.5, hue=1, dark=.3,
                                           light=.7,
                                           reverse=False)
    colormap[0] = other_colormap[0]
    colormap_dict = {reduction: color for reduction, color in
                     zip([1, 4, 8, 12],
                         colormap)}
    for result, analysis in zip(results, analyses):
        if result['alpha'] in [1e-4] and result['reduction'] in [1, 4, 8, 12]:

            print("%s %s" % (result['alpha'], result['reduction']))
            s, = ax[result[
                'alpha']].plot(np.array(analysis['records']),
                               np.array(analysis['objectives']) / 4,
                               color=colormap_dict[result['reduction']],
                               linewidth=1.5,
                               linestyle='--' if result[
                                                     'reduction'] == 1 else '-',
                               zorder=result['reduction'] if result[
                                                                    'reduction'] > 1 else 100)
            h_reductions.append(
                (s, result['reduction']))

    handles, labels = list(zip(*h_reductions[::-1]))
    argsort = sorted(range(len(labels)), key=lambda t: int(labels[t]))
    handles = [handles[i] for i in argsort]
    labels = [('$r=%i$' % labels[i]) for i in argsort]
    labels[0] = 'No reduction\n(original alg.)'

    ax[1e-4].annotate('$\\lambda  = 10^{%.0f}$' % log(alpha, 10),
                       xy=(0.07, 0.07),
                       ha='left',
                       va='bottom',
                       fontsize=8,
                       xycoords='axes fraction')
    legend_ratio = mlegend.Legend(ax[1e-4], handles[0:], labels[0:],
                                  loc='upper right',
                                  ncol=1,
                                  numpoints=1,
                                  handlelength=2,
                                  frameon=False,
                                  bbox_to_anchor=(1, 1.15)
                                  )
    ax[1e-4].add_artist(legend_ratio)

    fig.savefig(join(output_dir, 'hcp_epoch.pdf'))
コード例 #33
0
color_dict[1] = ref_colormap[0]

for method, sub_data in data.groupby('method'):
    for _, this_data in sub_data.iterrows():
        if method == 'sgd':
            label = 'SGD (best step-size)'
            color = sgd_colormap[0]
        else:
            reduction = this_data['reduction']
            color = color_dict[reduction]
            if reduction == 1:
                label = 'Online matrix factorization'
            else:
                label = 'SOMF ($r = %i$)' % reduction
        time = np.array(this_data['time'])
        # Offset log
        time += 4
        ax.plot(this_data['time'], this_data['score'],
                label=label,
                color=color,
                linestyle='-')
ax.set_xscale('log')
ax.set_ylabel('Test objective value')
ax.yaxis.set_label_coords(-0.13, 0.38)
ax.set_xlabel('Time')
ax.ticklabel_format(axis='y', style='sci', scilimits=(-2, 2))
ax.legend(frameon=False, loc='upper right', bbox_to_anchor=(1., 1.))
sns.despine(fig, ax)
plt.savefig(join(analysis_dir, 'bench.pdf'))
plt.show()
def main():
    hidden_grps = df.loadExptGrps('GOL')

    WT_expt_grp_hidden = hidden_grps['WT_place_set']
    Df_expt_grp_hidden = hidden_grps['Df_place_set']
    expt_grps_hidden = [WT_expt_grp_hidden, Df_expt_grp_hidden]

    acute_grps = df.loadExptGrps('RF')

    WT_expt_grp_acute = acute_grps['WT_place_set'].unpair()
    Df_expt_grp_acute = acute_grps['Df_place_set'].unpair()
    expt_grps_acute = [WT_expt_grp_acute, Df_expt_grp_acute]

    WT_label = WT_expt_grp_hidden.label()
    Df_label = Df_expt_grp_hidden.label()
    labels = [WT_label, Df_label]

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.1, right=0.20)
    across_ctx_ax = fig.add_subplot(gs1[0, 0])

    gs2 = plt.GridSpec(3, 1, top=0.9, bottom=0.7, left=0.25, right=0.35)
    wt_pie_ax = fig.add_subplot(gs2[0, 0])
    df_pie_ax = fig.add_subplot(gs2[1, 0])
    shuffle_pie_ax = fig.add_subplot(gs2[2, 0])
    pie_axs = (wt_pie_ax, df_pie_ax, shuffle_pie_ax)

    gs3 = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.4, right=0.5)
    cue_cell_bar_ax = fig.add_subplot(gs3[0, 0])

    gs5 = plt.GridSpec(1, 1, top=0.5, bottom=0.3, left=0.1, right=0.3)
    acute_stability_ax = fig.add_subplot(gs5[0, 0])

    acute_stability_inset_ax = fig.add_axes([0.23, 0.32, 0.05, 0.08])

    gs6 = plt.GridSpec(1, 1, top=0.5, bottom=0.3, left=0.4, right=0.5)
    task_compare_ax = fig.add_subplot(gs6[0, 0])

    #
    # RF Compare
    #
    params = {}
    params['filename'] = filename

    params_cent_shift_pc = {}
    params_cent_shift_pc['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_pc['stability_kwargs'] = {
        'activity_filter': 'pc_both',
        'circ_var_pcs': False,
        'units': 'norm',
        'shuffle': True
    }
    params_cent_shift_pc['stability_label'] = \
        'Centroid shift (fraction of belt)'

    params_cent_shift_all = {}
    params_cent_shift_all['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_all['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'norm',
        'shuffle': True
    }
    params_cent_shift_all['stability_label'] = \
        'Centroid shift (fraction of belt)'
    params_cent_shift_all['stability_inset_ylim'] = (0.15, 0.30)
    params_cent_shift_all['stability_cdf_range'] = (0.15, 0.35)
    params_cent_shift_all['stability_cdf_ticks'] = \
        (0.15, 0.20, 0.25, 0.30, 0.35)
    params_cent_shift_all['stability_compare_ylim'] = (0.15, 0.27)
    params_cent_shift_all['stability_compare_yticks'] = (0.15, 0.20, 0.25)
    params_cent_shift_all['ctx_compare_ylim'] = (0.10, 0.30)
    params_cent_shift_all['ctx_compare_yticks'] = \
        (0.10, 0.15, 0.20, 0.25, 0.30)

    params_cent_shift_cm = {}
    params_cent_shift_cm['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_cm['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'cm',
        'shuffle': True
    }
    params_cent_shift_cm['stability_label'] = 'Centroid shift (cm)'

    params_pop_vect_corr = {}
    params_pop_vect_corr['stability_fn'] = place.population_vector_correlation
    params_pop_vect_corr['stability_kwargs'] = {
        'method': 'corr',
        'activity_filter': 'pc_both',
        'min_pf_density': 0.05,
        'circ_var_pcs': False
    }
    params_pop_vect_corr['stability_label'] = 'Population vector correlation'

    params_pf_corr = {}
    params_pf_corr['stability_fn'] = place.place_field_correlation
    params_pf_corr['stability_kwargs'] = {'activity_filter': 'pc_either'}
    params_pf_corr['stability_label'] = 'Place field correlation'
    params_pf_corr['stability_inset_ylim'] = (0, 0.50)
    params_pf_corr['stability_cdf_range'] = (0.15, 0.55)
    params_pf_corr['stability_cdf_ticks'] = (0.15, 0.25, 0.35, 0.45, 0.55)
    params_pf_corr['stability_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['stability_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)
    params_pf_corr['hidden_ctx_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['hidden_ctx_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)
    params_pf_corr['ctx_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['ctx_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)

    params.update(params_cent_shift_all)

    day_paired_grps_acute = [
        grp.pair('consecutive groups', groupby=['day_in_df'])
        for grp in expt_grps_acute
    ]
    paired_grps_acute = day_paired_grps_acute
    paired_grps_hidden = [
        grp.pair('consecutive groups', groupby=['X_condition', 'X_day'])
        for grp in expt_grps_hidden
    ]

    filter_fn = lambda df: (df['expt_pair_label'] == 'SameAll')
    filter_columns = ['expt_pair_label']

    acute_stability = plotting.plot_metric(
        acute_stability_ax,
        paired_grps_acute,
        metric_fn=params['stability_fn'],
        groupby=[['expt_pair_label', 'second_expt']],
        plotby=None,
        plot_method='cdf',
        plot_abs=True,
        roi_filters=roi_filters,
        activity_kwargs=params['stability_kwargs'],
        plot_shuffle=True,
        shuffle_plotby=False,
        pool_shuffle=True,
        activity_label=params['stability_label'],
        colors=colors,
        rotate_labels=False,
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        return_full_dataframes=False,
        linestyles=linestyles)
    acute_stability_ax.set_xlabel(params['stability_label'])
    acute_stability_ax.set_title('')
    sns.despine(ax=acute_stability_ax)
    acute_stability_ax.set_xlim(params['stability_cdf_range'])
    acute_stability_ax.set_xticks(params['stability_cdf_ticks'])
    acute_stability_ax.legend(loc='upper left', fontsize=6)

    plotting.plot_metric(acute_stability_inset_ax,
                         paired_grps_acute,
                         metric_fn=params['stability_fn'],
                         groupby=[['second_expt'], ['second_mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         plot_abs=True,
                         roi_filters=roi_filters,
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label=params['stability_label'],
                         colors=colors,
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         linewidth=0.2,
                         edgecolor='gray',
                         plot_shuffle_as_hline=True)
    acute_stability_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=acute_stability_inset_ax)
    acute_stability_inset_ax.set_title('')
    acute_stability_inset_ax.set_ylabel('')
    acute_stability_inset_ax.set_xlabel('')
    acute_stability_inset_ax.tick_params(bottom=False, labelbottom=False)
    acute_stability_inset_ax.set_ylim(params['stability_inset_ylim'])
    acute_stability_inset_ax.set_yticks(params['stability_inset_ylim'])

    tmp_fig = plt.figure()
    tmp_ax = tmp_fig.add_subplot(111)
    hidden_stability = plotting.plot_metric(
        tmp_ax,
        paired_grps_hidden,
        metric_fn=params['stability_fn'],
        groupby=[['expt_pair_label', 'second_expt']],
        plotby=('expt_pair_label', ),
        plot_method='line',
        plot_abs=True,
        roi_filters=roi_filters,
        activity_kwargs=params['stability_kwargs'],
        plot_shuffle=True,
        shuffle_plotby=False,
        pool_shuffle=True,
        activity_label=params['stability_label'],
        colors=colors,
        rotate_labels=False,
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        return_full_dataframes=False)
    plt.close(tmp_fig)

    wt_acute = acute_stability[WT_label]['dataframe']
    wt_acute_shuffle = acute_stability[WT_label]['shuffle']
    df_acute = acute_stability[Df_label]['dataframe']
    df_acute_shuffle = acute_stability[Df_label]['shuffle']

    wt_hidden = hidden_stability[WT_label]['dataframe']
    wt_hidden_shuffle = hidden_stability[WT_label]['shuffle']
    df_hidden = hidden_stability[Df_label]['dataframe']
    df_hidden_shuffle = hidden_stability[Df_label]['shuffle']

    for dataframe in (wt_acute, wt_acute_shuffle, df_acute, df_acute_shuffle):
        dataframe['task'] = 'RF'

    for dataframe in (wt_hidden, wt_hidden_shuffle, df_hidden,
                      df_hidden_shuffle):
        dataframe['task'] = 'GOL'

    WT_data = wt_acute.append(wt_hidden, ignore_index=True)
    Df_data = df_acute.append(df_hidden, ignore_index=True)

    WT_shuffle = wt_acute_shuffle.append(wt_hidden_shuffle, ignore_index=True)
    Df_shuffle = df_acute_shuffle.append(df_hidden_shuffle, ignore_index=True)

    filter_columns = ('expt_pair_label', )
    filter_fn = lambda df: (df['expt_pair_label'] == 'SameAll')

    order_dict = {'RF': 0, 'GOL': 1}
    WT_data['order'] = WT_data['task'].map(order_dict)
    Df_data['order'] = Df_data['task'].map(order_dict)
    line_kwargs = {'markersize': 4}
    plotting.plot_dataframe(task_compare_ax, [WT_data, Df_data],
                            [WT_shuffle, Df_shuffle],
                            labels=labels,
                            activity_label='',
                            groupby=[['task', 'second_mouseID']],
                            plotby=('task', ),
                            plot_method='box_and_line',
                            colors=colors,
                            filter_fn=filter_fn,
                            filter_columns=filter_columns,
                            plot_shuffle=True,
                            shuffle_plotby=False,
                            pool_shuffle=True,
                            orderby='order',
                            notch=False,
                            plot_shuffle_as_hline=True,
                            markers=markers,
                            linestyles=linestyles,
                            line_kwargs=line_kwargs,
                            flierprops={
                                'markersize': 3,
                                'marker': 'o'
                            },
                            whis='range')
    task_compare_ax.set_title('')
    sns.despine(ax=task_compare_ax)
    task_compare_ax.set_ylim(params['stability_compare_ylim'])
    task_compare_ax.set_yticks(params['stability_compare_yticks'])
    task_compare_ax.set_xlabel('')
    task_compare_ax.set_ylabel(params['stability_label'])
    task_compare_ax.legend(loc='upper right', fontsize=6)

    #
    # Stability across transition
    #
    groupby = [['second_expt'], ['second_mouse']]
    filter_fn = lambda df: (df['X_first_condition'] == 'A') \
        & (df['X_second_condition'] == 'B')
    filter_columns = ('X_first_condition', 'X_second_condition')
    plotting.plot_metric(across_ctx_ax,
                         paired_grps_hidden,
                         metric_fn=params['stability_fn'],
                         groupby=groupby,
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         colors=colors,
                         activity_label=params['stability_label'],
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         plot_shuffle_as_hline=True,
                         return_full_dataframes=False,
                         plot_bar=True,
                         roi_filters=roi_filters)
    sns.despine(ax=across_ctx_ax)
    across_ctx_ax.set_ylim(0.0, 0.3)
    across_ctx_ax.set_yticks([0.0, 0.1, 0.2, 0.3])
    across_ctx_ax.set_xticklabels([])
    across_ctx_ax.set_xlabel('')
    across_ctx_ax.set_title('')
    across_ctx_ax.get_legend().set_visible(False)
    plotting.stackedText(across_ctx_ax, labels, colors=colors, loc=2, size=10)

    #
    # Cue remapping
    #

    THRESHOLD = 0.05 * 2 * np.pi
    CUENESS_THRESHOLD = 0.33

    def first_cue_position(row):
        expt = row['first_expt']
        cue = row['cue']
        cues = expt.belt().cues(normalized=True)
        first_cue = cues.ix[cues['cue'] == cue]
        pos = (first_cue['start'] + first_cue['stop']) / 2
        angle = pos * 2 * np.pi
        return np.complex(np.cos(angle), np.sin(angle))

    def dotproduct(v1, v2):
        return sum((a * b) for a, b in zip(v1, v2))

    def length(v):
        return math.sqrt(dotproduct(v, v))

    def angle(v1, v2):
        return math.acos(
            np.round(dotproduct(v1, v2) / (length(v1) * length(v2)), 3))

    def distance_to_first_cue(row):
        centroid = row['second_centroid']
        pos = row['first_cue_position']
        return angle((pos.real, pos.imag), (centroid.real, centroid.imag))

    WT_copy = copy(WT_expt_grp_hidden)
    WT_copy.filterby(lambda df: ~df['X_condition'].str.contains('C'),
                     ['X_condition'])
    WT_paired = WT_copy.pair('consecutive groups',
                             groupby=['X_condition',
                                      'X_day']).pair('consecutive groups',
                                                     groupby=['X_condition'])

    Df_copy = copy(Df_expt_grp_hidden)
    Df_copy.filterby(lambda df: ~df['X_condition'].str.contains('C'),
                     ['X_condition'])
    Df_paired = Df_copy.pair('consecutive groups',
                             groupby=['X_condition',
                                      'X_day']).pair('consecutive groups',
                                                     groupby=['X_condition'])

    WT_df, WT_shuffle_df = place.cue_cell_remapping(
        WT_paired,
        roi_filter=WT_filter,
        near_threshold=THRESHOLD,
        activity_filter='active_both',
        circ_var_pcs=False,
        shuffle=True)
    Df_df, Df_shuffle_df = place.cue_cell_remapping(
        Df_paired,
        roi_filter=Df_filter,
        near_threshold=THRESHOLD,
        activity_filter='active_both',
        circ_var_pcs=False,
        shuffle=True)

    shuffle_df = pd.concat([WT_shuffle_df, Df_shuffle_df], ignore_index=True)

    cueness, cueness_fraction = [], []
    cue_n, place_n, neither_n = [], [], []

    for grp_df in (WT_df, Df_df, shuffle_df):

        grp_df['first_cue_position'] = grp_df.apply(first_cue_position, axis=1)

        grp_df['second_distance_to_first_cue_position'] = grp_df.apply(
            distance_to_first_cue, axis=1)

        grp_df['cueness'] = grp_df['second_distance_to_first_cue_position'] / \
            (grp_df['value'] + grp_df['second_distance_to_first_cue_position'])

        plotting.prepare_dataframe(grp_df, ['first_mouse'])
        cueness_fraction.append([[]])
        cue_n.append([])
        place_n.append([])
        neither_n.append([])

        for mouse, mouse_df in grp_df.groupby('first_mouse'):
            cue_n[-1].append((mouse_df['cueness'] >
                              (1 - CUENESS_THRESHOLD)).sum())
            place_n[-1].append((mouse_df['cueness'] < CUENESS_THRESHOLD).sum())
            neither_n[-1].append(mouse_df.shape[0] - cue_n[-1][-1] -
                                 place_n[-1][-1])
            cueness_fraction[-1][0].append(cue_n[-1][-1] /
                                           float(place_n[-1][-1]))
        cueness.append([grp_df['cueness']])

    cue_labels = labels + ['shuffle']

    plotting.swarm_plot(cue_cell_bar_ax,
                        cueness_fraction[:2],
                        condition_labels=labels,
                        colors=colors,
                        plot_bar=True)
    cue_cell_bar_ax.axhline(np.mean(cueness_fraction[-1][0]),
                            ls='--',
                            color='k')
    sns.despine(ax=cue_cell_bar_ax)
    cue_cell_bar_ax.set_ylim(0, 1.5)
    cue_cell_bar_ax.set_yticks([0, 0.5, 1.0, 1.5])
    cue_cell_bar_ax.set_xticklabels([])
    cue_cell_bar_ax.set_xlabel('')
    cue_cell_bar_ax.set_ylabel('Cue-to-position ratio')
    cue_cell_bar_ax.get_legend().set_visible(False)
    plotting.stackedText(cue_cell_bar_ax,
                         labels,
                         colors=colors,
                         loc=2,
                         size=10)

    WT_colors = sns.light_palette(WT_color, 7)[:-6:-2]
    Df_colors = sns.light_palette(Df_color, 7)[:-6:-2]
    shuffle_colors = sns.light_palette('k', 7)[:-6:-2]
    pie_colors = (WT_colors, Df_colors, shuffle_colors)
    pie_labels = ['cue', 'position', 'neither']
    orig_size = mpl.rcParams.get('xtick.labelsize')
    mpl.rcParams['xtick.labelsize'] = 5
    for grp_ax, grp_label, grp_cue_n, grp_place_n, grp_neither_n, p_cs in zip(
            pie_axs, cue_labels, cue_n, place_n, neither_n, pie_colors):
        grp_ax.pie([sum(grp_cue_n),
                    sum(grp_place_n),
                    sum(grp_neither_n)],
                   autopct='%1.0f%%',
                   shadow=False,
                   frame=False,
                   labels=pie_labels,
                   colors=p_cs,
                   textprops={'fontsize': 5})
        grp_ax.set_title(grp_label)
        plotting.square_axis(grp_ax)
    mpl.rcParams['xtick.labelsize'] = orig_size

    misc.save_figure(fig, params['filename'], save_dir=save_dir)

    plt.close('all')
コード例 #35
0
dcd_file = sys.argv[3]
dat_file = sys.argv[4] + '.dat'
out_file = sys.argv[4] + '.png'

u = mda.Universe(psf_file, dcd_file)
ref = mda.Universe(psf_file, pdb_file) # default the 0th frame

R = MDAnalysis.analysis.rms.RMSD(u, ref, select = "backbone", filename=dat_file)

R.run()
R.save()

rmsd = R.rmsd.T
time = rmsd[1]
import seaborn.apionly as sns
#matplotlib inline
plt.style.use('ggplot')
rcParams.update({'figure.autolayout': True})
sns.set_style('ticks')
fig = plt.figure(figsize=(5,3))
ax = fig.add_subplot(111)
color = sns.color_palette()[2]
#ax.fill_between(ca.residues.resids, rmsf, alpha=0.3, color=color)
ax.plot(time, rmsd[2], lw=1, color=color)
sns.despine(ax=ax)
ax.set_xlabel("Time (ps)")
ax.set_ylabel(r"RMSD ($\AA$)")
ax.set_xlim(0, max(time))
ax.set_ylim(0, 10)
fig.savefig(out_file)
コード例 #36
0
ファイル: resultsPlotter.py プロジェクト: YesperP/Master
def _plotInitializationTime2D(plotData, loadFilePath, simLength, timeStep, nTargets):
    timeArray = np.arange(0, simLength, timeStep)
    for M_init, d1 in plotData.items():
        for N_init, d2 in d1.items():
            figure = plt.figure(figsize=(figureWidth, fullPageHeight), dpi=600)

            ax11 = figure.add_subplot(311)
            ax12 = figure.add_subplot(312)
            ax13 = figure.add_subplot(313)

            sns.set_style(style='white')
            savePath = _getSavePath(loadFilePath, "Time({0:}-{1:})".format(M_init, N_init))
            cpfmList = []
            falseCPFMlist = []
            accFalseTrackList = []
            for k, (lambda_phi, d3) in enumerate(d2.items()):
                for j, (P_d, (correctInitTimeLog, falseInitTimeLog)) in enumerate(d3.items()):
                    falsePFM = np.zeros_like(timeArray)
                    pmf = np.zeros_like(timeArray)
                    falseTrackDelta = np.zeros_like(timeArray)
                    for i, time in enumerate(timeArray):
                        if str(time) in correctInitTimeLog:
                            pmf[i] = correctInitTimeLog[str(time)]
                        if str(time) in falseInitTimeLog:
                            falsePFM[i] = falseInitTimeLog[str(time)][0]
                            falseTrackDelta[i] = falseInitTimeLog[str(time)][1]
                    cpmf = np.cumsum(pmf) / float(nTargets)
                    falseCPFM = np.cumsum(falsePFM)
                    falseTrackDelta = np.cumsum(falseTrackDelta)
                    cpfmList.append((P_d, lambda_phi, cpmf))
                    falseCPFMlist.append((P_d, lambda_phi, falseCPFM))
                    accFalseTrackList.append((P_d, lambda_phi, falseTrackDelta))
            cpfmList.sort(key=lambda tup: float(tup[1]))
            cpfmList.sort(key=lambda tup: float(tup[0]), reverse=True)

            falseCPFMlist.sort(key=lambda tup: float(tup[1]))
            falseCPFMlist.sort(key=lambda tup: float(tup[0]), reverse=True)

            accFalseTrackList.sort(key=lambda tup: float(tup[1]))
            accFalseTrackList.sort(key=lambda tup: float(tup[0]), reverse=True)

            pdSet = set()
            lambdaPhiSet = set()
            for P_d, lambda_phi, cpmf in cpfmList:
                if P_d not in pdSet:
                    lambdaPhiSet.clear()
                pdSet.add(P_d)
                lambdaPhiSet.add(lambda_phi)
                ax11.plot(timeArray,
                        cpmf,
                        label="$P_D$ = {0:}, $\lambda_\phi$ = {1:}".format(P_d, float(lambda_phi)),
                        c=colors[len(pdSet)-1],
                        linestyle=linestyleList[len(lambdaPhiSet)-1],
                        linewidth=linewidth)

            pdSet = set()
            lambdaPhiSet = set()
            for P_d, lambda_phi, cpmf in falseCPFMlist:
                if P_d not in pdSet:
                    lambdaPhiSet.clear()
                pdSet.add(P_d)
                lambdaPhiSet.add(lambda_phi)
                ax12.semilogy(timeArray,
                        cpmf+(1e-10),
                        label="$P_D$ = {0:}, $\lambda_\phi$ = {1:}".format(P_d, float(lambda_phi)),
                        c=colors[len(pdSet)-1],
                        linestyle=linestyleList[len(lambdaPhiSet)-1],
                        linewidth=linewidth)

            pdSet = set()
            lambdaPhiSet = set()
            for P_d, lambda_phi, accFalseTrack in accFalseTrackList:
                if P_d not in pdSet:
                    lambdaPhiSet.clear()
                pdSet.add(P_d)
                lambdaPhiSet.add(lambda_phi)
                ax13.plot(timeArray,
                          accFalseTrack,
                         label="$P_D$ = {0:}, $\lambda_\phi$ = {1:}".format(P_d, float(lambda_phi)),
                         c=colors[len(pdSet) - 1],
                         linestyle=linestyleList[len(lambdaPhiSet) - 1],
                            linewidth=linewidth)

            ax11.set_xlabel("Time [s]", fontsize=labelFontsize)
            ax11.set_ylabel("Average cpfm", fontsize=labelFontsize)
            ax11.set_title("Cumulative Probability Mass Function", fontsize=titleFontsize)
            ax11.legend(loc=4, ncol=len(pdSet), fontsize=legendFontsize)
            ax11.grid(False)
            ax11.set_xlim(0, simLength)
            ax11.set_ylim(0,1)
            ax11.xaxis.set_ticks(np.arange(0,simLength+15, 15))
            ax11.tick_params(labelsize=labelFontsize)
            sns.despine(ax=ax11, offset=0)

            ax12.set_xlabel("Time [s]", fontsize=labelFontsize)
            ax12.set_ylabel("Average number of tracks", fontsize=labelFontsize)
            ax12.set_title("Accumulative number of erroneous tracks", fontsize=titleFontsize)
            ax12.grid(False)
            ax12.set_xlim(0, simLength)
            ax12.set_ylim(1e-3,1000)
            ax12.xaxis.set_ticks(np.arange(0,simLength+15, 15))
            ax12.tick_params(labelsize=labelFontsize)
            sns.despine(ax=ax12, offset=0)

            ax13.set_xlabel("Time [s]", fontsize=labelFontsize)
            ax13.set_ylabel("Average number of tracks", fontsize=labelFontsize)
            ax13.set_title("Number of erroneous tracks alive", fontsize=titleFontsize)
            ax13.grid(False)
            ax13.set_xlim(0, simLength)
            ax13.set_ylim(-0.02, max(1,ax13.get_ylim()[1]))
            ax13.xaxis.set_ticks(np.arange(0,simLength+15, 15))
            ax13.tick_params(labelsize=labelFontsize)
            sns.despine(ax=ax13, offset=0)

            figure.tight_layout(pad=0.8, h_pad=0.8, w_pad=0.8)
            figure.savefig(savePath)
            figure.clf()

            plt.close()
コード例 #37
0
source = ['Lustre', 'SSD']
size = ['9000', '900']
analysis = ['RDF', 'RMS']
scheduler = ['distr', 'multi']
nodes = ['3nodes', '6nodes']
scheduler_full_name = {'distr': 'distributed', 'multi': 'multiprocessing'}

# for Lustre distributed
for ana in analysis:
    fig = plt.figure(figsize=(12, 8), tight_layout=True)
    gs = gridspec.GridSpec(5, 9)

    ax0 = fig.add_subplot(gs[0:2, 0:3])
    ax0.set_xlabel('Number of cores')
    ax0.set_ylabel('Wait (s)')
    sns.despine(offset=10, ax=ax0)

    ax0.set_title('pmda.{} on 900 frames'.format(ana.lower()))
    ax0.set_xscale("log")
    ax0.set_yscale("log")

    ax1 = fig.add_subplot(gs[0:2, 3:6])
    ax1.set_xlabel('Number of cores')
    ax1.set_ylabel('Compute per block (s)')
    sns.despine(offset=10, ax=ax1)

    ax1.set_title('pmda.{} on 900 frames'.format(ana.lower()))
    ax1.set_xscale("log")
    ax1.set_yscale("log")

    ax2 = fig.add_subplot(gs[0:2, 6:9])
コード例 #38
0
def pairedcontrast(data,
                   x,
                   y,
                   idcol,
                   reps=3000,
                   statfunction=None,
                   idx=None,
                   figsize=None,
                   beforeAfterSpacer=0.01,
                   violinWidth=0.005,
                   floatOffset=0.05,
                   showRawData=False,
                   showAllYAxes=False,
                   floatContrast=True,
                   smoothboot=False,
                   floatViolinOffset=None,
                   showConnections=True,
                   summaryBar=False,
                   contrastYlim=None,
                   swarmYlim=None,
                   barWidth=0.005,
                   rawMarkerSize=8,
                   rawMarkerType='o',
                   summaryMarkerSize=10,
                   summaryMarkerType='o',
                   summaryBarColor='grey',
                   meansSummaryLineStyle='solid',
                   contrastZeroLineStyle='solid',
                   contrastEffectSizeLineStyle='solid',
                   contrastZeroLineColor='black',
                   contrastEffectSizeLineColor='black',
                   pal=None,
                   legendLoc=2,
                   legendFontSize=12,
                   legendMarkerScale=1,
                   axis_title_size=None,
                   yticksize=None,
                   xticksize=None,
                   tickAngle=45,
                   tickAlignment='right',
                   **kwargs):

    # Preliminaries.
    data = data.dropna()

    # plot params
    if axis_title_size is None:
        axis_title_size = 15
    if yticksize is None:
        yticksize = 12
    if xticksize is None:
        xticksize = 12

    axisTitleParams = {'labelsize': axis_title_size}
    xtickParams = {'labelsize': xticksize}
    ytickParams = {'labelsize': yticksize}

    rc('axes', **axisTitleParams)
    rc('xtick', **xtickParams)
    rc('ytick', **ytickParams)

    ## If `idx` is not specified, just take the FIRST TWO levels alphabetically.
    if idx is None:
        idx = tuple(np.unique(data[x])[0:2], )
    else:
        # check if multi-plot or not
        if all(isinstance(element, str) for element in idx):
            # if idx is supplied but not a multiplot (ie single list or tuple)
            if len(idx) != 2:
                print(idx, "does not have length 2.")
                sys.exit(0)
            else:
                idx = (tuple(idx, ), )
        elif all(isinstance(element, tuple) for element in idx):
            # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
            if (any(len(element) != 2 for element in idx)):
                # If any of the tuples contain more than 2 elements.
                print(element, "does not have length 2.")
                sys.exit(0)
    if floatViolinOffset is None:
        floatViolinOffset = beforeAfterSpacer / 2
    if contrastYlim is not None:
        contrastYlim = np.array([contrastYlim[0], contrastYlim[1]])
    if swarmYlim is not None:
        swarmYlim = np.array([swarmYlim[0], swarmYlim[1]])

    ## Here we define the palette on all the levels of the 'x' column.
    ## Thus, if the same pandas dataframe is re-used across different plots,
    ## the color identity of each group will be maintained.
    ## Set palette based on total number of categories in data['x'] or data['hue_column']
    if 'hue' in kwargs:
        u = kwargs['hue']
    else:
        u = x
    if ('color' not in kwargs and 'hue' not in kwargs):
        kwargs['color'] = 'k'

    if pal is None:
        pal = dict(
            zip(data[u].unique(),
                sns.color_palette(n_colors=len(data[u].unique()))))
    else:
        pal = pal

    # Initialise figure.
    if figsize is None:
        if len(idx) > 2:
            figsize = (12, (12 / np.sqrt(2)))
        else:
            figsize = (6, 6)
    fig = plt.figure(figsize=figsize)

    # Initialise GridSpec based on `levs_tuple` shape.
    gsMain = gridspec.GridSpec(
        1,
        np.shape(idx)[0])  # 1 row; columns based on number of tuples in tuple.
    # Set default statfunction
    if statfunction is None:
        statfunction = np.mean
    # Create list to collect all the contrast DataFrames generated.
    contrastList = list()
    contrastListNames = list()

    for gsIdx, xlevs in enumerate(idx):
        ## Pivot tempdat to get before and after lines.
        data_pivot = data.pivot_table(index=idcol, columns=x, values=y)

        # Start plotting!!
        if floatContrast is True:
            ax_raw = fig.add_subplot(gsMain[gsIdx], frame_on=False)
            ax_contrast = ax_raw.twinx()
        else:
            gsSubGridSpec = gridspec.GridSpecFromSubplotSpec(
                2, 1, subplot_spec=gsMain[gsIdx])
            ax_raw = plt.Subplot(fig, gsSubGridSpec[0, 0], frame_on=False)
            ax_contrast = plt.Subplot(fig,
                                      gsSubGridSpec[1, 0],
                                      sharex=ax_raw,
                                      frame_on=False)

        ## Plot raw data as swarmplot or stripplot.
        if showRawData is True:
            swarm_raw = sns.swarmplot(data=data,
                                      x=x,
                                      y=y,
                                      order=xlevs,
                                      ax=ax_raw,
                                      palette=pal,
                                      size=rawMarkerSize,
                                      marker=rawMarkerType,
                                      **kwargs)
        else:
            swarm_raw = sns.stripplot(data=data,
                                      x=x,
                                      y=y,
                                      order=xlevs,
                                      ax=ax_raw,
                                      palette=pal,
                                      **kwargs)
        swarm_raw.set_ylim(swarmYlim)

        ## Get some details about the raw data.
        maxXBefore = max(swarm_raw.collections[0].get_offsets().T[0])
        minXAfter = min(swarm_raw.collections[1].get_offsets().T[0])
        if showRawData is True:
            #beforeAfterSpacer = (getSwarmSpan(swarm_raw, 0) + getSwarmSpan(swarm_raw, 1))/2
            beforeAfterSpacer = 1
        xposAfter = maxXBefore + beforeAfterSpacer
        xAfterShift = minXAfter - xposAfter

        ## shift the after swarmpoints closer for aesthetic purposes.
        offsetSwarmX(swarm_raw.collections[1], -xAfterShift)

        ## pandas DataFrame of 'before' group
        x1 = pd.DataFrame({
            str(xlevs[0] + '_x'):
            pd.Series(swarm_raw.collections[0].get_offsets().T[0]),
            xlevs[0]:
            pd.Series(swarm_raw.collections[0].get_offsets().T[1]),
            '_R_':
            pd.Series(swarm_raw.collections[0].get_facecolors().T[0]),
            '_G_':
            pd.Series(swarm_raw.collections[0].get_facecolors().T[1]),
            '_B_':
            pd.Series(swarm_raw.collections[0].get_facecolors().T[2]),
        })
        ## join the RGB columns into a tuple, then assign to a column.
        x1['_hue_'] = x1[['_R_', '_G_', '_B_']].apply(tuple, axis=1)
        x1 = x1.sort_values(by=xlevs[0])
        x1.index = data_pivot.sort_values(by=xlevs[0]).index

        ## pandas DataFrame of 'after' group
        ### create convenient signifiers for column names.
        befX = str(xlevs[0] + '_x')
        aftX = str(xlevs[1] + '_x')

        x2 = pd.DataFrame({
            aftX:
            pd.Series(swarm_raw.collections[1].get_offsets().T[0]),
            xlevs[1]:
            pd.Series(swarm_raw.collections[1].get_offsets().T[1])
        })
        x2 = x2.sort_values(by=xlevs[1])
        x2.index = data_pivot.sort_values(by=xlevs[1]).index

        ## Join x1 and x2, on both their indexes.
        plotPoints = x1.merge(x2,
                              left_index=True,
                              right_index=True,
                              how='outer')

        ## Add the hue column if hue argument was passed.
        if 'hue' in kwargs:
            h = kwargs['hue']
            plotPoints[h] = data.pivot(index=idcol, columns=x,
                                       values=h)[xlevs[0]]
            swarm_raw.legend(loc=legendLoc,
                             fontsize=legendFontSize,
                             markerscale=legendMarkerScale)

        ## Plot the lines to join the 'before' points to their respective 'after' points.
        if showConnections is True:
            for i in plotPoints.index:
                ax_raw.plot(
                    [plotPoints.ix[i, befX], plotPoints.ix[i, aftX]],
                    [plotPoints.ix[i, xlevs[0]], plotPoints.ix[i, xlevs[1]]],
                    linestyle='solid',
                    color=plotPoints.ix[i, '_hue_'],
                    linewidth=0.75,
                    alpha=0.75)

        ## Hide the raw swarmplot data if so desired.
        if showRawData is False:
            swarm_raw.collections[0].set_visible(False)
            swarm_raw.collections[1].set_visible(False)

        if showRawData is True:
            #maxSwarmSpan = max(np.array([getSwarmSpan(swarm_raw, 0), getSwarmSpan(swarm_raw, 1)]))/2
            maxSwarmSpan = 0.5
        else:
            maxSwarmSpan = barWidth

        ## Plot Summary Bar.
        if summaryBar is True:
            # Calculate means
            means = data.groupby([x], sort=True).mean()[y]
            # # Calculate medians
            # medians = data.groupby([x], sort = True).median()[y]

            ## Draw summary bar.
            bar_raw = sns.barplot(x=means.index,
                                  y=means.values,
                                  order=xlevs,
                                  ax=ax_raw,
                                  ci=0,
                                  facecolor=summaryBarColor,
                                  alpha=0.25)
            ## Draw zero reference line.
            ax_raw.add_artist(
                Line2D((ax_raw.xaxis.get_view_interval()[0],
                        ax_raw.xaxis.get_view_interval()[1]), (0, 0),
                       color='black',
                       linewidth=0.75))

            ## get swarm with largest span, set as max width of each barplot.
            for i, bar in enumerate(bar_raw.patches):
                x_width = bar.get_x()
                width = bar.get_width()
                centre = x_width + width / 2.
                if i == 0:
                    bar.set_x(centre - maxSwarmSpan / 2.)
                else:
                    bar.set_x(centre - xAfterShift - maxSwarmSpan / 2.)
                bar.set_width(maxSwarmSpan)

        # Get y-limits of the treatment swarm points.
        beforeRaw = pd.DataFrame(swarm_raw.collections[0].get_offsets())
        afterRaw = pd.DataFrame(swarm_raw.collections[1].get_offsets())
        before_leftx = min(beforeRaw[0])
        after_leftx = min(afterRaw[0])
        after_rightx = max(afterRaw[0])
        after_stat_summary = statfunction(beforeRaw[1])

        # Calculate the summary difference and CI.
        plotPoints['delta_y'] = plotPoints[xlevs[1]] - plotPoints[xlevs[0]]
        plotPoints['delta_x'] = [0] * np.shape(plotPoints)[0]

        tempseries = plotPoints['delta_y'].tolist()
        test = tempseries.count(tempseries[0]) != len(tempseries)

        bootsDelta = bootstrap(plotPoints['delta_y'],
                               statfunction=statfunction,
                               smoothboot=smoothboot,
                               reps=reps)
        summDelta = bootsDelta['summary']
        lowDelta = bootsDelta['bca_ci_low']
        highDelta = bootsDelta['bca_ci_high']

        # set new xpos for delta violin.
        if floatContrast is True:
            if showRawData is False:
                xposPlusViolin = deltaSwarmX = after_rightx + floatViolinOffset
            else:
                xposPlusViolin = deltaSwarmX = after_rightx + maxSwarmSpan
        else:
            xposPlusViolin = xposAfter
        if showRawData is True:
            # If showRawData is True and floatContrast is True,
            # set violinwidth to the barwidth.
            violinWidth = maxSwarmSpan

        xmaxPlot = xposPlusViolin + violinWidth

        # Plot the summary measure.
        ax_contrast.plot(xposPlusViolin,
                         summDelta,
                         marker='o',
                         markerfacecolor='k',
                         markersize=summaryMarkerSize,
                         alpha=0.75)

        # Plot the CI.
        ax_contrast.plot([xposPlusViolin, xposPlusViolin],
                         [lowDelta, highDelta],
                         color='k',
                         alpha=0.75,
                         linestyle='solid')

        # Plot the violin-plot.
        v = ax_contrast.violinplot(bootsDelta['stat_array'], [xposPlusViolin],
                                   widths=violinWidth,
                                   showextrema=False,
                                   showmeans=False)
        halfviolin(v, half='right', color='k')

        # Remove left axes x-axis title.
        ax_raw.set_xlabel("")
        # Remove floating axes y-axis title.
        ax_contrast.set_ylabel("")

        # Set proper x-limits
        ax_raw.set_xlim(before_leftx - beforeAfterSpacer / 2, xmaxPlot)
        ax_raw.get_xaxis().set_view_interval(
            before_leftx - beforeAfterSpacer / 2,
            after_rightx + beforeAfterSpacer / 2)
        ax_contrast.set_xlim(ax_raw.get_xlim())

        if floatContrast is True:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))

            # Make sure they have the same y-limits.
            ax_contrast.set_ylim(ax_raw.get_ylim())

            # Drawing in the x-axis for ax_raw.
            ## Set the tick labels!
            ax_raw.set_xticklabels(xlevs,
                                   rotation=tickAngle,
                                   horizontalalignment=tickAlignment)
            ## Get lowest y-value for ax_raw.
            y = ax_raw.get_yaxis().get_view_interval()[0]

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, statfunction(plotPoints[xlevs[0]]),
                        ax_contrast, 0)

            # Add label to floating axes. But on ax_raw!
            ax_raw.text(x=deltaSwarmX,
                        y=ax_raw.get_yaxis().get_view_interval()[0],
                        horizontalalignment='left',
                        s='Difference',
                        fontsize=15)

            # Set reference lines
            ## zero line
            ax_contrast.hlines(
                0,  # y-coordinate
                ax_contrast.xaxis.get_majorticklocs()
                [0],  # x-coordinates, start and end.
                ax_raw.xaxis.get_view_interval()[1],
                linestyle='solid',
                linewidth=0.75,
                color='black')

            ## effect size line
            ax_contrast.hlines(summDelta,
                               ax_contrast.xaxis.get_majorticklocs()[1],
                               ax_raw.xaxis.get_view_interval()[1],
                               linestyle='solid',
                               linewidth=0.75,
                               color='black')

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, after_stat_summary, ax_contrast, 0.)
        else:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))

            fig.add_subplot(ax_raw)
            fig.add_subplot(ax_contrast)
        ax_contrast.set_ylim(contrastYlim)
        # Calculate p-values.
        # 1-sample t-test to see if the mean of the difference is different from 0.
        ttestresult = ttest_1samp(plotPoints['delta_y'], popmean=0)[1]
        bootsDelta['ttest_pval'] = ttestresult
        contrastList.append(bootsDelta)
        contrastListNames.append(str(xlevs[1]) + ' v.s. ' + str(xlevs[0]))

    # Turn contrastList into a pandas DataFrame,
    contrastList = pd.DataFrame(contrastList).T
    contrastList.columns = contrastListNames

    # Now we iterate thru the contrast axes to normalize all the ylims.
    for j, i in enumerate(range(1, len(fig.get_axes()), 2)):
        axx = fig.get_axes()[i]
        ## Get max and min of the dataset.
        lower = np.min(contrastList.ix['stat_array', j])
        upper = np.max(contrastList.ix['stat_array', j])
        meandiff = contrastList.ix['summary', j]

        ## Make sure we have zero in the limits.
        if lower > 0:
            lower = 0.
        if upper < 0:
            upper = 0.

        ## Get tick distance on raw axes.
        ## This will be the tick distance for the contrast axes.
        rawAxesTicks = fig.get_axes()[i - 1].yaxis.get_majorticklocs()
        rawAxesTickDist = rawAxesTicks[1] - rawAxesTicks[0]

        ## First re-draw of axis with new tick interval
        axx.yaxis.set_major_locator(MultipleLocator(rawAxesTickDist))
        newticks1 = fig.get_axes()[i].get_yticks()

        if floatContrast is False:
            if (showAllYAxes is False and i in range(2, len(fig.get_axes()))):
                axx.get_yaxis().set_visible(showAllYAxes)
            else:
                ## Obtain major ticks that comfortably encompass lower and upper.
                newticks2 = list()
                for a, b in enumerate(newticks1):
                    if (b >= lower and b <= upper):
                        # if the tick lies within upper and lower, take it.
                        newticks2.append(b)
                # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
                if np.max(newticks2) < meandiff:
                    ind = np.where(newticks1 == np.max(newticks2))[0][
                        0]  # find out the max tick index in newticks1.
                    newticks2.append(newticks1[ind + 1])
                elif meandiff < np.min(newticks2):
                    ind = np.where(newticks1 == np.min(newticks2))[0][
                        0]  # find out the min tick index in newticks1.
                    newticks2.append(newticks1[ind - 1])
                newticks2 = np.array(newticks2)
                newticks2.sort()
                axx.yaxis.set_major_locator(FixedLocator(locs=newticks2))

                ## Draw zero reference line.
                axx.hlines(
                    y=0,
                    xmin=fig.get_axes()[i].get_xaxis().get_view_interval()[0],
                    xmax=fig.get_axes()[i].get_xaxis().get_view_interval()[1],
                    linestyle=contrastZeroLineStyle,
                    linewidth=0.75,
                    color=contrastZeroLineColor)

                sns.despine(ax=fig.get_axes()[i],
                            trim=True,
                            bottom=False,
                            right=True,
                            left=False,
                            top=True)

                ## Draw back the lines for the relevant y-axes.
                drawback_y(axx)

                ## Draw back the lines for the relevant x-axes.
                drawback_x(axx)

        elif floatContrast is True:
            ## Get the original ticks on the floating y-axis.
            newticks1 = fig.get_axes()[i].get_yticks()

            ## Obtain major ticks that comfortably encompass lower and upper.
            newticks2 = list()
            for a, b in enumerate(newticks1):
                if (b >= lower and b <= upper):
                    # if the tick lies within upper and lower, take it.
                    newticks2.append(b)
            # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
            if np.max(newticks2) < meandiff:
                ind = np.where(newticks1 == np.max(newticks2))[0][
                    0]  # find out the max tick index in newticks1.
                newticks2.append(newticks1[ind + 1])
            elif meandiff < np.min(newticks2):
                ind = np.where(newticks1 == np.min(newticks2))[0][
                    0]  # find out the min tick index in newticks1.
                newticks2.append(newticks1[ind - 1])
            newticks2 = np.array(newticks2)
            newticks2.sort()

            ## Re-draw the axis.
            axx.yaxis.set_major_locator(FixedLocator(locs=newticks2))

            ## Despine and trim the axes.
            sns.despine(ax=axx,
                        trim=True,
                        bottom=False,
                        right=False,
                        left=True,
                        top=True)

    for i in range(0, len(fig.get_axes()), 2):
        # Loop through the raw data swarmplots and despine them appropriately.
        if floatContrast is True:
            sns.despine(ax=fig.get_axes()[i], trim=True, right=True)

        else:
            sns.despine(ax=fig.get_axes()[i],
                        trim=True,
                        bottom=True,
                        right=True)
            fig.get_axes()[i].get_xaxis().set_visible(False)

        # Draw back the lines for the relevant y-axes.
        ymin = fig.get_axes()[i].get_yaxis().get_majorticklocs()[0]
        ymax = fig.get_axes()[i].get_yaxis().get_majorticklocs()[-1]
        x, _ = fig.get_axes()[i].get_xaxis().get_view_interval()
        fig.get_axes()[i].add_artist(
            Line2D((x, x), (ymin, ymax), color='black', linewidth=1.5))

    # Zero gaps between plots on the same row, if floatContrast is False
    if (floatContrast is False and showAllYAxes is False):
        gsMain.update(wspace=0)
    else:
        # Tight Layout!
        gsMain.tight_layout(fig)

    # And we're done.
    rcdefaults()  # restore matplotlib defaults.
    sns.set()  # restore seaborn defaults.
    return fig, contrastList
コード例 #39
0
ファイル: fit_plot.py プロジェクト: VOD555/Useful-Script
import numpy as np
from sklearn import linear_model
import matplotlib.pyplot as plt
import seaborn.apionly as sns

data = np.loadtxt('msd.xvg')
x = np.zeros(np.shape(data))
x[:, 1] = data[:, 0]
x = x / 1000
regr = linear_model.LinearRegression()
regr.fit(x, data[:, 1])

y_hat = np.dot(x, np.array(regr.coef_).T)

fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)

ax.plot(data[:, 0], data[:, 1], color='k', label='Data')
ax.plot(data[:, 0], y_hat, color='blue', label='Fitted line')

ax.legend(loc='best')
ax.set_xlabel(r"time  $t$ (ps)")
ax.set_ylabel("msd (nm)")

sns.despine(offset=10, ax=ax)

plt.tight_layout()

fig.savefig('fit.png')
def main():

    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_hidden_behavior_set']
    Df_expt_grp = all_grps['Df_hidden_behavior_set']

    behavior_fn = ra.fraction_licks_in_reward_zone
    behavior_kwargs = {}
    activity_label = 'Fraction of licks in reward zone'

    WT_colors = sns.light_palette(WT_color, 8)[::-1]
    Df_colors = sns.light_palette(Df_color, 7)[::-1]
    markers = ('o', 'v', '^', 'D', '*', 's')

    fig, axs = plt.subplots(4, 2, figsize=(8.5, 11))

    sns.despine(fig)

    wt_ax = axs[0, 0]
    df_ax = axs[0, 1]

    for ax in list(axs.flat)[2:]:
        ax.set_visible(False)

    wt_expt_grps = [
        WT_expt_grp.subGroup(list(expts['expt']), label=mouse)
        for mouse, expts in WT_expt_grp.dataframe(
            WT_expt_grp, include_columns=['mouseID']).groupby('mouseID')
    ]
    df_expt_grps = [
        Df_expt_grp.subGroup(list(expts['expt']), label=mouse)
        for mouse, expts in Df_expt_grp.dataframe(
            Df_expt_grp, include_columns=['mouseID']).groupby('mouseID')
    ]

    plotting.plot_metric(wt_ax,
                         wt_expt_grps,
                         metric_fn=behavior_fn,
                         activity_kwargs=behavior_kwargs,
                         groupby=[['expt'], ['condition_day']],
                         plotby=['condition_day'],
                         plot_method='line',
                         ms=5,
                         activity_label=activity_label,
                         colors=WT_colors,
                         markers=markers,
                         label_every_n=1,
                         label_groupby=False,
                         rotate_labels=False)
    wt_ax.set_xlabel('Day in Condition')
    wt_ax.set_title(WT_expt_grp.label())
    wt_ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5])
    wt_ax.set_xticklabels(['1', '2', '3', '1', '2', '3', '1', '2', '3'])
    label_conditions(wt_ax)
    wt_ax.get_legend().set_visible(False)
    wt_ax.tick_params(length=3, pad=2)

    plotting.plot_metric(df_ax,
                         df_expt_grps,
                         metric_fn=behavior_fn,
                         activity_kwargs=behavior_kwargs,
                         groupby=[['expt'], ['condition_day']],
                         plotby=['condition_day'],
                         plot_method='line',
                         ms=5,
                         activity_label=activity_label,
                         colors=Df_colors,
                         markers=markers,
                         label_every_n=1,
                         label_groupby=False,
                         rotate_labels=False)
    df_ax.set_xlabel('Day in Condition')
    df_ax.set_title(Df_expt_grp.label())
    df_ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5])
    df_ax.set_xticklabels(['1', '2', '3', '1', '2', '3', '1', '2', '3'])
    label_conditions(df_ax)
    df_ax.get_legend().set_visible(False)
    df_ax.tick_params(length=3, pad=2)

    misc.save_figure(fig, filename, save_dir=save_dir)
コード例 #41
0
ファイル: _old.py プロジェクト: josesho/bootstrapContrast
def contrastplot(
    data, x=None, y=None, idx=None, idcol=None,

    alpha=0.75, 
    axis_title_size=None,

    ci=95,
    contrastShareY=True,
    contrastEffectSizeLineStyle='solid',
    contrastEffectSizeLineColor='black',

    contrastYlim=None,
    contrastZeroLineStyle='solid', 
    contrastZeroLineColor='black', 
    connectPairs=True,

    effectSizeYLabel="Effect Size", 

    figsize=None, 
    floatContrast=True,
    floatSwarmSpacer=0.2,

    heightRatio=(1, 1),

    lineWidth=2,
    legend=True,
    legendFontSize=14,
    legendFontProps={},

    paired=False,
    pairedDeltaLineAlpha=0.3,
    pairedDeltaLineWidth=1.2,
    pal=None, 

    rawMarkerSize=8,
    rawMarkerType='o',
    reps=3000,
    
    showGroupCount=True,
    showCI=False, 
    showAllYAxes=False,
    showRawData=True,
    smoothboot=False, 
    statfunction=None, 

    summaryBar=False, 
    summaryBarColor='grey',
    summaryBarAlpha=0.25,

    summaryColour='black', 
    summaryLine=True, 
    summaryLineStyle='solid', 
    summaryLineWidth=0.25, 

    summaryMarkerSize=10, 
    summaryMarkerType='o',

    swarmShareY=True, 
    swarmYlim=None, 

    tickAngle=45,
    tickAlignment='right',

    violinOffset=0.375,
    violinWidth=0.2, 
    violinColor='k',

    xticksize=None,
    yticksize=None,

    **kwargs):

    '''Takes a pandas DataFrame and produces a contrast plot:
    either a Cummings hub-and-spoke plot or a Gardner-Altman contrast plot.
    Paired and unpaired options available.

    Keyword arguments:
        data: pandas DataFrame
            
        x: string
            column name containing categories to be plotted on the x-axis.

        y: string
            column name containing values to be plotted on the y-axis.

        idx: tuple
            flxible declaration of groupwise comparisons.

        idcol: string
            for paired plots.

        alpha: float
            alpha (transparency) of raw swarmed data points.
            
        axis_title_size=None
        ci=95
        contrastShareY=True
        contrastEffectSizeLineStyle='solid'
        contrastEffectSizeLineColor='black'
        contrastYlim=None
        contrastZeroLineStyle='solid'
        contrastZeroLineColor='black'
        effectSizeYLabel="Effect Size"
        figsize=None
        floatContrast=True
        floatSwarmSpacer=0.2
        heightRatio=(1,1)
        lineWidth=2
        legend=True
        legendFontSize=14
        legendFontProps={}
        paired=False
        pairedDeltaLineAlpha=0.3
        pairedDeltaLineWidth=1.2
        pal=None
        rawMarkerSize=8
        rawMarkerType='o'
        reps=3000
        showGroupCount=True
        showCI=False
        showAllYAxes=False
        showRawData=True
        smoothboot=False
        statfunction=None
        summaryBar=False
        summaryBarColor='grey'
        summaryBarAlpha=0.25
        summaryColour='black'
        summaryLine=True
        summaryLineStyle='solid'
        summaryLineWidth=0.25
        summaryMarkerSize=10
        summaryMarkerType='o'
        swarmShareY=True
        swarmYlim=None
        tickAngle=45
        tickAlignment='right'
        violinOffset=0.375
        violinWidth=0.2
        violinColor='k'
        xticksize=None
        yticksize=None

    Returns:
        An matplotlib Figure.
        Organization of figure Axes.
    '''

    # Check that `data` is a pandas dataframe
    if 'DataFrame' not in str(type(data)):
        raise TypeError("The object passed to the command is not not a pandas DataFrame.\
         Please convert it to a pandas DataFrame.")

    # make sure that at least x, y, and idx are specified.
    if x is None and y is None and idx is None:
        raise ValueError('You need to specify `x` and `y`, or `idx`. Neither has been specifed.')

    if x is None:
        # if x is not specified, assume this is a 'wide' dataset, with each idx being the name of a column.
        datatype='wide'
        # Check that the idx are legit columns.
        all_idx=np.unique([element for tupl in idx for element in tupl])
        # # melt the data.
        # data=pd.melt(data,value_vars=all_idx)
        # x='variable'
        # y='value'
    else:
        # if x is specified, assume this is a 'long' dataset with each row corresponding to one datapoint.
        datatype='long'
        # make sure y is not none.
        if y is None:
            raise ValueError("`paired` is false, but no y-column given.")
        # Calculate Ns.
        counts=data.groupby(x)[y].count()

    # Get and set levels of data[x]
    if paired is True:
        violinWidth=0.1
        # # Calculate Ns--which should be simply the number of rows in data.
        # counts=len(data)
        # is idcol supplied?
        if idcol is None and datatype=='long':
            raise ValueError('`idcol` has not been supplied but a paired plot is desired; please specify the `idcol`.')
        if idx is not None:
            # check if multi-plot or not
            if all(isinstance(element, str) for element in idx):
                # check that every idx is a column name.
                idx_not_in_cols=[n
                for n in idx
                if n not in data[x].unique()]
                if len(idx_not_in_cols)!=0:
                    raise ValueError(str(idx_not_in_cols)+" cannot be found in the columns of `data`.")
                # data_wide_cols=[n for n in idx if n in data.columns]
                # if idx is supplied but not a multiplot (ie single list or tuple)
                if len(idx) != 2:
                    raise ValueError(idx+" does not have length 2.")
                else:
                    tuple_in=(tuple(idx, ),)
                widthratio=[1]
            elif all(isinstance(element, tuple) for element in idx):
                # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
                idx_not_in_cols=[n
                for tup in idx
                for n in tup
                if n not in data[x].unique()]
                if len(idx_not_in_cols)!=0:
                    raise ValueError(str(idx_not_in_cols)+" cannot be found in the column "+x)
                # data_wide_cols=[n for tup in idx for n in tup if n in data.columns]
                if ( any(len(element) != 2 for element in idx) ):
                    # If any of the tuples does not contain exactly 2 elements.
                    raise ValueError(element+" does not have length 2.")
                # Make sure the widthratio of the seperate multiplot corresponds to how 
                # many groups there are in each one.
                tuple_in=idx
                widthratio=[]
                for i in tuple_in:
                    widthratio.append(len(i))
        elif idx is None:
            raise ValueError('Please specify idx.')
        showRawData=False # Just show lines, do not show data.
        showCI=False # wait till I figure out how to plot this for sns.barplot.
        if datatype=='long':
            if idx is None:
                ## If `idx` is not specified, just take the FIRST TWO levels alphabetically.
                tuple_in=tuple(np.sort(np.unique(data[x]))[0:2],)
            # pivot the dataframe if it is long!
            data_pivot=data.pivot_table(index = idcol, columns = x, values = y)

    elif paired is False:
        if idx is None:
            widthratio=[1]
            tuple_in=( tuple(data[x].unique()) ,)
            if len(tuple_in[0])>2:
                floatContrast=False
        else:
            if all(isinstance(element, str) for element in idx):
                # if idx is supplied but not a multiplot (ie single list or tuple)
                # check all every idx specified can be found in data[x]
                idx_not_in_x=[n for n in idx 
                if n not in data[x].unique()]
                if len(idx_not_in_x)!=0:
                    raise ValueError(str(idx_not_in_x)+" cannot be found in the column "+x)
                tuple_in=(idx, )
                widthratio=[1]
                if len(idx)>2:
                    floatContrast=False
            elif all(isinstance(element, tuple) for element in idx):
                # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
                idx_not_in_x=[n
                for tup in idx
                for n in tup
                if n not in data[x].unique()]
                if len(idx_not_in_x)!=0:
                    raise ValueError(str(idx_not_in_x)+" cannot be found in the column "+x)
                tuple_in=idx

                if ( any(len(element)>2 for element in tuple_in) ):
                    # if any of the tuples in idx has more than 2 groups, we turn set floatContrast as False.
                    floatContrast=False
                # Make sure the widthratio of the seperate multiplot corresponds to how 
                # many groups there are in each one.
                widthratio=[]
                for i in tuple_in:
                    widthratio.append(len(i))
            else:
                raise TypeError("The object passed to `idx` consists of a mixture of single strings and tuples. \
                    Please make sure that `idx` is either a tuple of column names, or a tuple of tuples, for plotting.")

    # Ensure summaryLine and summaryBar are not displayed together.
    if summaryLine is True and summaryBar is True:
        summaryBar=True
        summaryLine=False
    # Turn off summary line if floatContrast is true
    if floatContrast:
        summaryLine=False
    # initialise statfunction
    if statfunction == None:
        statfunction=np.mean
    # Create list to collect all the contrast DataFrames generated.
    contrastList=list()
    contrastListNames=list()

    # Setting color palette for plotting.
    if pal is None:
        if 'hue' in kwargs:
            colorCol=kwargs['hue']
            if colorCol not in data.columns:
                raise ValueError(colorCol+' is not a column name.')
            colGrps=data[colorCol].unique()#.tolist()
            plotPal=dict( zip( colGrps, sns.color_palette(n_colors=len(colGrps)) ) )
        else:
            if datatype=='long':
                colGrps=data[x].unique()#.tolist()
                plotPal=dict( zip( colGrps, sns.color_palette(n_colors=len(colGrps)) ) )
            if datatype=='wide':
                plotPal=np.repeat('k',len(data))
    else:
        if datatype=='long':
            plotPal=pal
        if datatype=='wide':
            plotPal=list(map(lambda x:pal[x], data[hue]))

    if swarmYlim is None:
        # get range of _selected groups_.
        # u = list()
        # for t in tuple_in:
        #     for i in np.unique(t):
        #         u.append(i)
        # u = np.unique(u)
        u=np.unique([element for tupl in tuple_in for element in tupl])
        if datatype=='long':
            tempdat=data[data[x].isin(u)]
            swarm_ylim=np.array([np.min(tempdat[y]), np.max(tempdat[y])])
        if datatype=='wide':
            allMin=list()
            allMax=list()
            for col in u:
                allMin.append(np.min(data[col]))
                allMax.append(np.max(data[col]))
            swarm_ylim=np.array( [np.min(allMin),np.max(allMax)] )
        swarm_ylim=np.round(swarm_ylim)
    else:
        swarm_ylim=np.array([swarmYlim[0],swarmYlim[1]])

    if summaryBar is True:
        lims=swarm_ylim
        # check that 0 lies within the desired limits.
        # if not, extend (upper or lower) limit to zero.
        if 0 not in range( int(round(lims[0])),int(round(lims[1])) ): # turn swarm_ylim to integer range.
            # check if all negative:.
            if lims[0]<0. and lims[1]<0.:
                swarm_ylim=np.array([np.min(lims),0.])
            # check if all positive.
            elif lims[0]>0. and lims[1]>0.:
                swarm_ylim=np.array([0.,np.max(lims)])

    if contrastYlim is not None:
        contrastYlim=np.array([contrastYlim[0],contrastYlim[1]])

    # plot params
    if axis_title_size is None:
        axis_title_size=27
    if yticksize is None:
        yticksize=22
    if xticksize is None:
        xticksize=22

    # Set clean style
    sns.set(style='ticks')

    axisTitleParams={'labelsize' : axis_title_size}
    xtickParams={'labelsize' : xticksize}
    ytickParams={'labelsize' : yticksize}
    svgParams={'fonttype' : 'none'}

    rc('axes', **axisTitleParams)
    rc('xtick', **xtickParams)
    rc('ytick', **ytickParams)
    rc('svg', **svgParams) 

    if figsize is None:
        if len(tuple_in)>2:
            figsize=(12,(12/np.sqrt(2)))
        else:
            figsize=(8,(8/np.sqrt(2)))
    
    # calculate CI.
    if ci<0 or ci>100:
        raise ValueError('ci should be between 0 and 100.')
    alpha_level=(100.-ci)/100.

    # Initialise figure, taking into account desired figsize.
    fig=plt.figure(figsize=figsize)

    # Initialise GridSpec based on `tuple_in` shape.
    gsMain=gridspec.GridSpec( 
        1, np.shape(tuple_in)[0], 
         # 1 row; columns based on number of tuples in tuple.
         width_ratios=widthratio,
         wspace=0 )

    for gsIdx, current_tuple in enumerate(tuple_in):
        #### FOR EACH TUPLE IN IDX
        if datatype=='long':
            plotdat=data[data[x].isin(current_tuple)]
            plotdat[x]=plotdat[x].astype("category")
            plotdat[x].cat.set_categories(
                current_tuple,
                ordered=True,
                inplace=True)
            plotdat.sort_values(by=[x])
            # # Drop all nans. 
            # plotdat.dropna(inplace=True)
            summaries=plotdat.groupby(x)[y].apply(statfunction)
        if datatype=='wide':
            plotdat=data[list(current_tuple)]
            summaries=statfunction(plotdat)
            plotdat=pd.melt(plotdat) ##### NOW I HAVE MELTED THE WIDE DATA.
            
        if floatContrast is True:
            # Use fig.add_subplot instead of plt.Subplot.
            ax_raw=fig.add_subplot(gsMain[gsIdx],
                frame_on=False)
            ax_contrast=ax_raw.twinx()
        else:
        # Create subGridSpec with 2 rows and 1 column.
            subGridSpec=gridspec.GridSpecFromSubplotSpec(2, 1,
                subplot_spec=gsMain[gsIdx],
                wspace=0)
            # Use plt.Subplot instead of fig.add_subplot
            ax_raw=plt.Subplot(fig,
                subGridSpec[0, 0],
                frame_on=False)
            ax_contrast=plt.Subplot(fig,
                subGridSpec[1, 0],
                sharex=ax_raw,
                frame_on=False)
        # Calculate the boostrapped contrast
        bscontrast=list()
        if paired is False:
            tempplotdat=plotdat[[x,y]] # only select the columns used for x and y plotting.
            for i in range (1, len(current_tuple)):
                # Note that you start from one. No need to do auto-contrast!
                # if datatype=='long':aas
                    tempbs=bootstrap_contrast(
                        data=tempplotdat.dropna(), 
                        x=x,
                        y=y,
                        idx=[current_tuple[0], current_tuple[i]],
                        statfunction=statfunction,
                        smoothboot=smoothboot,
                        alpha_level=alpha_level,
                        reps=reps)
                    bscontrast.append(tempbs)
                    contrastList.append(tempbs)
                    contrastListNames.append(current_tuple[i]+' vs. '+current_tuple[0])

        #### PLOT RAW DATA.
        ax_raw.set_ylim(swarm_ylim)
        # ax_raw.yaxis.set_major_locator(MaxNLocator(n_bins='auto'))
        # ax_raw.yaxis.set_major_locator(LinearLocator())
        if paired is False and showRawData is True:
            # Seaborn swarmplot doc says to set custom ylims first.
            sw=sns.swarmplot(
                data=plotdat, 
                x=x, y=y, 
                order=current_tuple, 
                ax=ax_raw, 
                alpha=alpha, 
                palette=plotPal,
                size=rawMarkerSize,
                marker=rawMarkerType,
                **kwargs)

            if floatContrast:
                # Get horizontal offset values.
                maxXBefore=max(sw.collections[0].get_offsets().T[0])
                minXAfter=min(sw.collections[1].get_offsets().T[0])
                xposAfter=maxXBefore+floatSwarmSpacer
                xAfterShift=minXAfter-xposAfter
                # shift the (second) swarmplot
                offsetSwarmX(sw.collections[1], -xAfterShift)
                # shift the tick.
                ax_raw.set_xticks([0.,1-xAfterShift])

        elif paired is True:
            if showRawData is True:
                sw=sns.swarmplot(data=plotdat, 
                    x=x, y=y, 
                    order=current_tuple, 
                    ax=ax_raw, 
                    alpha=alpha, 
                    palette=plotPal,
                    size=rawMarkerSize,
                    marker=rawMarkerType,
                **kwargs)
            if connectPairs is True:
                # Produce paired plot with lines.
                before=plotdat[plotdat[x]==current_tuple[0]][y].tolist()
                after=plotdat[plotdat[x]==current_tuple[1]][y].tolist()
                linedf=pd.DataFrame(
                    {'before':before,
                    'after':after}
                    )
                # to get color, need to loop thru each line and plot individually.
                for ii in range(0,len(linedf)):
                    ax_raw.plot( [0,0.25], [ linedf.loc[ii,'before'],
                                            linedf.loc[ii,'after'] ],
                                linestyle='solid',
                                linewidth=pairedDeltaLineWidth,
                                color=plotPal[current_tuple[0]],
                                alpha=pairedDeltaLineAlpha,
                               )
                ax_raw.set_xlim(-0.25,0.5)
                ax_raw.set_xticks([0,0.25])
                ax_raw.set_xticklabels([current_tuple[0],current_tuple[1]])

        # if swarmYlim is None:
        #     # if swarmYlim was not specified, tweak the y-axis 
        #     # to show all the data without losing ticks and range.
        #     ## Get all yticks.
        #     axxYTicks=ax_raw.yaxis.get_majorticklocs()
        #     ## Get ytick interval.
        #     YTickInterval=axxYTicks[1]-axxYTicks[0]
        #     ## Get current ylim
        #     currentYlim=ax_raw.get_ylim()
        #     ## Extend ylim by adding a fifth of the tick interval as spacing at both ends.
        #     ax_raw.set_ylim(
        #         currentYlim[0]-(YTickInterval/5),
        #         currentYlim[1]+(YTickInterval/5)
        #         )
        #     ax_raw.yaxis.set_major_locator(MaxNLocator(nbins='auto'))
        # ax_raw.yaxis.set_major_locator(MaxNLocator(nbins='auto'))
        # ax_raw.yaxis.set_major_locator(LinearLocator())

        if summaryBar is True:
            if paired is False:
                bar_raw=sns.barplot(
                    x=summaries.index.tolist(),
                    y=summaries.values,
                    facecolor=summaryBarColor,
                    ax=ax_raw,
                    alpha=summaryBarAlpha)
                if floatContrast is True:
                    maxSwarmSpan=2/10.
                    xlocs=list()
                    for i, bar in enumerate(bar_raw.patches):
                        x_width=bar.get_x()
                        width=bar.get_width()
                        centre=x_width + (width/2.)
                        if i == 0:
                            bar.set_x(centre-maxSwarmSpan/2.)
                            xlocs.append(centre)
                        else:
                            bar.set_x(centre-xAfterShift-maxSwarmSpan/2.)
                            xlocs.append(centre-xAfterShift)
                        bar.set_width(maxSwarmSpan)
                    ax_raw.set_xticks(xlocs) # make sure xticklocs match the barplot.
                elif floatContrast is False:
                    maxSwarmSpan=4/10.
                    xpos=ax_raw.xaxis.get_majorticklocs()
                    for i, bar in enumerate(bar_raw.patches):
                        bar.set_x(xpos[i]-maxSwarmSpan/2.)
                        bar.set_width(maxSwarmSpan)
            else:
                # if paired is true
                ax_raw.bar([0,0.25], 
                    [ statfunction(plotdat[current_tuple[0]]),
                    statfunction(plotdat[current_tuple[1]]) ],
                    color=summaryBarColor,
                    alpha=0.5,
                    width=0.05)
                ## Draw zero reference line.
                ax_raw.add_artist(Line2D(
                    (ax_raw.xaxis.get_view_interval()[0],
                     ax_raw.xaxis.get_view_interval()[1]),
                    (0,0),
                    color='k', linewidth=1.25)
                                 )

        if summaryLine is True:
            if paired is True:
                xdelta=0
            else:
                xdelta=summaryLineWidth
            for i, m in enumerate(summaries):
                ax_raw.plot(
                    (i-xdelta, 
                    i+xdelta), # x-coordinates
                    (m, m),
                    color=summaryColour, 
                    linestyle=summaryLineStyle)

        if showCI is True:
                sns.barplot(
                    data=plotdat, 
                    x=x, y=y, 
                    ax=ax_raw, 
                    alpha=0, ci=95)

        ax_raw.set_xlabel("")
        if floatContrast is False:
            fig.add_subplot(ax_raw)

        #### PLOT CONTRAST DATA.
        if len(current_tuple)==2:
            if paired is False:
                # Plot the CIs on the contrast axes.
                plotbootstrap(sw.collections[1],
                              bslist=tempbs,
                              ax=ax_contrast, 
                              violinWidth=violinWidth,
                              violinOffset=violinOffset,
                              markersize=summaryMarkerSize,
                              marker=summaryMarkerType,
                              offset=floatContrast,
                              color=violinColor,
                              linewidth=1)
            else:
                bootsDelta = bootstrap(
                    plotdat[current_tuple[1]]-plotdat[current_tuple[0]],
                    statfunction=statfunction,
                    smoothboot=smoothboot,
                    alpha_level=alpha_level,
                    reps=reps)
                contrastList.append(bootsDelta)
                contrastListNames.append(current_tuple[1]+' vs. '+current_tuple[0])
                summDelta = bootsDelta['summary']
                lowDelta = bootsDelta['bca_ci_low']
                highDelta = bootsDelta['bca_ci_high']

                if floatContrast:
                    xpos=0.375
                else:
                    xpos=0.25

                # Plot the summary measure.
                ax_contrast.plot(xpos, bootsDelta['summary'],
                         marker=summaryMarkerType,
                         markerfacecolor='k',
                         markersize=summaryMarkerSize,
                         alpha=0.75
                        )
                # Plot the CI.
                ax_contrast.plot([xpos, xpos],
                         [lowDelta, highDelta],
                         color='k',
                         alpha=0.75,
                         # linewidth=1,
                         linestyle='solid'
                        )
                
                # Plot the violin-plot.
                v = ax_contrast.violinplot(bootsDelta['stat_array'], [xpos], 
                                           widths = violinWidth, 
                                           showextrema = False, 
                                           showmeans = False)
                halfviolin(v, half = 'right', color = 'k')

            if floatContrast:
                # Set reference lines
                if paired is False:
                    ## First get leftmost limit of left reference group
                    xtemp, _=np.array(sw.collections[0].get_offsets()).T
                    leftxlim=xtemp.min()
                    ## Then get leftmost limit of right test group
                    xtemp, _=np.array(sw.collections[1].get_offsets()).T
                    rightxlim=xtemp.min()
                    ref=tempbs['summary']
                else:
                    leftxlim=0
                    rightxlim=0.25
                    ref=bootsDelta['summary']
                    ax_contrast.set_xlim(-0.25, 0.5) # does this work?

                ## zero line
                ax_contrast.hlines(0,                   # y-coordinates
                                leftxlim, 3.5,       # x-coordinates, start and end.
                                linestyle=contrastZeroLineStyle,
                                linewidth=1,
                                color=contrastZeroLineColor)

                ## effect size line
                ax_contrast.hlines(ref, 
                                rightxlim, 3.5,        # x-coordinates, start and end.
                                linestyle=contrastEffectSizeLineStyle,
                                linewidth=1,
                                color=contrastEffectSizeLineColor)


                if paired is False:
                    es=float(tempbs['summary'])
                    refSum=tempbs['statistic_ref']
                else:
                    es=float(bootsDelta['summary'])
                    refSum=statfunction(plotdat[current_tuple[0]])
                ## If the effect size is positive, shift the right axis up.
                if es>0:
                    rightmin=ax_raw.get_ylim()[0]-es
                    rightmax=ax_raw.get_ylim()[1]-es
                ## If the effect size is negative, shift the right axis down.
                elif es<0:
                    rightmin=ax_raw.get_ylim()[0]+es
                    rightmax=ax_raw.get_ylim()[1]+es
                ax_contrast.set_ylim(rightmin, rightmax)

                if gsIdx>0:
                    ax_contrast.set_ylabel('')
                align_yaxis(ax_raw, refSum, ax_contrast, 0.)

            else:
                # Set bottom axes ybounds
                if contrastYlim is not None:
                    ax_contrast.set_ylim(contrastYlim)

                if paired is False:
                    # Set xlims so everything is properly visible!
                    swarm_xbounds=ax_raw.get_xbound()
                    ax_contrast.set_xbound(swarm_xbounds[0] -(summaryLineWidth * 1.1), 
                        swarm_xbounds[1] + (summaryLineWidth * 1.1))
                else:
                    ax_contrast.set_xlim(-0.05,0.25+violinWidth)

        else:
            # Plot the CIs on the bottom axes.
            plotbootstrap_hubspoke(
                bslist=bscontrast,
                ax=ax_contrast,
                violinWidth=violinWidth,
                violinOffset=violinOffset,
                markersize=summaryMarkerSize,
                marker=summaryMarkerType,
                linewidth=lineWidth)

        if floatContrast is False:
            fig.add_subplot(ax_contrast)

        if gsIdx>0:
            ax_raw.set_ylabel('')
            ax_contrast.set_ylabel('')

    # Turn contrastList into a pandas DataFrame,
    contrastList=pd.DataFrame(contrastList).T
    contrastList.columns=contrastListNames

    # Get number of axes in figure for aesthetic tweaks.
    axesCount=len(fig.get_axes())
    for i in range(0, axesCount, 2):
        # Set new tick labels.
        # The tick labels belong to the SWARM axes
        # for both floating and non-floating plots.
        # This is because `sharex` was invoked.
        axx=fig.axes[i]
        newticklabs=list()
        for xticklab in axx.xaxis.get_ticklabels():
            t=xticklab.get_text()
            if paired:
                N=str(counts)
            else:
                N=str(counts.ix[t])

            if showGroupCount:
                newticklabs.append(t+' n='+N)
            else:
                newticklabs.append(t)
            axx.set_xticklabels(
                newticklabs,
                rotation=tickAngle,
                horizontalalignment=tickAlignment)

    ## Loop thru SWARM axes for aesthetic touchups.
    for i in range(0, axesCount, 2):
        axx=fig.axes[i]

        if floatContrast is False:
            axx.xaxis.set_visible(False)
            sns.despine(ax=axx, trim=True, bottom=False, left=False)
        else:
            sns.despine(ax=axx, trim=True, bottom=True, left=True)

        if i==0:
            drawback_y(axx)

        if i!=axesCount-2 and 'hue' in kwargs:
            # If this is not the final swarmplot, remove the hue legend.
            axx.legend().set_visible(False)

        if showAllYAxes is False:
            if i in range(2, axesCount):
                axx.yaxis.set_visible(False)
            else:
                # Draw back the lines for the relevant y-axes.
                # Not entirely sure why I have to do this.
                drawback_y(axx)
        else:
            drawback_y(axx)

        # Add zero reference line for swarmplots with bars.
        if summaryBar is True:
            axx.add_artist(Line2D(
                (axx.xaxis.get_view_interval()[0], 
                    axx.xaxis.get_view_interval()[1]), 
                (0,0),
                color='black', linewidth=0.75
                )
            )
        
        if legend is False:
            axx.legend().set_visible(False)
        else:
            if i==axesCount-2: # the last (rightmost) swarm axes.
                axx.legend(loc='top right',
                    bbox_to_anchor=(1.1,1.0),
                    fontsize=legendFontSize,
                    **legendFontProps)

    ## Loop thru the CONTRAST axes and perform aesthetic touch-ups.
    ## Get the y-limits:
    for j,i in enumerate(range(1, axesCount, 2)):
        axx=fig.get_axes()[i]

        if floatContrast is False:
            xleft, xright=axx.xaxis.get_view_interval()
            # Draw zero reference line.
            axx.hlines(y=0,
                xmin=xleft-1, 
                xmax=xright+1,
                linestyle=contrastZeroLineStyle,
                linewidth=0.75,
                color=contrastZeroLineColor)
            # reset view interval.
            axx.set_xlim(xleft, xright)

            if showAllYAxes is False:
                if i in range(2, axesCount):
                    axx.yaxis.set_visible(False)
                else:
                    # Draw back the lines for the relevant y-axes, only is axesCount is 2.
                    # Not entirely sure why I have to do this.
                    if axesCount==2:
                        drawback_y(axx)

            sns.despine(ax=axx, 
                top=True, right=True, 
                left=False, bottom=False, 
                trim=True)
            if j==0 and axesCount==2:
                # Draw back x-axis lines connecting ticks.
                drawback_x(axx)
            # Rotate tick labels.
            rotateTicks(axx,tickAngle,tickAlignment)

        elif floatContrast is True:
            if paired is True:
                # Get the bootstrapped contrast range.
                lower=np.min(contrastList.ix['stat_array',j])
                upper=np.max(contrastList.ix['stat_array',j])
            else:
                lower=np.min(contrastList.ix['diffarray',j])
                upper=np.max(contrastList.ix['diffarray',j])
            meandiff=contrastList.ix['summary', j]

            ## Make sure we have zero in the limits.
            if lower>0:
                lower=0.
            if upper<0:
                upper=0.

            ## Get the tick interval from the left y-axis.
            leftticks=fig.get_axes()[i-1].get_yticks()
            tickstep=leftticks[1] -leftticks[0]

            ## First re-draw of axis with new tick interval
            axx.yaxis.set_major_locator(MultipleLocator(base=tickstep))
            newticks1=axx.get_yticks()

            ## Obtain major ticks that comfortably encompass lower and upper.
            newticks2=list()
            for a,b in enumerate(newticks1):
                if (b >= lower and b <= upper):
                    # if the tick lies within upper and lower, take it.
                    newticks2.append(b)
            # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
            if np.max(newticks2)<meandiff:
                ind=np.where(newticks1 == np.max(newticks2))[0][0] # find out the max tick index in newticks1.
                newticks2.append( newticks1[ind+1] )
            elif meandiff<np.min(newticks2):
                ind=np.where(newticks1 == np.min(newticks2))[0][0] # find out the min tick index in newticks1.
                newticks2.append( newticks1[ind-1] )
            newticks2=np.array(newticks2)
            newticks2.sort()

            ## Second re-draw of axis to shrink it to desired limits.
            axx.yaxis.set_major_locator(FixedLocator(locs=newticks2))
            
            ## Despine the axes.
            sns.despine(ax=axx, trim=True, 
                bottom=False, right=False,
                left=True, top=True)

    # Normalize bottom/right Contrast axes to each other for Cummings hub-and-spoke plots.
    if (axesCount>2 and 
        contrastShareY is True and 
        floatContrast is False):

        # Set contrast ylim as max ticks of leftmost swarm axes.
        if contrastYlim is None:
            lower=list()
            upper=list()
            for c in range(0,len(contrastList.columns)):
                lower.append( np.min(contrastList.ix['bca_ci_low',c]) )
                upper.append( np.max(contrastList.ix['bca_ci_high',c]) )
            lower=np.min(lower)
            upper=np.max(upper)
        else:
            lower=contrastYlim[0]
            upper=contrastYlim[1]

        normalizeContrastY(fig, 
            contrast_ylim = contrastYlim, 
            show_all_yaxes = showAllYAxes)

    # Zero gaps between plots on the same row, if floatContrast is False
    if (floatContrast is False and showAllYAxes is False):
        gsMain.update(wspace=0.)

    else:    
        # Tight Layout!
        gsMain.tight_layout(fig)
    
    # And we're all done.
    rcdefaults() # restore matplotlib defaults.
    sns.set() # restore seaborn defaults.
    return fig, contrastList
コード例 #42
0
def plot_learning_rate(
        output_dir=expanduser('~/output/recommender/learning_rate')):
    with open(join(output_dir, 'results_netflix.json'), 'r') as f:
        data_netflix = json.load(f)
    with open(join(output_dir, 'results_10m.json'), 'r') as f:
        data_10m = json.load(f)
    min_time = 400
    for i, learning_rate in enumerate(
            sorted(data_netflix.keys(), key=lambda t: float(t))):
        this_data = data_netflix[learning_rate]
        min_time = min(this_data['time'][0], min_time)
    for i, learning_rate in enumerate(
            sorted(data_netflix.keys(), key=lambda t: float(t))):
        this_data = data_netflix[learning_rate]
        for j in range(len(this_data)):
            this_data['time'][j] -= this_data['time'][0] - min_time
    fig = plt.figure()
    # fig.subplots_adjust(right=0.7)
    fig.subplots_adjust(bottom=0.33)
    fig.subplots_adjust(top=0.99)
    fig.subplots_adjust(right=0.98)

    fig.set_figwidth(3.25653379549)
    fig.set_figheight(1.25)
    ax = {}
    gs = gridspec.GridSpec(1, 2)
    palette = sns.cubehelix_palette(10,
                                    start=0,
                                    rot=3,
                                    hue=1,
                                    dark=.3,
                                    light=.7,
                                    reverse=False)
    for j, data in enumerate([data_10m, data_netflix]):
        ax[j] = fig.add_subplot(gs[j])
        # palette = sns.hls_palette(10, l=.4, s=.7)
        for i, learning_rate in enumerate(
                sorted(data.keys(), key=lambda t: float(t))):
            if float(learning_rate) > .6:
                this_data = data[learning_rate]
                ax[j].plot(np.linspace(0., 20, len(this_data['rmse'])),
                           this_data['rmse'],
                           label='%.2f' % float(learning_rate),
                           color=palette[i],
                           zorder=int(100 * float(learning_rate)))
                ax[j].set_xscale('log')
        sns.despine(fig, ax)

        ax[j].spines['left'].set_color((.6, .6, .6))
        ax[j].spines['bottom'].set_color((.6, .6, .6))
        ax[j].xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[j].yaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[j].tick_params(axis='y', labelsize=6)

    ax[0].set_ylabel('RMSE on test set')
    ax[0].set_xlabel('Epoch', ha='left', va='top')
    ax[0].xaxis.set_label_coords(-.18, -0.055)

    ax[0].set_xlim([.1, 20])
    ax[0].set_xticks([1, 10, 20])
    ax[0].set_xticklabels(['1', '10', '20'])
    ax[1].set_xlim([.1, 20])
    ax[1].set_xticks([.1, 1, 10, 20])
    ax[1].set_xticklabels(['.1', '1', '10', '20'])

    ax[0].annotate('MovieLens 10M',
                   xy=(.95, .8),
                   ha='right',
                   xycoords='axes fraction')
    ax[1].annotate('Netflix',
                   xy=(.95, .8),
                   ha='right',
                   xycoords='axes fraction')

    ax[0].set_ylim([0.795, 0.863])
    ax[1].set_ylim([0.93, 0.983])
    ax[0].legend(ncol=4,
                 loc='upper left',
                 bbox_to_anchor=(0., -.13),
                 fontsize=6,
                 numpoints=1,
                 columnspacing=.3,
                 frameon=False)
    ax[0].annotate('Learning rate $\\beta$',
                   xy=(1.6, -.38),
                   xycoords='axes fraction')
    ltext = ax[0].get_legend().get_texts()
    plt.setp(ltext, fontsize=7)
    plt.savefig(expanduser('~/output/icml/learning_rate.pdf'))
コード例 #43
0
def plot_learning_rate(output_dir=expanduser('~/output/recommender/learning_rate')):
    with open(join(output_dir, 'results_netflix.json'), 'r') as f:
        data_netflix = json.load(f)
    with open(join(output_dir, 'results_10m.json'), 'r') as f:
        data_10m = json.load(f)
    min_time = 400
    for i, learning_rate in enumerate(sorted(data_netflix.keys(), key=lambda t : float(t))):
        this_data = data_netflix[learning_rate]
        min_time = min(this_data['time'][0], min_time)
    for i, learning_rate in enumerate(sorted(data_netflix.keys(), key=lambda t : float(t))):
        this_data = data_netflix[learning_rate]
        for j in range(len(this_data)):
            this_data['time'][j] -= this_data['time'][0] - min_time
    fig = plt.figure()
    # fig.subplots_adjust(right=0.7)
    fig.subplots_adjust(bottom=0.33)
    fig.subplots_adjust(top=0.99)
    fig.subplots_adjust(right=0.98)

    fig.set_figwidth(3.25653379549)
    fig.set_figheight(1.25)
    ax = {}
    gs = gridspec.GridSpec(1, 2)
    palette = sns.cubehelix_palette(10, start=0, rot=3, hue=1, dark=.3,
                                    light=.7,
                                    reverse=False)
    for j, data in enumerate([data_10m, data_netflix]):
        ax[j] = fig.add_subplot(gs[j])
        # palette = sns.hls_palette(10, l=.4, s=.7)
        for i, learning_rate in enumerate(sorted(data.keys(), key=lambda t : float(t))):
            if float(learning_rate) > .6:
                this_data = data[learning_rate]
                ax[j].plot(np.linspace(0., 20, len(this_data['rmse'])),
                           this_data['rmse'],
                           label='%.2f' % float(learning_rate),
                           color=palette[i],
                           zorder=int(100 * float(learning_rate)))
                ax[j].set_xscale('log')
        sns.despine(fig, ax)

        ax[j].spines['left'].set_color((.6, .6, .6))
        ax[j].spines['bottom'].set_color((.6, .6, .6))
        ax[j].xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[j].yaxis.set_tick_params(color=(.6, .6, .6), which='both')
        ax[j].tick_params(axis='y', labelsize=6)

    ax[0].set_ylabel('RMSE on test set')
    ax[0].set_xlabel('Epoch', ha='left', va='top')
    ax[0].xaxis.set_label_coords(-.18, -0.055)

    ax[0].set_xlim([.1, 20])
    ax[0].set_xticks([1, 10, 20])
    ax[0].set_xticklabels(['1', '10', '20'])
    ax[1].set_xlim([.1, 20])
    ax[1].set_xticks([.1, 1, 10, 20])
    ax[1].set_xticklabels(['.1', '1', '10', '20'])

    ax[0].annotate('MovieLens 10M', xy=(.95, .8), ha='right', xycoords='axes fraction')
    ax[1].annotate('Netflix', xy=(.95, .8), ha='right', xycoords='axes fraction')


    ax[0].set_ylim([0.795, 0.863])
    ax[1].set_ylim([0.93, 0.983])
    ax[0].legend(ncol=4, loc='upper left', bbox_to_anchor=(0., -.13), fontsize=6, numpoints=1, columnspacing=.3, frameon=False)
    ax[0].annotate('Learning rate $\\beta$', xy=(1.6, -.38), xycoords='axes fraction')
    ltext  = ax[0].get_legend().get_texts()
    plt.setp(ltext, fontsize=7)
    plt.savefig(expanduser('~/output/icml/learning_rate.pdf'))
コード例 #44
0
ファイル: sandbox.py プロジェクト: josesho/bootstrapContrast
def contrastplot_test(
    data, x, y, idx=None, 
    
    alpha=0.75, 
    axis_title_size=None,

    barWidth=5,

    contrastShareY=True,
    contrastEffectSizeLineStyle='solid',
    contrastEffectSizeLineColor='black',
    contrastYlim=None,
    contrastZeroLineStyle='solid', 
    contrastZeroLineColor='black', 

    effectSizeYLabel="Effect Size", 

    figsize=None, 
    floatContrast=True,
    floatSwarmSpacer=0.2,

    heightRatio=(1, 1),

    idcol=None,

    lineWidth=2,
    legend=True,
    legendFontSize=14,
    legendFontProps={},

    paired=False,
    pal=None, 

    rawMarkerSize=8,
    rawMarkerType='o',
    reps=3000,
    
    showGroupCount=True,
    show95CI=False, 
    showAllYAxes=False,
    showRawData=True,
    smoothboot=False, 
    statfunction=None, 

    summaryBar=False, 
    summaryBarColor='grey',
    summaryBarAlpha=0.25,

    summaryColour='black', 
    summaryLine=True, 
    summaryLineStyle='solid', 
    summaryLineWidth=0.25, 

    summaryMarkerSize=10, 
    summaryMarkerType='o',

    swarmShareY=True, 
    swarmYlim=None, 

    tickAngle=45,
    tickAlignment='right',

    violinOffset=0.375,
    violinWidth=0.2, 
    violinColor='k',

    xticksize=None,
    yticksize=None,

    **kwargs):

    '''Takes a pandas dataframe and produces a contrast plot:
    either a Cummings hub-and-spoke plot or a Gardner-Altman contrast plot.
    -----------------------------------------------------------------------
    Description of flags upcoming.'''

    # Check that `data` is a pandas dataframe
    if 'DataFrame' not in str(type(data)):
        raise TypeError("The object passed to the command is not not a pandas DataFrame.\
         Please convert it to a pandas DataFrame.")

    # Get and set levels of data[x]    
    if idx is None:
        widthratio=[1]
        allgrps=np.sort(data[x].unique())
        if paired:
            # If `idx` is not specified, just take the FIRST TWO levels alphabetically.
            tuple_in=tuple(allgrps[0:2],)
        else:
            # No idx is given, so all groups are compared to the first one in the DataFrame column.
            tuple_in=(tuple(allgrps), )
            if len(allgrps)>2:
                floatContrast=False

    else:
        if all(isinstance(element, str) for element in idx):
            # if idx is supplied but not a multiplot (ie single list or tuple) 
            tuple_in=(idx, )
            widthratio=[1]
            if len(idx)>2:
                floatContrast=False
        elif all(isinstance(element, tuple) for element in idx):
            # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
            tuple_in=idx
            if ( any(len(element)>2 for element in tuple_in) ):
                # if any of the tuples in idx has more than 2 groups, we turn set floatContrast as False.
                floatContrast=False
            # Make sure the widthratio of the seperate multiplot corresponds to how 
            # many groups there are in each one.
            widthratio=[]
            for i in tuple_in:
                widthratio.append(len(i))
        else:
            raise TypeError("The object passed to `idx` consists of a mixture of single strings and tuples. \
                Please make sure that `idx` is either a tuple of column names, or a tuple of tuples for plotting.")

    # initialise statfunction
    if statfunction == None:
        statfunction=np.mean

    # Create list to collect all the contrast DataFrames generated.
    contrastList=list()
    contrastListNames=list()
    # # Calculate the bootstraps according to idx.
    # for ix, current_tuple in enumerate(tuple_in):
    #     bscontrast=list()
    #     for i in range (1, len(current_tuple)):
    #     # Note that you start from one. No need to do auto-contrast!
    #         tempbs=bootstrap_contrast(
    #             data=data,
    #             x=x,
    #             y=y,
    #             idx=[current_tuple[0], current_tuple[i]],
    #             statfunction=statfunction,
    #             smoothboot=smoothboot,
    #             reps=reps)
    #         bscontrast.append(tempbs)
    #         contrastList.append(tempbs)
    #         contrastListNames.append(current_tuple[i]+' vs. '+current_tuple[0])

    # Setting color palette for plotting.
    if pal is None:
        if 'hue' in kwargs:
            colorCol=kwargs['hue']
            colGrps=data[colorCol].unique()
            nColors=len(colGrps)
        else:
            colorCol=x
            colGrps=data[x].unique()
            nColors=len([element for tupl in tuple_in for element in tupl])
        plotPal=dict( zip( colGrps, sns.color_palette(n_colors=nColors) ) )
    else:
        plotPal=pal

    # Ensure summaryLine and summaryBar are not displayed together.
    if summaryLine is True and summaryBar is True:
        summaryBar=True
        summaryLine=False
    # Turn off summary line if floatContrast is true
    if floatContrast:
        summaryLine=False

    if swarmYlim is None:
        # get range of _selected groups_.
        u = list()
        for t in idx:
            for i in np.unique(t):
                u.append(i)
        u = np.unique(u)
        tempdat=data[data[x].isin(u)]
        swarm_ylim=np.array([np.min(tempdat[y]), np.max(tempdat[y])])
    else:
        swarm_ylim=np.array([swarmYlim[0],swarmYlim[1]])

    if contrastYlim is not None:
        contrastYlim=np.array([contrastYlim[0],contrastYlim[1]])

    barWidth=barWidth/1000 # Not sure why have to reduce the barwidth by this much! 
    if showRawData is True:
        maxSwarmSpan=0.25
    else:
        maxSwarmSpan=barWidth

    # Expand the ylim in both directions.
    ## Find half of the range of swarm_ylim.
    swarmrange=swarm_ylim[1] -swarm_ylim[0]
    pad=0.1*swarmrange
    x2=np.array([swarm_ylim[0]-pad, swarm_ylim[1]+pad])
    swarm_ylim=x2

    # plot params
    if axis_title_size is None:
        axis_title_size=25
    if yticksize is None:
        yticksize=18
    if xticksize is None:
        xticksize=18

    # Set clean style
    sns.set(style='ticks')

    axisTitleParams={'labelsize' : axis_title_size}
    xtickParams={'labelsize' : xticksize}
    ytickParams={'labelsize' : yticksize}
    svgParams={'fonttype' : 'none'}

    rc('axes', **axisTitleParams)
    rc('xtick', **xtickParams)
    rc('ytick', **ytickParams)
    rc('svg', **svgParams) 

    if figsize is None:
        if len(tuple_in)>2:
            figsize=(12,(12/np.sqrt(2)))
        else:
            figsize=(8,(8/np.sqrt(2)))
    
    # Initialise figure, taking into account desired figsize.
    fig=plt.figure(figsize=figsize)

    # Initialise GridSpec based on `tuple_in` shape.
    gsMain=gridspec.GridSpec( 
        1, np.shape(tuple_in)[0], 
         # 1 row; columns based on number of tuples in tuple.
         width_ratios=widthratio,
         wspace=0 )

    for gsIdx, current_tuple in enumerate(tuple_in):
        #### FOR EACH TUPLE IN IDX
        plotdat=data[data[x].isin(current_tuple)]
        plotdat[x]=plotdat[x].astype("category")
        plotdat[x].cat.set_categories(
            current_tuple,
            ordered=True,
            inplace=True)
        plotdat.sort_values(by=[x])
        # Drop all nans. 
        plotdat=plotdat.dropna()

        # Calculate summaries.
        summaries=plotdat.groupby([x],sort=True)[y].apply(statfunction)

        if floatContrast is True:
            # Use fig.add_subplot instead of plt.Subplot
            ax_raw=fig.add_subplot(gsMain[gsIdx],
                frame_on=False)
            ax_contrast=ax_raw.twinx()
        else:
        # Create subGridSpec with 2 rows and 1 column.
            subGridSpec=gridspec.GridSpecFromSubplotSpec(2, 1,
                subplot_spec=gsMain[gsIdx],
                wspace=0)
            # Use plt.Subplot instead of fig.add_subplot
            ax_raw=plt.Subplot(fig,
                subGridSpec[0, 0],
                frame_on=False)
            ax_contrast=plt.Subplot(fig,
                subGridSpec[1, 0],
                sharex=ax_raw,
                frame_on=False)
        # Calculate the boostrapped contrast
        bscontrast=list()
        for i in range (1, len(current_tuple)):
        # Note that you start from one. No need to do auto-contrast!
            tempbs=bootstrap_contrast(
                data=data,
                x=x,
                y=y,
                idx=[current_tuple[0], current_tuple[i]],
                statfunction=statfunction,
                smoothboot=smoothboot,
                reps=reps)
            bscontrast.append(tempbs)
            contrastList.append(tempbs)
            contrastListNames.append(current_tuple[i]+' vs. '+current_tuple[0])
        
        #### PLOT RAW DATA.
        if showRawData is True:
            # Seaborn swarmplot doc says to set custom ylims first.
            ax_raw.set_ylim(swarm_ylim)
            sw=sns.swarmplot(
                data=plotdat, 
                x=x, y=y, 
                order=current_tuple, 
                ax=ax_raw, 
                alpha=alpha, 
                palette=plotPal,
                size=rawMarkerSize,
                marker=rawMarkerType,
                **kwargs)

        if summaryBar is True:
            bar_raw=sns.barplot(
                x=summaries.index.tolist(),
                y=summaries.values,
                facecolor=summaryBarColor,
                ax=ax_raw,
                alpha=summaryBarAlpha)
        
        if floatContrast:
            # Get horizontal offset values.
            maxXBefore=max(sw.collections[0].get_offsets().T[0])
            minXAfter=min(sw.collections[1].get_offsets().T[0])
            xposAfter=maxXBefore+floatSwarmSpacer
            xAfterShift=minXAfter-xposAfter
            # shift the swarmplots
            offsetSwarmX(sw.collections[1], -xAfterShift)

            ## get swarm with largest span, set as max width of each barplot.
            for i, bar in enumerate(bar_raw.patches):
                x_width=bar.get_x()
                width=bar.get_width()
                centre=x_width + (width/2.)
                if i == 0:
                    bar.set_x(centre-maxSwarmSpan/2.)
                else:
                    bar.set_x(centre-xAfterShift-maxSwarmSpan/2.)
                bar.set_width(maxSwarmSpan)

            ## Set the ticks locations for ax_raw.
            ax_raw.xaxis.set_ticks((0, xposAfter))
            firstTick=ax_raw.xaxis.get_ticklabels()[0].get_text()
            secondTick=ax_raw.xaxis.get_ticklabels()[1].get_text()
            ax_raw.set_xticklabels([firstTick,#+' n='+count[firstTick],
                                     secondTick],#+' n='+count[secondTick]],
                                   rotation=tickAngle,
                                   horizontalalignment=tickAlignment)

        if summaryLine is True:
            for i, m in enumerate(summaries):
                ax_raw.plot(
                    (i -summaryLineWidth, 
                    i + summaryLineWidth), # x-coordinates
                    (m, m),
                    color=summaryColour, 
                    linestyle=summaryLineStyle)

        if show95CI is True:
                sns.barplot(
                    data=plotdat, 
                    x=x, y=y, 
                    ax=ax_raw, 
                    alpha=0, ci=95)

        ax_raw.set_xlabel("")
        if floatContrast is False:
            fig.add_subplot(ax_raw)

        #### PLOT CONTRAST DATA.
        if len(current_tuple)==2:
            # Plot the CIs on the contrast axes.
            plotbootstrap(sw.collections[1],
                          bslist=tempbs,
                          ax=ax_contrast, 
                          violinWidth=violinWidth,
                          violinOffset=violinOffset,
                          markersize=summaryMarkerSize,
                          marker=summaryMarkerType,
                          offset=floatContrast,
                          color=violinColor,
                          linewidth=1)
            if floatContrast:
                # Set reference lines
                ## First get leftmost limit of left reference group
                xtemp, _=np.array(sw.collections[0].get_offsets()).T
                leftxlim=xtemp.min()
                ## Then get leftmost limit of right test group
                xtemp, _=np.array(sw.collections[1].get_offsets()).T
                rightxlim=xtemp.min()

                ## zero line
                ax_contrast.hlines(0,                   # y-coordinates
                                leftxlim, 3.5,       # x-coordinates, start and end.
                                linestyle=contrastZeroLineStyle,
                                linewidth=0.75,
                                color=contrastZeroLineColor)

                ## effect size line
                ax_contrast.hlines(tempbs['summary'], 
                                rightxlim, 3.5,        # x-coordinates, start and end.
                                linestyle=contrastEffectSizeLineStyle,
                                linewidth=0.75,
                                color=contrastEffectSizeLineColor)

                
                ## If the effect size is positive, shift the right axis up.
                if float(tempbs['summary'])>0:
                    rightmin=ax_raw.get_ylim()[0] -float(tempbs['summary'])
                    rightmax=ax_raw.get_ylim()[1] -float(tempbs['summary'])
                ## If the effect size is negative, shift the right axis down.
                elif float(tempbs['summary'])<0:
                    rightmin=ax_raw.get_ylim()[0] + float(tempbs['summary'])
                    rightmax=ax_raw.get_ylim()[1] + float(tempbs['summary'])

                ax_contrast.set_ylim(rightmin, rightmax)

                    
                if gsIdx>0:
                    ax_contrast.set_ylabel('')

                align_yaxis(ax_raw, tempbs['statistic_ref'], ax_contrast, 0.)

            else:
                # Set bottom axes ybounds
                if contrastYlim is not None:
                    ax_contrast.set_ylim(contrastYlim)
                
                # Set xlims so everything is properly visible!
                swarm_xbounds=ax_raw.get_xbound()
                ax_contrast.set_xbound(swarm_xbounds[0] -(summaryLineWidth * 1.1), 
                    swarm_xbounds[1] + (summaryLineWidth * 1.1))

        else:
            # Plot the CIs on the bottom axes.
            plotbootstrap_hubspoke(
                bslist=bscontrast,
                ax=ax_contrast,
                violinWidth=violinWidth,
                violinOffset=violinOffset,
                markersize=summaryMarkerSize,
                marker=summaryMarkerType,
                linewidth=lineWidth)

        if floatContrast is False:
            fig.add_subplot(ax_contrast)

        if gsIdx>0:
            ax_raw.set_ylabel('')
            ax_contrast.set_ylabel('')

    # Turn contrastList into a pandas DataFrame,
    contrastList=pd.DataFrame(contrastList).T
    contrastList.columns=contrastListNames
    
    ########
    axesCount=len(fig.get_axes())

    ## Loop thru SWARM axes for aesthetic touchups.
    for i in range(0, axesCount, 2):
        axx=fig.axes[i]

        if i!=axesCount-2 and 'hue' in kwargs:
            # If this is not the final swarmplot, remove the hue legend.
            axx.legend().set_visible(False)

        if floatContrast is False:
            axx.xaxis.set_visible(False)
            sns.despine(ax=axx, trim=True, bottom=False, left=False)
        else:
            sns.despine(ax=axx, trim=True, bottom=True, left=True)

        if showAllYAxes is False:
            if i in range(2, axesCount):
                axx.yaxis.set_visible(showAllYAxes)
            else:
                # Draw back the lines for the relevant y-axes.
                # Not entirely sure why I have to do this.
                drawback_y(axx)

        # Add zero reference line for swarmplots with bars.
        if summaryBar is True:
            axx.add_artist(Line2D(
                (axx.xaxis.get_view_interval()[0], 
                    axx.xaxis.get_view_interval()[1]), 
                (0,0),
                color='black', linewidth=0.75
                )
            )

        # I don't know why the swarm axes controls the contrast axes ticks....
        if showGroupCount:
            count=data.groupby(x).count()[y]
            newticks=list()
            for ix, t in enumerate(axx.xaxis.get_ticklabels()):
                t_text=t.get_text()
                nt=t_text+' n='+str(count[t_text])
                newticks.append(nt)
            axx.xaxis.set_ticklabels(newticks)

        if legend is False:
            axx.legend().set_visible(False)
        else:
            if i==axesCount-2: # the last (rightmost) swarm axes.
                axx.legend(loc='top right',
                    bbox_to_anchor=(1.1,1.0),
                    fontsize=legendFontSize,
                    **legendFontProps)

    ## Loop thru the CONTRAST axes and perform aesthetic touch-ups.
    ## Get the y-limits:
    for j,i in enumerate(range(1, axesCount, 2)):
        axx=fig.get_axes()[i]

        if floatContrast is False:
            xleft, xright=axx.xaxis.get_view_interval()
            # Draw zero reference line.
            axx.hlines(y=0,
                xmin=xleft-1, 
                xmax=xright+1,
                linestyle=contrastZeroLineStyle,
                linewidth=0.75,
                color=contrastZeroLineColor)
            # reset view interval.
            axx.set_xlim(xleft, xright)
            # # Draw back x-axis lines connecting ticks.
            # drawback_x(axx)

            if showAllYAxes is False:
                if i in range(2, axesCount):
                    axx.yaxis.set_visible(False)
                else:
                    # Draw back the lines for the relevant y-axes.
                    # Not entirely sure why I have to do this.
                    drawback_y(axx)

            sns.despine(ax=axx, 
                top=True, right=True, 
                left=False, bottom=False, 
                trim=True)

            # Rotate tick labels.
            rotateTicks(axx,tickAngle,tickAlignment)

        else:
            # Re-draw the floating axis to the correct limits.
            lower=np.min(contrastList.ix['diffarray',j])
            upper=np.max(contrastList.ix['diffarray',j])
            meandiff=contrastList.ix['summary', j]

            ## Make sure we have zero in the limits.
            if lower>0:
                lower=0.
            if upper<0:
                upper=0.

            ## Get the tick interval from the left y-axis.
            leftticks=fig.get_axes()[i-1].get_yticks()
            tickstep=leftticks[1] -leftticks[0]

            ## First re-draw of axis with new tick interval
            axx.yaxis.set_major_locator(MultipleLocator(base=tickstep))
            newticks1=axx.get_yticks()

            ## Obtain major ticks that comfortably encompass lower and upper.
            newticks2=list()
            for a,b in enumerate(newticks1):
                if (b >= lower and b <= upper):
                    # if the tick lies within upper and lower, take it.
                    newticks2.append(b)
            # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
            if np.max(newticks2)<meandiff:
                ind=np.where(newticks1 == np.max(newticks2))[0][0] # find out the max tick index in newticks1.
                newticks2.append( newticks1[ind+1] )
            elif meandiff<np.min(newticks2):
                ind=np.where(newticks1 == np.min(newticks2))[0][0] # find out the min tick index in newticks1.
                newticks2.append( newticks1[ind-1] )
            newticks2=np.array(newticks2)
            newticks2.sort()

            ## Second re-draw of axis to shrink it to desired limits.
            axx.yaxis.set_major_locator(FixedLocator(locs=newticks2))
            
            ## Despine the axes.
            sns.despine(ax=axx, trim=True, 
                bottom=False, right=False,
                left=True, top=True)

    # Normalize bottom/right Contrast axes to each other for Cummings hub-and-spoke plots.
    if (axesCount>2 and 
        contrastShareY is True and 
        floatContrast is False):

        # Set contrast ylim as max ticks of leftmost swarm axes.
        if contrastYlim is None:
            lower=list()
            upper=list()
            for c in range(0,len(contrastList.columns)):
                lower.append( np.min(contrastList.ix['bca_ci_low',c]) )
                upper.append( np.max(contrastList.ix['bca_ci_high',c]) )
            lower=np.min(lower)
            upper=np.max(upper)
        else:
            lower=contrastYlim[0]
            upper=contrastYlim[1]

        normalizeContrastY(fig, 
            contrast_ylim = contrastYlim, 
            show_all_yaxes = showAllYAxes)

    # if (axesCount==2 and 
    #     floatContrast is False):
    #     drawback_x(fig.get_axes()[1])
    #     drawback_y(fig.get_axes()[1])

    # if swarmShareY is False:
    #     for i in range(0, axesCount, 2):
    #         drawback_y(fig.get_axes()[i])
                       
    # if contrastShareY is False:
    #     for i in range(1, axesCount, 2):
    #         if floatContrast is True:
    #             sns.despine(ax=fig.get_axes()[i], 
    #                        top=True, right=False, left=True, bottom=True, 
    #                        trim=True)
    #         else:
    #             sns.despine(ax=fig.get_axes()[i], trim=True)

    # Zero gaps between plots on the same row, if floatContrast is False
    if (floatContrast is False and showAllYAxes is False):
        gsMain.update(wspace=0.)

    else:    
        # Tight Layout!
        gsMain.tight_layout(fig)
    
    # And we're all done.
    rcdefaults() # restore matplotlib defaults.
    sns.set() # restore seaborn defaults.
    return fig, contrastList
コード例 #45
0
def plot_benchs(output_dir=expanduser('~/output/recommender/benches')):
    fig = plt.figure()

    fig.subplots_adjust(right=.9)
    fig.subplots_adjust(top=.915)
    fig.subplots_adjust(bottom=.12)
    fig.subplots_adjust(left=.08)
    fig.set_figheight(fig.get_figheight() * 0.66)
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1.5])

    ylims = {
        '100k': [.90, .96],
        '1m': [.865, .915],
        '10m': [.795, .868],
        'netflix': [.928, .99]
    }
    xlims = {
        '100k': [0.0001, 10],
        '1m': [0.1, 15],
        '10m': [1, 200],
        'netflix': [30, 4000]
    }

    names = {
        'dl_partial': 'Proposed \n(partial projection)',
        'dl': 'Proposed \n(full projection)',
        'cd': 'Coordinate descent'
    }
    for i, version in enumerate(['1m', '10m', 'netflix']):
        with open(join(output_dir, 'results_%s.json' % version), 'r') as f:
            data = json.load(f)
        ax_time = fig.add_subplot(gs[0, i])
        ax_time.grid()
        sns.despine(fig, ax_time)

        ax_time.spines['left'].set_color((.6, .6, .6))
        ax_time.spines['bottom'].set_color((.6, .6, .6))
        ax_time.xaxis.set_tick_params(color=(.6, .6, .6), which='both')
        # ax_time.tick_params(axis='x', which='major', pad=2)
        ax_time.yaxis.set_tick_params(color=(.6, .6, .6), which='both')

        for tick in ax_time.xaxis.get_major_ticks():
            tick.label.set_fontsize(7)
            tick.label.set_color('black')
        for tick in ax_time.yaxis.get_major_ticks():
            tick.label.set_fontsize(7)
            tick.label.set_color('black')

        if i == 0:
            ax_time.set_ylabel('RMSE on test set')
        if i == 2:
            ax_time.set_xlabel('CPU time')
            ax_time.xaxis.set_label_coords(1.12, -0.045)

        ax_time.grid()
        palette = sns.cubehelix_palette(3,
                                        start=0,
                                        rot=.5,
                                        hue=1,
                                        dark=.3,
                                        light=.7,
                                        reverse=False)
        color = {'dl_partial': palette[2], 'dl': palette[1], 'cd': palette[0]}
        for estimator in sorted(OrderedDict(data).keys()):
            this_data = data[estimator]
            ax_time.plot(this_data['time'],
                         this_data['rmse'],
                         label=names[estimator],
                         color=color[estimator],
                         linewidth=2,
                         linestyle='-' if estimator != 'cd' else '--')
        if version == 'netflix':
            ax_time.legend(loc='upper left',
                           bbox_to_anchor=(.65, 1.1),
                           numpoints=1,
                           frameon=False)
        ax_time.set_xscale('log')
        ax_time.set_ylim(ylims[version])
        ax_time.set_xlim(xlims[version])
        if version == '1m':
            ax_time.set_xticks([.1, 1, 10])
            ax_time.set_xticklabels(['0.1 s', '1 s', '10 s'])
        elif version == '10m':
            ax_time.set_xticks([1, 10, 100])
            ax_time.set_xticklabels(['1 s', '10 s', '100 s'])
        else:
            ax_time.set_xticks([100, 1000])
            ax_time.set_xticklabels(['100 s', '1000 s'])
        ax_time.annotate(
            'MovieLens %s' %
            version.upper() if version != 'netflix' else 'Netflix (140M)',
            xy=(.5 if version != 'netflix' else .4, 1),
            xycoords='axes fraction',
            ha='center',
            va='bottom')
    plt.savefig(expanduser('~/output/icml/rec_bench.pdf'))