예제 #1
0
def get_epoch_offsets(dataset_dir, opts=None):
    '''
    Parameters
    ----------
    dataset_dir : str
    opts : None or dict, optional

    Returns
    -------
    sample_offsets : ???
    total_samples : ???

    '''
    if opts is None:
        opts = {}
    if 'mda_list' in opts:
        # initialize with 0 (first start time)
        lengths = [0]

        for idx in range(len(opts['mda_list'])):
            ep_path = opts['mda_list'][idx]
            ep_mda = mdaio.DiskReadMda(ep_path)
            # get length of the mda (N dimension)
            samplength = ep_mda.N2()
            # add to prior sum and append
            lengths.append(samplength + lengths[(idx)])

    else:

        prv_list = os.path.join(dataset_dir, 'raw.mda.prv')

        with open(prv_list, 'r') as f:
            ep_files = json.load(f)

        # initialize with 0 (first start time)
        lengths = [0]

        for idx in range(len(ep_files['files'])):
            ep_path = ep_files['files'][idx]['prv']['original_path']
            ep_mda = mdaio.DiskReadMda(ep_path)
            # get length of the mda (N dimension)
            samplength = ep_mda.N2()
            # add to prior sum and append
            lengths.append(samplength + lengths[(idx)])

    # first entries (incl 0) are starttimes; last is total time
    total_samples = lengths[-1]
    sample_offsets = lengths[0:-1]

    return sample_offsets, total_samples
예제 #2
0
def compute_templates_helper(*, timeseries, firings, clip_size=100):
    X = mdaio.DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    T = clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1)
    times = F[1, :].ravel()
    labels = F[2, :].ravel().astype(int)
    K = np.max(labels)

    sums = np.zeros((M, T, K), dtype='float64')
    counts = np.zeros(K)

    for k in range(1, K + 1):
        inds_k = np.where(labels == k)[0]
        #TODO: subsample
        for ind_k in inds_k:
            t0 = int(times[ind_k])
            if (clip_size <= t0) and (t0 < N - clip_size):
                clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T)
                sums[:, :, k - 1] += clip0
                counts[k - 1] += 1
    templates = np.zeros((M, T, K))
    for k in range(K):
        templates[:, :, k] = sums[:, :, k] / counts[k]
    return templates
예제 #3
0
def prepare_timeseries_hdf5(timeseries_fname, timeseries_hdf5_fname, *,
                            chunk_size, padding):
    chunk_size_with_padding = chunk_size + 2 * padding
    with h5py.File(timeseries_hdf5_fname, "w") as f:
        X = mdaio.DiskReadMda(timeseries_fname)
        M = X.N1()  # Number of channels
        N = X.N2()  # Number of timepoints
        num_chunks = int(np.ceil(N / chunk_size))
        f.create_dataset('chunk_size', data=[chunk_size])
        f.create_dataset('num_chunks', data=[num_chunks])
        f.create_dataset('padding', data=[padding])
        f.create_dataset('num_channels', data=[M])
        f.create_dataset('num_timepoints', data=[N])
        for j in range(num_chunks):
            padded_chunk = np.zeros((X.N1(), chunk_size_with_padding),
                                    dtype=X.dt())
            t1 = int(j * chunk_size)  # first timepoint of the chunk
            t2 = int(np.minimum(
                X.N2(), (t1 + chunk_size)))  # last timepoint of chunk (+1)
            s1 = int(np.maximum(
                0, t1 - padding))  # first timepoint including the padding
            s2 = int(np.minimum(
                X.N2(),
                t2 + padding))  # last timepoint (+1) including the padding

            # determine aa so that t1-s1+aa = padding
            # so, aa = padding-(t1-s1)
            aa = padding - (t1 - s1)
            padded_chunk[:, aa:aa + s2 - s1] = X.readChunk(
                i1=0, N1=X.N1(), i2=s1, N2=s2 - s1)  # Read the padded chunk

            for m in range(M):
                f.create_dataset('part-{}-{}'.format(m, j),
                                 data=padded_chunk[m, :].ravel())
예제 #4
0
def bandpass_filter(timeseries,
                    timeseries_out,
                    samplerate,
                    freq_min,
                    freq_max,
                    freq_wid=1000,
                    padding=3000,
                    chunk_size=3000 * 10,
                    num_processes=os.cpu_count()):
    """
    Apply a bandpass filter to a multi-channel timeseries

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)
        
    timeseries_out : OUTPUT
        Filtered output (MxN array)
        
    samplerate : float
        The sampling rate in Hz
    freq_min : float
        The lower endpoint of the frequency band (Hz)
    freq_max : float
        The upper endpoint of the frequency band (Hz)
    freq_wid : float
        The optional width of the roll-off (Hz)
    """
    X = mdaio.DiskReadMda(timeseries)
    M = X.N1()  # Number of channels
    N = X.N2()  # Number of timepoints

    num_chunks = int(np.ceil(N / chunk_size))
    print('Chunk size: {}, Padding: {}, Num chunks: {}, Num processes: {}'.
          format(chunk_size, padding, num_chunks, num_processes))

    opts = {
        "timeseries": timeseries,
        "timeseries_out": timeseries_out,
        "samplerate": samplerate,
        "freq_min": freq_min,
        "freq_max": freq_max,
        "freq_wid": freq_wid,
        "chunk_size": chunk_size,
        "padding": padding,
        "num_processes": num_processes,
        "num_chunks": num_chunks
    }

    global g_shared_data
    g_shared_data = SharedChunkInfo(num_chunks)
    global g_opts
    g_opts = opts
    mdaio.writemda32(np.zeros([M, 0]), timeseries_out)

    pool = multiprocessing.Pool(processes=num_processes)
    pool.map(filter_chunk, range(num_chunks), chunksize=1)
    return True
예제 #5
0
 def getRawTraces(self, t_start=None, t_end=None, electrode_ids=None):
     X = mdaio.DiskReadMda(self._timeseries_path)
     recordings = X.readChunk(i1=0,
                              i2=t_start,
                              N1=X.N1(),
                              N2=t_end - t_start)
     times = np.arange(t_start, t_end) / self._samplerate
     return recordings, times
예제 #6
0
def whiten(*,
        timeseries,timeseries_out,
        chunk_size=30000*10,num_processes=os.cpu_count()
        ):
    """
    Whiten a multi-channel timeseries

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)
        
    timeseries_out : OUTPUT
        Whitened output (MxN array)

    """
    X=mdaio.DiskReadMda(timeseries)
    M=X.N1() # Number of channels
    N=X.N2() # Number of timepoints

    num_chunks_for_computing_cov_matrix=10
    
    num_chunks=int(np.ceil(N/chunk_size))
    print ('Chunk size: {}, Num chunks: {}, Num processes: {}'.format(chunk_size,num_chunks,num_processes))
    
    opts={
        "timeseries":timeseries,
        "timeseries_out":timeseries_out,
        "chunk_size":chunk_size,
        "num_processes":num_processes,
        "num_chunks":num_chunks
    }
    global g_opts
    g_opts=opts
    
    pool = multiprocessing.Pool(processes=num_processes)
    step=int(np.maximum(1,np.floor(num_chunks/num_chunks_for_computing_cov_matrix)))
    AAt_matrices=pool.map(compute_AAt_matrix_for_chunk,range(0,num_chunks,step),chunksize=1)
    
    AAt=np.zeros((M,M),dtype='float64')
    
    for M0 in AAt_matrices:
        AAt+=M0/(len(AAt_matrices)*chunk_size) ##important: need to fix the denominator here to account for possible smaller chunk
    
    U, S, Ut = np.linalg.svd(AAt, full_matrices=True)
    
    W = (U @ np.diag(1/np.sqrt(S))) @ Ut
    #print ('Whitening matrix:')
    #print (W)
    
    global g_shared_data
    g_shared_data=SharedChunkInfo(num_chunks)
    mdaio.writemda32(np.zeros([M,0]),timeseries_out)
    
    pool = multiprocessing.Pool(processes=num_processes)
    pool.starmap(whiten_chunk,[(num,W) for num in range(0,num_chunks)],chunksize=1)
    
    return True
예제 #7
0
 def getChunk(self, *, trange=None, channels=None):
     if not channels:
         channels = range(1, self._num_channels + 1)
     if not trange:
         trange = [0, self._num_timepoints]
     X = mdaio.DiskReadMda(self._mda_path)
     chunk = X.readChunk(i1=0,
                         i2=trange[0],
                         N1=self._num_channels,
                         N2=trange[1] - trange[0])
     return chunk[np.array(channels) - 1, :]
def view_timeseries(timeseries,
                    trange=None,
                    channels=None,
                    samplerate=30000,
                    title='',
                    fig_size=[18, 6]):
    # timeseries=mls.loadMdaFile(timeseries)
    if type(timeseries) == str:
        X = mdaio.DiskReadMda(timeseries)
        M = X.N1()
        N = X.N2()
        if not trange:
            trange = [0, np.minimum(1000, N)]
        X = X.readChunk(i1=0,
                        N1=X.N1(),
                        i2=int(trange[0]),
                        N2=int(trange[1] - trange[0]))
    else:
        M = timeseries.shape[0]
        N = timeseries.shape[1]
        if not channels:
            channels = range(M)
        if not trange:
            trange = [0, N]
        X = timeseries[channels][:, int(trange[0]):int(trange[1])]

    set_fig_size(fig_size[0], fig_size[1])

    channel_colors = _get_channel_colors(M)
    if not channels:
        channels = np.arange(M).tolist()

    spacing_between_channels = np.max(np.abs(X.ravel()))

    y_offset = 0
    for m in range(len(channels)):
        A = X[m, :]
        plt.plot(np.arange(trange[0], trange[1]),
                 A + y_offset,
                 color=channel_colors[channels[m]])
        y_offset -= spacing_between_channels

    ax = plt.gca()
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

    if title:
        plt.title(title, fontsize=title_fontsize)

    plt.show()
    return ax
예제 #9
0
 def getRawTraces(self, start_frame=None, end_frame=None, channel_ids=None):
     if start_frame is None:
         start_frame = 0
     if end_frame is None:
         end_frame = self.getNumFrames()
     if channel_ids is None:
         channel_ids = range(self.getNumChannels())
     X = mdaio.DiskReadMda(self._timeseries_path)
     recordings = X.readChunk(i1=0,
                              i2=t_start,
                              N1=X.N1(),
                              N2=t_end - t_start)
     recordings = recordings[channel_ids, :]
     return recordings
예제 #10
0
def get_epoch_offsets(*, dataset_dir, opts={}):

    if 'mda_list' in opts:
        # initialize with 0 (first start time)
        lengths = [0]

        for idx in range(len(opts['mda_list'])):
            ep_path = opts['mda_list'][idx]
            ep_mda = mdaio.DiskReadMda(ep_path)
            #get length of the mda (N dimension)
            samplength = ep_mda.N2()
            #add to prior sum and append
            lengths.append(samplength + lengths[(idx)])

    else:

        prv_list = dataset_dir + '/raw.mda.prv'

        with open(prv_list, 'r') as f:
            ep_files = json.load(f)

        # initialize with 0 (first start time)
        lengths = [0]

        for idx in range(len(ep_files['files'])):
            ep_path = ep_files['files'][idx]['prv']['original_path']
            ep_mda = mdaio.DiskReadMda(ep_path)
            #get length of the mda (N dimension)
            samplength = ep_mda.N2()
            #add to prior sum and append
            lengths.append(samplength + lengths[(idx)])

    #first entries (incl 0) are starttimes; last is total time
    total_samples = lengths[-1]
    sample_offsets = lengths[0:-1]

    return sample_offsets, total_samples
예제 #11
0
 def __init__(self, dataset_directory, download=True):
     InputExtractor.__init__(self)
     self._dataset_directory = dataset_directory
     timeseries0 = dataset_directory + '/raw.mda'
     if download:
         print('Downloading file if needed: ' + timeseries0)
         self._timeseries_path = mlp.realizeFile(timeseries0)
         print('Done.')
     else:
         self._timeseries_path = mlp.locateFile(timeseries0)
     X = mdaio.DiskReadMda(self._timeseries_path)
     self._num_channels = X.N1()
     self._num_timepoints = X.N2()
     self._dataset_params = read_dataset_params(dataset_directory)
     self._samplerate = self._dataset_params['samplerate']
예제 #12
0
def mask_chunk(num, use_it):
    opts = g_opts

    in_fname = opts['timeseries']  # The entire (large) input file
    out_fname = opts['timeseries_out']  # The entire (large) output file
    chunk_size = opts['chunk_size']
    num_write_chunks = opts['num_write_chunks']
    write_chunk_size = opts['write_chunk_size']

    X = mdaio.DiskReadMda(in_fname)

    # t1=int(num*chunk_size) # first timepoint of the chunk
    # t2=int(np.minimum(X.N2(),(t1+chunk_size))) # last timepoint of chunk (+1)

    t1 = int(num * write_chunk_size)  # first timepoint of the chunk
    t2 = int(np.minimum(
        X.N2(), (t1 + write_chunk_size)))  # last timepoint of chunk (+1)

    chunk = X.readChunk(i1=0, N1=X.N1(), i2=t1,
                        N2=t2 - t1).astype(np.float32)  # Read the chunk

    if sum(use_it) != len(use_it):
        chunk[:,
              get_masked_indices(use_it, write_chunk_size, chunk_size,
                                 num_write_chunks)] = 0

    ###########################################################################################
    # Now we wait until we are ready to append to the output file
    # Note that we need to append in order, thus the shared_data object
    ###########################################################################################
    g_shared_data.reportChunkCompleted(
        num)  # Report that we have completed this chunk
    while True:
        if num == g_shared_data.lastAppendedChunk() + 1:
            break
        time.sleep(0.005)  # so we don't saturate the CPU unnecessarily

    # Append the filtered chunk (excluding the padding) to the output file
    mdaio.appendmda(chunk, out_fname)

    # Report that we have appended so the next chunk can proceed
    g_shared_data.reportChunkAppended(num)

    # Print status if it has been long enough
    if g_shared_data.elapsedTime() > 4:
        g_shared_data.printStatus()
        g_shared_data.resetTimer()
예제 #13
0
def compute_AAt_matrix_for_chunk(num):
    opts=g_opts
    in_fname=opts['timeseries'] # The entire (large) input file
    out_fname=opts['timeseries_out'] # The entire (large) output file
    chunk_size=opts['chunk_size']
    
    X=mdaio.DiskReadMda(in_fname)
    
    t1=int(num*opts['chunk_size']) # first timepoint of the chunk
    t2=int(np.minimum(X.N2(),(t1+chunk_size))) # last timepoint of chunk (+1)
    
    # Ensuring that this chunk value is float64 to avoid svd complications
    chunk=X.readChunk(i1=0,N1=X.N1(),i2=t1,N2=t2-t1).astype(np.float32) # Read the chunk
    
    ret=chunk @ np.transpose(chunk)
    
    return ret
예제 #14
0
    def run(self, mdafile_path_or_diskreadmda, func):
        if (type(mdafile_path_or_diskreadmda) == str):
            X = mdaio.DiskReadMda(mdafile_path_or_diskreadmda)
        else:
            X = mdafile_path_or_diskreadmda
        M, N = X.N1(), X.N2()
        cs = max(
            [self._chunk_size,
             int(self._chunk_size_mb * 1e6 / (M * 4)), M])
        if self._t1 < 0:
            self._t1 = 0
        if self._t2 < 0:
            self._t2 = N - 1
        t = self._t1
        while t <= self._t2:
            t1 = t
            t2 = min(self._t2, t + cs - 1)
            s1 = max(0, t1 - self._overlap_size)
            s2 = min(N - 1, t2 + self._overlap_size)

            timer = time.time()
            chunk = X.readChunk(i1=0, N1=M, i2=s1, N2=s2 - s1 + 1)
            self._elapsed_reading += time.time() - timer

            info = TimeseriesChunkInfo()
            info.t1 = t1
            info.t2 = t2
            info.t1a = t1 - s1
            info.t2a = t2 - s1
            info.size = t2 - t1 + 1

            timer = time.time()
            if not func(chunk, info):
                return False
            self._elapsed_running += time.time() - timer

            t = t + cs
        if self._verbose:
            print(
                'Elapsed for TimeseriesChunkReader: %g sec reading, %g sec running'
                % (self._elapsed_reading, self._elapsed_running))
        return True
예제 #15
0
 def __init__(self, *, dataset_directory, download=True):
     InputExtractor.__init__(self)
     self._dataset_directory = dataset_directory
     timeseries0 = dataset_directory + '/raw.mda'
     self._dataset_params = read_dataset_params(dataset_directory)
     self._samplerate = self._dataset_params['samplerate']
     if download:
         print('Downloading file if needed: ' + timeseries0)
         self._timeseries_path = mlp.realizeFile(timeseries0)
         print('Done.')
     else:
         self._timeseries_path = mlp.locateFile(timeseries0)
     geom0 = dataset_directory + '/geom.csv'
     self._geom_fname = mlp.realizeFile(geom0)
     self._geom = np.genfromtxt(self._geom_fname, delimiter=',')
     X = mdaio.DiskReadMda(self._timeseries_path)
     if self._geom.shape[0] != X.N1():
         raise Exception(
             'Incompatible dimensions between geom.csv and timeseries file {} <> {}'
             .format(self._geom.shape[0], X.N1()))
     self._num_channels = X.N1()
     self._num_timepoints = X.N2()
예제 #16
0
def whiten_chunk(num,W):
    #print('Whitening {}'.format(num))
    opts=g_opts
    #print('Whitening chunk {} of {}'.format(num,opts['num_chunks']))
    in_fname=opts['timeseries'] # The entire (large) input file
    out_fname=opts['timeseries_out'] # The entire (large) output file
    chunk_size=opts['chunk_size']
    
    X=mdaio.DiskReadMda(in_fname)
    
    t1=int(num*opts['chunk_size']) # first timepoint of the chunk
    t2=int(np.minimum(X.N2(),(t1+chunk_size))) # last timepoint of chunk (+1)
    
    chunk=X.readChunk(i1=0,N1=X.N1(),i2=t1,N2=t2-t1) # Read the chunk
    
    chunk=W @ chunk
    
    ###########################################################################################
    # Now we wait until we are ready to append to the output file
    # Note that we need to append in order, thus the shared_data object
    ###########################################################################################
    g_shared_data.reportChunkCompleted(num) # Report that we have completed this chunk
    while True:
        if num == g_shared_data.lastAppendedChunk()+1:
            break
        time.sleep(0.005) # so we don't saturate the CPU unnecessarily
    
    # Append the filtered chunk (excluding the padding) to the output file
    mdaio.appendmda(chunk,out_fname)
    
    # Report that we have appended so the next chunk can proceed
    g_shared_data.reportChunkAppended(num)

    # Print status if it has been long enough
    if g_shared_data.elapsedTime()>4:
        g_shared_data.printStatus()
        g_shared_data.resetTimer()
예제 #17
0
def extract_clips_helper(*, timeseries, times, clip_size=100, verbose=False):
    X = mdaio.DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    L = times.size
    T = clip_size
    extract_clips_helper._clips = np.zeros((M, T, L))

    def _kernel(chunk, info):
        inds = np.where((info.t1 <= times) & (times <= info.t2))[0]
        times0 = times[inds] - info.t1 + info.t1a
        clips0 = np.zeros((M, clip_size, len(inds)),
                          dtype=np.float32,
                          order='F')
        cpp.extract_clips(clips0, chunk, times0, clip_size)

        extract_clips_helper._clips[:, :, inds] = clips0
        return True

    TCR = TimeseriesChunkReader(chunk_size_mb=100,
                                overlap_size=clip_size * 2,
                                verbose=verbose)
    if not TCR.run(timeseries, _kernel):
        return None
    return extract_clips_helper._clips
예제 #18
0
def filter_chunk(num):
    #print('Filtering {}'.format(num))
    opts = g_opts
    #print('Filtering chunk {} of {}'.format(num,opts['num_chunks']))
    in_fname = opts['timeseries']  # The entire (large) input file
    out_fname = opts['timeseries_out']  # The entire (large) output file
    samplerate = opts['samplerate']
    freq_min = opts['freq_min']
    freq_max = opts['freq_max']
    freq_wid = opts['freq_wid']
    chunk_size = opts['chunk_size']
    padding = opts['padding']

    X = mdaio.DiskReadMda(in_fname)

    chunk_size_with_padding = chunk_size + 2 * padding
    padded_chunk = np.zeros((X.N1(), chunk_size_with_padding), dtype='float32')

    t1 = int(num * opts['chunk_size'])  # first timepoint of the chunk
    t2 = int(np.minimum(X.N2(),
                        (t1 + chunk_size)))  # last timepoint of chunk (+1)
    s1 = int(np.maximum(0,
                        t1 - padding))  # first timepoint including the padding
    s2 = int(np.minimum(X.N2(), t2 +
                        padding))  # last timepoint (+1) including the padding

    # determine aa so that t1-s1+aa = padding
    # so, aa = padding-(t1-s1)
    aa = padding - (t1 - s1)
    padded_chunk[:, aa:aa + s2 - s1] = X.readChunk(i1=0,
                                                   N1=X.N1(),
                                                   i2=s1,
                                                   N2=s2 -
                                                   s1)  # Read the padded chunk

    # Do the actual filtering with a DFT with real input
    padded_chunk = np.fft.rfft(padded_chunk)
    # Subtract off the mean of each channel unless we are doing only a low-pass filter
    if freq_min != 0:
        for m in range(padded_chunk.shape[0]):
            padded_chunk[m, :] = padded_chunk[m, :] - np.mean(
                padded_chunk[m, :])
    kernel = create_filter_kernel(chunk_size_with_padding, samplerate,
                                  freq_min, freq_max, freq_wid)
    kernel = kernel[
        0:padded_chunk.shape[1]]  # because this is the DFT of real data
    padded_chunk = padded_chunk * np.tile(kernel, (padded_chunk.shape[0], 1))
    padded_chunk = np.fft.irfft(padded_chunk)

    ###########################################################################################
    # Now we wait until we are ready to append to the output file
    # Note that we need to append in order, thus the shared_data object
    ###########################################################################################
    g_shared_data.reportChunkCompleted(
        num)  # Report that we have completed this chunk
    while True:  # Alex: maybe there should be a timeout here in case ...
        if num == g_shared_data.lastAppendedChunk() + 1:
            break
        time.sleep(0.005)  # so we don't saturate the CPU unnecessarily

    # Append the filtered chunk (excluding the padding) to the output file
    mdaio.appendmda(padded_chunk[:, padding:padding + (t2 - t1)], out_fname)

    # Report that we have appended so the next chunk can proceed
    g_shared_data.reportChunkAppended(num)

    # Print status if it has been long enough
    if g_shared_data.elapsedTime() > 4:
        g_shared_data.printStatus()
        g_shared_data.resetTimer()
예제 #19
0
 def __init__(self, path, *, samplerate):
     self._samplerate = samplerate
     self._mda_path = path
     X = mdaio.DiskReadMda(path)
     self._num_channels = X.N1()
     self._num_timepoints = X.N2()
예제 #20
0
def sort(*,
         timeseries,
         geom='',
         firings_out,
         adjacency_radius,
         detect_sign,
         detect_interval=10,
         detect_threshold=3,
         clip_size=50,
         num_workers=multiprocessing.cpu_count()):
    """
    MountainSort spike sorting (version 4)

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)
    geom : INPUT
        Optional geometry file (.csv format)
        
    firings_out : OUTPUT
        Firings array channels/times/labels (3xL, L = num. events)
        
    adjacency_radius : float
        Radius of local sorting neighborhood, corresponding to the geometry file (same units). 0 means each channel is sorted independently. -1 means all channels are included in every neighborhood.
    detect_sign : int
        Use 1, -1, or 0 to detect positive peaks, negative peaks, or both, respectively
    detect_threshold : float
        Threshold for event detection, corresponding to the input file. So if the input file is normalized to have noise standard deviation 1 (e.g., whitened), then this is in units of std. deviations away from the mean.
    detect_interval : int
        The minimum number of timepoints between adjacent spikes detected in the same channel neighborhood.
    clip_size : int
        Size of extracted clips or snippets, used throughout
    num_workers : int
        Number of simultaneous workers (or processes). The default is multiprocessing.cpu_count().
    """

    tempdir = os.environ.get('ML_PROCESSOR_TEMPDIR')
    if not tempdir:
        print(
            'Warning: environment variable ML_PROCESSOR_TEMPDIR not set. Using current directory.'
        )
        tempdir = '.'
    print('Using tempdir={}'.format(tempdir))

    os.environ['OMP_NUM_THREADS'] = '1'

    # Read the header of the timeseries input to get the num. channels and num. timepoints
    X = mdaio.DiskReadMda(timeseries)
    M = X.N1()  # Number of channels
    N = X.N2()  # Number of timepoints

    # Read the geometry file
    if geom:
        Geom = np.genfromtxt(geom, delimiter=',')
    else:
        Geom = np.zeros((M, 2))

    if Geom.shape[0] != M:
        raise Exception(
            'Incompatible dimensions between geom and timeseries: {} != {}'.
            format(Geom.shape[1], M))

    MS4 = ms4alg.MountainSort4()
    MS4.setGeom(Geom)
    MS4.setSortingOpts(clip_size=clip_size,
                       adjacency_radius=adjacency_radius,
                       detect_sign=detect_sign,
                       detect_interval=detect_interval,
                       detect_threshold=detect_threshold)
    MS4.setNumWorkers(num_workers)
    MS4.setTimeseriesPath(timeseries)
    MS4.setFiringsOutPath(firings_out)
    MS4.setTemporaryDirectory(tempdir)
    MS4.sort()
    return True
예제 #21
0
def mask_out_artifacts(*,
                       timeseries,
                       timeseries_out,
                       threshold=6,
                       chunk_size=2000,
                       num_write_chunks=150,
                       num_processes=os.cpu_count()):
    """
    Masks out artifacts. Each chunk will be analyzed, and if the square root of the
    RSS of the chunk is above threshold, all the samples in this chunk (and neighboring chunks)
    will be set to zero.

    Parameters
    ----------
    timeseries : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints)

    timeseries_out : OUTPUT
        masked output (MxN array)

    threshold : int
        Number of standard deviations away from the mean to consider as artifacts (default of 6).
    chunk_size : int
        This chunk size will be the number of samples that will be set to zero if the square root RSS of this chunk is above threshold.
    num_write_chunks : int
        How many chunks will be simultaneously written to the timeseries_out path (default of 150).
    """

    if threshold == 0 or chunk_size == 0 or num_write_chunks == 0:
        print(
            "Problem with input parameters. Either threshold, num_write_chunks, or chunk_size is zero.\n"
        )
        return False

    write_chunk_size = chunk_size * num_write_chunks

    opts = {
        "timeseries": timeseries,
        "timeseries_out": timeseries_out,
        "chunk_size": chunk_size,
        "num_processes": num_processes,
        "num_write_chunks": num_write_chunks,
        "write_chunk_size": write_chunk_size,
    }

    global g_opts
    g_opts = opts

    X = mdaio.DiskReadMda(timeseries)

    M = X.N1()  # Number of channels
    N = X.N2()  # Number of timepoints

    # compute norms of chunks
    num_chunks = int(np.ceil(N / chunk_size))
    num_write = int(np.ceil(N / write_chunk_size))

    norms = np.zeros((M, num_chunks))  # num channels x num_chunks

    for i in np.arange(num_chunks):
        t1 = int(i * chunk_size)  # first timepoint of the chunk
        t2 = int(np.minimum(N,
                            (t1 + chunk_size)))  # last timepoint of chunk (+1)

        chunk = X.readChunk(i1=0, N1=X.N1(), i2=t1,
                            N2=t2 - t1).astype(np.float32)  # Read the chunk

        norms[:, i] = np.sqrt(np.sum(chunk**2,
                                     axis=1))  # num_channels x num_chunks

    # determine which chunks to use
    use_it = np.ones(num_chunks)  # initialize use_it array

    for m in np.arange(M):
        vals = norms[m, :]

        sigma0 = np.std(vals)
        mean0 = np.mean(vals)

        artifact_indices = np.where(vals > mean0 + sigma0 * threshold)[0]

        # check if the first chunk is above threshold, ensure that we don't use negative indices later
        negIndBool = np.where(artifact_indices > 0)[0]

        # check if the last chunk is above threshold to avoid a IndexError
        maxIndBool = np.where(artifact_indices < num_chunks - 1)[0]

        use_it[artifact_indices] = 0
        use_it[artifact_indices[negIndBool] -
               1] = 0  # don't use the neighbor chunks either
        use_it[artifact_indices[maxIndBool] +
               1] = 0  # don't use the neighbor chunks either

        print("For channel %d: mean=%.2f, stdev=%.2f, chunk size = %d\n" %
              (m, mean0, sigma0, chunk_size))

    global g_shared_data
    g_shared_data = SharedChunkInfo(num_write)

    mdaio.writemda32(
        np.zeros([M, 0]), timeseries_out
    )  # create initial file w/ empty array so we can append to it

    pool = multiprocessing.Pool(processes=num_processes)
    # pool.starmap(mask_chunk,[(num,use_it[num]) for num in range(0,num_chunks)],chunksize=1)
    pool.starmap(
        mask_chunk,
        [(num, use_it[num * num_write_chunks:(num + 1) * num_write_chunks])
         for num in range(0, num_write)],
        chunksize=1)

    num_timepoints_used = sum(use_it)
    num_timepoints_not_used = sum(use_it == 0)
    print("Using %.2f%% of all timepoints.\n" %
          (num_timepoints_used * 100.0 /
           (num_timepoints_used + num_timepoints_not_used)))
    return True
예제 #22
0
def convert_clips_fet_spk(*,
                          timeseries,
                          firings,
                          waveforms_out,
                          ntro_nchannels,
                          clip_size=32,
                          nPC=4,
                          DEBUG=True):
    """
    Compute templates (average waveforms) for clusters defined by the labeled events in firings. One .spk.n file per n-trode.

    Parameters
    ----------
    timeseries : INPUT
        Path of timeseries mda file (MxN) from which to draw the event clips (snippets) for computing the templates. M is number of channels, N is number of timepoints.
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels.
    params : INPUT
        params.json file. Needed to see number of channels per tetrode.
        
    ntro_nchannels : INPUT
        Comma-seperated determining the number of channels should be taken for each ntrode.
        
    waveforms_out : OUTPUT
        Base Path (MxTxK). T=clip_size, K=maximum cluster label. Note that empty clusters will correspond to a template of all zeros. 
        
    clip_size : int
        (Optional) clip size, aka snippet size, number of timepoints in a single template
    
    nPC : int
        (Optional) Number of principal components *per channel* for .fet files.
        
    DEBUG : bool
        (Optional) Verbose output for debugging purposes.
    """
    X = mdaio.DiskReadMda(timeseries)
    M, N = X.N1(), X.N2()
    F = mdaio.readmda(firings)
    L = F.shape[1]
    L = L
    T = clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1)
    whch = F[0, :].ravel()[:]
    times = F[1, :].ravel()[:]
    labels = F[2, :].ravel().astype(int)[:]
    K = np.max(labels)

    tetmap = list()
    i = 0
    for nch in ntro_nchannels.split(','):
        tp = i + 1 + np.arange(0, int(nch))
        tetmap.append(tp)
        i = tp[-1]
    which_tet = [np.where(w == tetmap)[0][0] + 1 for w in whch]

    print("Starting:")
    for tro in np.arange(1, 12):
        chans = tetmap[tro - 1]
        inds_k = np.where(which_tet == tro)[0]

        ###### Write .res file ######
        res = times[inds_k]
        res_fname = firings_out + '.res.' + tro
        np.savetxt(res_fname, times, fmt='%d')

        labels_tet = labels[inds_k]
        u_labels_tet = np.unique(labels_tet)
        K = len(u_labels_tet)
        u_new_labels = np.argsort(u_labels_tet).argsort() + 2
        new_labels = [
            u_new_labels[(l == np.array(u_labels_tet)).argmax()]
            for l in labels_tet
        ]
        clu_fname = firings_out + '.clu.' + tro
        np.savetxt(clu_fname,
                   np.concatenate((np.array([K]), new_lables)),
                   fmt='%d')

        if DEBUG:
            print("Tetrode: " + str(tro))
            print("Chans: " + str(chans))
            print("Create Waveforms Array: " + str(len(chans)) + "," + str(T) +
                  "," + str(len(inds_k)))

        waveforms = np.zeros((len(chans), T, len(inds_k)), dtype='int16')
        for i, ind_k in enumerate(inds_k):  # for each spike
            t0 = int(times[ind_k])
            if (clip_size <= t0) and (t0 < N - clip_size):
                clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T)
                clip0 = clip0[chans, :] * 100
                clip_int = clip0.astype(dtype='int16')
                waveforms[:, :, i] = clip_int

        fname = waveforms_out + '.spk.' + str(tro)
        if DEBUG:
            print("Writing Waveforms to File: " + fname)
        waveforms.tofile(fname, format='')

        if DEBUG:
            print("Calculating Feature Array")
        fet = np.zeros((np.shape(waveforms)[2], (len(chans) * nPC) + 1))
        for c in np.arange(len(chans)):
            pca = decomposition.PCA(n_components=nPC)
            x_std = StandardScaler().fit_transform(
                np.transpose(waveforms[c, :, :]).astype(dtype='float64'))
            fpos = (c * nPC)
            fet[:, fpos:fpos + 4] = pca.fit_transform(x_std)

        fet[:, (len(chans) * nPC)] = times[inds_k]
        fet *= 1000
        fet.astype(dtype='int64')
        fname = waveforms_out + '.fet.' + str(tro)
        if DEBUG:
            print("Writing Features to File: " + fname)
        np.savetxt(fname, fet, fmt='%d')

    return True
def compute_cluster_metrics(*,timeseries='',firings,metrics_out,clip_size=100,samplerate=0,
        refrac_msec=1):
    """
    Compute cluster metrics for a spike sorting output

    Parameters
    ----------
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer cluster labels.
    timeseries : INPUT
        Optional path of timeseries mda file (MxN) which could be raw or preprocessed   
        
    metrics_out : OUTPUT
        Path of output json file containing the metrics.
        
    clip_size : int
        (Optional) clip size, aka snippet size (used when computing the templates, or average waveforms)
    samplerate : float
        Optional sample rate in Hz

    refrac_msec : float
        (Optional) Define interval (in ms) as refractory period. If 0 then don't compute the refractory period metric.
    """    
    print('Reading firings...')
    F=mdaio.readmda(firings)

    print('Initializing...')
    R=F.shape[0]
    L=F.shape[1]
    assert(R>=3)
    times=F[1,:]
    labels=F[2,:].astype(np.int)
    K=np.max(labels)
    N=0
    if timeseries:
        X=mdaio.DiskReadMda(timeseries)
        N=X.N2()

    if (samplerate>0) and (N>0):
        duration_sec=N/samplerate
    else:
        duration_sec=0

    clusters=[]
    for k in range(1,K+1):
        inds_k=np.where(labels==k)[0]
        metrics_k={
            "num_events":len(inds_k)
        }
        if duration_sec:
            metrics_k['firing_rate']=len(inds_k)/duration_sec
        cluster={
            "label":k,
            "metrics":metrics_k
        }
        clusters.append(cluster)

    if timeseries:
        print('Computing templates...')
        templates=compute_templates_helper(timeseries=timeseries,firings=firings,clip_size=clip_size)
        for k in range(1,K+1):
            template_k=templates[:,:,k-1]
            # subtract mean on each channel (todo: vectorize this operation)
            # Alex: like this?
            # template_k[m,:] -= np.mean(template_k[m,:])
            for m in range(templates.shape[0]):
                template_k[m,:]=template_k[m,:]-np.mean(template_k[m,:])
            peak_amplitude=np.max(np.abs(template_k))
            clusters[k-1]['metrics']['peak_amplitude']=peak_amplitude
        ## todo: subtract template means, compute peak amplitudes

    if refrac_msec > 0:
        print('Computing Refractory Period Violations')
        msec_string = '{0:.5f}'.format(refrac_msec).rstrip('0').rstrip('.')
        rr_string = 'refractory_violation_{}msec'.format(msec_string)
        max_dt_msec=50
        bin_size_msec=refrac_msec
        auto_cors = compute_cross_correlograms_helper(firings=firings,mode='autocorrelograms',samplerate=samplerate,max_dt_msec=max_dt_msec,bin_size_msec=bin_size_msec)
        mid_ind  = np.floor(max_dt_msec/bin_size_msec).astype(int)
        mid_inds = np.arange(mid_ind-2,mid_ind+2)
        for k0,auto_cor_obj in enumerate(auto_cors['correlograms']):
            auto = auto_cor_obj['bin_counts']
            k    = auto_cor_obj['k']
            peak = np.percentile(auto, 75)
            mid  = np.mean(auto[mid_inds])
            rr   = safe_div(mid,peak)
            clusters[k-1]['metrics'][rr_string] = rr

    ret={
        "clusters":clusters
    }

    print('Writing output...')
    str=json.dumps(ret,indent=4)
    with open(metrics_out, 'w') as out:
        out.write(str)
    print('Done.')
    return True