Ejemplo n.º 1
0
def test_concatenate_firings():
    M, N1, N2 = 4, 2000, 30000
    test_offset_str = '300000,123456789'
    test_offset = [300000, 123456789]
    fir1 = np.around(np.random.rand(M, N1), decimals=3)
    mlpy.writemda64(fir1, 'tmp.fir1.mda')
    fir2 = np.around(np.random.rand(M, N2), decimals=3)
    mlpy.writemda64(fir2, 'tmp.fir2.mda')
    fir1_incr = fir1
    fir2_incr = fir2
    fir12 = np.append(fir1, fir2, axis=1)
    fir12 = np.around(fir12, decimals=3)
    fir1_incr[1, :] += test_offset[0]
    fir2_incr[1, :] += test_offset[1]
    fir12_incr = np.append(fir1_incr, fir2_incr, axis=1)
    concatenate_firings(firings_list=['tmp.fir1.mda', 'tmp.fir2.mda'],
                        firings_out='tmp.test_fir12.mda',
                        time_offsets=test_offset_str,
                        increment_labels='false')
    concatenate_firings(firings_list=['tmp.fir1.mda', 'tmp.fir2.mda'],
                        firings_out='tmp.test_fir12_incr.mda',
                        time_offsets=test_offset_str,
                        increment_labels='true')
    test_fir12 = mlpy.readmda('tmp.test_fir12.mda')
    test_fir12 = np.around(test_fir12, decimals=3)
    test_fir12_incr = mlpy.readmda('tmp.test_fir12_incr.mda')
    test_fir12_incr = np.around(test_fir12_incr, decimals=3)
    np.testing.assert_array_almost_equal(fir12, test_fir12, decimal=3)
    np.testing.assert_array_almost_equal(fir12_incr,
                                         test_fir12_incr,
                                         decimal=3)
    return True
Ejemplo n.º 2
0
def concatenate_firings(*,
                        firings_list,
                        firings_out,
                        time_offsets,
                        increment_labels='false'):
    """
    Combine a list of firings files to form a single firings file

    Parameters
    ----------
    firings_list : INPUT
        A list of paths of firings mda files to be concatenated
    firings_out : OUTPUT
        ...

    time_offsets : string
        An array of time offsets for each firings file. Expect one offset for each firings file.
        ...
    increment_labels : string
        ...
    """
    if time_offsets:
        time_offsets = np.fromstring(time_offsets, dtype=np.float_, sep=',')
    else:
        time_offsets = np.zeros(len(firings_list))
    if len(firings_list) == len(time_offsets):
        concatenated_firings = np.zeros(
            (3, 0))  #default to case where the list is empty
        first = True
        for idx, firings in enumerate(firings_list):
            to_append = mlpy.readmda(firings)
            to_append[1, :] += time_offsets[idx]
            if not first:
                if increment_labels == 'true':
                    to_append[2, :] += max(concatenated_firings[
                        2, :])  #add the Kmax from previous
            if first:
                concatenated_firings = to_append
            else:
                concatenated_firings = np.append(concatenated_firings,
                                                 to_append,
                                                 axis=1)
            first = False
        mlpy.writemda64(concatenated_firings, firings_out)
        return True
    else:
        print('Mismatch between number of firings files and number of offsets')
        return False
Ejemplo n.º 3
0
def join_segments(*, timeseries_list, firings_list, dmatrix_out,
                  templates_out):
    """
    Join the results of spike sorting on a sequence of time segments to form a single firings file

    Parameters
    ----------
    timeseries_list : INPUT
        A list of paths of adjacent preprocessed timeseries segment files
    firings_list : INPUT
        A list of paths to corresponding firings files
        
    dmatrix_out : OUTPUT
        dmatrix for debugging    
    templates_out : OUTPUT
        templates for debugging

    """
    X = DiskReadMda(timeseries_list[0])
    M = X.N1()
    clip_size = 100
    num_segments = len(timeseries_list)
    firings_arrays = []
    for j in range(num_segments):
        F = readmda(firings_list[j])
        firings_arrays.append(F)
    Kmax = 0
    for j in range(num_segments):
        F = firings_arrays[j]
        labels = F[2, :]
        Kmax = int(max(Kmax, np.max(labels)))
    dmatrix = np.ones((Kmax, Kmax, num_segments - 1)) * (-1)
    templates = np.zeros((M, clip_size, Kmax, 2 * (num_segments - 1)))

    for j in range(num_segments - 1):
        print('Computing dmatrix between segments %d and %d' % (j, j + 1))
        (dmatrix0, templates1,
         templates2) = compute_dmatrix(timeseries_list[j],
                                       timeseries_list[j + 1],
                                       firings_arrays[j],
                                       firings_arrays[j + 1],
                                       clip_size=clip_size)
        dmatrix[0:dmatrix0.shape[0], 0:dmatrix0.shape[1], j] = dmatrix0
        templates[:, :, 0:dmatrix0.shape[0], j * 2] = templates1
        templates[:, :, 0:dmatrix0.shape[1], j * 2 + 1] = templates2

    writemda64(templates, templates_out)
    return writemda64(dmatrix, dmatrix_out)
Ejemplo n.º 4
0
def test_compute_templates():
    M, N, K, T, L = 5, 1000, 6, 50, 100
    X = np.random.rand(M, N)
    writemda32(X, 'tmp.mda')
    F = np.zeros((3, L))
    F[1, :] = 1 + np.random.randint(N, size=(1, L))
    F[2, :] = 1 + np.random.randint(K, size=(1, L))
    writemda64(F, 'tmp2.mda')
    ret = compute_templates(timeseries='tmp.mda',
                            firings='tmp2.mda',
                            templates_out='tmp3.mda',
                            clip_size=T)
    assert (ret)
    templates0 = readmda('tmp3.mda')
    assert (templates0.shape == (M, T, K))
    return True
Ejemplo n.º 5
0
def create_label_map(*,
                     metrics,
                     label_map_out,
                     firing_rate_thresh=.05,
                     isolation_thresh=0.95,
                     noise_overlap_thresh=.03,
                     peak_snr_thresh=1.5):
    """
    Generate a label map based on the metrics file, where labels being mapped to zero are to be removed.

    Parameters
    ----------
    metrics : INPUT
        Path of metrics json file to be used for generating the label map
    label_map_out : OUTPUT
        Path to mda file where the second column is the present label, and the first column is the new label
        ...
    firing_rate_thresh : float64
        (Optional) firing rate must be above this
    isolation_thresh : float64
        (Optional) isolation must be above this
    noise_overlap_thresh : float64
        (Optional) noise_overlap_thresh must be below this
    peak_snr_thresh : float64
        (Optional) peak snr must be above this
    """
    #TODO: Way to pass in logic or thresholds flexibly

    label_map = []

    #Load json
    with open(metrics) as metrics_json:
        metrics_data = json.load(metrics_json)

    #Iterate through all clusters
    for idx in range(len(metrics_data['clusters'])):
        if metrics_data['clusters'][idx]['metrics']['firing_rate'] < firing_rate_thresh or \
            metrics_data['clusters'][idx]['metrics']['isolation'] < isolation_thresh or \
            metrics_data['clusters'][idx]['metrics']['noise_overlap'] > noise_overlap_thresh or \
            metrics_data['clusters'][idx]['metrics']['peak_snr'] < peak_snr_thresh:
            #Map to zero (mask out)
            label_map.append([0, metrics_data['clusters'][idx]['label']])
        elif metrics_data['clusters'][idx]['metrics'][
                'bursting_parent']:  #Check if burst parent exists
            label_map.append([
                metrics_data['clusters'][idx]['metrics']['bursting_parent'],
                metrics_data['clusters'][idx]['label']
            ])
        else:
            label_map.append([
                metrics_data['clusters'][idx]['label'],
                metrics_data['clusters'][idx]['label']
            ])  # otherwise, map to itself!

    #Writeout
    return writemda64(np.array(label_map), label_map_out)
Ejemplo n.º 6
0
def test_extract_clips():
    M, T, L, N = 5, 100, 100, 1000
    X = np.random.rand(M, N).astype(np.float32)
    writemda32(X, 'tmp.mda')
    F = np.zeros((2, L))
    F[1, :] = 200 + np.random.randint(N - 400, size=(1, L))
    writemda64(F, 'tmp2.mda')
    ret = extract_clips(timeseries='tmp.mda',
                        firings='tmp2.mda',
                        clips_out='tmp3.mda',
                        clip_size=T)
    assert (ret)
    clips0 = readmda('tmp3.mda')
    assert (clips0.shape == (M, T, L))
    t0 = int(F[1, 10])
    a = int(np.floor((T + 1) / 2 - 1))
    np.array_equal(clips0[:, :, 10], X[:, t0 - a:t0 - a + T])
    #np.testing.assert_almost_equal(clips0[:,:,10],X[:,t0-a:t0-a+T],decimal=4)
    return True
Ejemplo n.º 7
0
def test_extract_timeseries():
    M, N = 4, 10000
    X = np.random.rand(M, N)
    X.astype('float64').transpose().tofile('tmp.dat')
    ret = extract_timeseries(timeseries="tmp.dat",
                             timeseries_out="tmp2.mda",
                             channels="1,3",
                             t1=-1,
                             t2=-1,
                             timeseries_num_channels=M,
                             timeseries_dtype='float64')
    writemda64(X, 'tmp.mda')
    #ret=extract_timeseries(timeseries="tmp.mda",timeseries_out="tmp2.mda",channels="1,3",t1=-1,t2=-1)
    assert (ret)
    A = readmda('tmp.mda')
    B = readmda('tmp2.mda')
    assert (B.shape[0] == 2)
    assert (B.shape[1] == N)
    assert (np.array_equal(X[[0, 2], ], B))
    return True
Ejemplo n.º 8
0
def synthesize_random_firings(*,firings_out,K=20,samplerate=30000,duration=60):
    """
    Synthesize random waveforms for use in creating a synthetic timeseries dataset

    Parameters
    ----------
    firings_out : OUTPUT
        Path to output firings mda file. 3xL, L is the number of events, second row are timestamps, third row are integer unit labels
    
    K : int
        (Optional) number of simulated units
    samplerate : double
        (Optional) sampling frequency in Hz
    duration : double
        (Optional) duration of the simulated acquisition in seconds
    """
    firing_rates=3*np.ones((K))
    refr=10
    
    N=np.int64(duration*samplerate)
    
    # events/sec * sec/timepoint * N
    populations=np.ceil(firing_rates/samplerate*N).astype('int')
    times=np.zeros(0)
    labels=np.zeros(0)
    for k in range(1,K+1):
        refr_timepoints=refr/1000*samplerate
        
        times0=np.random.rand(populations[k-1])*(N-1)+1

        ## make an interesting autocorrelogram shape
        times0=np.hstack((times0,times0+rand_distr2(refr_timepoints,refr_timepoints*20,times0.size)))
        times0=times0[np.random.choice(times0.size,int(times0.size/2))]
        times0=times0[np.where((0<=times0)&(times0<N))]
        
        times0=enforce_refractory_period(times0,refr_timepoints)
        times=np.hstack((times,times0))
        labels=np.hstack((labels,k*np.ones(times0.shape)))

    sort_inds=np.argsort(times)
    times=times[sort_inds]
    labels=labels[sort_inds]

    firings=np.zeros((3,times.size),dtype=np.float64)
    firings[1,:]=times
    firings[2,:]=labels
    return writemda64(firings,firings_out)
Ejemplo n.º 9
0
def apply_label_map(*, firings, label_map, firings_out):
    """
    Apply a label map to a given firings, including masking and merging

    Parameters
    ----------
    firings : INPUT
        Path of input firings mda file
    label_map : INPUT
        Path of input label map mda file [base 1, mapping to zero removes from firings]
    firings_out : OUTPUT
        ...
    """
    firings = readmda(firings)
    label_map = readmda(label_map)
    label_map = np.reshape(label_map, (-1, 2))
    label_map = label_map[np.argsort(label_map[:,
                                               0])]  # Assure input is sorted

    #Propagate merge pairs to lowest label number
    for idx, label in enumerate(label_map[:, 1]):
        # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :)
        #label_map[np.isin(label_map[:,0],label),0] = label_map[idx,0] # Input should be sorted
        label_map[np.where(label_map[:, 0] == label)[0],
                  0] = label_map[idx, 0]  # Input should be sorted

    #Apply label map
    for label_pair in range(label_map.shape[0]):
        # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :)
        #firings[2, np.isin(firings[2, :], label_map[label_pair, 1])] = label_map[label_pair,0]
        firings[2, np.where(
            firings[2, :] == label_map[label_pair,
                                       1])[0]] = label_map[label_pair, 0]

    #Mask out all labels mapped to zero
    firings = firings[:, firings[2, :] != 0]

    #Write remapped firings
    return writemda64(firings, firings_out)
Ejemplo n.º 10
0
def extract_subfirings(*, firings, t1='', t2='', channels='', channels_array='', timeseries='', firings_out):
    """
    Extract a firings subset based on times and/or channels.
    If a time subset is extracted, the firings are adjusted to t_new = t_original - t1
    If channel(s) are extracted with a timeseries, only clusters with largest amplitude on the given channel (as determined by the average waveform in the time range) will be extracted
    First developed for use with extract_timeseries in inspecting very large datasets

    Parameters
    ----------
    firings : INPUT
        A path of a firings file from which a subset is extracted
    t1 : INPUT
        Start time for extracted firings
    t2 : INPUT
        End time for extracted firings; use -1 OR no value for end of timeseries
    channels : INPUT
        A string of channels from which clusters with maximal energy (based on template) will be extracted
    channels_array : INPUT
        An array of channels from which clusters with maximal energy (based on template) will be extracted
    timeseries : INPUT
        A path of a timeseries file from which templates will be calculated if a subset of channels is given
    firings_out : OUTPUT
        The extracted subfirings path
        ...
    """
    firings=readmda(firings)

    if channels:
        _channels=np.fromstring(channels,dtype=int,sep=',')
    elif channels_array:
        _channels=channels_array
    else:
        _channels=np.empty(0)

    if t1:
        print('Time extraction...')
        t_valid=(t1<firings[1,:])#Get bool mask in greater than t1
        if t2 and t2>0:
            t_valid = t_valid * (firings[1,:]<t2)
        firings = firings[:,t_valid]
    else:
        print('Using full time chunk')

    if _channels and timeseries:
        print('Channels extraction...')
        #Determine if need to parse from string
        amps = compute_templates_helper(timeseries, firings, clip_size=1) #Get only amplitude, returns zeroes if empty (M X T X K)
        #Get indices of max chan for each cluster
        main_chan=np.zeros(np.max(firings[2,:]))
        for k in range(np.max(firings[2,:])):
            if np.max(amps[:,:,k]):
                main_chan[k]=np.argmax(amps[:,:,k])+1 #base 1 adj
        labels_valid = np.argwhere(np.isin(main_chan,_channels)) +1 #base 1 adj again
        k_valid=np.isin(firings[2,:],labels_valid)
        firings = firings[:,k_valid]
    else:
        print('Using all channels')

    if t1:
        firings[1,:] -= t1 #adjust t1 to 0

    return writemda64(firings,firings_out)
Ejemplo n.º 11
0
def handle_drift_in_segment(*, timeseries, firings, firings_out):
    """
    Handle drift in segment.

    Parameters
    ----------
    timeseries : INPUT
        Path to preprocessed timeseries from which the events are extracted from (MxN)
    firings : INPUT
        Path of input firings mda file
    firings_out : OUTPUT
        Path of output drift-adjusted firings mda file
        ...
    """
    subcluster_size = 500  # Size of subclusters for comparison of merge candidate pairs
    bin_factor = 10  # subcluster_size / bin_factor = numbins for hist
    corr_comp_thresh = 0.95  # Minimum correlation in templates to consider as merge candidate
    clip_size = 50
    n_pca_dim = 10

    ## compute the templates
    templates = compute_templates_helper(timeseries=timeseries,
                                         firings=firings,
                                         clip_size=clip_size)
    templates = np.swapaxes(templates, 0, 1)
    templates = np.swapaxes(
        templates, 2, 0)  #Makes templates of form Clust x Chan x Clipsize
    firings = mlpy.readmda(firings)
    print('templates', templates.shape)

    ## Determine the merge candidate pairs based on correlation
    subflat_templates = np.reshape(
        templates, (templates.shape[0], -1)
    )  #flatten templates from templates from M x N x L (Clust x Chan x Clipsize) to (clust x flat)
    pairwise_idxs = np.array(
        list(
            it.chain.from_iterable(
                it.combinations(range(templates.shape[0]), 2)))
    )  #Generates 1D Array of all poss pairwise comparisons of clusters ([0 1 2] --> [0 1 0 2 1 2])
    pairwise_idxs = pairwise_idxs.reshape(
        -1, 2)  #Reshapes array, from above to readable [[0,1],[0,2],[1,2]]
    pairwise_corrcoef = np.zeros(
        pairwise_idxs.shape[0]
    )  #Empty array for all pairs correlation measurements
    for row in range(
            pairwise_idxs.shape[0]
    ):  #Calculate the correlation coefficient for each pair of flattened templates
        pairwise_corrcoef[row] = np.corrcoef(
            subflat_templates[:, pairwise_idxs[row, 0]],
            subflat_templates[:, pairwise_idxs[row, 1]])[1, 0]
    pairs_for_eval = np.array(
        pairwise_idxs[pairwise_corrcoef >= corr_comp_thresh]
    )  #Threshold the correlation array, and use to index the pairwise comparison array
    pairs_to_merge = np.array([])  #holder variable for merging pairs
    ## Loop through the pairs for comparison

    for pair_to_test in range(
            pairs_for_eval.shape[0]
    ):  # Iterate through pairs that are above correlation comparison threshold

        ## Extract out the times and labels corresponding to the pair
        firings_subset = firings[:,
                                 np.isin(
                                     firings[2, :],
                                     pairs_for_eval[pair_to_test, :] + 1
                                 )]  # Generate subfirings of only events from given pair, correct for base 0 vs. 1 difference
        test_labels = firings_subset[2, :]  # Labels from the pair of clusters
        test_eventtimes = firings_subset[
            1, :]  # Times from the pair of clusters
        sort_indices = np.argsort(
            test_eventtimes
        )  # there's no strict guarantee the firing times will be sorted, so adding a sort step for safety
        test_labels = test_labels[sort_indices]
        test_eventtimes = test_eventtimes[sort_indices]

        ## find the subcluster times and labels
        subcluster_event_indices = find_random_paired_events(
            test_eventtimes, test_labels, subcluster_size)
        subcluster_times = test_eventtimes[subcluster_event_indices]
        subcluster_labels = test_labels[subcluster_event_indices]

        ## Extract the clips for the subcluster
        subcluster_clips = extract_clips_helper(timeseries=timeseries,
                                                times=subcluster_times,
                                                clip_size=clip_size)

        ## Compute the centroids and project the clips onto the direction of the line connecting the two centroids

        # PCA to extract features of clips (number dim = n_pca_dim);
        subcluster_clips = np.reshape(
            subcluster_clips, (subcluster_clips.shape[0],
                               -1))  # Flatten clips for PCA (expects 2d array)
        dimenReduc = PCA(n_components=n_pca_dim, whiten=True)
        clip_features = dimenReduc.fit_transform(subcluster_clips)

        # Use label data to separate clips into two groups, and adjust for base 0 vs base 1 difference
        A_indices = np.isin(subcluster_labels,
                            pairs_for_eval[pair_to_test, 0] + 1)
        B_indices = np.isin(subcluster_labels,
                            pairs_for_eval[pair_to_test, 1] + 1)
        clip_features_A = clip_features[A_indices, :]
        clip_features_B = clip_features[B_indices, :]

        # Calculate centroid
        centroidA = np.mean(clip_features_A, axis=0)
        centroidB = np.mean(clip_features_B, axis=0)

        # Project points onto line
        V = centroidA - centroidB
        V = np.tile(V, (clip_features.shape[0], 1))
        clip_1d_projs = np.einsum('ij,ij->i', clip_features, V)

        #TODO: Test for merge subprocess
        #If the clusters are to be merged, add to the cluster to merge list
        if test_for_merge(clip_1d_projs, A_indices, B_indices):
            pairs_to_merge = np.append(pairs_to_merge,
                                       pairs_for_eval[pair_to_test, :] +
                                       1)  #Base 1 correction

    pairs_to_merge = np.reshape(pairs_to_merge, (-1, 2))  #easier to read
    pairs_to_merge = pairs_to_merge[np.argsort(
        pairs_to_merge[:, 0])]  #Assure that input is sorted

    #Propagate merge pairs to lowest label number
    for idx, label in enumerate(pairs_to_merge[:, 1]):
        pairs_to_merge[np.isin(pairs_to_merge[:, 0], label),
                       0] = pairs_to_merge[idx, 0]  #Input should be sorted

    #Merge firing labels
    for merge_pair in range(pairs_to_merge.shape[0]):
        firings[2, np.isin(firings[2, :], pairs_to_merge[
            merge_pair, 1])] = pairs_to_merge[merge_pair,
                                              0]  #Already base 1 corrected

    #Write merged firings
    mlpy.writemda64(firings, firings_out)
Ejemplo n.º 12
0
def anneal_segments(*,
                    timeseries_list,
                    firings_list,
                    firings_out,
                    dmatrix_out='',
                    k1_dmatrix_out='',
                    k2_dmatrix_out='',
                    dmatrix_templates_out='',
                    time_offsets):
    """
    Combine a list of firings files to form a single firings file
    Link firings labels to first firings.mda, all other firings labels are incremented

    Parameters
    ----------
    timeseries_list : INPUT
        A list of paths of timeseries mda files to be used for drift adjustment / time offsets
    firings_list : INPUT
        A list of paths of firings mda files to be concatenated/drift adjusted
    firings_out : OUTPUT
        The output firings
    dmatrix_out : OUTPUT
        The distance matrix used
    k1_dmatrix_out : OUTPUT
        The mean distances of k1 templates to k1 spikes
    k2_dmatrix_out : OUTPUT
        The mean distances of k2 templates to k2 spikes
    dmatrix_templates_out : OUTPUT
        The templates used to compute the distance matrix
        ...
        

    time_offsets : string
        An array of time offsets for each firings file. Expect one offset for each firings file.
        ...
    """
    print('timeseries_list' + str(timeseries_list))
    print('firings_list' + str(firings_list))
    print('firings_out' + str(firings_out))
    print('time_offsets ' + str(time_offsets))
    if time_offsets:
        time_offsets = np.fromstring(time_offsets, dtype=np.float_, sep=',')
        #print('time_offsets ' + str(time_offsets))
    else:
        print(
            'No time offsets provided - assuming zero time gap/continuously recorded data'
        )
        time_offsets = np.zeros(len(timeseries_list))
        # Get toffsets based on length of preceeding timeseries - first one left as zero
        for timeseries in range(len(timeseries_list) - 1):
            X = DiskReadMda(timeseries_list[timeseries])
            time_offsets[timeseries + 1] = time_offsets[timeseries] + X.N2()

    concatenated_firings = concat_and_increment(firings_list, time_offsets)

    (dmatrix, k1_dmatrix, k2_dmatrix, templates,
     Kmaxes) = get_dmatrix_templates(timeseries_list, firings_list)
    dmatrix[np.isnan(dmatrix)] = -1
    # set nans to -1 to avoid runtime error
    k1_dmatrix[
        dmatrix <
        0] = np.nan  # replace all negative dist numbers (no comparison) with NaN
    k2_dmatrix[
        dmatrix <
        0] = np.nan  # replace all negative dist numbers (no comparison) with NaN
    dmatrix[
        dmatrix <
        0] = np.nan  # then replace all negative dist numbers (no comparison) with NaN

    #TODO: Improve join function
    pairs_to_merge = get_join_matrix(dmatrix, k1_dmatrix, templates,
                                     Kmaxes)  # Returns with base 1 adjustment

    pairs_to_merge = np.reshape(pairs_to_merge, (-1, 2))
    pairs_to_merge = pairs_to_merge[~np.isnan(pairs_to_merge).any(
        axis=1)]  # Eliminate all rows with NaN
    pairs_to_merge = pairs_to_merge[np.argsort(
        pairs_to_merge[:, 0])]  # Assure that input is sorted

    #Propagate merge pairs to lowest label number
    for idx, label in enumerate(pairs_to_merge[:, 1]):
        pairs_to_merge[np.isin(pairs_to_merge[:, 0], label),
                       0] = pairs_to_merge[idx, 0]  # Input should be sorted

    #Merge firing labels
    for merge_pair in range(pairs_to_merge.shape[0]):
        concatenated_firings[
            2,
            np.isin(concatenated_firings[2, :], pairs_to_merge[
                merge_pair,
                1])] = pairs_to_merge[merge_pair,
                                      0]  # Already base 1 corrected

    writemda64(dmatrix, dmatrix_out)
    writemda32(templates, dmatrix_templates_out)
    writemda64(k1_dmatrix, k1_dmatrix_out)
    writemda64(k2_dmatrix, k2_dmatrix_out)

    #Write
    return writemda64(concatenated_firings, firings_out)
Ejemplo n.º 13
0
def reptrack(*,
             timeseries,
             firings_out,
             detect_threshold=3,
             detect_sign=0,
             section_size=60 * 30000,
             detect_interval=20,
             detect_channel=0):
    """
    Find representative spikes for the single "best"unit that stretches all the way through the dataset

    Parameters
    ----------
    timeseries : INPUT
        The preprocessed timeseries array
    firings_out : OUTPUT
        The firings file (for the single unit)

    detect_channel : int
        Channel for detection (1-based indexing) or 0 to detect on max over all channels
    detect_threshold : float
        Threshold for detection
    detect_sign : int
        Sign for the detection -1, 0, or 1
    section_size : int
        Size of each section (in timepoints)
    """

    X = DiskReadMda(timeseries)
    M = X.N1()
    N = X.N2()
    num_sections = int(np.floor(N / section_size))
    chunk_infos = []

    S = 3  #number of scores to track

    clips_prev = np.zeros(0)
    for ii in range(0, num_sections):
        # Read the current chunk
        chunk0 = X.readChunk(i1=0, i2=ii * section_size, N1=M, N2=section_size)

        # Detect the events during this chunk and offset the times
        if (detect_channel > 0):
            signal_for_detect = chunk0[detect_channel - 1, :]
        else:
            if detect_sign == 0:
                signal_for_detect = np.max(np.abs(chunk0), axis=0)
            elif detect_sign > 0:
                signal_for_detect = np.max(chunk0, axis=0)
            else:
                signal_for_detect = np.min(chunk0, axis=0)
        times0 = detect(signal_for_detect, detect_threshold, detect_sign,
                        detect_interval)
        times0 = times0 + ii * section_size
        L0 = len(times0)

        # Extract the clips for this chunk
        clips0 = extract_clips_helper(timeseries=timeseries,
                                      times=times0,
                                      clip_size=50)
        if ii == 0:
            # If this is the first chunk, initialize things
            scores0 = np.zeros((S, L0))
            connections0 = np.ones(L0) * -1
        else:
            # Some results from the previous chunk
            times_prev = chunk_infos[ii - 1]['times']
            scores_prev = chunk_infos[ii - 1]['scores']

            # Compute PCA features on the clips from this and the previous chunk combined
            clips_combined = np.concatenate((clips_prev, clips0), axis=2)
            features_combined = compute_clips_features(clips_combined,
                                                       num_features=10)
            features0 = features_combined[:, len(times_prev):]
            features_prev = features_combined[:, 0:len(times_prev)]

            # Compute the nearest neighbors (candidates for connections)
            nbrs = NearestNeighbors(n_neighbors=50, algorithm='ball_tree')
            nbrs.fit(features_prev.transpose())
            nearest_inds = nbrs.kneighbors(features0.transpose(),
                                           return_distance=False)

            # For each, find the best connection among the candidates
            scores0 = np.zeros((S, L0))
            connections0 = np.zeros(L0)
            maxmins_prev = scores_prev[0, :]
            averages_prev = scores_prev[1, :]
            for jj in range(len(times0)):
                tmp = features0[:, jj]
                nearest_inds_jj = nearest_inds[jj, :].tolist()
                dists = np.linalg.norm(features_prev[:, nearest_inds_jj] -
                                       tmp.reshape((len(tmp), 1)),
                                       axis=0)
                normalized_distances = dists / np.linalg.norm(tmp)
                maxmins = np.maximum(normalized_distances,
                                     maxmins_prev[nearest_inds_jj])
                averages = (normalized_distances +
                            averages_prev[nearest_inds_jj] *
                            (ii + 1)) / (ii + 2)
                overall_scores = maxmins + averages * 0.1
                ind0 = np.argmin(overall_scores)
                scores0[0, jj] = maxmins[ind0]
                scores0[1, jj] = averages[ind0]
                scores0[2, jj] = overall_scores[ind0]
                connections0[jj] = nearest_inds_jj[ind0]

        clips_prev = clips0

        # Store the results for this chunk
        info0 = {
            'times': times0,
            'connections': connections0,
            'scores': scores0
        }
        chunk_infos.append(info0)

    rep_times = np.zeros(len(chunk_infos))
    last_chunk_info = chunk_infos[len(chunk_infos) - 1]

    last_times = last_chunk_info['times']
    last_overall_scores = last_chunk_info['scores'][S - 1, :]
    last_to_first_connections = np.zeros(len(last_times))
    for kk in range(0, len(last_times)):
        ind0 = kk
        for ii in range(len(chunk_infos) - 2, -1, -1):
            ind0 = int(chunk_infos[ii + 1]['connections'][ind0])
        last_to_first_connections[kk] = ind0

    print('Unique:')
    unique1 = np.unique(last_to_first_connections)
    print(len(unique1))
    print(len(chunk_infos[0]['times']))

    rep_times = []
    rep_labels = []
    for aa in range(0, len(unique1)):
        bb = np.where(last_to_first_connections == unique1[aa])[0]
        cc = np.argmax(last_overall_scores[bb])
        ind0 = bb[cc]
        rep_times.append(last_chunk_info['times'][ind0])
        rep_labels.append(aa)
        for ii in range(len(chunk_infos) - 1, 0, -1):
            ind0 = int(chunk_infos[ii]['connections'][ind0])
            rep_times.append(chunk_infos[ii - 1]['times'][ind0])
            rep_labels.append(aa)

    #ind0=np.argmin(last_chunk_info['scores'][S-1,:]) #Overall score is in row S-1
    #rep_times[len(chunk_infos)-1]=last_chunk_info['times'][ind0]
    #for ii in range(len(chunk_infos)-1,0,-1):
    #    ind0=int(chunk_infos[ii]['connections'][ind0])
    #    rep_times[ii-1]=chunk_infos[ii-1]['times'][ind0]

    firings = np.zeros((3, len(rep_times)))
    for jj in range(len(rep_times)):
        firings[1, jj] = rep_times[jj]
        firings[2, jj] = rep_labels[jj]
    return writemda64(firings, firings_out)