示例#1
0
def raster_from_epoch(signals, epoch, occurrences=0, channels=0,
                          xlabel='Time', ylabel='Value',
                          linestyle='-', linewidth=1,
                          ax=None, title=None):
    """TODO: doc"""
    raise NotImplementedError
    if occurrences is None:
        return
    occurrences = pad_to_signals(signals, occurrences)
    channels = pad_to_signals(signals, channels)

    legend = [s.name for s in signals]
    times = []
    values = []
    for s, o, c in zip(signals, occurrences, channels):
        # Get occurrences x chans x time
        extracted = s.extract_epoch(epoch)
        # Get values from specified occurrence and channel
        value_vector = extracted[o][c]
        # Convert bins to time (relative to start of epoch)
        # TODO: want this to be absolute time relative to start of data?
        time_vector = np.arange(0, len(value_vector)) / s.fs
        times.append(time_vector)
        values.append(value_vector)
    plot_timeseries(times, values, xlabel, ylabel, legend=legend,
                    linestyle=linestyle, linewidth=linewidth,
                    ax=ax, title=title)
示例#2
0
def raster_from_vectors(vectors, xlabel='Time', ylabel='Value', fs=None,
                        linestyle='-', linewidth=1, legend=None,
                        ax=None, title=None):
    """TODO: doc"""
    raise NotImplementedError
    times = []
    values = []
    for v in vectors:
        values.append(v)
        if fs is None:
            times.append(np.arange(0, len(v)))
        else:
            times.append(np.arange(0, len(v))/fs)
    plot_timeseries(times, values, xlabel, ylabel, legend=legend,
                    linestyle=linestyle, linewidth=linewidth,
                    ax=ax, title=title)
示例#3
0
def strf_timeseries(modelspec,
                    ax=None,
                    clim=None,
                    show_factorized=True,
                    show_fir_only=True,
                    title=None,
                    fs=1,
                    chans=None):
    """
    chans: list
       if not None, label each row of the strf with the corresponding
       channel name
    """

    wcc = _get_wc_coefficients(modelspec)
    firc = _get_fir_coefficients(modelspec)
    if wcc is None and firc is None:
        log.warn('Unable to generate STRF.')
        return
    elif show_fir_only or (wcc is None):
        strf = np.array(firc)
    elif wcc is not None and firc is None:
        strf = np.array(wcc).T
    else:
        wc_coefs = np.array(wcc).T
        fir_coefs = np.array(firc)
        if wc_coefs.shape[1] == fir_coefs.shape[0]:
            strf = wc_coefs @ fir_coefs
        else:
            strf = fir_coefs

    times = np.arange(strf.shape[1]) / fs
    plot_timeseries([times], [strf.T],
                    xlabel='Time lag',
                    ylabel='Gain',
                    legend=chans,
                    linestyle='-',
                    linewidth=1,
                    ax=ax,
                    title=title)
    plt.plot(times[[0, len(times) - 1]],
             np.array([0, 0]),
             linewidth=0.5,
             color='gray')
示例#4
0
def raster_from_signals(signals, channels=0, xlabel='Time', ylabel='Value',
                            linestyle='-', linewidth=1,
                            ax=None, title=None):
    """TODO: doc"""
    raise NotImplementedError
    channels = pad_to_signals(signals, channels)

    times = []
    values = []
    legend = []
    for s, c in zip(signals, channels):
        # Get values from specified channel
        value_vector = s.as_continuous()[c]
        # Convert indices to absolute time based on sampling frequency
        time_vector = np.arange(0, len(value_vector)) / s.fs
        times.append(time_vector)
        values.append(value_vector)
        legend.append(s.name+' '+s.chans[c])

    plot_timeseries(times, values, xlabel, ylabel, legend=legend,
                    linestyle=linestyle, linewidth=linewidth,
                    ax=ax, title=title)
示例#5
0
def psth_from_raster(times, values, xlabel='Time', ylabel='Value',
                     legend=None, linestyle='-', linewidth=1,
                     ax=None, title=None, facecolor='lightblue',
                     binsize=1):

    if binsize > 1:
        x = np.reshape(values, [values.shape[0], -1, binsize])
        x = np.nanmean(x, axis=2)
        t = times[int(binsize/2)::binsize]
    else:
        x = values
        t = times

    m = np.nanmean(x, axis=0)
    e = np.nanstd(x, axis=0) / np.sqrt(np.sum(np.isfinite(x[:, 0])))

    if ax is not None:
        plt.sca(ax)

    plt.fill_between(t, m-e, m+e, facecolor=facecolor)

    plot_timeseries([t], [m], xlabel=xlabel, ylabel=ylabel,
                    legend=legend, linestyle=linestyle,
                    linewidth=linewidth, ax=ax, title=title)
示例#6
0
文件: heatmap.py 项目: LBHB/NEMS
def strf_timeseries(modelspec,
                    ax=None,
                    show_factorized=True,
                    show_fir_only=True,
                    title=None,
                    fs=1,
                    chans=None,
                    colors=None,
                    **options):
    """
    chans: list
       if not None, label each row of the strf with the corresponding
       channel name
    """
    wcc = _get_wc_coefficients(modelspec)
    firc = _get_fir_coefficients(modelspec)
    if wcc is None and firc is None:
        log.warn('Unable to generate STRF.')
        return
    elif show_fir_only or (wcc is None):
        strf = np.array(firc)
    elif wcc is not None and firc is None:
        strf = np.array(wcc).T
    else:
        wc_coefs = np.array(wcc).T
        fir_coefs = np.array(firc)
        if wc_coefs.shape[1] == fir_coefs.shape[0]:
            strf = wc_coefs @ fir_coefs
        else:
            strf = fir_coefs

    times = [np.arange(strf.shape[1]) / fs] * strf.shape[0]
    filters = [strf[i] for i in range(strf.shape[0])]
    if colors is None:
        if strf.shape[0] == 1:
            colors = [[0, 0, 0]]
        elif strf.shape[0] == 2:
            colors = [[254 / 255, 15 / 255, 6 / 255],
                      [129 / 255, 201 / 255, 224 / 255]]
        elif strf.shape[0] == 3:
            colors = [[254 / 255, 15 / 255, 6 / 255],
                      [217 / 255, 217 / 255, 217 / 255],
                      [129 / 255, 201 / 255, 224 / 255]]
        elif strf.shape[0] > 3:
            colors = [[254 / 255, 15 / 255, 6 / 255],
                      [217 / 255, 217 / 255, 217 / 255],
                      [129 / 255, 201 / 255, 224 / 255],
                      [128 / 255, 128 / 255, 128 / 255],
                      [32 / 255, 32 / 255, 32 / 255]]
    #import pdb
    #pdb.set_trace()
    _, strf_h = plot_timeseries(times,
                                filters,
                                xlabel='Time lag',
                                ylabel='Gain',
                                legend=chans,
                                linestyle='-',
                                linewidth=1,
                                ax=ax,
                                title=title,
                                colors=colors)
    plt.plot(times[0][[0, len(times[0]) - 1]],
             np.array([0, 0]),
             linewidth=0.5,
             color='gray')

    if show_factorized and not show_fir_only:
        wcN = wcc.shape[0]

        ax.set_prop_cycle(None)
        _, fir_h = plot_timeseries([times], [firc.T],
                                   xlabel='Time lag',
                                   ylabel='Gain',
                                   legend=chans,
                                   linestyle='--',
                                   linewidth=1,
                                   ax=ax,
                                   title=title)

        ax.set_prop_cycle(None)
        weight_x = np.arange(-1 * wcN, 0)
        w_h = ax.plot(weight_x, wcc)
        ax.plot(weight_x, np.array([0, 0]), linewidth=0.5, color='gray')
        ax.set_xlim((-1 * wcN, len(times)))
        strf_l = ['Weighted FIR {}'.format(n + 1) for n in range(wcN)]
        fir_l = ['Raw FIR {}'.format(n + 1) for n in range(wcN)]
        plt.legend(strf_h + fir_h, strf_l + fir_l, loc=1, fontsize='x-small')
        ax.set_xticks(
            np.hstack((np.arange(-1 * wcN, 0), np.arange(0,
                                                         len(times) + 1, 2))))
        ax.set_xticklabels(
            np.hstack((np.arange(1, wcN + 1), np.arange(0,
                                                        len(times) + 1, 2))))