Example #1
0
def plot_wfs(waveforms,
             tvec,
             cols=None,
             lw=0.1,
             alpha=0.05,
             xlim=None,
             ylim=None,
             title=None,
             xlab=None,
             ylab=None,
             ffig=None,
             ax=None,
             **kwargs):
    """
    Plot large set of waveforms efficiently as LineCollections.

    waveforms: waveform matrix of (waveform x time sample)
    tvec: vector of sample times
    cols: color matrix of (waveform x RGBA)

    """

    # Init.
    nwfs = waveforms.shape[0]
    ntsamp = tvec.size
    ax = putil.axes(ax)

    if cols is None:
        cols = np.tile(putil.convert_to_rgba('g'), (nwfs, 1))

    # Plot all waveforms efficiently at the same time as LineCollections.
    # Reformat waveform matrix and tvec vector into format:
    # [[(t0, v0), (t1, v1)], [(t0, v0), (t1, v1)], ...]
    wf_col = waveforms.reshape((-1, 1))
    tvec_col = np.tile(np.array(tvec), (nwfs, 1)).reshape((-1, 1))
    tvec_wf_cols = np.hstack((tvec_col, wf_col)).reshape(-1, 1, 2)
    tv_segments = np.hstack([tvec_wf_cols[:-1], tvec_wf_cols[1:]])
    btw_wf_mask = ntsamp * np.arange(1, nwfs) - 1
    tv_segments = np.delete(tv_segments, btw_wf_mask, axis=0)

    # Set color of each segment.
    cols_segments = np.repeat(cols, ntsamp - 1, axis=0)

    # Create and add LineCollection to axes.
    lc = LineCollection(tv_segments,
                        linewidths=lw,
                        colors=cols_segments,
                        **kwargs)
    ax.add_collection(lc)
    # Need to update view manually after adding artists manually.
    ax.autoscale_view()

    # Format plot.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
Example #2
0
def raster(spk_trains,
           t_unit=ms,
           prds=None,
           c='b',
           xlim=None,
           title=None,
           xlab=None,
           ylab=None,
           ffig=None,
           ax=None):
    """Plot rasterplot."""

    # Init.
    ax = putil.axes(ax)

    putil.plot_periods(prds, ax=ax)
    putil.set_limits(ax, xlim)

    # There's nothing to plot.
    if not len(spk_trains):
        return ax

    # Plot raster.
    for i, spk_tr in enumerate(spk_trains):
        x = np.array(spk_tr.rescale(t_unit))
        y = (i + 1) * np.ones_like(x)

        # Spike markers are plotted in absolute size (figure coordinates).
        # ax.scatter(x, y, c=c, s=1.8, edgecolor=c, marker='|')

        # Spike markers are plotted in relative size (axis coordinates)
        patches = [
            Rectangle((xi - wsp / 2, yi - hsp / 2), wsp, hsp)
            for xi, yi in zip(x, y)
        ]
        collection = PatchCollection(patches, facecolor=c, edgecolor=c)
        ax.add_collection(collection)

    # Format plot.
    ylim = [0.5, len(spk_trains) + 0.5] if len(spk_trains) else [0, 1]
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    putil.hide_axes(ax, show_x=True)
    putil.hide_spines(ax)

    # Order trials from top to bottom, only after setting axis limits.
    ax.invert_yaxis()

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
Example #3
0
def plot_tuning(xfit,
                yfit,
                vals=None,
                meanr=None,
                semr=None,
                color='b',
                baseline=None,
                xticks=None,
                xlim=None,
                ylim=None,
                xlab=None,
                ylab=None,
                title=None,
                ffig=None,
                ax=None,
                **kwargs):
    """Plot tuning curve, optionally with data samples."""

    ax = putil.axes(ax)

    # Plot baseline.
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Plot fitted curve.
    pplot.lines(xfit, yfit, color=color, ax=ax)

    # Plot data samples.
    if meanr is not None and semr is not None:
        pplot.errorbar(vals,
                       meanr,
                       yerr=semr,
                       fmt='o',
                       color=color,
                       ax=ax,
                       **kwargs)

    # Set x axis ticks.
    if xticks is not None:
        putil.set_xtick_labels(ax, xticks)
    elif vals is not None:
        putil.set_xtick_labels(ax, vals)

    # Format plot.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
Example #4
0
def joint_scatter(x,
                  y,
                  is_sign=None,
                  kind='reg',
                  stat_func=util.pearson_r,
                  c='b',
                  xlim=None,
                  ylim=None,
                  xlab=None,
                  ylab=None,
                  title=None,
                  ytitle=None,
                  ffig=None,
                  **kwargs):
    """
    Plot paired data on scatter plot with
        - marginal distributions added to the side
        - linear regression on center scatter
        - N, r and p values reported

    Additional parameters for scatter (e.g. size) can be passed as kwargs.
    """

    # Create scatter plot and distribution plots on the side.
    g = sns.jointplot(x,
                      y,
                      color=c,
                      kind=kind,
                      stat_func=stat_func,
                      xlim=xlim,
                      ylim=ylim)
    ax = g.ax_joint  # scatter axes

    # Make non-significant points hollow (white face color).
    if is_sign is not None or kwargs is not None:
        ax.collections[0].set_visible(False)  # hide scatter points
        scatter(x, y, is_sign, c=c, ax=ax, **kwargs)

    # Add N to legend.
    leg_txt = g.ax_joint.get_legend().texts[0]
    new_txt = 'n = {}\n'.format(len(x)) + leg_txt.get_text()
    leg_txt.set_text(new_txt)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #5
0
def scatter(x,
            y,
            is_sign=None,
            c='b',
            bc='w',
            nc='grey',
            ec='k',
            alpha=0.5,
            xlim=None,
            ylim=None,
            xlab=None,
            ylab=None,
            title=None,
            ytitle=None,
            polar=False,
            id_line=True,
            ffig=None,
            ax=None,
            **kwargs):
    """
    Plot paired data on scatter plot.
    Color settings:
        - c:  face color of significant points (is_sign == True)
        - bc: face color of non-significant points (is_sign == False)
        - nc: face color of untested/untestable points (is_sign == None)
    """

    # Init.
    ax = putil.axes(ax, polar=polar)
    cols = c
    # Get point-specific color array.
    if (is_sign is not None) and isinstance(c, str) and isinstance(bc, str):
        cols = putil.get_cmat(is_sign, c, bc, nc)

    # Plot colored points.
    ax.scatter(x, y, c=cols, edgecolor=ec, alpha=alpha, **kwargs)

    # Add identity line.
    if id_line:
        putil.add_identity_line(equal_xy=True, zorder=99, ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #6
0
def plot_spike_count_hist(spk_cnt, title, ax, hist=True,
                          kde_alpha=1.0, kde_color='b', kde_lw=2):
    """Plot spike count histogram."""

    # Set histrogram parameters.
    max_spk = spk_cnt.max()
    bins = np.linspace(0, max_spk+1, max_spk+2) - 0.5
    hist_kws = {'edgecolor': 'grey'}
    kde_kws = {'alpha': kde_alpha, 'color': kde_color, 'lw': kde_lw}

    # Plot spike count histogram.
    sns.distplot(spk_cnt, bins=bins, ax=ax, hist=hist,
                 hist_kws=hist_kws, kde_kws=kde_kws)

    # Format plot.
    xlim = None
    ylim = None
    xlab = '# spikes'
    ylab = 'density'
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
Example #7
0
def heatmap(mat,
            vmin=None,
            vmax=None,
            cmap=None,
            cbar=True,
            cbar_ax=None,
            annot=None,
            square=False,
            xlab=None,
            ylab=None,
            title=None,
            ytitle=None,
            xlim=None,
            ylim=None,
            xticklabels=True,
            yticklabels=True,
            ffig=None,
            ax=None):
    """Plot rectangular data as heatmap."""

    # Plot data.
    ax = putil.axes(ax)
    sns.heatmap(mat,
                vmin,
                vmax,
                cmap,
                annot=annot,
                cbar=cbar,
                cbar_ax=cbar_ax,
                square=square,
                xticklabels=xticklabels,
                yticklabels=yticklabels,
                ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #8
0
def multi_hist(vals,
               xlim=None,
               ylim=None,
               xlab=None,
               ylab='n',
               title=None,
               ytitle=None,
               polar=False,
               ffig=None,
               ax=None,
               **kwargs):
    """Plot histogram of multiple samples side by side."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.hist(vals, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #9
0
def cat_hist(vals,
             xlim=None,
             ylim=None,
             xlab=None,
             ylab='n',
             title=None,
             ytitle=None,
             polar=False,
             ffig=None,
             ax=None,
             **kwargs):
    """Plot histogram of categorical data."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax = sns.countplot(vals, ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #10
0
def bars(x,
         y,
         ylim=None,
         xlim=None,
         xlab=None,
         ylab=None,
         title=None,
         ytitle=None,
         polar=False,
         ffig=None,
         ax=None,
         **kwargs):
    """Plot bar plot."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.bar(x, y, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #11
0
def mean_err(x,
             ymean,
             ystd,
             ylim=None,
             xlim=None,
             xlab=None,
             ylab=None,
             title=None,
             ytitle=None,
             polar=False,
             ffig=None,
             ax=None,
             mean_kws=None,
             band_kws=None):
    """Plot mean and highlighted band area around it."""

    # Init params.
    if mean_kws is None:
        mean_kws = dict()
    if band_kws is None:
        band_kws = dict()

    # Init data.
    ylower = ymean - ystd
    yupper = ymean + ystd

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    lines(x, ymean, ax=ax, **mean_kws)
    band(x, ylower, yupper, ax=ax, **band_kws)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #12
0
def band(x,
         ylower,
         yupper,
         ylim=None,
         xlim=None,
         xlab=None,
         ylab=None,
         title=None,
         ytitle=None,
         polar=False,
         ffig=None,
         ax=None,
         **kwargs):
    """Plot highlighted band area."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.fill_between(x, ylower, yupper, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
Example #13
0
def plot_DR_tuning(DS,
                   title=None,
                   labels=True,
                   baseline=None,
                   DR_legend=True,
                   tuning_legend=True,
                   ffig=None,
                   fig=None,
                   sps=None):
    """Plot direction selectivity on polar plot and tuning curve."""

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 1, 2, wspace=0.3)
    ax_DR = fig.add_subplot(gsp[0], polar=True)
    ax_tuning = fig.add_subplot(gsp[1])
    DR_patches = []
    tuning_patches = []

    stims = DS.DSI.index
    for stim in stims:

        # Extract params and results.
        a, b, x0, sigma, FWHM, R2, RMSE = DS.TP.loc[stim]
        DSI = DS.DSI.wDS[stim]
        PD, cPD = DS.PD.loc[(stim, 'weighted'), ['PD', 'cPD']]
        dirs = np.array(DS.DR[stim].index) * deg
        SR, SRsem = [DS.DR[stim][col] for col in ('mean', 'sem')]

        # Center direction  and response stats.
        dirsc = direction.center_to_dir(dirs, PD)

        # Generate data points for plotting fitted tuning curve.
        x, y = tuning.gen_fit_curve(tuning.gaus,
                                    -180,
                                    180,
                                    a=a,
                                    b=b,
                                    x0=x0,
                                    sigma=sigma)

        # Plot DR polar plot.
        color = putil.stim_colors.loc[stim]
        ttl = 'Direction response' if labels else None
        plot_DR(dirs, SR, DSI, PD, baseline, title=ttl, color=color, ax=ax_DR)

        # Collect parameters of DR plot (stimulus - response).
        s_pd = str(float(round(PD, 1)))
        s_pd_c = str(int(cPD)) if not np.isnan(float(cPD)) else 'nan'
        lgd_lbl = '{}:   {:.3f}'.format(stim, DSI)
        lgd_lbl += '     {:>5}$^\circ$ --> {:>3}$^\circ$ '.format(s_pd, s_pd_c)
        DR_patches.append(putil.get_artist(lgd_lbl, color))

        # Calculate and plot direction tuning curve.
        xticks = np.arange(-180, 180 + 1, 45)
        plot_tuning(x,
                    y,
                    dirsc,
                    SR,
                    SRsem,
                    color,
                    baseline,
                    xticks,
                    ax=ax_tuning)

        # Collect parameters tuning curve fit.
        s_a, s_b, s_x0, s_sigma, s_FWHM = [
            str(float(round(p, 1))) for p in (a, b, x0, sigma, FWHM)
        ]
        s_R2 = format(R2, '.2f')
        lgd_lbl = '{}:{}{:>6}{}{:>6}'.format(stim, 5 * ' ', s_a, 5 * ' ', s_b)
        lgd_lbl += '{}{:>6}{}{:>6}'.format(5 * ' ', s_x0, 8 * ' ', s_sigma)
        lgd_lbl += '{}{:>6}{}{:>6}'.format(8 * ' ', s_FWHM, 8 * ' ', s_R2)
        tuning_patches.append(putil.get_artist(lgd_lbl, color))

    # Format tuning curve plot (only after all stimuli has been plotted!).
    xlab = 'Difference from preferred direction (deg)' if labels else None
    ylab = putil.FR_lbl if labels else None
    xlim = [-180 - 5, 180 + 5]  # degrees
    ylim = [0, None]
    ttl = 'Tuning curve' if labels else None
    putil.format_plot(ax_tuning, xlim, ylim, xlab, ylab, ttl)
    putil.add_zero_line('y', ax=ax_tuning)  # add zero reference line

    # Set super title.
    if title is not None:
        fig.suptitle(title, y=1.1, fontsize='x-large')

    # Set legends.
    ylegend = -0.38 if labels else -0.15
    fr_on = False if labels else True
    lgd_kws = dict([('fancybox', True), ('shadow', False), ('frameon', fr_on),
                    ('framealpha', 1.0), ('loc', 'lower center'),
                    ('bbox_to_anchor', [0., ylegend, 1., .0]),
                    ('prop', {
                        'family': 'monospace'
                    })])
    DR_lgn_ttl = 'DSI'.rjust(20) + 'PD'.rjust(14) + 'PD8'.rjust(14)
    tuning_lgd_ttl = ('a (sp/s)'.rjust(35) + 'b (sp/s)'.rjust(15) +
                      'x0 (deg)'.rjust(13) + 'sigma (deg)'.rjust(15) +
                      'FWHM (deg)'.rjust(15) + 'R-squared'.rjust(15))

    lgn_params = [(DR_legend, DR_lgn_ttl, DR_patches, ax_DR),
                  (tuning_legend, tuning_lgd_ttl, tuning_patches, ax_tuning)]

    for (plot_legend, lgd_ttl, patches, ax) in lgn_params:
        if not plot_legend:
            continue
        if not labels:  # customize for summary plot
            lgd_ttl = None
        lgd = putil.set_legend(ax, handles=patches, title=lgd_ttl, **lgd_kws)
        lgd.get_title().set_ha('left')
        if lgd_kws['frameon']:
            lgd.get_frame().set_linewidth(.5)

    # Save figure.
    if ffig is not None:
        putil.save_fig(ffig, fig, rect_height=0.55)

    return ax_DR, ax_tuning
Example #14
0
def rate(rate_list,
         names=None,
         prds=None,
         evts=None,
         cols=None,
         baseline=None,
         pval=0.05,
         test='mann_whitney_u',
         test_kws=None,
         xlim=None,
         ylim=None,
         title=None,
         xlab=None,
         ylab=putil.FR_lbl,
         add_lgn=True,
         lgn_lbl='trs',
         ffig=None,
         ax=None):
    """Plot firing rate."""

    # Init.
    ax = putil.axes(ax)
    if test_kws is None:
        test_kws = dict()

    # Plot periods and baseline first.
    putil.plot_periods(prds, ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)
    putil.set_limits(ax, xlim)

    if not len(rate_list):
        return ax

    if cols is None:
        cols = putil.get_colors(as_cycle=False)
    if names is None:
        names = len(rate_list) * ['']

    # Iterate through list of rate arrays
    xmin, xmax, ymax = None, None, None
    for i, rts in enumerate(rate_list):

        # Init.
        name = names[i]
        col = cols[i]

        # Skip empty array (no trials).
        if not rts.shape[0]:
            continue

        # Set line label. Convert to Numpy array to format floats nicely.
        lbl = str(np.array(name)) if util.is_iterable(name) else str(name)

        if lgn_lbl is not None:
            lbl += ' ({} {})'.format(rts.shape[0], lgn_lbl)

        # Plot mean +- SEM of rate vectors.
        tvec, meanr, semr = rts.columns, rts.mean(), rts.sem()
        ax.plot(tvec, meanr, label=lbl, color=col)
        ax.fill_between(tvec,
                        meanr - semr,
                        meanr + semr,
                        alpha=0.2,
                        facecolor=col,
                        edgecolor=col)

        # Update limits.
        tmin, tmax, rmax = tvec.min(), tvec.max(), (meanr + semr).max()
        xmin = np.min([xmin, tmin]) if xmin is not None else tmin
        xmax = np.max([xmax, tmax]) if xmax is not None else tmax
        ymax = np.max([ymax, rmax]) if ymax is not None else rmax

    # Set ticks, labels and axis limits.
    if xlim is None:
        if xmin == xmax:  # avoid setting identical limits
            xmax = None
        xlim = (xmin, xmax)
    if ylim is None:
        ymax = 1.02 * ymax if (ymax is not None) and (ymax > 0) else None
        ylim = (0, ymax)
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    t1, t2 = ax.get_xlim()  # in case it was set to None
    tmarks, tlbls = putil.get_tick_marks_and_labels(t1, t2)
    putil.set_xtick_labels(ax, tmarks, tlbls)
    putil.set_max_n_ticks(ax, 7, 'y')

    # Add legend.
    if add_lgn and len(rate_list):
        putil.set_legend(ax,
                         loc=1,
                         borderaxespad=0.0,
                         handletextpad=0.4,
                         handlelength=0.6)

    # Add significance line to top of axes.
    if (pval is not None) and (len(rate_list) == 2):
        rates1, rates2 = rate_list
        sign_prds = stats.sign_periods(rates1, rates2, pval, test, **test_kws)
        putil.plot_signif_prds(sign_prds, color='m', linewidth=4.0, ax=ax)

    # Plot event markers.
    putil.plot_event_markers(evts, ax=ax)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
Example #15
0
def plot_up_down_raster(Spikes, task, rec, itrs):
    """Plot spike raster for up-down dynamics analysis."""

    # Set params for plotting.
    uids, trs = Spikes.index, Spikes.columns
    plot_trs = trs[itrs]
    ntrs = len(plot_trs)
    nunits = len(uids)
    tr_gap = nunits / 2

    # Init figure.
    putil.set_style('notebook', 'ticks')
    fig = putil.figure(figsize=(10, ntrs))
    ax = fig.add_subplot(111)

    # Per trial, per unit.
    for itr, tr in enumerate(plot_trs):
        for iu, uid in enumerate(uids):

            # Init y level and spike times.
            i = (tr_gap + nunits) * itr + iu
            spk_tr = Spikes.loc[uid, tr]

            # Plot (spike time, y-level) pairs.
            x = np.array(spk_tr.rescale('ms'))
            y = (i+1) * np.ones_like(x)

            patches = [Rectangle((xi-wsp/2, yi-hsp/2), wsp, hsp)
                       for xi, yi in zip(x, y)]
            collection = PatchCollection(patches, facecolor=c, edgecolor=c)
            ax.add_collection(collection)

    # Add stimulus lines.
    for stim in constants.stim_dur.index:
        t_start, t_stop = constants.fixed_tr_prds.loc[stim]
        events = pd.DataFrame([(t_start, 't_start'), (t_stop, 't_stop')],
                              index=['start', 'stop'],
                              columns=['time', 'label'])
        putil.plot_events(events, add_names=False, color='grey',
                          alpha=0.5, ls='-', lw=0.5, ax=ax)

    # Add inter-trial shading.
    for itr in range(ntrs+1):
        ymin = itr * (tr_gap + nunits) - tr_gap + 0.5
        ax.axhspan(ymin, ymin+tr_gap, alpha=.05, color='grey')

    # Set tick labels.
    pos = np.arange(ntrs) * (tr_gap + nunits) + nunits/2
    lbls = plot_trs + 1
    putil.set_ytick_labels(ax, pos, lbls)
    # putil.sparsify_tick_labels(ax, 'y', freq=2, istart=1)
    putil.hide_tick_marks(ax, show_x_tick_mrks=True)

    # Format plot.
    xlim = constants.fixed_tr_prds.loc['whole trial']
    ylim = [-tr_gap/2, ntrs * (nunits+tr_gap)-tr_gap/2]
    xlab = 'Time since S1 onset (ms)'
    ylab = 'Trial number'
    title = '{} {}'.format(rec, task)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    putil.set_spines(ax, True, False, False, False)

    # Save figure.
    fname = 'UpDown_dynamics_{}_{}.pdf'.format(rec, task)
    ffig = util.join(['results', 'UpDown', fname])
    putil.save_fig(ffig, fig, dpi=600)