def compute_templates_helper(*,timeseries,firings,clip_size=100):
    X=mdaio.DiskReadMda(timeseries)
    M,N = X.N1(),X.N2()
    F=mdaio.readmda(firings)
    L=F.shape[1]
    L=L
    T=clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1);
    times=F[1,:].ravel()
    labels=F[2,:].ravel().astype(int)
    K=np.max(labels)

    sums=np.zeros((M,T,K),dtype='float64')
    counts=np.zeros(K)

    for k in range(1,K+1):
        inds_k=np.where(labels==k)[0]
        #TODO: subsample
        for ind_k in inds_k:
            t0=int(times[ind_k])
            if (clip_size<=t0) and (t0<N-clip_size):
                clip0=X.readChunk(i1=0,N1=M,i2=t0-Tmid,N2=T)
                sums[:,:,k-1]+=clip0
                counts[k-1]+=1
    templates=np.zeros((M,T,K))
    for k in range(K):
        templates[:,:,k]=sums[:,:,k]/counts[k]
    return templates
Beispiel #2
0
def prepare_timeseries_hdf5(timeseries_fname,timeseries_hdf5_fname,*,chunk_size,padding):
    chunk_size_with_padding=chunk_size+2*padding
    with h5py.File(timeseries_hdf5_fname,"w") as f:
        X=mdaio.DiskReadMda(timeseries_fname)
        M=X.N1() # Number of channels
        N=X.N2() # Number of timepoints
        num_chunks=int(np.ceil(N/chunk_size))
        f.create_dataset('chunk_size',data=[chunk_size])
        f.create_dataset('num_chunks',data=[num_chunks])
        f.create_dataset('padding',data=[padding])
        f.create_dataset('num_channels',data=[M])
        f.create_dataset('num_timepoints',data=[N])
        for j in range(num_chunks):
            padded_chunk=np.zeros((X.N1(),chunk_size_with_padding),dtype=X.dt())    
            t1=int(j*chunk_size) # first timepoint of the chunk
            t2=int(np.minimum(X.N2(),(t1+chunk_size))) # last timepoint of chunk (+1)
            s1=int(np.maximum(0,t1-padding)) # first timepoint including the padding
            s2=int(np.minimum(X.N2(),t2+padding)) # last timepoint (+1) including the padding
            
            # determine aa so that t1-s1+aa = padding
            # so, aa = padding-(t1-s1)
            aa = padding-(t1-s1)
            padded_chunk[:,aa:aa+s2-s1]=X.readChunk(i1=0,N1=X.N1(),i2=s1,N2=s2-s1) # Read the padded chunk

            for m in range(M):
                f.create_dataset('part-{}-{}'.format(m,j),data=padded_chunk[m,:].ravel())
Beispiel #3
0
def bandpass_filter(*,
                    timeseries,
                    timeseries_out,
                    samplerate,
                    freq_min,
                    freq_max,
                    freq_wid=1000,
                    padding=3000,
                    chunk_size=3000 * 10,
                    num_processes=os.cpu_count()):
    """
    Apply a bandpass filter to a multi-channel timeseries

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)
        
    timeseries_out : OUTPUT
        Filtered output (MxN array)
        
    samplerate : float
        The sampling rate in Hz
    freq_min : float
        The lower endpoint of the frequency band (Hz)
    freq_max : float
        The upper endpoint of the frequency band (Hz)
    freq_wid : float
        The optional width of the roll-off (Hz)
    """
    X = mdaio.DiskReadMda(timeseries)
    M = X.N1()  # Number of channels
    N = X.N2()  # Number of timepoints

    num_chunks = int(np.ceil(N / chunk_size))
    print('Chunk size: {}, Padding: {}, Num chunks: {}, Num processes: {}'.
          format(chunk_size, padding, num_chunks, num_processes))

    opts = {
        "timeseries": timeseries,
        "timeseries_out": timeseries_out,
        "samplerate": samplerate,
        "freq_min": freq_min,
        "freq_max": freq_max,
        "freq_wid": freq_wid,
        "chunk_size": chunk_size,
        "padding": padding,
        "num_processes": num_processes,
        "num_chunks": num_chunks
    }
    global g_shared_data
    g_shared_data = SharedChunkInfo(num_chunks)
    global g_opts
    g_opts = opts
    mdaio.writemda32(np.zeros([M, 0]), timeseries_out)

    pool = multiprocessing.Pool(processes=num_processes)
    pool.map(filter_chunk, range(num_chunks), chunksize=1)
    return True
def view_timeseries(timeseries,
                    trange=None,
                    channels=None,
                    samplerate=30000,
                    title='',
                    fig_size=[18, 6]):
    #timeseries=mls.loadMdaFile(timeseries)
    if type(timeseries) == str:
        X = mdaio.DiskReadMda(timeseries)
        M = X.N1()
        N = X.N2()
        if not trange:
            trange = [0, np.minimum(1000, N)]
        X = X.readChunk(i1=0,
                        N1=X.N1(),
                        i2=int(trange[0]),
                        N2=int(trange[1] - trange[0]))
    else:
        M = timeseries.shape[0]
        N = timeseries.shape[1]
        X = timeseries[channels][:, int(trange[0]):int(trange[1])]

    set_fig_size(fig_size[0], fig_size[1])

    channel_colors = _get_channel_colors(M)
    if not channels:
        channels = np.arange(M).tolist()

    spacing_between_channels = np.max(np.abs(X.ravel()))

    y_offset = 0
    for m in range(len(channels)):
        A = X[m, :]
        plt.plot(np.arange(trange[0], trange[1]),
                 A + y_offset,
                 color=channel_colors[channels[m]])
        y_offset -= spacing_between_channels

    ax = plt.gca()
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

    if title:
        plt.title(title, fontsize=title_fontsize)

    plt.show()
    return ax
Beispiel #5
0
def compute_AAt_matrix_for_chunk(num):
    opts = g_opts
    in_fname = opts['timeseries']  # The entire (large) input file
    out_fname = opts['timeseries_out']  # The entire (large) output file
    chunk_size = opts['chunk_size']

    X = mdaio.DiskReadMda(in_fname)

    t1 = int(num * opts['chunk_size'])  # first timepoint of the chunk
    t2 = int(np.minimum(X.N2(),
                        (t1 + chunk_size)))  # last timepoint of chunk (+1)

    chunk = X.readChunk(i1=0, N1=X.N1(), i2=t1, N2=t2 - t1)  # Read the chunk

    ret = chunk @ np.transpose(chunk)

    return ret
Beispiel #6
0
def whiten_chunk(num, W):
    #print('Whitening {}'.format(num))
    opts = g_opts
    #print('Whitening chunk {} of {}'.format(num,opts['num_chunks']))
    in_fname = opts['timeseries']  # The entire (large) input file
    out_fname = opts['timeseries_out']  # The entire (large) output file
    chunk_size = opts['chunk_size']

    X = mdaio.DiskReadMda(in_fname)

    t1 = int(num * opts['chunk_size'])  # first timepoint of the chunk
    t2 = int(np.minimum(X.N2(),
                        (t1 + chunk_size)))  # last timepoint of chunk (+1)

    chunk = X.readChunk(i1=0, N1=X.N1(), i2=t1, N2=t2 - t1)  # Read the chunk

    chunk = W @ chunk

    ###########################################################################################
    # Now we wait until we are ready to append to the output file
    # Note that we need to append in order, thus the shared_data object
    ###########################################################################################
    g_shared_data.reportChunkCompleted(
        num)  # Report that we have completed this chunk
    while True:
        if num == g_shared_data.lastAppendedChunk() + 1:
            break
        time.sleep(0.005)  # so we don't saturate the CPU unnecessarily

    # Append the filtered chunk (excluding the padding) to the output file
    mdaio.appendmda(chunk, out_fname)

    # Report that we have appended so the next chunk can proceed
    g_shared_data.reportChunkAppended(num)

    # Print status if it has been long enough
    if g_shared_data.elapsedTime() > 4:
        g_shared_data.printStatus()
        g_shared_data.resetTimer()
Beispiel #7
0
 def run(self, mdafile_path_or_diskreadmda, func):
     if (type(mdafile_path_or_diskreadmda)==str):
         X=mdaio.DiskReadMda(mdafile_path_or_diskreadmda)
     else:
         X=mdafile_path_or_diskreadmda
     M,N = X.N1(),X.N2()
     cs=max([self._chunk_size,int(self._chunk_size_mb*1e6/(M*4)),M])        
     if self._t1<0:
         self._t1=0
     if self._t2<0:
         self._t2=N-1
     t=self._t1
     while t <= self._t2:
         t1=t
         t2=min(self._t2,t+cs-1)
         s1=max(0,t1-self._overlap_size)
         s2=min(N-1,t2+self._overlap_size)
         
         timer=time.time()
         chunk=X.readChunk(i1=0, N1=M, i2=s1, N2=s2-s1+1)
         self._elapsed_reading+=time.time()-timer
         
         info=TimeseriesChunkInfo()
         info.t1=t1
         info.t2=t2
         info.t1a=t1-s1
         info.t2a=t2-s1
         info.size=t2-t1+1
         
         timer=time.time()
         if not func(chunk, info):
             return False
         self._elapsed_running+=time.time()-timer                
             
         t=t+cs
     if self._verbose:
         print('Elapsed for TimeseriesChunkReader: %g sec reading, %g sec running' % (self._elapsed_reading,self._elapsed_running))
     return True
Beispiel #8
0
def filter_chunk(num):
    #print('Filtering {}'.format(num))
    opts = g_opts
    #print('Filtering chunk {} of {}'.format(num,opts['num_chunks']))
    in_fname = opts['timeseries']  # The entire (large) input file
    out_fname = opts['timeseries_out']  # The entire (large) output file
    samplerate = opts['samplerate']
    freq_min = opts['freq_min']
    freq_max = opts['freq_max']
    freq_wid = opts['freq_wid']
    chunk_size = opts['chunk_size']
    padding = opts['padding']

    X = mdaio.DiskReadMda(in_fname)

    chunk_size_with_padding = chunk_size + 2 * padding
    padded_chunk = np.zeros((X.N1(), chunk_size_with_padding), dtype='float32')

    t1 = int(num * opts['chunk_size'])  # first timepoint of the chunk
    t2 = int(np.minimum(X.N2(),
                        (t1 + chunk_size)))  # last timepoint of chunk (+1)
    s1 = int(np.maximum(0,
                        t1 - padding))  # first timepoint including the padding
    s2 = int(np.minimum(X.N2(), t2 +
                        padding))  # last timepoint (+1) including the padding

    # determine aa so that t1-s1+aa = padding
    # so, aa = padding-(t1-s1)
    aa = padding - (t1 - s1)
    padded_chunk[:, aa:aa + s2 - s1] = X.readChunk(i1=0,
                                                   N1=X.N1(),
                                                   i2=s1,
                                                   N2=s2 -
                                                   s1)  # Read the padded chunk

    # Do the actual filtering with a DFT with real input
    padded_chunk = np.fft.rfft(padded_chunk)
    # Subtract off the mean of each channel unless we are doing only a low-pass filter
    if freq_min != 0:
        for m in range(padded_chunk.shape[0]):
            padded_chunk[m, :] = padded_chunk[m, :] - np.mean(
                padded_chunk[m, :])
    kernel = create_filter_kernel(chunk_size_with_padding, samplerate,
                                  freq_min, freq_max, freq_wid)
    kernel = kernel[
        0:padded_chunk.shape[1]]  # because this is the DFT of real data
    padded_chunk = padded_chunk * np.tile(kernel, (padded_chunk.shape[0], 1))
    padded_chunk = np.fft.irfft(padded_chunk)

    ###########################################################################################
    # Now we wait until we are ready to append to the output file
    # Note that we need to append in order, thus the shared_data object
    ###########################################################################################
    g_shared_data.reportChunkCompleted(
        num)  # Report that we have completed this chunk
    while True:  # Alex: maybe there should be a timeout here in case ...
        if num == g_shared_data.lastAppendedChunk() + 1:
            break
        time.sleep(0.005)  # so we don't saturate the CPU unnecessarily

    # Append the filtered chunk (excluding the padding) to the output file
    mdaio.appendmda(padded_chunk[:, padding:padding + (t2 - t1)], out_fname)

    # Report that we have appended so the next chunk can proceed
    g_shared_data.reportChunkAppended(num)

    # Print status if it has been long enough
    if g_shared_data.elapsedTime() > 4:
        g_shared_data.printStatus()
        g_shared_data.resetTimer()
Beispiel #9
0
def convert_clips_fet_spk(*,
                          timeseries,
                          firings,
                          waveforms_out,
                          ntro_nchannels,
                          clip_size=32,
                          nPC=4,
                          DEBUG=True):
    """
    Compute templates (average waveforms) for clusters defined by the labeled events in firings. One .spk.n file per n-trode.

    Parameters
    ----------
    timeseries : INPUT
        Path of timeseries mda file (MxN) from which to draw the event clips (snippets) for computing the templates. M is number of channels, N is number of timepoints.
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels.
    params : INPUT
        params.json file. Needed to see number of channels per tetrode.
        
    ntro_nchannels : INPUT
        Comma-seperated determining the number of channels should be taken for each ntrode.
        
    waveforms_out : OUTPUT
        Base Path (MxTxK). T=clip_size, K=maximum cluster label. Note that empty clusters will correspond to a template of all zeros. 
        
    clip_size : int
        (Optional) clip size, aka snippet size, number of timepoints in a single template
    
    nPC : int
        (Optional) Number of principal components *per channel* for .fet files.
        
    DEBUG : bool
        (Optional) Verbose output for debugging purposes.
    """
    X = mdaio.DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    T = clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1)
    whch = F[0, :].ravel()[:]
    times = F[1, :].ravel()[:]
    labels = F[2, :].ravel().astype(int)[:]
    K = np.max(labels)

    tetmap = list()
    i = 0
    for nch in ntro_nchannels.split(','):
        tp = i + 1 + np.arange(0, int(nch))
        tetmap.append(tp)
        i = tp[-1]
    which_tet = [np.where(w == tetmap)[0][0] + 1 for w in whch]

    print("Starting:")
    for tro in np.arange(1, 12):
        chans = tetmap[tro - 1]
        inds_k = np.where(which_tet == tro)[0]

        if DEBUG:
            print("Tetrode: " + str(tro))
            print("Chans: " + str(chans))
            print("Create Waveforms Array: " + str(len(chans)) + "," + str(T) +
                  "," + str(len(inds_k)))

        waveforms = np.zeros((len(chans), T, len(inds_k)), dtype='int16')
        for i, ind_k in enumerate(inds_k):  # for each spike
            t0 = int(times[ind_k])
            if (clip_size <= t0) and (t0 < N - clip_size):
                clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T)
                clip0 = clip0[chans, :] * 100
                clip_int = clip0.astype(dtype='int16')
                waveforms[:, :, i] = clip_int

        fname = waveforms_out + '.spk.' + str(tro)
        if DEBUG:
            print("Writing Waveforms to File: " + fname)
        waveforms.tofile(fname, format='')

        if DEBUG:
            print("Calculating Feature Array")
        fet = np.zeros((np.shape(waveforms)[2], (len(chans) * nPC) + 1))
        for c in np.arange(len(chans)):
            pca = decomposition.PCA(n_components=nPC)
            x_std = StandardScaler().fit_transform(
                np.transpose(waveforms[c, :, :]).astype(dtype='float64'))
            fpos = (c * nPC)
            fet[:, fpos:fpos + 4] = pca.fit_transform(x_std)

        fet[:, (len(chans) * nPC)] = times[inds_k]
        fet *= 1000
        fet.astype(dtype='int64')
        fname = waveforms_out + '.fet.' + str(tro)
        if DEBUG:
            print("Writing Features to File: " + fname)
        np.savetxt(fname, fet, fmt='%d')

    return True
def compute_cluster_metrics(*,
                            timeseries='',
                            firings,
                            metrics_out,
                            clip_size=100,
                            samplerate=0,
                            refrac_msec=1):
    """
    Compute cluster metrics for a spike sorting output

    Parameters
    ----------
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer cluster labels.
    timeseries : INPUT
        Optional path of timeseries mda file (MxN) which could be raw or preprocessed   
        
    metrics_out : OUTPUT
        Path of output json file containing the metrics.
        
    clip_size : int
        (Optional) clip size, aka snippet size (used when computing the templates, or average waveforms)
    samplerate : float
        Optional sample rate in Hz

    refrac_msec : float
        (Optional) Define interval (in ms) as refractory period. If 0 then don't compute the refractory period metric.
    """
    print('Reading firings...')
    F = mdaio.readmda(firings)

    print('Initializing...')
    R = F.shape[0]
    L = F.shape[1]
    assert (R >= 3)
    times = F[1, :]
    labels = F[2, :].astype(np.int)
    K = np.max(labels)
    N = 0
    if timeseries:
        X = mdaio.DiskReadMda(timeseries)
        N = X.N2()

    if (samplerate > 0) and (N > 0):
        duration_sec = N / samplerate
    else:
        duration_sec = 0

    clusters = []
    for k in range(1, K + 1):
        inds_k = np.where(labels == k)[0]
        metrics_k = {"num_events": len(inds_k)}
        if duration_sec:
            metrics_k['firing_rate'] = len(inds_k) / duration_sec
        cluster = {"label": k, "metrics": metrics_k}
        clusters.append(cluster)

    if timeseries:
        print('Computing templates...')
        templates = compute_templates_helper(timeseries=timeseries,
                                             firings=firings,
                                             clip_size=clip_size)
        for k in range(1, K + 1):
            template_k = templates[:, :, k - 1]
            # subtract mean on each channel (todo: vectorize this operation)
            # Alex: like this?
            # template_k[m,:] -= np.mean(template_k[m,:])
            for m in range(templates.shape[0]):
                template_k[m, :] = template_k[m, :] - np.mean(template_k[m, :])
            peak_amplitude = np.max(np.abs(template_k))
            clusters[k - 1]['metrics']['peak_amplitude'] = peak_amplitude
        ## todo: subtract template means, compute peak amplitudes

    if refrac_msec > 0:
        print('Computing Refractory Period Violations')
        msec_string = '{0:.5f}'.format(refrac_msec).rstrip('0').rstrip('.')
        rr_string = 'refractory_violation_{}msec'.format(msec_string)
        max_dt_msec = 50
        bin_size_msec = refrac_msec
        auto_cors = compute_cross_correlograms_helper(
            firings=firings,
            mode='autocorrelograms',
            samplerate=samplerate,
            max_dt_msec=max_dt_msec,
            bin_size_msec=bin_size_msec)
        mid_ind = np.floor(max_dt_msec / bin_size_msec).astype(int)
        mid_inds = np.arange(mid_ind - 2, mid_ind + 2)
        for k0, auto_cor_obj in enumerate(auto_cors['correlograms']):
            auto = auto_cor_obj['bin_counts']
            k = auto_cor_obj['k']
            peak = np.percentile(auto, 75)
            mid = np.mean(auto[mid_inds])
            rr = safe_div(mid, peak)
            clusters[k - 1]['metrics'][rr_string] = rr

    ret = {"clusters": clusters}

    print('Writing output...')
    str = json.dumps(ret, indent=4)
    with open(metrics_out, 'w') as out:
        out.write(str)
    print('Done.')
    return True
Beispiel #11
0
def whiten(*,
           timeseries,
           timeseries_out,
           chunk_size=30000 * 10,
           num_processes=os.cpu_count()):
    """
    Whiten a multi-channel timeseries

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)
        
    timeseries_out : OUTPUT
        Whitened output (MxN array)

    """
    X = mdaio.DiskReadMda(timeseries)
    M = X.N1()  # Number of channels
    N = X.N2()  # Number of timepoints

    num_chunks_for_computing_cov_matrix = 10

    num_chunks = int(np.ceil(N / chunk_size))
    print('Chunk size: {}, Num chunks: {}, Num processes: {}'.format(
        chunk_size, num_chunks, num_processes))

    opts = {
        "timeseries": timeseries,
        "timeseries_out": timeseries_out,
        "chunk_size": chunk_size,
        "num_processes": num_processes,
        "num_chunks": num_chunks
    }
    global g_opts
    g_opts = opts

    pool = multiprocessing.Pool(processes=num_processes)
    step = int(
        np.maximum(1,
                   np.floor(num_chunks / num_chunks_for_computing_cov_matrix)))
    AAt_matrices = pool.map(compute_AAt_matrix_for_chunk,
                            range(0, num_chunks, step),
                            chunksize=1)

    AAt = np.zeros((M, M), dtype='float64')

    for M0 in AAt_matrices:
        AAt += M0 / (
            len(AAt_matrices) * chunk_size
        )  ##important: need to fix the denominator here to account for possible smaller chunk

    U, S, Ut = np.linalg.svd(AAt, full_matrices=True)

    W = (U @ np.diag(1 / np.sqrt(S))) @ Ut
    #print ('Whitening matrix:')
    #print (W)

    global g_shared_data
    g_shared_data = SharedChunkInfo(num_chunks)
    mdaio.writemda32(np.zeros([M, 0]), timeseries_out)

    pool = multiprocessing.Pool(processes=num_processes)
    pool.starmap(whiten_chunk, [(num, W) for num in range(0, num_chunks)],
                 chunksize=1)

    return True
Beispiel #12
0
def sort(*,
         timeseries,
         geom='',
         firings_out,
         adjacency_radius,
         detect_sign,
         detect_interval=10,
         detect_threshold=3,
         clip_size=50,
         num_workers=multiprocessing.cpu_count()):
    """
    MountainSort spike sorting (version 4)

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)
    geom : INPUT
        Optional geometry file (.csv format)
        
    firings_out : OUTPUT
        Firings array channels/times/labels (3xL, L = num. events)
        
    adjacency_radius : float
        Radius of local sorting neighborhood, corresponding to the geometry file (same units). 0 means each channel is sorted independently. -1 means all channels are included in every neighborhood.
    detect_sign : int
        Use 1, -1, or 0 to detect positive peaks, negative peaks, or both, respectively
    detect_threshold : float
        Threshold for event detection, corresponding to the input file. So if the input file is normalized to have noise standard deviation 1 (e.g., whitened), then this is in units of std. deviations away from the mean.
    detect_interval : int
        The minimum number of timepoints between adjacent spikes detected in the same channel neighborhood.
    clip_size : int
        Size of extracted clips or snippets, used throughout
    num_workers : int
        Number of simultaneous workers (or processes). The default is multiprocessing.cpu_count().
    """

    tempdir = os.environ.get('ML_PROCESSOR_TEMPDIR')
    if not tempdir:
        print(
            'Warning: environment variable ML_PROCESSOR_TEMPDIR not set. Using current directory.'
        )
        tempdir = '.'
    print('Using tempdir={}'.format(tempdir))

    os.environ['OMP_NUM_THREADS'] = '1'

    # Read the header of the timeseries input to get the num. channels and num. timepoints
    X = mdaio.DiskReadMda(timeseries)
    M = X.N1()  # Number of channels
    N = X.N2()  # Number of timepoints

    # Read the geometry file
    if geom:
        Geom = np.genfromtxt(geom, delimiter=',')
    else:
        Geom = np.zeros((M, 2))

    if Geom.shape[0] != M:
        raise Exception(
            'Incompatible dimensions between geom and timeseries: {} != {}'.
            format(Geom.shape[1], M))

    MS4 = ms4alg.MountainSort4()
    MS4.setGeom(Geom)
    MS4.setSortingOpts(clip_size=clip_size,
                       adjacency_radius=adjacency_radius,
                       detect_sign=detect_sign,
                       detect_interval=detect_interval,
                       detect_threshold=detect_threshold)
    MS4.setNumWorkers(num_workers)
    MS4.setTimeseriesPath(timeseries)
    MS4.setFiringsOutPath(firings_out)
    MS4.setTemporaryDirectory(tempdir)
    MS4.sort()
    return True