コード例 #1
0
    def test_indexing_1(self):

        bst = (nel.SpikeTrainArray(
            [[1, 2, 3, 4, 5, 6, 7, 8, 9.5, 10, 10.5, 11.4, 15, 18, 19, 20, 21],
             [4, 8, 17]],
            support=nel.EpochArray([[0, 8], [12, 22]]),
            fs=1).bin(ds=1))
        data = bst.data

        bst._desc = 'test case for bst'

        expected_bins = np.array([2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 20, 21, 22])
        expected_bin_centers = np.array(
            [2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 12.5, 13.5, 20.5, 21.5])
        expected_binned_support = np.array([[0, 5], [6, 7], [8, 9]])

        bst_indexed = bst[nel.EpochArray([[2, 8], [9, 14], [19.5, 25]]), 1]

        assert bst_indexed.n_series == 1

        # binned support is an int array and should be exact. The others
        # are floats so we use np.allclose
        assert bst_indexed.binned_support.dtype.kind in ('i', 'u')
        assert np.all(bst_indexed.binned_support == expected_binned_support)
        assert np.allclose(bst_indexed.bins, expected_bins)
        assert np.allclose(bst_indexed.bin_centers, expected_bin_centers)

        # make sure original object's data didn't get mutated when indexing
        assert np.all(bst.data == data)

        # make sure metadata didn't get lost!
        assert bst_indexed._desc == bst._desc
コード例 #2
0
def get_base_data(data_path, spike_path, session):
    """
    Load and format data for replay analysis
    """
    # get data session path from mat file
    path = functions.get_session_path(
        os.path.join(data_path, session) + '.mat')
    # load position data from .mat file
    df = functions.load_position(os.path.join(data_path, session) + '.mat')
    # get the size of each maze
    maze_size_cm = functions.get_maze_size_cm(
        os.path.join(data_path, session) + '.mat')
    # get session epochs
    session_epochs = nel.EpochArray(
        functions.get_epochs(os.path.join(data_path, session) + '.mat'))
    # rescale epoch coordinates into cm
    df = rescale_coords(df, session_epochs, maze_size_cm)
    # put position into object
    pos = nel.AnalogSignalArray(timestamps=df.ts,
                                data=[df.x],
                                fs=1 / statistics.mode(np.diff(df.ts)),
                                support=(session_epochs))

    # load spikes & add to object
    spikes = np.load(os.path.join(spike_path, session) + '.npy',
                     allow_pickle=True)
    spikes_ = list(itertools.chain(*spikes))
    session_bounds = nel.EpochArray([min(spikes_), max(spikes_)])
    st = nel.SpikeTrainArray(timestamps=spikes,
                             support=session_bounds,
                             fs=32000)

    return maze_size_cm, pos, st
コード例 #3
0
    def test_empty(self):

        bst = (nel.SpikeTrainArray([[3, 4, 5, 6, 7], [2, 4, 5]],
                                   support=nel.EpochArray([0, 8]),
                                   fs=1).bin(ds=1))

        desc = 'test case for bst'
        bst._desc = desc
        n_series = bst.n_series

        bst1 = bst.empty(inplace=False)
        bst.empty(inplace=True)

        assert bst.binned_support == None
        assert bst.bin_centers == None
        assert bst.bins == None
        assert bst.eventarray.isempty
        assert bst.n_series == n_series
        assert bst._desc == desc
        assert bst.support.isempty

        # Emptying should be consistent whether we do it
        # in place or not
        assert bst1.binned_support == bst.binned_support
        assert bst1.bin_centers == bst.bin_centers
        assert bst1.bins == bst.bins
        assert bst1.eventarray.isempty
        assert bst1._desc == bst._desc
        assert bst1.support.isempty
コード例 #4
0
def load_add_spikes(spike_path, session, fs=32000):
    spikes = np.load(os.path.join(spike_path, session) + '.npy',
                     allow_pickle=True)
    spikes_ = list(itertools.chain(*spikes))
    session_bounds = nel.EpochArray([min(spikes_), max(spikes_)])
    return nel.SpikeTrainArray(timestamps=spikes,
                               support=session_bounds,
                               fs=fs)
コード例 #5
0
 def test_asa_mean3(self):
     asa = nel.AnalogSignalArray([[1, 2, 4, 5], [7, 8, 9, 10]])
     asa.add_signal([3, 4, 5, 6])
     asa = asa[nel.EpochArray([[0, 1.1], [1.9, 3.1]])]
     means = [seg.mean() for seg in asa]
     assert np.array(means == np.array(
         [np.array([1.5, 7.5, 3.5]),
          np.array([4.5, 9.5, 5.5])])).all()
コード例 #6
0
    def test_indexing(self):

        sta = (nel.SpikeTrainArray(
            [[1, 2, 3, 4, 5, 6, 7, 8, 9.5, 10, 10.5, 11.4, 15, 18, 19, 20, 21],
             [4, 8, 17]],
            support=nel.EpochArray([[0, 8], [12, 22]]),
            fs=1).bin(ds=1))
        sta._desc = 'test case for sta'
        data = sta.data

        sta_indexed = sta[nel.EpochArray([[2, 8], [9, 14], [19.5, 25]]), 1]

        assert sta_indexed.n_series == 1

        # make sure original object's data didn't get mutated when indexing
        assert np.all(sta.data == data)

        # make sure metadata didn't get lost!
        assert sta_indexed._desc == sta._desc
コード例 #7
0
    def test_merge(self):
        times = np.array([[1.0, 3.0], [4.0, 8.0], [12.0, 13.0], [20.0, 25.0],
                          [1.0, 5.0], [6.0, 7.0], [15.0, 18.0], [30.0, 35.0]])

        epoch = nel.EpochArray(times)
        merged = epoch.merge()
        assert np.allclose(merged.starts,
                           np.array([1.0, 12.0, 15.0, 20.0, 30.0]))
        assert np.allclose(merged.stops,
                           np.array([8.0, 13.0, 18.0, 25.0, 35.0]))
コード例 #8
0
    def test_intersection_of_contiguous_epochs(self):
        """We want contiguous intervals to stay contiguous, even if intersecting"""
        x = nel.EpochArray([[2, 3], [3, 4], [5, 7]])
        y = nel.EpochArray([2, 8])

        assert x[y].n_intervals == 3
        assert y[x].n_intervals == 3


# epochs_a = nel.EpochArray([[0, 5], [5,10], [10,12], [12,16], [14,18]])
# epochs_b = nel.EpochArray([[3, 12], [15,20], [15,18]])
# epochs_c = nel.EpochArray([[3,21]])

# epochs_a[epochs_b][epochs_c].time

# array([[ 3,  5],
#        [ 5, 10],
#        [10, 12],
#        [15, 16],
#        [15, 16],
#        [15, 18],
#        [15, 18]])
コード例 #9
0
    def test_copy_without_data(self):

        sta = nel.SpikeTrainArray([[3, 4, 5, 6, 7], [2, 4, 5]],
                                  support=nel.EpochArray([0, 8]),
                                  fs=1)

        desc = 'test case for sta'
        sta._desc = desc

        sta_copied = sta._copy_without_data()

        assert sta_copied.n_series == sta.n_series
        assert sta_copied._desc == sta._desc
        assert sta_copied.isempty
コード例 #10
0
    def copy_without_data(self):

        bst = (nel.SpikeTrainArray([[3, 4, 5, 6, 7], [2, 4, 5]],
                                   support=nel.EpochArray([0, 8]),
                                   fs=1).bin(ds=1))

        desc = 'test case for bst'
        bst._desc = desc

        bst_copied = bst._copy_without_data()

        assert bst_copied.n_series == bst.n_series
        assert bst._desc == desc
        assert bst.isempty
        assert bst.eventarray.isempty
コード例 #11
0
    def test_indexing_2(self):

        # support indexing by list
        bst = (nel.SpikeTrainArray(
            [1, 2, 3, 4, 5, 6, 7, 8, 9.5, 10, 10.5, 11.4, 15, 18, 19, 20, 21],
            support=nel.EpochArray([[0, 8], [10, 12], [15, 22]]),
            fs=1).bin(ds=1))

        bst_indexed = bst[[0, 1, 2]]
        assert bst_indexed.n_intervals == 3

        # Now test if we don't take all epochs, the original object
        # should not have been mutated
        bst_indexed = bst[[1, 2]]
        assert bst_indexed.n_intervals == 2
        assert bst.n_intervals == 3
コード例 #12
0
    def test_empty(self):

        sta = nel.SpikeTrainArray([[3, 4, 5, 6, 7], [2, 4, 5]],
                                  support=nel.EpochArray([0, 8]),
                                  fs=1)

        desc = 'test case for sta'
        sta._desc = desc
        n_series = sta.n_series

        sta1 = sta.empty(inplace=False)
        sta.empty(inplace=True)

        assert sta.n_series == n_series
        assert sta._desc == desc  # ensure metadata preserved
        assert sta.isempty
        assert sta.support.isempty

        # Emptying should be consistent whether we do it
        # in place or not
        assert sta1.n_series == sta.n_series
        assert sta1._desc == sta._desc
        assert sta1.isempty
        assert sta1.support.isempty
コード例 #13
0
 def test_partition(self):
     ep = nel.EpochArray([0, 10])
     partitioned = ep.partition(n_intervals=5)
     assert ep.n_intervals == 1
     assert partitioned.n_intervals == 5
コード例 #14
0
def run_all(session,
            data_path,
            spike_path,
            save_path,
            mua_df,
            df_cell_class,
            traj_shuff=1500,
            verbose=False):
    """
    Main function that conducts the replay analysis
    """
    if verbose:
        print('loading data')
    maze_size_cm, pos, st_all = get_base_data(data_path, spike_path, session)

    # to make everything more simple, lets restrict to just the linear track
    pos = pos[0]
    st_all = st_all[0]
    maze_size_cm = maze_size_cm[0]

    # compute and smooth speed
    speed1 = nel.utils.ddt_asa(pos, smooth=True, sigma=0.1, norm=True)

    # find epochs where the animal ran > 4cm/sec
    run_epochs = nel.utils.get_run_epochs(speed1, v1=4, v2=4)

    # set up results
    results = {}

    # loop through each area seperately
    areas = df_cell_class.area[df_cell_class.session == session]
    for current_area in pd.unique(areas):
        if verbose:
            print('running through: ', current_area)

        # subset units to current area
        st = st_all._unit_subset(np.where(areas == current_area)[0] + 1)
        # reset unit ids like the other units never existed
        st.series_ids = np.arange(0, len(st.series_ids)) + 1

        # restrict spike trains to those epochs during which the animal was running
        st_run = st[run_epochs]
        ds_run = 0.5
        ds_50ms = 0.05
        # smooth and re-bin:
        #     sigma = 0.3 # 300 ms spike smoothing
        bst_run = st_run.bin(ds=ds_50ms).smooth(
            sigma=0.3, inplace=True).rebin(w=ds_run / ds_50ms)

        sigma = 3  # smoothing std dev in cm
        tc = nel.TuningCurve1D(bst=bst_run,
                               extern=pos,
                               n_extern=40,
                               extmin=0,
                               extmax=maze_size_cm,
                               sigma=sigma,
                               min_duration=0)

        # locate pyr cells with >= 100 spikes, peak rate >= 1 Hz, peak/mean ratio >=1.5
        peak_firing_rates = tc.max(axis=1)
        mean_firing_rates = tc.mean(axis=1)
        ratio = peak_firing_rates / mean_firing_rates
        temp_df = df_cell_class[(df_cell_class.session == session)
                                & (df_cell_class.area == current_area)]
        unit_ids_to_keep = (
            np.where(((temp_df.cell_type == "pyr")) & (temp_df.n_spikes >= 100)
                     & (tc.ratemap.max(axis=1) >= 1) & (ratio >= 1.5))[0] +
            1).squeeze().tolist()

        if isinstance(unit_ids_to_keep, int):
            print('warning: only 1 unit...skipping')
            results[current_area] = {}
            continue
        elif len(unit_ids_to_keep) == 0:
            print('warning: no units...skipping')
            results[current_area] = {}
            continue

        sta_placecells = st._unit_subset(unit_ids_to_keep)
        tc = tc._unit_subset(unit_ids_to_keep)
        total_units = sta_placecells.n_active
        # tc.reorder_units(inplace=True)

        if verbose:
            print('decoding and scoring position')

        # access decoding accuracy on behavioral time scale
        decoding_r2, median_error, decoding_r2_shuff, _ = decode_and_shuff(
            bst_run.loc[:, unit_ids_to_keep], tc, pos, n_shuffles=1000)
        # check decoding quality against chance distribution
        _, decoding_r2_pval = get_significant_events(decoding_r2,
                                                     decoding_r2_shuff)

        if decoding_r2_pval > 0.05:
            print('warning: poor decoding...skipping')
            results[current_area] = {}
            continue

        # create intervals for PBEs epochs
        temp_df = mua_df[mua_df.session == session]

        # restrict to events at least 80ms
        temp_df = temp_df[temp_df.ripple_duration >= 0.08]

        if temp_df.shape[0] == 0:
            print('warning: no PBE events...skipping')
            results[current_area] = {}
            continue

        # make epoch object
        PBEs = nel.EpochArray(
            [np.array([temp_df.start_time, temp_df.end_time]).T])

        # bin data into 20ms
        bst_placecells = sta_placecells[PBEs].bin(ds=0.02)

        # count units per event
        n_active = [bst.n_active for bst in bst_placecells]
        n_active = np.array(n_active)
        # also count the proportion of bins in each event with 0 activity
        inactive_bin_prop = [
            sum(bst.n_active_per_bin == 0) / bst.lengths[0]
            for bst in bst_placecells
        ]
        inactive_bin_prop = np.array(inactive_bin_prop)
        # restrict bst to instances with >= 5 active units and < 50% inactive bins
        idx = (n_active >= 5) & (inactive_bin_prop < .5)
        bst_placecells = bst_placecells[np.where(idx)[0]]
        # restrict df to instances with >= 5 active units
        temp_df = temp_df[idx]
        n_active = n_active[idx]
        inactive_bin_prop = inactive_bin_prop[idx]

        # decode each event
        posteriors, bdries, mode_pth, mean_pth = nel.decoding.decode1D(
            bst_placecells, tc, xmin=0, xmax=maze_size_cm)

        # score each event using trajectory_score_bst (sums the posterior probability in a range (w) from the LS line)
        if verbose:
            print('scoring events')

        scores, scores_time_swap, scores_col_cycle = replay.trajectory_score_bst(
            bst_placecells, tc, w=3, n_shuffles=traj_shuff, normalize=True)

        # find sig events using time and column shuffle distributions
        _, score_pval_time_swap = get_significant_events(
            scores, scores_time_swap)
        _, score_pval_col_cycle = get_significant_events(
            scores, scores_col_cycle)

        if verbose:
            print('extracting features')
        (traj_dist, traj_speed, traj_step, replay_type, dist_rat_start,
         dist_rat_end, position) = get_features(bst_placecells, posteriors,
                                                bdries, mode_pth, pos,
                                                list(temp_df.ep_type))

        slope, intercept, r2values = replay.linregress_bst(bst_placecells, tc)

        # package data into results dictionary
        results[current_area] = {}

        results[current_area]['sta_placecells'] = sta_placecells
        results[current_area]['bst_placecells'] = bst_placecells
        results[current_area]['tc'] = tc
        results[current_area]['posteriors'] = posteriors
        results[current_area]['bdries'] = bdries
        results[current_area]['mode_pth'] = mode_pth
        results[current_area]['position'] = position

        # add event by event metrics to df
        temp_df['n_active'] = n_active
        temp_df['inactive_bin_prop'] = inactive_bin_prop
        temp_df['trajectory_score'] = scores
        temp_df['r_squared'] = r2values
        temp_df['slope'] = slope
        temp_df['intercept'] = intercept
        temp_df['score_pval_time_swap'] = score_pval_time_swap
        temp_df['score_pval_col_cycle'] = score_pval_col_cycle
        temp_df['traj_dist'] = traj_dist
        temp_df['traj_speed'] = traj_speed
        temp_df['traj_step'] = traj_step
        temp_df['replay_type'] = replay_type
        temp_df['dist_rat_start'] = dist_rat_start
        temp_df['dist_rat_end'] = dist_rat_end
        results[current_area]['df'] = temp_df

        results[current_area]['session'] = session
        results[current_area]['decoding_r2'] = decoding_r2
        results[current_area]['decoding_r2_pval'] = decoding_r2_pval
        results[current_area]['decoding_median_error'] = median_error
        results[current_area]['total_units'] = total_units

    return results
コード例 #15
0
 def test_asa_mean2(self):
     asa = nel.AnalogSignalArray([[1, 2, 4, 5], [7, 8, 9, 10]])
     asa.add_signal([3, 4, 5, 6])
     asa = asa[nel.EpochArray([[0, 1.1], [1.9, 3.1]])]
     assert np.array(asa.mean() == np.array([3., 8.5, 4.5])).all()