Esempio n. 1
0
 def test_bincount_2d(self):
     '''By Olivier to test Bincount2D, moved to test_processing.py by Berk'''
     # first test simple with indices
     x = np.array([0, 1, 1, 2, 2, 3, 3, 3])
     y = np.array([3, 2, 2, 1, 1, 0, 0, 0])
     r, xscale, yscale = processing.bincount2D(x, y, xbin=1, ybin=1)
     r_ = np.zeros_like(r)
     # sometimes life would have been simpler in c:
     for ix, iy in zip(x, y):
         r_[iy, ix] += 1
     self.assertTrue(np.all(np.equal(r_, r)))
     # test with negative values
     y = np.array([3, 2, 2, 1, 1, 0, 0, 0]) - 5
     r, xscale, yscale = processing.bincount2D(x, y, xbin=1, ybin=1)
     self.assertTrue(np.all(np.equal(r_, r)))
     # test unequal bins
     r, xscale, yscale = processing.bincount2D(x / 2, y / 2, xbin=1, ybin=2)
     r_ = np.zeros_like(r)
     for ix, iy in zip(np.floor(x / 2), np.floor((y / 2 + 2.5) / 2)):
         r_[int(iy), int(ix)] += 1
     self.assertTrue(np.all(r_ == r))
     # test with weights
     w = np.ones_like(x) * 2
     r, xscale, yscale = processing.bincount2D(x / 2, y / 2, xbin=1, ybin=2, weights=w)
     self.assertTrue(np.all(r_ * 2 == r))
     # test aggregation instead of binning
     x = np.array([0, 1, 1, 2, 2, 4, 4, 4])
     y = np.array([4, 2, 2, 1, 1, 0, 0, 0])
     r, xscale, yscale = processing.bincount2D(x, y)
     self.assertTrue(np.all(xscale == yscale) and np.all(xscale == np.array([0, 1, 2, 4])))
def bin_types(spikes, trials, wheel):
    T_BIN = 0.2  # [sec]
    # TO GET MEAN: bincount2D(..., weight=positions) / bincount2D(..., weight=None)
    reward_times = trials['feedback_times'][trials['feedbackType'] == 1]
    trial_start_times = trials['intervals'][:, 0]
    # trial_end_times = trials['intervals'][:, 1] #not working as there are
    # nans
    # compute raster map as a function of cluster number

    R1, times1, _ = bincount2D(spikes['times'],
                               spikes['clusters'],
                               T_BIN,
                               weights=spikes['amps'])
    R2, times2, _ = bincount2D(reward_times, np.array([0] * len(reward_times)),
                               T_BIN)
    R3, times3, _ = bincount2D(trial_start_times,
                               np.array([0] * len(trial_start_times)), T_BIN)
    R4, times4, _ = bincount2D(wheel['times'],
                               np.array([0] * len(wheel['times'])),
                               T_BIN,
                               weights=wheel['position'])
    R5, times5, _ = bincount2D(wheel['times'],
                               np.array([0] * len(wheel['times'])),
                               T_BIN,
                               weights=wheel['velocity'])
    #R6, times6, _ = bincount2D(trial_end_times, np.array([0]*len(trial_end_times)), T_BIN)
    start = max(
        [x for x in [times1[0], times2[0], times3[0], times4[0], times5[0]]])
    stop = min([
        x
        for x in [times1[-1], times2[-1], times3[-1], times4[-1], times5[-1]]
    ])
    time_points = np.linspace(start, stop, int((stop - start) / T_BIN))
    binned_data = {}
    binned_data['wheel_position'] = np.interp(time_points, wheel['times'],
                                              wheel['position'])
    binned_data['wheel_velocity'] = np.interp(time_points, wheel['times'],
                                              wheel['velocity'])
    binned_data['summed_spike_amps'] = R1[:,
                                          find_nearest(times1, start):
                                          find_nearest(times1, stop)]
    binned_data['reward_event'] = R2[
        0, find_nearest(times2, start):find_nearest(times2, stop)]
    binned_data['trial_start_event'] = R3[
        0, find_nearest(times3, start):find_nearest(times3, stop)]
    # binned_data['trial_end_event']=R6[0,find_nearest(times6,start):find_nearest(times6,stop)]
    # np.vstack([R1,R2,R3,R4])
    return binned_data
Esempio n. 3
0
def _bin_window_licks(lick_times, trials_df):
    """
    Helper function to bin and window the lick times and get them into trials df for plotting

    :param lick_times: np.array, timestamps of lick events
    :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial)
    :returns: pd.DataFrame with binned, windowed lick times for plotting
    """
    # Bin the licks
    lick_bins, bin_times, _ = bincount2D(lick_times, np.ones(len(lick_times)),
                                         T_BIN)
    lick_bins = np.squeeze(lick_bins)
    start_window, end_window = plt_window(trials_df['feedback_times'])
    # Translating the time window into an index window
    try:
        start_idx = insert_idx(bin_times, start_window)
    except ValueError:
        logger.error('Lick time stamps are outside of the trials windows')
        raise
    end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64')
    # Get the binned licks for each window
    trials_df['lick_bins'] = [
        lick_bins[start_idx[i]:end_idx[i]] for i in range(len(start_idx))
    ]
    # Remove windows that the exceed bins
    trials_df['end_idx'] = end_idx
    trials_df = trials_df[trials_df['end_idx'] <= len(lick_bins)]
    return trials_df
Esempio n. 4
0
def line_fr_plot(spike_depths,
                 spike_times,
                 chn_coords,
                 d_bin=10,
                 display=False):
    """
    Prepare data for 1D line plot of average firing rate across depth

    :param spike_depths:
    :param spike_times:
    :param chn_coords:
    :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
    :param display:
    :return:
    """
    t_bin = np.max(spike_times)
    n, x, y = bincount2D(spike_times,
                         spike_depths,
                         t_bin,
                         d_bin,
                         ylim=[0, np.max(chn_coords[:, 1])])
    mean_fr = n[:, 0] / t_bin

    data = LinePlot(x=mean_fr, y=y)
    data.set_xlim((0, np.max(mean_fr)))
    data.set_labels(title='Avg Firing Rate',
                    xlabel='Firing Rate (Hz)',
                    ylabel='Distance from probe tip (um)')

    if display:
        fig, ax = plot_line(data.convert2dict())
        return data.convert2dict(), fig, ax

    return data
Esempio n. 5
0
    def get_fr_img(self):
        if not self.spike_data_status:
            data_img = None
            return data_img
        else:
            T_BIN = 0.05
            D_BIN = 5
            n, times, depths = bincount2D(
                self.spikes['times'][self.spike_idx][self.kp_idx],
                self.spikes['depths'][self.spike_idx][self.kp_idx],
                T_BIN,
                D_BIN,
                ylim=[0, np.max(self.chn_coords[:, 1])])
            img = n.T / T_BIN
            xscale = (times[-1] - times[0]) / img.shape[0]
            yscale = (depths[-1] - depths[0]) / img.shape[1]

            data_img = {
                'img': img,
                'scale': np.array([xscale, yscale]),
                'levels': np.quantile(np.mean(img, axis=0), [0, 1]),
                'offset': np.array([0, 0]),
                'xrange': np.array([times[0], times[-1]]),
                'xaxis': 'Time (s)',
                'cmap': 'binary',
                'title': 'Firing Rate'
            }

            return data_img
Esempio n. 6
0
 def _bin_spike_trains(self):
     """
     Bins spike times passed to class at instantiation. Will not bin spike trains which did
     not meet the criteria for minimum number of spiking trials. Must be run before the
     NeuralGLM.fit() method is called.
     """
     spkarrs = []
     arrdiffs = []
     for i in self.trialsdf.index:
         duration = self.trialsdf.loc[i, 'duration']
         durmod = duration % self.binwidth
         if durmod > (self.binwidth / 2):
             duration = duration - (self.binwidth / 2)
         if len(self.spikes[i]) == 0:
             arr = np.zeros((self.binf(duration), len(self.clu_ids)))
             spkarrs.append(arr)
             continue
         spks = self.spikes[i]
         clu = self.clu[i]
         arr = bincount2D(spks,
                          clu,
                          xbin=self.binwidth,
                          ybin=self.clu_ids,
                          xlim=[0, duration])[0]
         arrdiffs.append(arr.shape[1] - self.binf(duration))
         spkarrs.append(arr.T)
     y = np.vstack(spkarrs)
     if hasattr(self, 'dm'):
         assert y.shape[0] == self.dm.shape[0], "Oh shit. Indexing error."
     self.binnedspikes = y
     return
Esempio n. 7
0
def spike_sorting_metrics(spike_times,
                          spike_clusters,
                          spike_amplitudes,
                          params=METRICS_PARAMS,
                          epochs=None):
    """ Spike sorting QC metrics """
    cluster_ids = np.unique(spike_clusters)
    nclust = cluster_ids.size
    r = Bunch({
        'cluster_id': cluster_ids,
        'num_spikes': np.zeros(nclust, ) + np.nan,
        'firing_rate': np.zeros(nclust, ) + np.nan,
        'presence_ratio': np.zeros(nclust, ) + np.nan,
        'presence_ratio_std': np.zeros(nclust, ) + np.nan,
        'isi_viol': np.zeros(nclust, ) + np.nan,
        'amplitude_cutoff': np.zeros(nclust, ) + np.nan,
        'amplitude_std': np.zeros(nclust, ) + np.nan,
        # 'isolation_distance': np.zeros(nclust, ) + np.nan,
        # 'l_ratio': np.zeros(nclust, ) + np.nan,
        # 'd_prime': np.zeros(nclust, ) + np.nan,
        # 'nn_hit_rate': np.zeros(nclust, ) + np.nan,
        # 'nn_miss_rate': np.zeros(nclust, ) + np.nan,
        # 'silhouette_score': np.zeros(nclust, ) + np.nan,
        # 'max_drift': np.zeros(nclust, ) + np.nan,
        # 'cumulative_drift': np.zeros(nclust, ) + np.nan,
        'epoch_name': np.zeros(nclust, dtype='object'),
    })

    tmin = 0
    tmax = spike_times[-1]
    """computes basic metrics such as spike rate and presence ratio"""
    presence_ratio = bincount2D(spike_times,
                                spike_clusters,
                                xbin=params['presence_bin_length_secs'],
                                ybin=cluster_ids,
                                xlim=[tmin, tmax])[0]
    r.num_spikes = np.sum(presence_ratio > 0, axis=1)
    r.firing_rate = r.num_spikes / (tmax - tmin)
    r.presence_ratio = np.sum(presence_ratio > 0,
                              axis=1) / presence_ratio.shape[1]
    r.presence_ratio_std = np.std(presence_ratio, axis=1)

    # loop over each cluster
    for ic in np.arange(nclust):
        # slice the spike_times array
        ispikes = spike_clusters == cluster_ids[ic]
        st = spike_times[ispikes]
        sa = spike_amplitudes[ispikes]
        # compute metrics
        r.isi_viol[ic], _ = isi_violations(
            st,
            tmin,
            tmax,
            isi_threshold=params['isi_threshold'],
            min_isi=params['min_isi'])
        r.amplitude_cutoff[ic] = amplitude_cutoff(amplitudes=sa)
        r.amplitude_std[ic] = np.std(sa)

    return pd.DataFrame(r)
Esempio n. 8
0
def line_amp_plot(spike_amps,
                  spike_depths,
                  spike_times,
                  chn_coords,
                  d_bin=10,
                  display=False,
                  title=None,
                  **kwargs):
    """
    Prepare data for 1D line plot of average firing rate across depth
    :param spike_amps:
    :param spike_depths:
    :param spike_times:
    :param chn_coords:
    :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
    :param display:
    :return:
    """
    title = title or 'Avg Amplitude'
    t_bin = np.max(spike_times)
    n, _, _ = bincount2D(spike_times,
                         spike_depths,
                         t_bin,
                         d_bin,
                         ylim=[0, np.max(chn_coords[:, 1])])
    amp, x, y = bincount2D(spike_times,
                           spike_depths,
                           t_bin,
                           d_bin,
                           ylim=[0, np.max(chn_coords[:, 1])],
                           weights=spike_amps)

    mean_amp = np.divide(amp[:, 0], n[:, 0]) * 1e6
    mean_amp[np.isnan(mean_amp)] = 0
    remove_bins = np.where(n[:, 0] < 50)[0]
    mean_amp[remove_bins] = 0

    data = LinePlot(x=mean_amp, y=y)
    data.set_xlim((0, np.max(mean_amp)))
    data.set_labels(title=title,
                    xlabel='Amplitude (uV)',
                    ylabel='Distance from probe tip (um)')
    if display:
        ax, fig = plot_line(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax
    return data
Esempio n. 9
0
def lick_raster(eid, combine=True):

    #plt.figure(figsize=(4,4))

    T_BIN = 0.02
    rt = 2
    st = -0.5

    if combine:
        # combine licking events from left and right cam
        lick_times = []
        for video_type in ['right', 'left']:
            times, XYs = get_dlc_XYs(eid, video_type)
            lick_times.append(times[get_licks(XYs)])

        lick_times = sorted(np.concatenate(lick_times))

    else:
        times, XYs = get_dlc_XYs(eid, 'left')
        lick_times = times[get_licks(XYs)]

    R, t, _ = bincount2D(lick_times, np.ones(len(lick_times)), T_BIN)
    D = R[0]

    # that's centered at feedback time
    d = constant_reaction_time(eid, rt, st, stype='feedback')

    licks_pos = []
    licks_neg = []

    for i in d:

        start_idx = find_nearest(t, d[i][0])
        end_idx = start_idx + int(d[i][1] / T_BIN)

        if end_idx > len(D):
            break

        # split by feedback type)
        if d[i][5] == 1:
            licks_pos.append(D[start_idx:end_idx])

    licks_pos_ = np.array(licks_pos).mean(axis=0)

    y_dims, x_dims = len(licks_pos), len(licks_pos[0])
    plt.imshow(licks_pos,
               aspect='auto',
               extent=[-0.5, 1.5, y_dims, 0],
               cmap='gray_r')

    ax = plt.gca()
    ax.set_xticks([-0.5, 0, 0.5, 1, 1.5])
    ax.set_xticklabels([-0.5, 0, 0.5, 1, 1.5])
    plt.ylabel('trials')
    plt.xlabel('time [sec]')
    ax.axvline(x=0, label='feedback time', linestyle='--', c='r')
    plt.title('lick events per correct trial')
    plt.tight_layout()
Esempio n. 10
0
def get_stim_aligned_activity(stim_events, spike_times, spike_depths, z_score_flag=True, d_bin=20,
                              t_bin=0.01, pre_stim=0.4, post_stim=1, base_stim=1,
                              y_lim=[0, 3840], x_lim=None):
    """

    Parameters
    ----------
    stim_events: dict of different stim events. Each key contains time of stimulus onset
    spike_times: array of spike times
    spike_depths: array of spike depths along probe
    z_score_flag: whether to return values as z_score of firing rate
    T_BIN: bin size along time dimension
    D_BIN: bin size along depth dimension
    pre_stim: time period before rf map stim onset to epoch around
    post_stim: time period after rf map onset to epoch around
    base_stim: time period before rf map stim to use as baseline for z_score correction
    y_lim: values to limit to in depth direction
    x_lim: values to limit in time direction

    Returns
    -------
    stim_activity: stimulus aligned activity for each stimulus type, returned as z_score of firing
    rate
    """

    binned_array, times, depths = bincount2D(spike_times, spike_depths, t_bin, d_bin,
                                             ylim=y_lim, xlim=x_lim)
    n_bins = int((pre_stim + post_stim) / t_bin)
    n_bins_base = int(np.ceil((base_stim - pre_stim) / t_bin))

    stim_activity = {}
    for stim_type, stim_times in stim_events.items():

        stim_intervals = np.c_[stim_times - pre_stim, stim_times + post_stim]
        base_intervals = np.c_[stim_times - base_stim, stim_times - pre_stim]
        idx_stim = np.searchsorted(times, stim_intervals)
        idx_base = np.searchsorted(times, base_intervals)

        stim_trials = np.zeros((depths.shape[0], n_bins, idx_stim.shape[0]))
        noise_trials = np.zeros((depths.shape[0], n_bins_base, idx_stim.shape[0]))
        for i, (st, ba) in enumerate(zip(idx_stim, idx_base)):
            stim_trials[:, :, i] = binned_array[:, st[0]:st[1]]
            noise_trials[:, :, i] = binned_array[:, ba[0]:ba[1]]

        # Average across trials
        avg_stim_trials = np.mean(stim_trials, axis=2)
        if z_score_flag:
            # Average across trials and time
            avg_base_trials = np.mean(np.mean(noise_trials, axis=2), axis=1)[:, np.newaxis]
            std_base_trials = np.std(np.mean(noise_trials, axis=2), axis=1)[:, np.newaxis]
            z_score = (avg_stim_trials - avg_base_trials) / std_base_trials
            z_score[np.isnan(z_score)] = 0
            avg_stim_trials = z_score

        stim_activity[stim_type] = avg_stim_trials

    return stim_activity
Esempio n. 11
0
def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
    """
    Estimate drift for spike sorted data.
    :param spike_times:
    :param spike_amps:
    :param spike_depths:
    :param display:
    :return:
    """
    # binning parameters
    DT_SECS = 1  # output sampling rate of the depth estimation (seconds)
    DEPTH_BIN_UM = 2  # binning parameter for depth
    AMP_RES_V = 100 * 1e-6  # binning parameter for amplitudes
    NXCORR = 50  # positive and negative lag in depth samples to look for depth
    NT_SMOOTH = 9  # length of the Gaussian smoothing window in samples (DT_SECS rate)

    # experimental: try the amp with a log scale
    na = int(np.ceil(np.nanmax(spike_amps) / AMP_RES_V))
    nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
    nt = int(np.ceil(np.max(spike_times) / DT_SECS))

    # 3d histogram of spikes along amplitude, depths and time
    atd_hist = np.zeros((na, nt, nd))
    abins = np.ceil(spike_amps / AMP_RES_V)
    for i, abin in enumerate(np.unique(abins)):
        inds = np.where(np.logical_and(abins == abin,
                                       ~np.isnan(spike_depths)))[0]
        a, _, _ = bincount2D(spike_depths[inds], spike_times[inds],
                             DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM],
                             [0, nt * DT_SECS])
        atd_hist[i] = a[:-1, :-1]

    # compute the depth lag by xcorr
    # experimental: LP the fft for a better tracking ?
    atd_ = np.fft.fft(atd_hist, axis=-1)
    xcorr = np.real(
        np.fft.ifft(atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
    xcorr = np.sum(xcorr, axis=0)
    xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]

    # experimental: parabolic fit to get max values
    raw_drift = (np.argmax(xcorr, axis=-1) - NXCORR) * DEPTH_BIN_UM
    drift = smooth.rolling_window(raw_drift,
                                  window_len=NT_SMOOTH,
                                  window='hanning')

    if display:
        import matplotlib.pyplot as plt
        from brainbox.plot import driftmap
        _, axs = plt.subplots(2,
                              1,
                              gridspec_kw={'height_ratios': [.15, .85]},
                              sharex=True)
        axs[0].plot(DT_SECS * np.arange(drift.size), drift)
        driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])

    return drift
Esempio n. 12
0
def get_rf_map_over_depth(rf_map_times, rf_map_pos, rf_stim_frames, spike_times, spike_depths,
                          t_bin=0.01, d_bin=80, pre_stim=0.05, post_stim=1.5, y_lim=[0, 3840],
                          x_lim=None):
    """
    Compute receptive field map for each stimulus onset binned across depth
    Parameters
    ----------
    rf_map_times
    rf_map_pos
    rf_stim_frames
    spike_times: array of spike times
    spike_depths: array of spike depths along probe
    t_bin: bin size along time dimension
    d_bin: bin size along depth dimension
    pre_stim: time period before rf map stim onset to epoch around
    post_stim: time period after rf map onset to epoch around
    y_lim: values to limit to in depth direction
    x_lim: values to limit in time direction

    Returns
    -------
    rfmap: receptive field map for 'on' 'off' stimuli.
    Each rfmap has shape (depths, x_pos, y_pos, epoch_window)
    depths: depths between which receptive field map has been computed
    """

    binned_array, times, depths = bincount2D(spike_times, spike_depths, t_bin, d_bin,
                                             ylim=y_lim, xlim=x_lim)

    x_bin = len(np.unique(rf_map_pos[:, 0]))
    y_bin = len(np.unique(rf_map_pos[:, 1]))
    n_bins = int((pre_stim + post_stim) / t_bin)

    rf_map = {}

    for stim_type, stims in rf_stim_frames.items():
        _rf_map = np.zeros(shape=(depths.shape[0], x_bin, y_bin, n_bins))
        for pos, stim_frame in zip(rf_map_pos, stims):

            x_pos = pos[0]
            y_pos = pos[1]

            stim_on_times = rf_map_times[stim_frame[0]]
            stim_intervals = np.c_[stim_on_times - pre_stim, stim_on_times + post_stim]

            idx_intervals = np.searchsorted(times, stim_intervals)

            stim_trials = np.zeros((depths.shape[0], n_bins, idx_intervals.shape[0]))
            for i, on in enumerate(idx_intervals):
                stim_trials[:, :, i] = binned_array[:, on[0]:on[1]]
            avg_stim_trials = np.mean(stim_trials, axis=2)

            _rf_map[:, x_pos, y_pos, :] = avg_stim_trials

        rf_map[stim_type] = _rf_map

    return rf_map, depths
Esempio n. 13
0
    def get_fr_amp_data_line(self):
        if not self.spike_data_status:
            data_fr_line = None
            data_amp_line = None
            return data_fr_line, data_amp_line
        else:
            T_BIN = np.max(self.spikes['times'])
            D_BIN = 10
            nspikes, times, depths = bincount2D(
                self.spikes['times'][self.spike_idx][self.kp_idx],
                self.spikes['depths'][self.spike_idx][self.kp_idx],
                T_BIN,
                D_BIN,
                ylim=[0, np.max(self.chn_coords[:, 1])])

            amp, times, depths = bincount2D(
                self.spikes['amps'][self.spike_idx][self.kp_idx],
                self.spikes['depths'][self.spike_idx][self.kp_idx],
                T_BIN,
                D_BIN,
                ylim=[0, np.max(self.chn_coords[:, 1])],
                weights=self.spikes['amps'][self.spike_idx][self.kp_idx])
            mean_fr = nspikes[:, 0] / T_BIN
            mean_amp = np.divide(amp[:, 0], nspikes[:, 0]) * 1e6
            mean_amp[np.isnan(mean_amp)] = 0
            remove_bins = np.where(nspikes[:, 0] < 50)[0]
            mean_amp[remove_bins] = 0

            data_fr_line = {
                'x': mean_fr,
                'y': depths,
                'xrange': np.array([0, np.max(mean_fr)]),
                'xaxis': 'Firing Rate (Sp/s)'
            }

            data_amp_line = {
                'x': mean_amp,
                'y': depths,
                'xrange': np.array([0, np.max(mean_amp)]),
                'xaxis': 'Amplitude (uV)'
            }

            return data_fr_line, data_amp_line
Esempio n. 14
0
def firing_rates(spike_times, spike_clusters, bin_size):
    """Return the time-dependent firing rate of a population of neurons.

    :param spike_times: the spike times of all neurons, in seconds
    :param spike_clusters: the cluster numbers of all spikes
    :param bin_size: the bin size, in seconds
    :return: a (n_clusters, n_samples) array with the firing rate of every cluster

    """
    rates, times, clusters = bincount2D(spike_times, spike_clusters, bin_size)
    return rates
Esempio n. 15
0
def check_for_saturation(eid, probes):
    '''
    This functions reads in spikes for a given session,
    bins them into time bins and computes for how many of them,
    there is too little activity across all channels such that
    this must be an artefact (saturation)
    '''

    T_BIN = 0.2  # time bin in sec
    ACT_THR = 0.05  # maximal activity for saturated segment
    print('Bin size: %s [ms]' % T_BIN)
    print('Activity threshold: %s [fraction]' % ACT_THR)

    #probes = ['probe00', 'probe01']
    probeDict = {'probe00': 'probe_left', 'probe01': 'probe_right'}

    one = ONE()
    dataset_types = ['spikes.times', 'spikes.clusters']
    D = one.load(eid, dataset_types=dataset_types, dclass_output=True)
    alf_path = Path(D.local_path[0]).parent.parent
    print(alf_path)

    l = []
    for probe in probes:
        probe_path = alf_path / probe
        if not probe_path.exists():
            probe_path = alf_path / probeDict[probe]
            if not probe_path.exists():
                print("% s doesn't exist..." % probe)
                continue
        try:
            spikes = alf.io.load_object(probe_path, 'spikes')
        except:
            continue

        # bin spikes
        R, times, Clusters = bincount2D(spikes['times'], spikes['clusters'],
                                        T_BIN)

        saturated_bins = np.where(np.mean(R, axis=0) < 0.15)[0]

        if len(saturated_bins) > 1:
            print('WARNING: Saturation present!')
            print(probe)
            print('Number of saturated bins: %s of %s' %
                  (len(saturated_bins), len(times)))

        l.append(['%s_%s' % (eid, probe), times[saturated_bins]])

    np.save('/home/mic/saturation_scan2/%s.npy' % eid, l)

    return l
Esempio n. 16
0
def image_crosscorr_plot(spike_depths,
                         spike_times,
                         chn_coords,
                         t_bin=0.05,
                         d_bin=40,
                         cmap='viridis',
                         display=False,
                         title=None,
                         **kwargs):
    """
    Prepare data for 2D cross correlation plot of data across depth

    :param spike_depths:
    :param spike_times:
    :param chn_coords:
    :param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D)
    :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
    :param cmap:
    :param display: generate figure
    :return: ImagePlot object, if display=True also returns matploltlib fig and ax objects
    """

    title = title or 'Correlation'
    n, x, y = bincount2D(spike_times,
                         spike_depths,
                         t_bin,
                         d_bin,
                         ylim=[0, np.max(chn_coords[:, 1])])
    corr = np.corrcoef(n)
    corr[np.isnan(corr)] = 0

    data = ImagePlot(corr, x=y, y=y, cmap=cmap)
    data.set_labels(title=title,
                    xlabel='Distance from probe tip (um)',
                    ylabel='Distance from probe tip (um)',
                    clabel='Correlation')

    if display:
        ax, fig = plot_image(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax

    return data
Esempio n. 17
0
def image_fr_plot(spike_depths,
                  spike_times,
                  chn_coords,
                  t_bin=0.05,
                  d_bin=5,
                  cmap='binary',
                  display=False,
                  title=None,
                  **kwargs):
    """
    Prepare data 2D raster plot of firing rate across recording

    :param spike_depths:
    :param spike_times:
    :param chn_coords:
    :param t_bin: time bin to average across (see also brainbox.processing.bincount2D)
    :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
    :param cmap:
    :param display: generate figure
    :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
    """

    title = title or 'Firing Rate'
    n, x, y = bincount2D(spike_times,
                         spike_depths,
                         t_bin,
                         d_bin,
                         ylim=[0, np.max(chn_coords[:, 1])])
    fr = n.T / t_bin

    data = ImagePlot(fr, x=x, y=y, cmap=cmap)
    data.set_labels(title=title,
                    xlabel='Time (s)',
                    ylabel='Distance from probe tip (um)',
                    clabel='Firing Rate (Hz)')
    data.set_clim(clim=(np.min(np.mean(fr, axis=0)),
                        np.max(np.mean(fr, axis=0))))
    if display:
        ax, fig = plot_image(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax

    return data
Esempio n. 18
0
def bin_spikes_trials(spikes, trials, bin_size=0.01):
    """
    Binarizes the spike times into a raster and assigns a trial number to each bin

    :param spikes: spikes object
    :type spikes: Bunch
    :param trials: trials object
    :type trials: Bunch
    :param bin_size: size, in s, of the bins
    :type bin_size: float
    :return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID,
    and a vector bins size with the time that the bins start
    """
    binned_spikes, bin_times, _ = bincount2D(spikes['times'],
                                             spikes['clusters'], bin_size)
    trial_start_times = trials['intervals'][:, 0]
    binned_trialIDs = np.digitize(bin_times, trial_start_times)
    # correct, as index 0 is whatever happens before the first trial
    binned_trialIDs_corrected = binned_trialIDs - 1

    return binned_spikes.T, binned_trialIDs_corrected, bin_times
Esempio n. 19
0
 def get_correlation_data_img(self):
     if not self.spike_data_status:
         data_img = None
         return data_img
     else:
         T_BIN = 0.05
         D_BIN = 40
         R, times, depths = bincount2D(
             self.spikes['times'][self.spike_idx][self.kp_idx],
             self.spikes['depths'][self.spike_idx][self.kp_idx],
             T_BIN,
             D_BIN,
             ylim=[0, np.max(self.chn_coords[:, 1])])
         corr = np.corrcoef(R)
         corr[np.isnan(corr)] = 0
         scale = (np.max(depths) - np.min(depths)) / corr.shape[0]
         data_img = {
             'img':
             corr,
             'scale':
             np.array([scale, scale]),
             'levels':
             np.array([np.min(corr), np.max(corr)]),
             'offset':
             np.array([0, 0]),
             'xrange':
             np.array([
                 np.min(self.chn_coords[:, 1]),
                 np.max(self.chn_coords[:, 1])
             ]),
             'cmap':
             'viridis',
             'title':
             'Correlation',
             'xaxis':
             'Distance from probe tip (um)'
         }
         return data_img
Esempio n. 20
0
def plot_grating_figures(
    session_path, cluster_ids_summary, cluster_ids_selected, save_dir=None, format='png',
        pre_time=0.5, post_time=2.5, bin_size=0.005, smoothing=0.025, n_rand_clusters=20,
        plot_summary=True, plot_selected=True):
    """
    Produces two summary figures for the oriented grating protocol; the first summary figure
    contains plots that compare different measures during the first and second grating protocols,
    such as orientation selectivity index (OSI), orientation preference, fraction of visual
    clusters, PSTHs, firing rate histograms, etc. The second summary figure contains plots of polar
    PSTHs and corresponding rasters for a random subset of visually responsive clusters.

    Parameters
    ----------
    session_path : str
        absolute path to experimental session directory
    cluster_ids_summary : list
        the clusters for which to plot summary psths/rasters; if empty, all clusters with responses
        during the grating presentations are used
    cluster_ids_selected : list
        the clusters for which to plot individual psths/rasters; if empty, `n_rand_clusters` are
        randomly chosen
    save_dir : str or NoneType
        if NoneType, figures are displayed; else a string defining the absolute filepath to the
        directory in which figures will be saved
    format : str
        file format, i.e. 'png' | 'pdf' | 'jpg'
    pre_time : float
        time (sec) to plot before grating presentation onset
    post_time : float
        time (sec) to plot after grating presentation onset (should include length of stimulus)
    bin_size : float
        size of bins for raster plots/psths
    smoothing : float
        size of smoothing kernel (sec)
    n_rand_clusters : int
        the number of random clusters to choose for which to plot psths/rasters if
        `cluster_ids_slected` is empty
    plot_summary : bool
        a flag for plotting the summary figure
    plot_selected : bool
        a flag for plotting the selected units figure

    Returns
    -------
    metrics : dict
        - 'osi' (dict): keys 'beg', 'end' point to arrays of osis during these epochs
        - 'orientation_pref' (dict): keys 'beg', 'end' point to arrays of orientation preference
        - 'frac_resp_by_depth' (dict): fraction of responsive clusters by depth

    fig_dict : dict
        A dict whose values are handles to one or both figures generated.
    """

    fig_dict = {}
    cluster_ids = cluster_ids_summary
    cluster_idxs = cluster_ids_selected
    epochs = ['beg', 'end']

    # -------------------------
    # load required alf objects
    # -------------------------
    print('loading alf objects...', end='', flush=True)
    spikes = ioalf.load_object(session_path, 'spikes')
    clusters = ioalf.load_object(session_path, 'clusters')
    gratings = ioalf.load_object(session_path, '_iblcertif_.odsgratings')
    spontaneous = ioalf.load_object(session_path, '_iblcertif_.spontaneous')
    grating_times = {
        'beg': gratings['odsgratings.times.00'],
        'end': gratings['odsgratings.times.01']}
    grating_vals = {
        'beg': gratings['odsgratings.stims.00'],
        'end': gratings['odsgratings.stims.01']}
    spont_times = {
        'beg': spontaneous['spontaneous.times.00'],
        'end': spontaneous['spontaneous.times.01']}

    # --------------------------
    # calculate relevant metrics
    # --------------------------
    print('calcuating mean responses to gratings...', end='', flush=True)
    # calculate mean responses to gratings
    mask_clust = np.isin(spikes.clusters, cluster_ids)  # update mask for responsive clusters
    mask_times = np.full(spikes.times.shape, fill_value=False)
    for epoch in epochs:
        mask_times |= (spikes.times >= grating_times[epoch].min()) & \
                      (spikes.times <= grating_times[epoch].max())
    resp = {epoch: [] for epoch in epochs}
    for epoch in epochs:
        resp[epoch] = are_neurons_responsive(
            spikes.times[mask_clust], spikes.clusters[mask_clust], grating_times[epoch],
            grating_vals[epoch], spont_times[epoch])
    responses = {epoch: [] for epoch in epochs}
    for epoch in epochs:
        responses[epoch] = bin_responses(
            spikes.times[mask_clust], spikes.clusters[mask_clust], grating_times[epoch],
            grating_vals[epoch])
    responses_mean = {epoch: np.mean(responses[epoch], axis=2) for epoch in epochs}
    # responses_se = {epoch: np.std(responses[epoch], axis=2) / np.sqrt(responses[epoch].shape[2])
    #                 for epoch in responses.keys()}
    print('done')

    # calculate osi and orientation preference
    print('calcuating osi/orientation preference...', end='', flush=True)
    ori_pref = {epoch: [] for epoch in epochs}
    osi = {epoch: [] for epoch in epochs}
    for epoch in epochs:
        osi[epoch], ori_pref[epoch] = compute_selectivity(
            responses_mean[epoch], np.unique(grating_vals[epoch]), 'ori')
    print('done')

    # calculate depth vs osi ratio (osi_beg/osi_end)
    print('calcuating osi ratio as a function of depth...', end='', flush=True)
    depths = np.array([clusters.depths[c] for c in cluster_ids])
    ratios = np.array([osi['beg'][c] / osi['end'][c] for c in range(len(cluster_ids))])
    print('done')

    # calculate fraction of visual neurons by depth
    print('calcuating fraction of visual clusters by depth...', end='', flush=True)
    n_bins = 10
    min_depth = np.min(clusters['depths'])
    max_depth = np.max(clusters['depths'])
    depth_limits = np.linspace(min_depth - 1, max_depth, n_bins + 1)
    depth_avg = (depth_limits[:-1] + depth_limits[1:]) / 2
    # aggregate clusters
    clusters_binned = {epoch: [] for epoch in epochs}
    frac_responsive = {epoch: [] for epoch in epochs}
    cids = cluster_ids
    for epoch in epochs:
        # just look at responsive clusters during this epoch
        cids_tmp = cids[resp[epoch]]
        for d in range(n_bins):
            lo_limit = depth_limits[d]
            up_limit = depth_limits[d + 1]
            # clusters.depth index is cluster id
            cids_curr_depth = np.where(
                (lo_limit < clusters.depths) & (clusters.depths <= up_limit))[0]
            clusters_binned[epoch].append(cids_curr_depth)
            frac_responsive[epoch].append(np.sum(
                np.isin(cids_tmp, cids_curr_depth)) / len(cids_curr_depth))
    # package for plotting
    responsive = {'fraction': frac_responsive, 'depth': depth_avg}
    print('done')

    # calculate PSTH averaged over all clusters/orientations
    print('calcuating average PSTH...', end='', flush=True)
    peths = {epoch: [] for epoch in epochs}
    peths_avg = {epoch: [] for epoch in epochs}
    for epoch in epochs:
        stim_ids = np.unique(grating_vals[epoch])
        peths[epoch] = {i: None for i in range(len(stim_ids))}
        peths_avg_tmp = []
        for i, stim_id in enumerate(stim_ids):
            curr_stim_idxs = np.where(grating_vals[epoch] == stim_id)
            align_times = grating_times[epoch][curr_stim_idxs, 0][0]
            peths[epoch][i], _ = calculate_peths(
                spikes.times[mask_times], spikes.clusters[mask_times], cluster_ids,
                align_times, pre_time=pre_time, post_time=post_time, bin_size=bin_size,
                smoothing=smoothing, return_fr=True)
            peths_avg_tmp.append(
                np.mean(peths[epoch][i]['means'], axis=0, keepdims=True))
        peths_avg_tmp = np.vstack(peths_avg_tmp)
        peths_avg[epoch] = {
            'mean': np.mean(peths_avg_tmp, axis=0),
            'std': np.std(peths_avg_tmp, axis=0) / np.sqrt(peths_avg_tmp.shape[0])}
    peths_avg['bin_size'] = bin_size
    peths_avg['on_idx'] = int(pre_time / bin_size)
    peths_avg['off_idx'] = peths_avg['on_idx'] + int(2 / bin_size)
    print('done')

    # compute rasters for entire orientation sequence at beg/end epoch
    if plot_summary:
        print('computing rasters for example stimulus sequences...', end='', flush=True)
        r = {epoch: None for epoch in epochs}
        r_times = {epoch: None for epoch in epochs}
        r_clusters = {epoch: None for epoch in epochs}
        for epoch in epochs:
            # restrict activity to a single stim series; assumes each possible grating direction
            # is displayed before repeating
            n_stims = len(np.unique(grating_vals[epoch]))
            mask_idxs_e = (spikes.times >= grating_times[epoch][:n_stims].min()) & \
                          (spikes.times <= grating_times[epoch][:n_stims].max())
            r_tmp, r_times[epoch], r_clusters[epoch] = bincount2D(
                spikes.times[mask_idxs_e], spikes.clusters[mask_idxs_e], bin_size)
            # order activity by anatomical depth of neurons
            d = dict(zip(spikes.clusters[mask_idxs_e], spikes.depths[mask_idxs_e]))
            y = sorted([[i, d[i]] for i in d])
            isort = np.argsort([x[1] for x in y])
            r[epoch] = r_tmp[isort, :]
        # package for plotting
        rasters = {'spikes': r, 'times': r_times, 'clusters': r_clusters, 'bin_size': bin_size}
        print('done')

    # -------------------------------------------------
    # compute psths and rasters for individual clusters
    # -------------------------------------------------
    if plot_selected:
        print('computing psths and rasters for clusters...', end='', flush=True)
        if len(cluster_ids_selected) == 0:
            if (n_rand_clusters < len(cluster_ids)):
                cluster_idxs = np.random.choice(cluster_ids, size=n_rand_clusters, replace=False)
            else:
                cluster_idxs = cluster_ids
        else:
            cluster_idxs = cluster_ids_selected
        mean_responses = {cluster: {epoch: [] for epoch in epochs} for cluster in cluster_idxs}
        osis = {cluster: {epoch: [] for epoch in epochs} for cluster in cluster_idxs}
        binned = {cluster: {epoch: [] for epoch in epochs} for cluster in cluster_idxs}
        for cluster_idx in cluster_idxs:
            cluster = np.where(cluster_ids == cluster_idx)[0]
            for epoch in epochs:
                mean_responses[cluster_idx][epoch] = responses_mean[epoch][cluster, :][0]
                osis[cluster_idx][epoch] = osi[epoch][cluster]
                stim_ids = np.unique(grating_vals[epoch])
                binned[cluster_idx][epoch] = {j: None for j in range(len(stim_ids))}
                for j, stim_id in enumerate(stim_ids):
                    curr_stim_idxs = np.where(grating_vals[epoch] == stim_id)
                    align_times = grating_times[epoch][curr_stim_idxs, 0][0]
                    _, binned[cluster_idx][epoch][j] = calculate_peths(
                        spikes.times[mask_times], spikes.clusters[mask_times], [cluster_idx],
                        align_times, pre_time=pre_time, post_time=post_time, bin_size=bin_size)
        print('done')

    # --------------
    # output figures
    # --------------
    print('producing figures...', end='')
    if plot_summary:
        if save_dir is None:
            save_file = None
        else:
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_file = os.path.join(save_dir, 'grating_summary_figure.' + format)
        fig_gr_summary = plot_summary_figure(
            ratios=ratios, depths=depths, responsive=responsive, peths_avg=peths_avg, osi=osi,
            ori_pref=ori_pref, responses_mean=responses_mean, rasters=rasters, save_file=save_file)
        fig_gr_summary.suptitle('Summary Grating Responses')
        fig_dict['gr_summary'] = fig_gr_summary

    if plot_selected:
        if save_dir is None:
            save_file = None
        else:
            save_file = os.path.join(save_dir, 'grating_random_responses.' + format)
        fig_gr_selected = plot_psths_and_rasters(
            mean_responses, binned, osis, grating_vals, on_idx=peths_avg['on_idx'],
            off_idx=peths_avg['off_idx'], bin_size=bin_size, save_file=save_file)
        fig_gr_selected.suptitle('Selected Units Grating Responses')
        print('done')
        fig_dict['gr_selected'] = fig_gr_selected

    # -----------------------------
    # package up and return metrics
    # -----------------------------
    metrics = {
        'osi': osi,
        'orientation_pref': ori_pref,
        'frac_resp_by_depth': responsive,
    }
    return fig_dict, metrics
Esempio n. 21
0
def quick_unit_metrics(spike_clusters,
                       spike_times,
                       spike_amps,
                       spike_depths,
                       params=METRICS_PARAMS,
                       cluster_ids=None,
                       tbounds=None):
    """
    Computes single unit metrics from only the spike times, amplitudes, and
    depths for a set of units.

    Metrics computed:
        'amp_max',
        'amp_min',
        'amp_median',
        'amp_std_dB',
        'contamination',
        'contamination_alt',
        'drift',
        'missed_spikes_est',
        'noise_cutoff',
        'presence_ratio',
        'presence_ratio_std',
        'slidingRP_viol',
        'spike_count'

    Parameters (see the METRICS_PARAMS constant)
    ----------
    spike_clusters : ndarray_like
        A vector of the unit ids for a set of spikes.
    spike_times : ndarray_like
        A vector of the timestamps for a set of spikes.
    spike_amps : ndarray_like
        A vector of the amplitudes for a set of spikes.
    spike_depths : ndarray_like
        A vector of the depths for a set of spikes.
    clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the
    spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent
    with the input arrays.
    tbounds: (optional) list or 2 elements array containing a time-selection to perform the
     metrics computation on.
    params : dict (optional)
        Parameters used for computing some of the metrics in the function:
            'presence_window': float
                The time window (in s) used to look for spikes when computing the presence ratio.
            'refractory_period': float
                The refractory period used when computing isi violations and the contamination
                estimate.
            'min_isi': float
                The minimum interspike-interval (in s) for counting duplicate spikes when computing
                the contamination estimate.
            'spks_per_bin_for_missed_spks_est': int
                The number of spikes per bin used to compute the spike amplitude pdf for a unit,
                when computing the missed spikes estimate.
            'std_smoothing_kernel_for_missed_spks_est': float
                The standard deviation for the gaussian kernel used to compute the spike amplitude
                pdf for a unit, when computing the missed spikes estimate.
            'min_num_bins_for_missed_spks_est': int
                The minimum number of bins used to compute the spike amplitude pdf for a unit,
                when computing the missed spikes estimate.

    Returns
    -------
    r : bunch
        A bunch whose keys are the computed spike metrics.

    Notes
    -----
    This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf`
    during alf extraction of an ephys dataset in the ibl ephys extraction pipeline.

    Examples
    --------
    1) Compute quick metrics from a ks2 output directory:
        >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path
        >>> m = phy_model_from_ks2_path(path_to_ks2_out)
        >>> cluster_ids = m.spike_clusters
        >>> ts = m.spike_times
        >>> amps = m.amplitudes
        >>> depths = m.depths
        >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths)
    """
    metrics_list = [
        'cluster_id', 'amp_max', 'amp_min', 'amp_median', 'amp_std_dB',
        'contamination', 'contamination_alt', 'drift', 'missed_spikes_est',
        'noise_cutoff', 'presence_ratio', 'presence_ratio_std',
        'slidingRP_viol', 'spike_count'
    ]
    if tbounds:
        ispi = between_sorted(spike_times, tbounds)
        spike_times = spike_times[ispi]
        spike_clusters = spike_clusters[ispi]
        spike_amps = spike_amps[ispi]
        spike_depths = spike_depths[ispi]

    if cluster_ids is None:
        cluster_ids = np.unique(spike_clusters)
    nclust = cluster_ids.size

    r = Bunch({k: np.full((nclust, ), np.nan) for k in metrics_list})
    r['cluster_id'] = cluster_ids

    # vectorized computation of basic metrics such as presence ratio and firing rate
    tmin = spike_times[0]
    tmax = spike_times[-1]
    presence_ratio = bincount2D(spike_times,
                                spike_clusters,
                                xbin=params['presence_window'],
                                ybin=cluster_ids,
                                xlim=[tmin, tmax])[0]
    r.presence_ratio = np.sum(presence_ratio > 0,
                              axis=1) / presence_ratio.shape[1]
    r.presence_ratio_std = np.std(presence_ratio, axis=1)
    r.spike_count = np.sum(presence_ratio, axis=1)
    r.firing_rate = r.spike_count / (tmax - tmin)

    # computing amplitude statistical indicators by aggregating over cluster id
    camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps),
                              spike_clusters],
                        columns=['amps', 'log_amps', 'clusters'])
    camp = camp.groupby('clusters')
    ir, ib = ismember(r.cluster_id, camp.clusters.unique())
    r.amp_min[ir] = np.array(camp['amps'].min())
    r.amp_max[ir] = np.array(camp['amps'].max())
    # this is the geometric median
    r.amp_median[ir] = np.array(10**(camp['log_amps'].median() / 20))
    r.amp_std_dB[ir] = np.array(camp['log_amps'].std())

    # loop over each cluster to compute the rest of the metrics
    for ic in np.arange(nclust):
        # slice the spike_times array
        ispikes = spike_clusters == cluster_ids[ic]
        if np.all(~ispikes):  # if this cluster has no spikes, continue
            continue
        ts = spike_times[ispikes]
        amps = spike_amps[ispikes]
        depths = spike_depths[ispikes]

        # compute metrics
        r.contamination_alt[ic] = contamination_alt(
            ts, rp=params['refractory_period'])
        r.contamination[ic], _ = contamination(ts,
                                               tmin,
                                               tmax,
                                               rp=params['refractory_period'],
                                               min_isi=params['min_isi'])
        r.slidingRP_viol[ic] = slidingRP_viol(
            ts,
            bin_size=params['bin_size'],
            thresh=params['RPslide_thresh'],
            acceptThresh=params['acceptable_contamination'])
        r.noise_cutoff[ic] = noise_cutoff(
            amps,
            quartile_length=params['nc_quartile_length'],
            n_bins=params['nc_bins'],
            n_low_bins=params['nc_n_low_bins'])
        r.missed_spikes_est[ic], _, _ = missed_spikes_est(
            amps,
            spks_per_bin=params['spks_per_bin_for_missed_spks_est'],
            sigma=params['std_smoothing_kernel_for_missed_spks_est'],
            min_num_bins=params['min_num_bins_for_missed_spks_est'])

        # wonder if there is a need to low-cut this
        r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600

    r.label = compute_labels(r)
    return r
Esempio n. 22
0
def compute_rfs(spike_times,
                spike_clusters,
                stimulus_times,
                stimulus,
                lags=8,
                binsize=0.025):
    """
    Compute receptive fields from locally sparse noise stimulus for all recorded neurons; uses a
    PSTH-like approach that averages responses from each neuron for each pixel flip
    :param spike_times: array of spike times
    :param spike_clusters: array of cluster ids associated with each entry of `spike_times`
    :param stimulus_times: (M,) array of stimulus presentation times
    :param stimulus: (M, y_pix, x_pix) array of pixel values
    :param lags: temporal dimension of receptive field
    :param binsize: length of each lag (seconds)
    :return: dictionary of "on" and "off" receptive fields (values are lists); each rf is
        [t, y_pix, x_pix]
    """

    from brainbox.processing import bincount2D

    cluster_ids = np.unique(spike_clusters)
    n_clusters = len(cluster_ids)
    _, y_pix, x_pix = stimulus.shape
    stimulus = stimulus.astype('float')
    subs = ['on', 'off']
    rfs = {
        sub: np.zeros(shape=(n_clusters, y_pix, x_pix, lags + 1))
        for sub in subs
    }
    flips = {sub: np.zeros(shape=(y_pix, x_pix)) for sub in subs}

    gray = np.median(stimulus)
    # loop over time points
    for i, t in enumerate(stimulus_times):
        # skip first frame since we're looking for pixels that flipped
        if i == 0:
            continue
        # find pixels that flipped
        frame_change = stimulus[i, :, :] - gray
        ys, xs = np.where((frame_change != 0)
                          & (stimulus[i - 1, :, :] == gray))
        # loop over changed pixels
        for y, x in zip(ys, xs):
            if frame_change[y, x] > 0:  # gray -> white
                sub = 'on'
            else:  # black -> white
                sub = 'off'
            # bin spikes in the binsize*lags seconds following this flip
            t_beg = t
            t_end = t + binsize * lags
            idxs_t = (spike_times >= t_beg) & (spike_times < t_end)
            binned_spikes, _, cluster_idxs = bincount2D(spike_times[idxs_t],
                                                        spike_clusters[idxs_t],
                                                        xbin=binsize,
                                                        xlim=[t_beg, t_end])
            # insert these binned spikes into the rfs
            _, cluster_idxs, _ = np.intersect1d(cluster_ids,
                                                cluster_idxs,
                                                return_indices=True)
            rfs[sub][cluster_idxs, y, x, :] += binned_spikes
            # record flip
            flips[sub][y, x] += 1

    # normalize spikes by number of flips
    for sub in rfs:
        for y in range(y_pix):
            for x in range(x_pix):
                if flips[sub][y, x] != 0:
                    rfs[sub][:, y, x, :] /= flips[sub][y, x]

    # turn into list
    rfs_list = {}
    for sub in subs:
        rfs_list[sub] = [rfs[sub][i, :, :, :] for i in range(n_clusters)]
    return rfs_list
Esempio n. 23
0
def compute_rfs_corr(spike_times,
                     spike_clusters,
                     stimulus_times,
                     stimulus,
                     lags=8,
                     binsize=0.025):
    """
    Compute receptive fields from locally sparse noise stimulus for all
    recorded neurons; uses a reverse correlation approach
    :param spike_times: array of spike times
    :param spike_clusters: array of cluster ids associated with each entry of `spikes`
    :param stimulus_times: (M,) array of stimulus presentation times
    :param stimulus: (M, y_pix, x_pix) array of pixel values
    :param lags: temporal dimension of receptive field
    :param binsize: length of each lag (seconds)
    :return: dictionary of "on" and "off" receptive fields (values are lists); each
        rf is [t, y_pix, x_pix]
    """

    from brainbox.processing import bincount2D
    from scipy.signal import correlate

    # bin spikes
    indx_t = (spike_times > np.min(stimulus_times)) & \
             (spike_times < np.max(stimulus_times))
    binned_spikes, ts_binned_spikes, cluster_ids = bincount2D(
        spike_times[indx_t], spike_clusters[indx_t], xbin=binsize)
    n_clusters = len(cluster_ids)

    _, y_pix, x_pix = stimulus.shape
    stimulus = stimulus.astype('float')
    gray = np.median(stimulus)

    subs = ['on', 'off']
    rfs = {
        sub: np.zeros(shape=(n_clusters, y_pix, x_pix, lags + 1))
        for sub in subs
    }

    # for indexing output of convolution
    i_end = binned_spikes.shape[1]
    i_beg = i_end - lags

    for sub in subs:
        for y in range(y_pix):
            for x in range(x_pix):
                # find times that pixels flipped
                diffs = np.concatenate([np.diff(stimulus[:, y, x]), [0]])
                if sub == 'on':  # gray -> white
                    changes = (diffs > 0) & (stimulus[:, y, x] == gray)
                else:
                    changes = (diffs < 0) & (stimulus[:, y, x] == gray)
                t_change = np.where(changes)[0]

                # put on same timescale as neural activity
                binned_stim = np.zeros(shape=ts_binned_spikes.shape)
                for t in t_change:
                    stim_time = stimulus_times[t]
                    # find nearest time in binned spikes
                    idx = np.argmin((ts_binned_spikes - stim_time)**2)
                    binned_stim[idx] = 1

                for n in range(n_clusters):
                    # cross correlate signal with spiking activity
                    # NOTE: scipy's correlate function is appx two orders of
                    # magnitude faster than numpy's correlate function on
                    # relevant data size; perhaps scipy uses FFT? Not in docs.
                    cc = correlate(binned_stim,
                                   binned_spikes[n, :],
                                   mode='full')
                    rfs[sub][n, y, x, :] = cc[i_beg:i_end + 1]

    # turn into list
    rfs_list = {}
    for sub in subs:
        rfs_list[sub] = [rfs[sub][i, :, :, :] for i in range(n_clusters)]
    return rfs_list
Esempio n. 24
0
def bin_types(spikes, trials, t_bin, clusters):
    '''
    This creates a dictionary of binned time series,
    all having the same number of observations.

    INPUT:

    spikes: alf.io.load_object(alf_path, 'spikes')
    trials: alf.io.load_object(alf_path, '_ibl_trials')
    t_bin: float, time bin in sec
    clusters: alf.io.load_object(alf_path, 'clusters')

    OUTPUT:

    binned_data['summed_spike_amps']: channels x observations
    binned_data['reward']: observations
    binned_data['choice']: observations
    binned_data['trial_number']: observations

    '''

    # TO GET MEAN: bincount2D(..., weight=positions) / bincount2D(...,
    # weight=None)

    # Bin spikes (summing together, i.e. not averaging per bin)
    R1, times1, _ = bincount2D(spikes['times'],
                               spikes['clusters'],
                               t_bin,
                               weights=spikes['amps'])

    # Get choice per bin
    R6, times6, _ = bincount2D(trials['goCue_times'], trials['choice'], t_bin)
    # Flatten choice -1 Left, 1  Right
    R6 = np.sum(R6 * np.array([[-1], [1]]), axis=0)
    R6 = np.expand_dims(R6, axis=0)
    # Fill 0 between trials with choice outcome of trial
    R6 = filliti(R6)
    R6[R6 == -1] = 0
    R6 = R6[0]

    # Get reward per bin
    R7, times7, _ = bincount2D(trials['goCue_times'], trials['feedbackType'],
                               t_bin)
    # Flatten reward -1 error, 1  reward
    R7 = np.sum(R7 * np.array([[-1], [1]]), axis=0)
    R7 = np.expand_dims(R7, axis=0)
    # Fill 0 between trials with reward outcome of trial
    R7 = filliti(R7)
    R7[R7 == -1] = 0
    R7 = R7[0]

    # restrict each time series to the same time bins
    start = max([x for x in [times1[0], times6[0], times7[0]]])
    stop = min([x for x in [times1[-1], times6[-1], times7[-1]]])

    time_points = np.linspace(start, stop, int((stop - start) / t_bin))

    binned_data = {}
    binned_data['summed_spike_amps'] = R1[:,
                                          find_nearest(times1, start):
                                          find_nearest(times1, stop)]
    binned_data['choice'] = R6[find_nearest(times6, start
                                            ):find_nearest(times6, stop)]
    binned_data['reward'] = R7[find_nearest(times7, start
                                            ):find_nearest(times7, stop)]
    binned_data['trial_number'] = np.digitize(time_points,
                                              trials['goCue_times'])

    # check lengths again for potential jumps
    chans, obs = binned_data['summed_spike_amps'].shape
    l_choice = len(binned_data['choice'])
    l_reward = len(binned_data['reward'])
    l_trial = len(binned_data['trial_number'])

    MIN = min([obs, l_choice, l_reward, l_trial])

    w = binned_data['summed_spike_amps'][:, :MIN]
    binned_data['summed_spike_amps'] = w
    binned_data['reward'] = binned_data['reward'][:MIN]
    binned_data['choice'] = binned_data['choice'][:MIN]
    binned_data['trial_number'] = binned_data['trial_number'][:MIN]

    print('Range of trials: ',
          [binned_data['trial_number'][0], binned_data['trial_number'][-1]])

    return binned_data
Esempio n. 25
0
def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
    """
    Electrode drift for spike sorted data.
    :param spike_times:
    :param spike_amps:
    :param spike_depths:
    :param display:
    :return: drift (ntimes vector) in input units (usually um)
    :return: ts (ntimes vector) time scale in seconds

    """
    # binning parameters
    DT_SECS = 1  # output sampling rate of the depth estimation (seconds)
    DEPTH_BIN_UM = 2  # binning parameter for depth
    AMP_BIN_LOG10 = [1.25,
                     3.25]  # binning parameter for amplitudes (log10 in uV)
    N_AMP = 1  # number of amplitude bins

    NXCORR = 50  # positive and negative lag in depth samples to look for depth
    NT_SMOOTH = 9  # length of the Gaussian smoothing window in samples (DT_SECS rate)

    # experimental: try the amp with a log scale
    nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
    tmin, tmax = (np.min(spike_times), np.max(spike_times))
    nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS)

    # 3d histogram of spikes along amplitude, depths and time
    atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single)
    abins = (np.log10(spike_amps * 1e6) -
             AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP
    abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1)

    for i, abin in enumerate(np.unique(abins)):
        inds = np.where(np.logical_and(abins == abin,
                                       ~np.isnan(spike_depths)))[0]
        a, _, _ = bincount2D(spike_depths[inds], spike_times[inds],
                             DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM],
                             [np.floor(tmin), np.ceil(tmax)])
        atd_hist[i] = a[:-1, :-1]

    fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM))
    # k-filter along the depth direction
    lp = dsp.fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp')
    # compute the depth lag by xcorr
    # to experiment: LP the fft for a better tracking ?
    atd_ = np.fft.fft(atd_hist, axis=-1)
    # xcorrelation against reference
    xcorr = np.real(
        np.fft.ifft(lp * atd_ *
                    np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
    xcorr = np.sum(xcorr, axis=0)
    xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]
    xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis]
    # import easyqc
    # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor')

    # to experiment: parabolic fit to get max values
    raw_drift = (parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM
    drift = smooth.rolling_window(raw_drift,
                                  window_len=NT_SMOOTH,
                                  window='hanning')
    drift = drift - np.mean(drift)
    ts = DT_SECS * np.arange(drift.size)
    if display:
        import matplotlib.pyplot as plt
        from brainbox.plot import driftmap
        fig1, axs = plt.subplots(2,
                                 1,
                                 gridspec_kw={'height_ratios': [.15, .85]},
                                 sharex=True,
                                 figsize=(20, 10))
        axs[0].plot(ts, drift)
        driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])
        axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2])
        fig2, axs = plt.subplots(2,
                                 1,
                                 gridspec_kw={'height_ratios': [.15, .85]},
                                 sharex=True,
                                 figsize=(20, 10))
        axs[0].plot(ts, drift)
        dd = np.interp(spike_times, ts, drift)
        driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1])
        axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2])
        return drift, ts, [fig1, fig2]

    return drift, ts
Esempio n. 26
0
T_BIN = 0.01

# get the data from flatiron and the current folder
one = ONE()
eid = one.search(subject='ZM_1150', date='2019-05-07', number=1)
D = one.load(eid[0], clobber=False, download_only=True)
session_path = Path(D.local_path[0]).parent

# load objects
spikes = ioalf.load_object(session_path, 'spikes')
clusters = ioalf.load_object(session_path, 'clusters')
channels = ioalf.load_object(session_path, 'channels')
trials = ioalf.load_object(session_path, '_ibl_trials')

# compute raster map as a function of cluster number
R, times, clusters = bincount2D(spikes['times'], spikes['clusters'], T_BIN)

# plot raster map
plt.imshow(R,
           aspect='auto',
           cmap='binary',
           vmax=T_BIN / 0.001 / 4,
           extent=np.r_[times[[0, -1]], clusters[[0, -1]]],
           origin='lower')
# plot trial start and reward time
reward = trials['feedback_times'][trials['feedbackType'] == 1]
iblplt.vertical_lines(trials['intervals'][:, 0],
                      ymin=0,
                      ymax=clusters[-1],
                      color='k',
                      linewidth=0.5,
Esempio n. 27
0
alf_path = '.../ZM_1735/2019-08-01/001/alf'

spikes = alf.io.load_object(alf_path, 'spikes')
clusters = alf.io.load_object(alf_path, 'clusters')
channels = alf.io.load_object(alf_path, 'channels')
trials = alf.io.load_object(alf_path, 'trials')

T_BIN = 0.01  # time bin in sec

# just get channels from probe 0, as there are two probes here
probe_id = clusters['probes'][spikes['clusters']]
restrict = np.where(probe_id == 0)[0]

# bin spikes
R, times, Clusters = bincount2D(
    spikes['times'][restrict], spikes['clusters'][restrict], T_BIN)

# Order activity by cortical depth of neurons
d = dict(zip(spikes['clusters'][restrict], spikes['depths'][restrict]))
y = sorted([[i, d[i]] for i in d])
isort = np.argsort([x[1] for x in y])
R = R[isort, :]

# get trial number for each time bin
trial_numbers = np.digitize(times, trials['goCue_times'])
print('Range of trials: ', [trial_numbers[0], trial_numbers[-1]])


def add_stim_off_times(trials):
    on = 'stimOn_times'
    off = 'stimOff_times'
Esempio n. 28
0
def driftmap(ts,
             feat,
             ax=None,
             plot_style='bincount',
             t_bin=0.01,
             d_bin=20,
             weights=None,
             vmax=None,
             **kwargs):
    """
    Plots the values of a spike feature array (y-axis) over time (x-axis).
    Two arguments can be given for the plot_style of the drift map:
    - 'scatter' : whereby each value is plotted as a marker (up to 100'000 data point)
    - 'bincount' : whereby the values are binned (optimised to represent spike raster)

    Parameters
    ----------
    feat : ndarray
        The spikes' feature values.
    ts : ndarray
        The spike timestamps from which to compute the firing rate.
    ax : axessubplot (optional)
        The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
    t_bin: time bin used when plot_style='bincount'
    d_bin: depth bin used when plot_style='bincount'
    plot_style: 'scatter', 'bincount'
    **kwargs: matplotlib.imshow arguments

    Returns
    -------
    cd: float
        The cumulative drift of `feat`.
    md: float
        The maximum drift of `feat`.

    See Also
    --------
    metrics.cum_drift
    metrics.max_drift

    Examples
    --------
    1) Plot the amplitude driftmap for unit 1.
        >>> ts = units_b['times']['1']
        >>> amps = units_b['amps']['1']
        >>> ax = bb.plot.driftmap(ts, amps)
    2) Plot the depth driftmap for unit 1.
        >>> ts = units_b['times']['1']
        >>> depths = units_b['depths']['1']
        >>> ax = bb.plot.driftmap(ts, depths)
    """
    iok = ~np.isnan(feat)
    if ax is None:
        fig, ax = plt.subplots()

    if plot_style == 'scatter' and len(ts) < 100000:
        print('here todo')
        if 'color' not in kwargs.keys():
            kwargs['color'] = 'k'
        ax.plot(ts, feat, **kwargs)
    else:
        # compute raster map as a function of site depth
        R, times, depths = bincount2D(ts[iok],
                                      feat[iok],
                                      t_bin,
                                      d_bin,
                                      weights=weights)
        # plot raster map
        ax.imshow(R,
                  aspect='auto',
                  cmap='binary',
                  vmin=0,
                  vmax=vmax or np.std(R) * 4,
                  extent=np.r_[times[[0, -1]], depths[[0, -1]]],
                  origin='lower',
                  **kwargs)
    ax.set_xlabel('time (secs)')
    ax.set_ylabel('depth (um)')
    return ax
Esempio n. 29
0
def quick_unit_metrics(spike_clusters,
                       spike_times,
                       spike_amps,
                       spike_depths,
                       params=METRICS_PARAMS):
    """
    Computes single unit metrics from only the spike times, amplitudes, and depths for a set of
    units.

    Metrics computed:
        num_spikes
        firing_rate
        presence_ratio
        presence_ratio_std
        frac_isi_viol (see `isi_viol`)
        contamination_est (see `contamination_est`)
        contamination_est2 (see `contamination_est2`)
        missed_spikes_est (see `missed_spikes_est`)
        cum_amp_drift (see `cum_drift`)
        max_amp_drift (see `max_drift`)
        cum_depth_drift (see `cum_drift`)
        max_depth_drift (see `max_drift`)

    Parameters
    ----------
    spike_clusters : ndarray_like
        A vector of the unit ids for a set of spikes.
    spike_times : ndarray_like
        A vector of the timestamps for a set of spikes.
    spike_amps : ndarray_like
        A vector of the amplitudes for a set of spikes.
    spike_depths : ndarray_like
        A vector of the depths for a set of spikes.
    params : dict (optional)
        Parameters used for computing some of the metrics in the function:
            'presence_window': float
                The time window (in s) used to look for spikes when computing the presence ratio.
            'refractory_period': float
                The refractory period used when computing isi violations and the contamination
                estimate.
            'min_isi': float
                The minimum interspike-interval (in s) for counting duplicate spikes when computing
                the contamination estimate.
            'spks_per_bin_for_missed_spks_est': int
                The number of spikes per bin used to compute the spike amplitude pdf for a unit,
                when computing the missed spikes estimate.
            'std_smoothing_kernel_for_missed_spks_est': float
                The standard deviation for the gaussian kernel used to compute the spike amplitude
                pdf for a unit, when computing the missed spikes estimate.
            'min_num_bins_for_missed_spks_est': int
                The minimum number of bins used to compute the spike amplitude pdf for a unit,
                when computing the missed spikes estimate.

    Returns
    -------
    r : bunch
        A bunch whose keys are the computed spike metrics.

    Notes
    -----
    This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf`
    during alf extraction of an ephys dataset in the ibl ephys extraction pipeline.

    Examples
    --------
    1) Compute quick metrics from a ks2 output directory:
        >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path
        >>> m = phy_model_from_ks2_path(path_to_ks2_out)
        >>> cluster_ids = m.spike_clusters
        >>> ts = m.spike_times
        >>> amps = m.amplitudes
        >>> depths = m.depths
        >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths)
    """

    cluster_ids = np.arange(np.max(spike_clusters) + 1)
    nclust = cluster_ids.size
    r = Bunch({
        'cluster_id': cluster_ids,
        'num_spikes': np.full((nclust, ), np.nan),
        'firing_rate': np.full((nclust, ), np.nan),
        'presence_ratio': np.full((nclust, ), np.nan),
        'presence_ratio_std': np.full((nclust, ), np.nan),
        'frac_isi_viol': np.full((nclust, ), np.nan),
        'contamination_est': np.full((nclust, ), np.nan),
        'contamination_est2': np.full((nclust, ), np.nan),
        'missed_spikes_est': np.full((nclust, ), np.nan),
        'cum_amp_drift': np.full((nclust, ), np.nan),
        'max_amp_drift': np.full((nclust, ), np.nan),
        'cum_depth_drift': np.full((nclust, ), np.nan),
        'max_depth_drift': np.full((nclust, ), np.nan),
        # could add 'epoch_name' in future:
        # 'epoch_name': np.zeros(nclust, dtype='object'),
    })

    # vectorized computation of basic metrics such as presence ratio and firing rate
    tmin = spike_times[0]
    tmax = spike_times[-1]
    presence_ratio = bincount2D(spike_times,
                                spike_clusters,
                                xbin=params['presence_window'],
                                ybin=cluster_ids,
                                xlim=[tmin, tmax])[0]
    r.presence_ratio = np.sum(presence_ratio > 0,
                              axis=1) / presence_ratio.shape[1]
    r.presence_ratio_std = np.std(presence_ratio, axis=1)
    r.num_spikes = np.sum(presence_ratio, axis=1)
    r.firing_rate = r.num_spikes / (tmax - tmin)

    # loop over each cluster to compute the rest of the metrics
    for ic in np.arange(nclust):
        # slice the spike_times array
        ispikes = spike_clusters == cluster_ids[ic]
        if np.all(~ispikes):  # if this cluster has no spikes, continue
            continue
        ts = spike_times[ispikes]
        amps = spike_amps[ispikes]
        depths = spike_depths[ispikes]

        # compute metrics
        r.frac_isi_viol[ic], _, _ = isi_viol(ts,
                                             rp=params['refractory_period'])
        r.contamination_est[ic] = contamination_est(
            ts, rp=params['refractory_period'])
        r.contamination_est2[ic], _ = contamination_est2(
            ts,
            tmin,
            tmax,
            rp=params['refractory_period'],
            min_isi=params['min_isi'])
        try:  # this may fail because `missed_spikes_est` requires a min number of spikes
            r.missed_spikes_est[ic], _, _ = missed_spikes_est(
                amps,
                spks_per_bin=params['spks_per_bin_for_missed_spks_est'],
                sigma=params['std_smoothing_kernel_for_missed_spks_est'],
                min_num_bins=params['min_num_bins_for_missed_spks_est'])
        except AssertionError:
            pass
        r.cum_amp_drift[ic] = cum_drift(amps)
        r.max_amp_drift[ic] = max_drift(amps)
        r.cum_depth_drift[ic] = cum_drift(depths)
        r.max_depth_drift[ic] = max_drift(depths)

    return r
Esempio n. 30
0
    def __init__(self,
                 design_matrix,
                 spk_times,
                 spk_clu,
                 binwidth=0.02,
                 mintrials=100,
                 stepwise=False):
        """
        Construct GLM object using information about all trials, and the relevant spike times.
        Only ingests data, and further object methods must be called to describe kernels, gain
        terms, etc. as components of the model.

        Parameters
        ----------
        design_matrix: brainbox.modeling.design_matrix.DesignMatrix
            Design matrix object which has already been compiled for use with neural data.
        spk_times: numpy.array of floats
            1-D array of times at which spiking events were detected, in seconds.
        spk_clu: numpy.array of integers
            1-D array of same shape as spk_times, with integer cluster IDs identifying which
            cluster a spike time belonged to.
        binwidth : float
            Size of bins to put spikes in to, in seconds.
        mintrials: int
            Minimum number of trials in which neurons fired a spike in order to be fit. Defaults
            to 100 trials.
        stepwise: bool
            Whether or not to perform stepwise regression, in which the model is built iteratively
            from only the mean rate, up. This allows comparison of D^2 scores for sub-models which
            incorporate only some parameters, to see which regressors actually improve
            explainability. Defaults to False.
        """
        # Data checks #
        if not len(spk_times) == len(spk_clu):
            raise IndexError("Spike times and cluster IDs are not same length")
        if not design_matrix.compiled:
            raise AttributeError(
                'Design matrix object must be compiled before passing to fit')

        # Filter out cells which don't meet the criteria for minimum spiking, while doing trial
        # assignment
        base_df = design_matrix.base_df
        clu_ids = np.unique(spk_clu).flatten()
        trbounds = base_df[['trial_start',
                            'trial_end']]  # Get the start/end of trials
        # Initialize a Cells x Trials bool array to easily see how many trials a clu spiked
        trialspiking = np.zeros((base_df.index.max() + 1, clu_ids.max() + 1),
                                dtype=bool)
        # Empty trial duration value to use later
        # Iterate through each trial, and store the relevant spikes for that trial into a dict
        # Along with the cluster labels. This makes binning spikes and accessing spikes easier.
        spks = {}
        clu = {}
        st_endlast = 0
        for i, (start, end) in trbounds.iterrows():
            st_startind = np.searchsorted(spk_times[st_endlast:],
                                          start) + st_endlast
            st_endind = np.searchsorted(
                spk_times[st_endlast:], end, side='right') + st_endlast
            st_endlast = st_endind
            trial_clu = np.unique(spk_clu[st_startind:st_endind])
            trialspiking[i, trial_clu] = True
            spks[i] = spk_times[st_startind:st_endind] - start
            clu[i] = spk_clu[st_startind:st_endind]

        # Set model parameters to begin with
        self.design = design_matrix
        self.spikes = spks
        self.clu = clu
        self.clu_ids = np.argwhere(
            np.sum(trialspiking, axis=0) > mintrials).flatten()
        self.stepwise = stepwise
        self.binwidth = binwidth

        if len(self.clu_ids) == 0:
            raise UserWarning('No neuron fired a spike in a minimum number.')

        # Bin spikes
        spkarrs, arrdiffs = [], []
        for i in self.design.trialsdf.index:
            duration = self.design.trialsdf.loc[i, 'duration']
            durmod = duration % self.binwidth
            if durmod > (self.binwidth / 2):
                duration = duration - (self.binwidth / 2)
            if len(self.spikes[i]) == 0:
                arr = np.zeros((self.binf(duration), len(self.clu_ids)))
                spkarrs.append(arr)
                continue
            spks = self.spikes[i]
            clu = self.clu[i]
            arr = bincount2D(spks,
                             clu,
                             xbin=self.binwidth,
                             ybin=self.clu_ids,
                             xlim=[0, duration])[0]
            arrdiffs.append(arr.shape[1] - self.binf(duration))
            spkarrs.append(arr.T)
        y = np.vstack(spkarrs)
        if hasattr(self.design, 'dm'):
            assert y.shape[0] == self.design.dm.shape[
                0], "Oh shit. Indexing error."
        self.binnedspikes = y