def compute_templates_helper(*,timeseries,firings,clip_size=100): X=mdaio.DiskReadMda(timeseries) M,N = X.N1(),X.N2() F=mdaio.readmda(firings) L=F.shape[1] L=L T=clip_size Tmid = int(np.floor((T + 1) / 2) - 1); times=F[1,:].ravel() labels=F[2,:].ravel().astype(int) K=np.max(labels) sums=np.zeros((M,T,K),dtype='float64') counts=np.zeros(K) for k in range(1,K+1): inds_k=np.where(labels==k)[0] #TODO: subsample for ind_k in inds_k: t0=int(times[ind_k]) if (clip_size<=t0) and (t0<N-clip_size): clip0=X.readChunk(i1=0,N1=M,i2=t0-Tmid,N2=T) sums[:,:,k-1]+=clip0 counts[k-1]+=1 templates=np.zeros((M,T,K)) for k in range(K): templates[:,:,k]=sums[:,:,k]/counts[k] return templates
def prepare_timeseries_hdf5(timeseries_fname,timeseries_hdf5_fname,*,chunk_size,padding): chunk_size_with_padding=chunk_size+2*padding with h5py.File(timeseries_hdf5_fname,"w") as f: X=mdaio.DiskReadMda(timeseries_fname) M=X.N1() # Number of channels N=X.N2() # Number of timepoints num_chunks=int(np.ceil(N/chunk_size)) f.create_dataset('chunk_size',data=[chunk_size]) f.create_dataset('num_chunks',data=[num_chunks]) f.create_dataset('padding',data=[padding]) f.create_dataset('num_channels',data=[M]) f.create_dataset('num_timepoints',data=[N]) for j in range(num_chunks): padded_chunk=np.zeros((X.N1(),chunk_size_with_padding),dtype=X.dt()) t1=int(j*chunk_size) # first timepoint of the chunk t2=int(np.minimum(X.N2(),(t1+chunk_size))) # last timepoint of chunk (+1) s1=int(np.maximum(0,t1-padding)) # first timepoint including the padding s2=int(np.minimum(X.N2(),t2+padding)) # last timepoint (+1) including the padding # determine aa so that t1-s1+aa = padding # so, aa = padding-(t1-s1) aa = padding-(t1-s1) padded_chunk[:,aa:aa+s2-s1]=X.readChunk(i1=0,N1=X.N1(),i2=s1,N2=s2-s1) # Read the padded chunk for m in range(M): f.create_dataset('part-{}-{}'.format(m,j),data=padded_chunk[m,:].ravel())
def bandpass_filter(*, timeseries, timeseries_out, samplerate, freq_min, freq_max, freq_wid=1000, padding=3000, chunk_size=3000 * 10, num_processes=os.cpu_count()): """ Apply a bandpass filter to a multi-channel timeseries Parameters ---------- timeseries : INPUT MxN raw timeseries array (M = #channels, N = #timepoints) timeseries_out : OUTPUT Filtered output (MxN array) samplerate : float The sampling rate in Hz freq_min : float The lower endpoint of the frequency band (Hz) freq_max : float The upper endpoint of the frequency band (Hz) freq_wid : float The optional width of the roll-off (Hz) """ X = mdaio.DiskReadMda(timeseries) M = X.N1() # Number of channels N = X.N2() # Number of timepoints num_chunks = int(np.ceil(N / chunk_size)) print('Chunk size: {}, Padding: {}, Num chunks: {}, Num processes: {}'. format(chunk_size, padding, num_chunks, num_processes)) opts = { "timeseries": timeseries, "timeseries_out": timeseries_out, "samplerate": samplerate, "freq_min": freq_min, "freq_max": freq_max, "freq_wid": freq_wid, "chunk_size": chunk_size, "padding": padding, "num_processes": num_processes, "num_chunks": num_chunks } global g_shared_data g_shared_data = SharedChunkInfo(num_chunks) global g_opts g_opts = opts mdaio.writemda32(np.zeros([M, 0]), timeseries_out) pool = multiprocessing.Pool(processes=num_processes) pool.map(filter_chunk, range(num_chunks), chunksize=1) return True
def view_timeseries(timeseries, trange=None, channels=None, samplerate=30000, title='', fig_size=[18, 6]): #timeseries=mls.loadMdaFile(timeseries) if type(timeseries) == str: X = mdaio.DiskReadMda(timeseries) M = X.N1() N = X.N2() if not trange: trange = [0, np.minimum(1000, N)] X = X.readChunk(i1=0, N1=X.N1(), i2=int(trange[0]), N2=int(trange[1] - trange[0])) else: M = timeseries.shape[0] N = timeseries.shape[1] X = timeseries[channels][:, int(trange[0]):int(trange[1])] set_fig_size(fig_size[0], fig_size[1]) channel_colors = _get_channel_colors(M) if not channels: channels = np.arange(M).tolist() spacing_between_channels = np.max(np.abs(X.ravel())) y_offset = 0 for m in range(len(channels)): A = X[m, :] plt.plot(np.arange(trange[0], trange[1]), A + y_offset, color=channel_colors[channels[m]]) y_offset -= spacing_between_channels ax = plt.gca() ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if title: plt.title(title, fontsize=title_fontsize) plt.show() return ax
def compute_AAt_matrix_for_chunk(num): opts = g_opts in_fname = opts['timeseries'] # The entire (large) input file out_fname = opts['timeseries_out'] # The entire (large) output file chunk_size = opts['chunk_size'] X = mdaio.DiskReadMda(in_fname) t1 = int(num * opts['chunk_size']) # first timepoint of the chunk t2 = int(np.minimum(X.N2(), (t1 + chunk_size))) # last timepoint of chunk (+1) chunk = X.readChunk(i1=0, N1=X.N1(), i2=t1, N2=t2 - t1) # Read the chunk ret = chunk @ np.transpose(chunk) return ret
def whiten_chunk(num, W): #print('Whitening {}'.format(num)) opts = g_opts #print('Whitening chunk {} of {}'.format(num,opts['num_chunks'])) in_fname = opts['timeseries'] # The entire (large) input file out_fname = opts['timeseries_out'] # The entire (large) output file chunk_size = opts['chunk_size'] X = mdaio.DiskReadMda(in_fname) t1 = int(num * opts['chunk_size']) # first timepoint of the chunk t2 = int(np.minimum(X.N2(), (t1 + chunk_size))) # last timepoint of chunk (+1) chunk = X.readChunk(i1=0, N1=X.N1(), i2=t1, N2=t2 - t1) # Read the chunk chunk = W @ chunk ########################################################################################### # Now we wait until we are ready to append to the output file # Note that we need to append in order, thus the shared_data object ########################################################################################### g_shared_data.reportChunkCompleted( num) # Report that we have completed this chunk while True: if num == g_shared_data.lastAppendedChunk() + 1: break time.sleep(0.005) # so we don't saturate the CPU unnecessarily # Append the filtered chunk (excluding the padding) to the output file mdaio.appendmda(chunk, out_fname) # Report that we have appended so the next chunk can proceed g_shared_data.reportChunkAppended(num) # Print status if it has been long enough if g_shared_data.elapsedTime() > 4: g_shared_data.printStatus() g_shared_data.resetTimer()
def run(self, mdafile_path_or_diskreadmda, func): if (type(mdafile_path_or_diskreadmda)==str): X=mdaio.DiskReadMda(mdafile_path_or_diskreadmda) else: X=mdafile_path_or_diskreadmda M,N = X.N1(),X.N2() cs=max([self._chunk_size,int(self._chunk_size_mb*1e6/(M*4)),M]) if self._t1<0: self._t1=0 if self._t2<0: self._t2=N-1 t=self._t1 while t <= self._t2: t1=t t2=min(self._t2,t+cs-1) s1=max(0,t1-self._overlap_size) s2=min(N-1,t2+self._overlap_size) timer=time.time() chunk=X.readChunk(i1=0, N1=M, i2=s1, N2=s2-s1+1) self._elapsed_reading+=time.time()-timer info=TimeseriesChunkInfo() info.t1=t1 info.t2=t2 info.t1a=t1-s1 info.t2a=t2-s1 info.size=t2-t1+1 timer=time.time() if not func(chunk, info): return False self._elapsed_running+=time.time()-timer t=t+cs if self._verbose: print('Elapsed for TimeseriesChunkReader: %g sec reading, %g sec running' % (self._elapsed_reading,self._elapsed_running)) return True
def filter_chunk(num): #print('Filtering {}'.format(num)) opts = g_opts #print('Filtering chunk {} of {}'.format(num,opts['num_chunks'])) in_fname = opts['timeseries'] # The entire (large) input file out_fname = opts['timeseries_out'] # The entire (large) output file samplerate = opts['samplerate'] freq_min = opts['freq_min'] freq_max = opts['freq_max'] freq_wid = opts['freq_wid'] chunk_size = opts['chunk_size'] padding = opts['padding'] X = mdaio.DiskReadMda(in_fname) chunk_size_with_padding = chunk_size + 2 * padding padded_chunk = np.zeros((X.N1(), chunk_size_with_padding), dtype='float32') t1 = int(num * opts['chunk_size']) # first timepoint of the chunk t2 = int(np.minimum(X.N2(), (t1 + chunk_size))) # last timepoint of chunk (+1) s1 = int(np.maximum(0, t1 - padding)) # first timepoint including the padding s2 = int(np.minimum(X.N2(), t2 + padding)) # last timepoint (+1) including the padding # determine aa so that t1-s1+aa = padding # so, aa = padding-(t1-s1) aa = padding - (t1 - s1) padded_chunk[:, aa:aa + s2 - s1] = X.readChunk(i1=0, N1=X.N1(), i2=s1, N2=s2 - s1) # Read the padded chunk # Do the actual filtering with a DFT with real input padded_chunk = np.fft.rfft(padded_chunk) # Subtract off the mean of each channel unless we are doing only a low-pass filter if freq_min != 0: for m in range(padded_chunk.shape[0]): padded_chunk[m, :] = padded_chunk[m, :] - np.mean( padded_chunk[m, :]) kernel = create_filter_kernel(chunk_size_with_padding, samplerate, freq_min, freq_max, freq_wid) kernel = kernel[ 0:padded_chunk.shape[1]] # because this is the DFT of real data padded_chunk = padded_chunk * np.tile(kernel, (padded_chunk.shape[0], 1)) padded_chunk = np.fft.irfft(padded_chunk) ########################################################################################### # Now we wait until we are ready to append to the output file # Note that we need to append in order, thus the shared_data object ########################################################################################### g_shared_data.reportChunkCompleted( num) # Report that we have completed this chunk while True: # Alex: maybe there should be a timeout here in case ... if num == g_shared_data.lastAppendedChunk() + 1: break time.sleep(0.005) # so we don't saturate the CPU unnecessarily # Append the filtered chunk (excluding the padding) to the output file mdaio.appendmda(padded_chunk[:, padding:padding + (t2 - t1)], out_fname) # Report that we have appended so the next chunk can proceed g_shared_data.reportChunkAppended(num) # Print status if it has been long enough if g_shared_data.elapsedTime() > 4: g_shared_data.printStatus() g_shared_data.resetTimer()
def convert_clips_fet_spk(*, timeseries, firings, waveforms_out, ntro_nchannels, clip_size=32, nPC=4, DEBUG=True): """ Compute templates (average waveforms) for clusters defined by the labeled events in firings. One .spk.n file per n-trode. Parameters ---------- timeseries : INPUT Path of timeseries mda file (MxN) from which to draw the event clips (snippets) for computing the templates. M is number of channels, N is number of timepoints. firings : INPUT Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels. params : INPUT params.json file. Needed to see number of channels per tetrode. ntro_nchannels : INPUT Comma-seperated determining the number of channels should be taken for each ntrode. waveforms_out : OUTPUT Base Path (MxTxK). T=clip_size, K=maximum cluster label. Note that empty clusters will correspond to a template of all zeros. clip_size : int (Optional) clip size, aka snippet size, number of timepoints in a single template nPC : int (Optional) Number of principal components *per channel* for .fet files. DEBUG : bool (Optional) Verbose output for debugging purposes. """ X = mdaio.DiskReadMda(timeseries) M, N = X.N1(), X.N2() F = mdaio.readmda(firings) L = F.shape[1] L = L T = clip_size Tmid = int(np.floor((T + 1) / 2) - 1) whch = F[0, :].ravel()[:] times = F[1, :].ravel()[:] labels = F[2, :].ravel().astype(int)[:] K = np.max(labels) tetmap = list() i = 0 for nch in ntro_nchannels.split(','): tp = i + 1 + np.arange(0, int(nch)) tetmap.append(tp) i = tp[-1] which_tet = [np.where(w == tetmap)[0][0] + 1 for w in whch] print("Starting:") for tro in np.arange(1, 12): chans = tetmap[tro - 1] inds_k = np.where(which_tet == tro)[0] if DEBUG: print("Tetrode: " + str(tro)) print("Chans: " + str(chans)) print("Create Waveforms Array: " + str(len(chans)) + "," + str(T) + "," + str(len(inds_k))) waveforms = np.zeros((len(chans), T, len(inds_k)), dtype='int16') for i, ind_k in enumerate(inds_k): # for each spike t0 = int(times[ind_k]) if (clip_size <= t0) and (t0 < N - clip_size): clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T) clip0 = clip0[chans, :] * 100 clip_int = clip0.astype(dtype='int16') waveforms[:, :, i] = clip_int fname = waveforms_out + '.spk.' + str(tro) if DEBUG: print("Writing Waveforms to File: " + fname) waveforms.tofile(fname, format='') if DEBUG: print("Calculating Feature Array") fet = np.zeros((np.shape(waveforms)[2], (len(chans) * nPC) + 1)) for c in np.arange(len(chans)): pca = decomposition.PCA(n_components=nPC) x_std = StandardScaler().fit_transform( np.transpose(waveforms[c, :, :]).astype(dtype='float64')) fpos = (c * nPC) fet[:, fpos:fpos + 4] = pca.fit_transform(x_std) fet[:, (len(chans) * nPC)] = times[inds_k] fet *= 1000 fet.astype(dtype='int64') fname = waveforms_out + '.fet.' + str(tro) if DEBUG: print("Writing Features to File: " + fname) np.savetxt(fname, fet, fmt='%d') return True
def compute_cluster_metrics(*, timeseries='', firings, metrics_out, clip_size=100, samplerate=0, refrac_msec=1): """ Compute cluster metrics for a spike sorting output Parameters ---------- firings : INPUT Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer cluster labels. timeseries : INPUT Optional path of timeseries mda file (MxN) which could be raw or preprocessed metrics_out : OUTPUT Path of output json file containing the metrics. clip_size : int (Optional) clip size, aka snippet size (used when computing the templates, or average waveforms) samplerate : float Optional sample rate in Hz refrac_msec : float (Optional) Define interval (in ms) as refractory period. If 0 then don't compute the refractory period metric. """ print('Reading firings...') F = mdaio.readmda(firings) print('Initializing...') R = F.shape[0] L = F.shape[1] assert (R >= 3) times = F[1, :] labels = F[2, :].astype(np.int) K = np.max(labels) N = 0 if timeseries: X = mdaio.DiskReadMda(timeseries) N = X.N2() if (samplerate > 0) and (N > 0): duration_sec = N / samplerate else: duration_sec = 0 clusters = [] for k in range(1, K + 1): inds_k = np.where(labels == k)[0] metrics_k = {"num_events": len(inds_k)} if duration_sec: metrics_k['firing_rate'] = len(inds_k) / duration_sec cluster = {"label": k, "metrics": metrics_k} clusters.append(cluster) if timeseries: print('Computing templates...') templates = compute_templates_helper(timeseries=timeseries, firings=firings, clip_size=clip_size) for k in range(1, K + 1): template_k = templates[:, :, k - 1] # subtract mean on each channel (todo: vectorize this operation) # Alex: like this? # template_k[m,:] -= np.mean(template_k[m,:]) for m in range(templates.shape[0]): template_k[m, :] = template_k[m, :] - np.mean(template_k[m, :]) peak_amplitude = np.max(np.abs(template_k)) clusters[k - 1]['metrics']['peak_amplitude'] = peak_amplitude ## todo: subtract template means, compute peak amplitudes if refrac_msec > 0: print('Computing Refractory Period Violations') msec_string = '{0:.5f}'.format(refrac_msec).rstrip('0').rstrip('.') rr_string = 'refractory_violation_{}msec'.format(msec_string) max_dt_msec = 50 bin_size_msec = refrac_msec auto_cors = compute_cross_correlograms_helper( firings=firings, mode='autocorrelograms', samplerate=samplerate, max_dt_msec=max_dt_msec, bin_size_msec=bin_size_msec) mid_ind = np.floor(max_dt_msec / bin_size_msec).astype(int) mid_inds = np.arange(mid_ind - 2, mid_ind + 2) for k0, auto_cor_obj in enumerate(auto_cors['correlograms']): auto = auto_cor_obj['bin_counts'] k = auto_cor_obj['k'] peak = np.percentile(auto, 75) mid = np.mean(auto[mid_inds]) rr = safe_div(mid, peak) clusters[k - 1]['metrics'][rr_string] = rr ret = {"clusters": clusters} print('Writing output...') str = json.dumps(ret, indent=4) with open(metrics_out, 'w') as out: out.write(str) print('Done.') return True
def whiten(*, timeseries, timeseries_out, chunk_size=30000 * 10, num_processes=os.cpu_count()): """ Whiten a multi-channel timeseries Parameters ---------- timeseries : INPUT MxN raw timeseries array (M = #channels, N = #timepoints) timeseries_out : OUTPUT Whitened output (MxN array) """ X = mdaio.DiskReadMda(timeseries) M = X.N1() # Number of channels N = X.N2() # Number of timepoints num_chunks_for_computing_cov_matrix = 10 num_chunks = int(np.ceil(N / chunk_size)) print('Chunk size: {}, Num chunks: {}, Num processes: {}'.format( chunk_size, num_chunks, num_processes)) opts = { "timeseries": timeseries, "timeseries_out": timeseries_out, "chunk_size": chunk_size, "num_processes": num_processes, "num_chunks": num_chunks } global g_opts g_opts = opts pool = multiprocessing.Pool(processes=num_processes) step = int( np.maximum(1, np.floor(num_chunks / num_chunks_for_computing_cov_matrix))) AAt_matrices = pool.map(compute_AAt_matrix_for_chunk, range(0, num_chunks, step), chunksize=1) AAt = np.zeros((M, M), dtype='float64') for M0 in AAt_matrices: AAt += M0 / ( len(AAt_matrices) * chunk_size ) ##important: need to fix the denominator here to account for possible smaller chunk U, S, Ut = np.linalg.svd(AAt, full_matrices=True) W = (U @ np.diag(1 / np.sqrt(S))) @ Ut #print ('Whitening matrix:') #print (W) global g_shared_data g_shared_data = SharedChunkInfo(num_chunks) mdaio.writemda32(np.zeros([M, 0]), timeseries_out) pool = multiprocessing.Pool(processes=num_processes) pool.starmap(whiten_chunk, [(num, W) for num in range(0, num_chunks)], chunksize=1) return True
def sort(*, timeseries, geom='', firings_out, adjacency_radius, detect_sign, detect_interval=10, detect_threshold=3, clip_size=50, num_workers=multiprocessing.cpu_count()): """ MountainSort spike sorting (version 4) Parameters ---------- timeseries : INPUT MxN raw timeseries array (M = #channels, N = #timepoints) geom : INPUT Optional geometry file (.csv format) firings_out : OUTPUT Firings array channels/times/labels (3xL, L = num. events) adjacency_radius : float Radius of local sorting neighborhood, corresponding to the geometry file (same units). 0 means each channel is sorted independently. -1 means all channels are included in every neighborhood. detect_sign : int Use 1, -1, or 0 to detect positive peaks, negative peaks, or both, respectively detect_threshold : float Threshold for event detection, corresponding to the input file. So if the input file is normalized to have noise standard deviation 1 (e.g., whitened), then this is in units of std. deviations away from the mean. detect_interval : int The minimum number of timepoints between adjacent spikes detected in the same channel neighborhood. clip_size : int Size of extracted clips or snippets, used throughout num_workers : int Number of simultaneous workers (or processes). The default is multiprocessing.cpu_count(). """ tempdir = os.environ.get('ML_PROCESSOR_TEMPDIR') if not tempdir: print( 'Warning: environment variable ML_PROCESSOR_TEMPDIR not set. Using current directory.' ) tempdir = '.' print('Using tempdir={}'.format(tempdir)) os.environ['OMP_NUM_THREADS'] = '1' # Read the header of the timeseries input to get the num. channels and num. timepoints X = mdaio.DiskReadMda(timeseries) M = X.N1() # Number of channels N = X.N2() # Number of timepoints # Read the geometry file if geom: Geom = np.genfromtxt(geom, delimiter=',') else: Geom = np.zeros((M, 2)) if Geom.shape[0] != M: raise Exception( 'Incompatible dimensions between geom and timeseries: {} != {}'. format(Geom.shape[1], M)) MS4 = ms4alg.MountainSort4() MS4.setGeom(Geom) MS4.setSortingOpts(clip_size=clip_size, adjacency_radius=adjacency_radius, detect_sign=detect_sign, detect_interval=detect_interval, detect_threshold=detect_threshold) MS4.setNumWorkers(num_workers) MS4.setTimeseriesPath(timeseries) MS4.setFiringsOutPath(firings_out) MS4.setTemporaryDirectory(tempdir) MS4.sort() return True