def test_mask_out_artifacts(): # Create noisy array samplerate = int(48e3) duration = 30 # seconds n_samples = samplerate*duration noise_amplitude = 5 noise = noise_amplitude*np.random.normal(0,1,n_samples) standard_dev = np.std(noise) # add three artefacts n_artifacts = 3 artifacts = np.zeros_like(noise) artifact_duration = int(0.2*samplerate) # samples artifact_signal = np.zeros((n_artifacts, artifact_duration)) for i in np.arange(n_artifacts): artifact_signal[i, :] = noise_amplitude*np.random.normal(0,6,artifact_duration) artifact_indices = np.tile(np.arange(artifact_duration), (3,1)) artifact_shift = np.array([int(n_samples*0.10), int(n_samples*0.20), int(n_samples*0.70)]) artifact_indices += artifact_shift.reshape((-1,1)) for i, indices in enumerate(artifact_indices): artifacts[indices] = artifact_signal[i,:] signal = noise + artifacts timeseries = 'test_mask.mda' timeseries_out = 'masked.mda' # write as mda mdaio.writemda32(signal.reshape((1,-1)), timeseries) # run the mask artefacts mask_out_artifacts(timeseries=timeseries, timeseries_out=timeseries_out, threshold=6, chunk_size=2000, num_write_chunks=150) # check that they are gone read_data = mdaio.readmda(timeseries).reshape((-1,1)) masked_data = mdaio.readmda(timeseries_out).reshape((-1,1)) indices_masked = sum(masked_data[artifact_indices,0].flatten() == 0) total_indices_to_mask = len(artifact_indices.flatten()) masked = indices_masked == total_indices_to_mask os.remove(timeseries) os.remove(timeseries_out) if masked: print('Artifacts 100% masked') return True else: print('Artifacts %.2f%% masked' % (100*(indices_masked/total_indices_to_mask))) return False
def compare_ground_truth(*, firings, firings_true, json_out, opts={}): Ft = mdaio.readmda(mlp.realizeFile(firings_true)) F = mdaio.readmda(mlp.realizeFile(firings)) times1 = Ft[1, :] labels1 = Ft[2, :] times2 = F[1, :] labels2 = F[2, :] out = compare_ground_truth_helper(times1, labels1, times2, labels2) with open(json_out, 'w') as outfile: json.dump(out, outfile, indent=4)
def test_convert_array(dtype='int32', shape=[12, 3, 7]): X = np.array(np.random.normal(0, 1, shape), dtype=dtype) np.save('test_convert1.npy', X) convert_array(input='test_convert1.npy', output='test_convert2.mda') # npy -> mda convert_array(input='test_convert2.mda', output='test_convert3.npy') # mda -> npy convert_array(input='test_convert3.npy', output='test_convert4.dat') # npy -> dat convert_array(input='test_convert4.dat', output='test_convert5.npy', dtype=dtype, dimensions=','.join(str(entry) for entry in X.shape)) # dat -> npy convert_array(input='test_convert5.npy', output='test_convert6.mda') # npy -> mda convert_array(input='test_convert6.mda', output='test_convert7.dat') # mda -> dat convert_array(input='test_convert7.dat', output='test_convert8.mda', dtype=dtype, dimensions=','.join(str(entry) for entry in X.shape)) # dat -> mda Y = mdaio.readmda('test_convert8.mda') print(np.max(np.abs(X - Y)), Y.dtype)
def convert_firings(*, firings, params, res_fname, clu_fname, tetrode): """ Export firings (either .mda or .npy) to .res & .clu format. Parameters ---------- firings : INPUT Path of firings mda (or npy) file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels. params : INPUT params_json including channel map for tetrode. res_fname : OUTPUT Path of output spike times .res file. clu_fname : OUTPUT Path of output spike ids .clu file. tetrode : int Which tetrode to write out firings for. """ F = mdaio.readmda(firings) L = F.shape[1] L = L whch = F[0, :].ravel()[:] times = F[1, :].ravel()[:] labels = F[2, :].ravel().astype(int)[:] K = np.max(labels) with open(params, 'r') as fp: params_obj = json.load(fp) tetmap = [] pre = 1 for i, x in enumerate(params_obj['tetrodes']): tetmap.append(list(range(pre, pre + len(x)))) pre += len(x) which_tet = np.array([np.where(w == tetmap)[0][0] + 1 for w in whch]) inds_k = np.where(which_tet == int(tetrode))[0] ###### Write .res file ###### res = times[inds_k] np.savetxt(res_fname, res, fmt='%d') labels_tet = labels[inds_k] u_labels_tet = np.unique(labels_tet) K = len(u_labels_tet) u_new_labels = np.argsort(u_labels_tet).argsort() + 2 new_labels = [ u_new_labels[(l == np.array(u_labels_tet)).argmax()] for l in labels_tet ] np.savetxt(clu_fname, np.concatenate((np.array([K + 1]), new_labels)), fmt='%d') return True
def validate_sorting_results(*,dataset_dir,sorting_output_dir,output_dir): if not os.path.exists(output_dir): os.mkdir(output_dir) compare_ground_truth( firings=sorting_output_dir+'/firings.mda', firings_true=dataset_dir+'/firings_true.mda', json_out=output_dir+'/compare_ground_truth.json', ) compute_templates( timeseries=dataset_dir+'/raw.mda', firings=dataset_dir+'/firings_true.mda', templates_out=output_dir+'/templates_true.mda.prv' ) mlp.runPipeline() templates_true=mdaio.readmda(mlp.realizeFile(output_dir+'/templates_true.mda')) amplitudes_true=np.max(np.max(np.abs(templates_true),axis=1),axis=0) accuracies=get_accuracies(output_dir+'/compare_ground_truth.json') return dict( accuracies=accuracies, amplitudes_true=amplitudes_true )
def compute_templates_helper(*, timeseries, firings, clip_size=100): X = mdaio.DiskReadMda(timeseries) M, N = X.N1(), X.N2() F = mdaio.readmda(firings) L = F.shape[1] L = L T = clip_size Tmid = int(np.floor((T + 1) / 2) - 1) times = F[1, :].ravel() labels = F[2, :].ravel().astype(int) K = np.max(labels) sums = np.zeros((M, T, K), dtype='float64') counts = np.zeros(K) for k in range(1, K + 1): inds_k = np.where(labels == k)[0] #TODO: subsample for ind_k in inds_k: t0 = int(times[ind_k]) if (clip_size <= t0) and (t0 < N - clip_size): clip0 = X.readChunk(i1=0, N1=M, i2=t0 - Tmid, N2=T) sums[:, :, k - 1] += clip0 counts[k - 1] += 1 templates = np.zeros((M, T, K)) for k in range(K): templates[:, :, k] = sums[:, :, k] / counts[k] return templates
def load_marks(marks_path, spikes_df): marks = mdaio.readmda(marks_path) if marks.shape[-1] == spikes_df.shape[0]: ch_cols = [f'ch{c:>02d}' for c in np.arange(marks.shape[0])] marks_df = pd.DataFrame(marks[:,0,:].squeeze().T, columns=ch_cols, \ index=spikes_df.index) return marks_df return
def test_synthesize_random_firings(): K = 10 synthesize_random_firings(K=K, firings_out='tmp.firings.mda') firings = mdaio.readmda('tmp.firings.mda') labels = firings[2, :] assert (max(labels) == K) assert (firings.shape[0] == 3) return True
def __init__(self, firings_fname): OutputExtractor.__init__(self) print('Downloading file if needed: ' + firings_fname) self._firings_path = mlp.realizeFile(firings_fname) print('Done.') self._firings = mdaio.readmda(self._firings_path) self._times = self._firings[1, :] self._labels = self._firings[2, :] self._num_units = np.max(self._labels)
def apply_label_map(*, firings, label_map, firings_out): """ Apply a label map to a given firings, including masking and merging Parameters ---------- firings : INPUT Path of input firings mda file label_map : INPUT Path of input label map mda file [base 1, mapping to zero removes from firings] firings_out : OUTPUT ... """ firings = mdaio.readmda(firings) label_map = mdaio.readmda(label_map) label_map = np.reshape(label_map, (-1, 2)) label_map = label_map[np.argsort(label_map[:, 0])] # Assure input is sorted #Propagate merge pairs to lowest label number for idx, label in enumerate(label_map[:, 1]): # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :) #label_map[np.isin(label_map[:,0],label),0] = label_map[idx,0] # Input should be sorted label_map[np.where(label_map[:, 0] == label)[0], 0] = label_map[idx, 0] # Input should be sorted #Apply label map for label_pair in range(label_map.shape[0]): # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :) #firings[2, np.isin(firings[2, :], label_map[label_pair, 1])] = label_map[label_pair,0] firings[2, np.where( firings[2, :] == label_map[label_pair, 1])[0]] = label_map[label_pair, 0] #Mask out all labels mapped to zero firings = firings[:, firings[2, :] != 0] #Write remapped firings return mdaio.writemda64(firings, firings_out)
def load_clips(clips_path, spikes_df): # load clips clips = mdaio.readmda(clips_path) if clips.shape[-1] == spikes_df.shape[0]: clips_2d_list = [ clips[int(primary_chan - 1), :, i] for i, primary_chan in enumerate(spikes_df.primary_chan.values) ] clips_2d = np.array(clips_2d_list) cl_cols = [f'{c:>03d}' for c in np.arange(clips_2d.shape[1])] clips_df = pd.DataFrame(clips_2d, columns=cl_cols, \ index=spikes_df.index) return clips_df return
def initialize(self): print('Downloading timeseries (if needed): {}'.format( self._timeseries)) if self._timeseries is not None: timeseries_path = mlp.realizeFile(self._timeseries) print('Downloading firings (if needed): {}'.format(self._firings)) firings_path = mlp.realizeFile(self._firings) if self._geom is not None: print('Downloading geom (if needed): {}'.format(self._geom)) geom_path = mlp.realizeFile(self._geom) self._G = np.genfromtxt(geom_path, delimiter=',').T self._G = np.flip(self._G, axis=0) else: self._G = None print('Reading arrays into memory...') if self._timeseries is not None: self._X = mdaio.readmda(timeseries_path) else: self._X = None self._F = mdaio.readmda(firings_path) self._times = self._F[1, :] self._labels = self._F[2, :] self._K = int(self._labels.max())
def get_result_from_folder(output_folder): # overwrite the SorterBase.get_result from mountainlab_pytools import mdaio result_fname = Path(output_folder) / 'firings.mda' assert result_fname.exists(), 'Result file does not exist: {}'.format( str(result_fname)) firings = mdaio.readmda(str(result_fname)) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def test_compute_templates(): M, N, K, T, L = 5, 1000, 6, 50, 100 X = np.random.rand(M, N) mdaio.writemda32(X, 'tmp.mda') F = np.zeros((3, L)) F[1, :] = 1 + np.random.randint(N, size=(1, L)) F[2, :] = 1 + np.random.randint(K, size=(1, L)) mdaio.writemda64(F, 'tmp2.mda') ret = compute_templates(timeseries='tmp.mda', firings='tmp2.mda', templates_out='tmp3.mda', clip_size=T) assert (ret) templates0 = mdaio.readmda('tmp3.mda') assert (templates0.shape == (M, T, K)) return True
def make_toes(fp): d = mdaio.readmda(os.path.join(fp, 'mountainout/firings.curated.mda')) alignfile = glob.glob(os.path.join(fp, '*.align')) align = pd.read_csv(alignfile[0]) clusters = np.unique(d[2]) Fs = 30000 for cluster in clusters: spikes = d[1, d[2] == cluster] channel = channel_map(d[0, d[2] == cluster][0]) starts = np.asarray(align.total_start, dtype=int) #starts = np.asarray(align.total_pulse, dtype = int) stops = np.asarray(align.total_stop, dtype=int) #stops = np.asarray(align.total_pulse+(1.5*Fs), dtype = int) offsets = np.asarray(align.total_pulse, dtype=int) indices = [(spikes >= start) & (spikes <= stop) for start, stop in zip(starts, stops)] spiketrains = [ spikes[ind] - offset for ind, offset in zip(indices, offsets) ] dirname = os.path.join(fp, "ch%d_c%d" % (channel, cluster)) if not os.path.exists(dirname): os.makedirs(dirname) for stim in align['stim'].unique(): f = [spiketrains[rec] for rec in align[align['stim'] == stim].rec] #put together an informative name for the file song = list(align.song[align.stim == stim].unique())[0] cond = list(align.condition[align.stim == stim].unique())[0] name = "%s_%s" % (song, cond) toefile = os.path.join(dirname, "%s.toe_lis" % name) with open(toefile, "wt") as ftl: tl.write(ftl, np.asarray(f) / Fs * 1000) # for song in align['song'].unique(): # f = [spiketrains[rec] for rec in align[align['song']==song][align['condition']=='no-scene'].rec] # name = song+'_no-scene' # toefile = os.path.join(dirname, "%s.toe_lis" % name) # with open(toefile, "wt") as ftl: # tl.write(ftl, np.asarray(f)/Fs*1000) # f = [spiketrains[rec] for rec in align[align['song']==song][align['condition'].str.contains('scene63')].rec] # name = song+'_scene63' # toefile = os.path.join(dirname, "%s.toe_lis" % name) # with open(toefile, "wt") as ftl: # tl.write(ftl, np.asarray(f)/Fs*1000) auditory_plot(dirname) cluster_info(dirname, fp, cluster)
def compute_templates_helper(*, timeseries, firings, clip_size=100): X = DiskReadMda(timeseries) M, N = X.N1(), X.N2() N = N F = mdaio.readmda(firings) L = F.shape[1] L = L T = clip_size times = F[1, :] labels = F[2, :].astype(int) K = np.max(labels) compute_templates._sums = np.zeros((M, T, K)) compute_templates._counts = np.zeros(K) def _kernel(chunk, info): inds = np.where((info.t1 <= times) & (times <= info.t2))[0] times0 = (times[inds] - info.t1 + info.t1a).astype(np.int32) labels0 = labels[inds] clips0 = np.zeros((M, clip_size, len(inds)), dtype=np.float32, order='F') cpp.extract_clips(clips0, chunk, times0, clip_size) for k in range(1, K + 1): inds_kk = np.where(labels0 == k)[0] compute_templates._sums[:, :, k - 1] = compute_templates._sums[:, :, k - 1] + np.sum( clips0[:, :, inds_kk], axis=2) compute_templates._counts[ k - 1] = compute_templates._counts[k - 1] + len(inds_kk) return True TCR = TimeseriesChunkReader(chunk_size_mb=40, overlap_size=clip_size * 2) if not TCR.run(timeseries, _kernel): return None templates = np.zeros((M, T, K)) for k in range(1, K + 1): if compute_templates._counts[k - 1]: templates[:, :, k - 1] = compute_templates._sums[:, :, k - 1] / compute_templates._counts[ k - 1] return templates
def loadSpikeTimestamps(data_file=None): """ Load spike timestamps, extracted from the rec file with exportmda returns: Timestamps for spike data """ if data_file is None: ts_file = QtHelperUtils.get_open_file_name(\ message="Select Timestamp MDA Files", file_format="LFP Data (.dat)") # Now that we have the timestamp file, we have to read it try: timestamps = mdaio.readmda(ts_file) except (FileNotFoundError, IOError) as err: print(err) print('Unable to read TIMESTAMPS file.') timestamps = None finally: return timestamps
def load_firings(firings_path, samples_offset, ep_samp_start_offset, animal, date, ntrode, config_path, fs=30000): # load firings firings = mdaio.readmda(firings_path) #create spikes_df with firing times with inter-epoch gap spike_cols = ['animal', 'day', 'epoch', 'ntrode', 'cluster', \ 'timedelta', 'sampleindex'] spikes_df = pd.DataFrame(columns=spike_cols) spikes_sampleindex = samples_offset[firings[1, :].astype(int)] spikes_timedelta = pd.TimedeltaIndex(spikes_sampleindex / fs, unit='s', name='time') spikes_df['sampleindex'] = spikes_sampleindex spikes_df['timedelta'] = spikes_timedelta spikes_df['cluster'] = firings[2, :].astype(int) spikes_df['primary_chan'] = firings[0, :].astype(int) #add animal, day, epoch cols: spikes_df['animal'] = animal spikes_df['day'] = core.convert_dates_to_days(animal, date, config_path) spikes_df['ntrode'] = ntrode spikes_df['epoch'] = 0 ep_samp_start_timedelta = pd.TimedeltaIndex([ep_s/fs for ep_s in \ ep_samp_start_offset], \ unit='s', name='time') for epn, ep_samp_td in enumerate(ep_samp_start_timedelta): spikes_df.loc[spikes_df['timedelta'] >= \ ep_samp_td, 'epoch'] = epn+1 spikes_df.set_index(['animal', 'day', 'epoch', 'timedelta', \ 'ntrode', 'cluster'], inplace=True) return spikes_df
def ironclust(*, recording, # Recording object tmpdir, # Temporary working directory detect_sign=-1, # Polarity of the spikes, -1, 0, or 1 adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file detect_threshold=5, # Threshold for detection merge_thresh=.98, # Cluster merging threhold 0..1 freq_min=300, # Lower frequency limit for band-pass filter freq_max=6000, # Upper frequency limit for band-pass filter pc_per_chan=3, # Number of pc per channel prm_template_name, # Name of the template file ironclust_src=None ): if ironclust_src is None: ironclust_src=os.getenv('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('You must either set the IRONCLUST_SRC environment variable, or pass the ironclust_src parameter') source_dir=os.path.dirname(os.path.realpath(__file__)) dataset_dir=tmpdir+'/ironclust_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json si.MdaRecordingExtractor.writeRecording(recording=recording,save_path=dataset_dir) samplerate=recording.getSamplingFrequency() print('Reading timeseries header...') HH=mdaio.readmda_header(dataset_dir+'/raw.mda') num_channels=HH.dims[0] num_timepoints=HH.dims[1] duration_minutes=num_timepoints/samplerate/60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes)) print('Creating .params file...') txt='' txt+='samplerate={}\n'.format(samplerate) txt+='detect_sign={}\n'.format(detect_sign) txt+='adjacency_radius={}\n'.format(adjacency_radius) txt+='detect_threshold={}\n'.format(detect_threshold) txt+='merge_thresh={}\n'.format(merge_thresh) txt+='freq_min={}\n'.format(freq_min) txt+='freq_max={}\n'.format(freq_max) txt+='pc_per_chan={}\n'.format(pc_per_chan) txt+='prm_template_name={}\n'.format(prm_template_name) _write_text_file(dataset_dir+'/argfile.txt',txt) print('Running IronClust...') cmd_path = "addpath('{}', '{}/matlab', '{}/mdaio');".format(ironclust_src, ironclust_src, ironclust_src) #"p_ironclust('$(tempdir)','$timeseries$','$geom$','$prm$','$firings_true$','$firings_out$','$(argfile)');" cmd_call = "p_ironclust('{}', '{}', '{}', '', '', '{}', '{}');"\ .format(tmpdir, dataset_dir+'/raw.mda', dataset_dir+'/geom.csv', tmpdir+'/firings.mda', dataset_dir+'/argfile.txt') cmd='matlab -nosplash -nodisplay -r "{} {} quit;"'.format(cmd_path, cmd_call) print(cmd) retcode=_run_command_and_print_output(cmd) if retcode != 0: raise Exception('IronClust returned a non-zero exit code') # parse output result_fname=tmpdir+'/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: '+ result_fname) firings=mdaio.readmda(result_fname) sorting=si.NumpySortingExtractor() sorting.setTimesLabels(firings[1,:],firings[2,:]) return sorting
def read_data(self, dataset_id, file_id): return readmda(self.directories[dataset_id][file_id])
def export_mountain_cells(): print('expecting call from channel level') init_path = os.getcwd() current_path = os.getcwd() path_split = str.split(current_path,'/') if path_split[-2] == 'mountains': channel_path = path_split[-1] os.chdir('output') else: os.chdir('../../../mountains/') os.chdir(path_split[-1]) channel_path = path_split[-1] os.chdir('output') original = mdaio.readmda('firings.mda') os.chdir('..') start_time = {} with open('start_indices.csv') as csvfile: csvreader = csv.reader(csvfile,delimiter = ',') linecount = 0 for line in csvreader: if linecount == 0: temp = [] for term in line[1:]: temp.append(int(term.strip(" ']["))) else: temp = [] for term in line[1:]: temp.append(term.strip(" '][")) start_time[line[0]] = temp linecount += 1 current_path = os.getcwd() ch = current_path.partition('channel')[1] + current_path.partition('channel')[2] total_ind = len(original[1]) marker = 0 count = 0 for i in range(total_ind): if count == len(start_time.get('2'))-1: exporting_arr = [original[1][marker:total_ind],original[2][marker:total_ind]] location = os.getcwd() temp = location.split('/') upper_folder1 = location[0:len(location)-len(temp[len(temp)-1])-1] os.chdir(upper_folder1) temp1 = upper_folder1.split('/') upper_folder2 = upper_folder1[0:len(upper_folder1)-len(temp1[len(temp1)-1])-1] os.chdir(upper_folder2) session_path = upper_folder2 + '/' + start_time.get('2')[count] os.chdir(session_path) full_list = [] for name in os.listdir("."): if os.path.isdir(name): full_list.append(name) for folder in full_list: array_path = session_path + '/' + folder os.chdir(array_path) if os.path.isdir(channel_path): os.chdir(array_path) split_into_cells_intra_session(ch, exporting_arr, start_time.get('1')[count]); os.chdir(location) break if original[1][i] >= start_time.get('1')[count+1]: exporting_arr = [original[1][marker:i-1],original[2][marker:i-1]] location = os.getcwd() temp = location.split('/') upper_folder1 = location[0:len(location)-len(temp[len(temp)-1])-1] os.chdir(upper_folder1) temp1 = upper_folder1.split('/') upper_folder2 = upper_folder1[0:len(upper_folder1)-len(temp1[len(temp1)-1])-1] os.chdir(upper_folder2) session_path = upper_folder2 + '/' + start_time.get('2')[count] os.chdir(session_path) full_list = [] for name in os.listdir("."): if os.path.isdir(name): full_list.append(name) for folder in full_list: array_path = session_path + '/' + folder os.chdir(array_path) if os.path.isdir(channel_path): os.chdir(array_path) split_into_cells_intra_session(ch, exporting_arr, start_time.get('1')[count]) marker = i count += 1 os.chdir(location) os.chdir(init_path)
def convert_array(*, input, output, format='', format_out='', dimensions='', dtype='', dtype_out='', channels=''): """ Convert a multi-dimensional array between various formats ('.mda', '.npy', '.dat') based on the file extensions of the input/output files Parameters ---------- input : INPUT Path of input array file (can be repeated for concatenation). output : OUTPUT Path of the output array file. format : string The format for the input array (mda, npy, dat), or determined from the file extension if empty format_out : string The format for the output input array (mda, npy, dat), or determined from the file extension if empty dimensions : string Comma-separated list of dimensions (shape). If empty, it is auto-determined, if possible, by the input array. If second dim is -1 then it will be extrapolated from file size / first dim. dtype : string The data format for the input array. Choices: int8, int16, int32, uint16, uint32, float32, float64 (possibly float16 in the future). dtype_out : string The data format for the output array. If empty, the dtype for the input array is used. channels : string Comma-seperated list of channels to keep in output. Zero-based indexing. Only works for .dat to .mda conversions. """ if isinstance(input, (list, )): multifile = True inputs = input input = inputs[0] else: multifile = False format_in = format if not format_in: format_in = determine_file_format(file_extension(input), dimensions) if not format_out: format_out = determine_file_format(file_extension(output), dimensions) print('Input/output formats: {}/{}'.format(format_in, format_out)) ext_in = file_extension(input) dims = None if (format_in == 'mda') and (dtype == ''): header = mdaio.readmda_header(input) dtype = header.dt dims = header.dims if (format_in == 'npy') and (dtype == ''): A = np.load(input, mmap_mode='r') dtype = npy_dtype_to_string(A.dtype) dims = A.shape A = 0 if dimensions: dims2 = [int(entry) for entry in dimensions.split(',')] if dims: if len(dims) != len(dims2): raise Exception( 'Inconsistent number of dimensions for input array') if not np.all(np.array(dims) == np.array(dims2)): raise Exception('Inconsistent dimensions for input array') dims = dims2 if not dtype_out: dtype_out = dtype if not dtype: raise Exception('Unable to determine datatype for input array') if not dtype_out: raise Exception('Unable to determine datatype for output array') if (dims[1] == -1) and (dims[0] > 0): if ((dtype) and (format_in == 'dat')): bits = int( dtype[-2:] ) # number of bits per entry of dtype, TODO: make this smarter if not multifile: filebytes = os.stat(input).st_size # bytes in input file else: dims1 = np.copy(dims) filebytes1 = os.stat(input).st_size # bytes in input file entries1 = int(filebytes1 / (int(bits / 8))) dims1[1] = int(entries1 / dims1[0]) filebytes = sum([os.stat(inp).st_size for inp in inputs]) entries = int(filebytes / (int(bits / 8))) # entries in input file dims[1] = int(entries / dims[0]) # caclulated second dimension if DEBUG: print(bits) print(filebytes) print(int(filebytes / (int(bits / 8)))) print(dims) else: raise Exception('Could not infer dimensions') if not dims: raise Exception('Unable to determine dimensions for input array') if not channels: channels = range(0, dims[0]) else: channels = np.array([int(entry) for entry in channels.split(',')]) if DEBUG: print(channels) print('Using dtype={}, dtype_out={}, dimensions={}'.format( dtype, dtype_out, ','.join(str(item) for item in dims))) if (format_in == format_out) and ((dtype == dtype_out) or (dtype_out == '')): if multifile and (format_in == 'dat'): print('Concatenating Files') with open(output, "wb") as outfile: for input_file in inputs: with open(input_file, "rb") as inpt: outfile.write(inpt.read()) elif not multifile: print('Simply copying file...') shutil.copyfile(input, output) print('Done.') return True if format_out == 'dat' and not multifile: if format_in == 'mda': H = mdaio.readmda_header(input) copy_raw_file_data(input, output, start_byte=H.header_size, num_entries=np.product(dims), dtype=dtype, dtype_out=dtype_out) return True elif format_in == 'npy': print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.load(input, mmap_mode='r').astype(dtype=dtype_out, order='F', copy=False) A = A.ravel(order='F') A.tofile(output) # The following was problematic because of row-major ordering, i think #header_size=determine_npy_header_size(input) #copy_raw_file_data(input,output,start_byte=header_size,num_entries=np.product(dims),dtype=dtype,dtype_out=dtype_out) return True elif format_in == 'dat': raise Exception('This case not yet implemented.') else: raise Exception('Unexpected case.') elif (format_out == 'mda') or (format_out == 'npy'): if format_in == 'dat' and multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) print(channels) #DEBUG A = np.fromfile(inputs[0], dtype=dtype, count=np.product(dims)) A = A.reshape(tuple(dims1), order='F') A = A[channels, :] for inputn in inputs[1:]: An = np.fromfile(inputn, dtype=dtype, count=np.product(dims)) dimsN = np.copy(dims1) dimsN[1] = An.size / dims1[0] An = An.reshape(tuple(dimsN), order='F') An = An[channels, :] print(A.shape) print(An.shape) A = np.concatenate((A, An), axis=1) print(A.shape) #DEBUG if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'dat' and not multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.fromfile(input, dtype=dtype, count=np.product(dims)) A = A.reshape(tuple(dims), order='F') A = A[channels, :] if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'mda' and not multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = mdaio.readmda(input) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'mda' and multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = mdaio.readmda(inputs[0]) A = A[channels, :] for inputn in inputs[1:]: An = mdaio.readmda(inputn) An = An[channels, :] A = np.concatenate((A, An), axis=0) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'npy' and not multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.load(input, mmap_mode='r').astype(dtype=dtype, order='F', copy=False) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'npy' and multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.load(inputs[0], mmap_mode='r').astype(dtype=dtype, order='F', copy=False) A = A[channels, :] for inputn in inputs[1:]: An = np.load(inputn, mmap_mode='r').astype(dtype=dtype, order='F', copy=False) An = An[channels, :] A = np.concatenate((A, An), axis=0) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True else: raise Exception('Unexpected case.') else: raise Exception('Unexpected output format: {}'.format(format_out)) raise Exception('Unexpected error.')
def separateSpikesInEpochs(data_dir=None, firings_file='firings.curated.mda', timestamp_files=None, write_separated_spikes=True): """ Takes curated spikes from MountainSort and combines this information with spike timestamps to create separate curated spikes for each epoch :firings_file: Curated firings file :timestamp_files: Spike timestamps file list :write_separated_spikes: If the separated spikes should be written back to the data directory. :returns: List of spikes for each epoch """ if data_dir is None: # Get the firings file data_dir = QtHelperUtils.get_directory( message="Select Curated firings location") separated_tetrodes = [] curated_firings = [] merged_curated_firings = [] tetrode_list = os.listdir(data_dir) for tt_dir in tetrode_list: try: if firings_file in os.listdir(data_dir + '/' + tt_dir): curated_firings.append([]) separated_tetrodes.append(tt_dir) firings_file_location = '/'.join( [data_dir, tt_dir, firings_file]) merged_curated_firings.append( mdaio.readmda(firings_file_location)) print(MODULE_IDENTIFIER + 'Read merged firings file for tetrode %s!' % tt_dir) else: print(MODULE_IDENTIFIER + 'Merged firings %s not found for tetrode %s!' % (firings_file, tt_dir)) except (FileNotFoundError, IOError) as err: print(MODULE_IDENTIFIER + 'Unable to read merged firings file for tetrode %s!' % tt_dir) print(err) gui_root = Tk() gui_root.withdraw() timestamp_headers = [] if timestamp_files is None: # Read all the timestamp files timestamp_files = filedialog.askopenfilenames(initialdir=DEFAULT_SEARCH_PATH, \ title="Select all timestamp files", \ filetypes=(("Timestamps", ".mda"), ("All Files", "*.*"))) gui_root.destroy() for ts_file in timestamp_files: timestamp_headers.append(mdaio.readmda_header(ts_file)) # Now that we have both the timestamp headers and the timestamp files, we # can separate spikes out. It is important here for the timestamp files to # be in the same order as the curated firings as that is the only way for # us to tell that the firings are being split up correctly. print(MODULE_IDENTIFIER + 'Looking at spike timestamps in order') print(timestamp_files) # First splice up curated spikes into indiviual epochs for tt_idx, tt_firings in enumerate(merged_curated_firings): for ep_idx, ts_header in enumerate(timestamp_headers): if tt_firings is None: curated_firings[tt_idx].append(None) continue n_data_points = ts_header.dims[0] print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx) + ': ' + str(n_data_points) + ' samples.') last_spike_from_epoch = np.searchsorted( tt_firings[1], n_data_points, side='left') - 1 # If there are no spikes in this epoch, there might still be some in future epochs! if last_spike_from_epoch < 0: tt_firings[1] = tt_firings[1] - float(n_data_points) curated_firings[tt_idx].append(None) continue last_spike_sample_number = tt_firings[1][last_spike_from_epoch] print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': First spike ' + str(tt_firings[1][0])\ + ', Last spike ' + str(last_spike_sample_number)) epoch_spikes = tt_firings[:, :last_spike_from_epoch] curated_firings[tt_idx].append(epoch_spikes) if last_spike_from_epoch < (len(tt_firings[1]) - 1): # Slice the merged curated firing to only have the remaining spikes tt_firings = tt_firings[:, last_spike_from_epoch + 1:] print(MODULE_IDENTIFIER + 'Trimming curated spikes. Aggregate sample start ' + str(tt_firings[1][0])) tt_firings[1] = tt_firings[1] - float(n_data_points) print(MODULE_IDENTIFIER + 'Sample number trimmed to ' + str(tt_firings[1][0])) else: print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ', Reached end of curated firings at Epoch ' + str(ep_idx)) tt_firings = None print(MODULE_IDENTIFIER + 'Spikes separated in epochs. Substituting timestamps!') # For each epoch replace the sample numbers with the corresponding # timestamps. We are going through multiple revisions for this so that we # only have to load one timestamp file at a time for ep_idx, ts_file in enumerate(timestamp_files): epoch_timestamps = mdaio.readmda(ts_file) print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx)) for tt_idx, tt_curated_firings in enumerate(curated_firings): if tt_curated_firings[ep_idx] is None: continue # Need to use the original array because changes to tt_curated_firings do not get copied back curated_firings[tt_idx][ep_idx][1] = epoch_timestamps[np.array( tt_curated_firings[ep_idx][1], dtype=int)] print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': Samples (' + \ str(tt_curated_firings[ep_idx][1][0]) + ', ' + str(tt_curated_firings[ep_idx][1][-1]), ')') if write_separated_spikes: try: for tt_idx, tet in enumerate(separated_tetrodes): for ep_idx in range(len(timestamp_files)): if curated_firings[tt_idx][ep_idx] is not None: ep_firings_file_name = data_dir + '/' + tet + '/firings-' + \ str(ep_idx+1) + '.curated.mda' mdaio.writemda64(curated_firings[tt_idx][ep_idx], ep_firings_file_name) except OSError as exception: if exception.errno != errno.EEXIST: print(MODULE_IDENTIFIER + 'Unable to write timestamped firings!') print(exception) return curated_firings
def compare_ground_truth(*, firings_true, firings, json_out, max_dt=20): """ compare a sorting (firings) with ground truth (firings_true) Parameters ---------- firings_true : INPUT Path of true firings file (RxL) R>=3, L=#evts firings : INPUT Path of sorted firings file (RxL) R>=3, L=#evts json_out : OUTPUT Path of the output file containing the results of the comparison max_dt : int Tolerance for matching events (in timepoints) """ print('Reading arrays...') F = mdaio.readmda(firings) Ft = mdaio.readmda(firings_true) print('Initializing data...') L = F.shape[1] Lt = Ft.shape[1] times = F[1, :] times_true = Ft[1, :] labels = F[2, :].astype(np.int32) labels_true = Ft[2, :].astype(np.int32) F = 0 # free memory? Alex: depends if python decides to call garbage collection Ft = 0 # free memory? N = np.maximum(np.max(times), np.max(times_true)) K = np.max(labels) Kt = np.max(labels_true) # todo: subsample in first pass to get the best # First we split into segments print('Splitting into segments') segment_size = max_dt num_segments = int(np.ceil(N / segment_size)) N = num_segments * segment_size segments = [] # occupancy: K x num_segments occupancy, counts = create_occupancy_array(times, labels, segment_size, num_segments, K, spread=False, multiplicity=True) # occupancy_true: Kt x num_segments occupancy_true, counts_true = create_occupancy_array(times_true, labels_true, segment_size, num_segments, Kt, spread=True, multiplicity=False) # Note: we spread the occupancy_true but not the occupancy # Note: we count the occupancy with multiplicity but not occupancy_true print('Computing pairwise counts and accuracies...') pairwise_counts = occupancy_true @ occupancy.transpose() # Kt x K pairwise_accuracies = np.zeros((Kt, K)) for k1 in range(1, Kt + 1): for k2 in range(1, K + 1): # jfm fixed +1 bug on 8/29/18 numer = pairwise_counts[k1 - 1, k2 - 1] denom = counts_true[k1 - 1] + counts[k2 - 1] - numer if denom > 0: pairwise_accuracies[k1 - 1, k2 - 1] = numer / denom print('Preparing output...') ret = {"true_units": {}} for k1 in range(1, Kt + 1): k2_match = int(1 + np.argmax(pairwise_accuracies[k1 - 1, :].ravel())) # todo: compute accuracy more precisely here num_matches = int(pairwise_counts[k1 - 1, k2_match - 1]) num_false_positives = int(counts[k2_match - 1] - num_matches) num_false_negatives = int(counts_true[k1 - 1] - num_matches) unit = { "best_match": k2_match, "accuracy": pairwise_accuracies[k1 - 1, k2_match - 1], "num_matches": num_matches, "num_false_positives": num_false_positives, "num_false_negatives": num_false_negatives } ret['true_units'][k1] = unit print('Writing output...') str = json.dumps(ret, indent=4) with open(json_out, 'w') as out: out.write(str) print('Done.') return True
from mountainlab_pytools import mdaio import sys from scipy.interpolate import interp1d if len(sys.argv)<2: print('Not enough input. I need the tetrode number #t (folder tet+#t where to find the files firings.mda and raw_filt.mda') exit() tet= sys.argv[1] fold='tet'+tet fpar = open(fold + '/' + 'tet' + tet + '.par.' + tet) nchan = int(fpar.readline().split()[1]) print('Number of channels:',nchan) a = mdaio.readmda(fold+'/raw_filt.mda') b = mdaio.readmda(fold+'/firings.mda') nC, L = a.shape # number of channels, length of recording res = b[1,:].astype(int) clu = b[2,:].astype(int) print(res[-1]) c = np.zeros([len(res),nchan,32]) # was 32 ce = np.zeros([len(res),nchan,33]) for ir, r in enumerate(res): if r > 15 and r < L-18: c[ir,:,:] = a[:,r-15:r+17] # was 17
def __init__(self, path): self._timeseries_path = path self._timeseries = mdaio.readmda(path)
def compute_cluster_metrics(*,timeseries='',firings,metrics_out,clip_size=100,samplerate=0, refrac_msec=1): """ Compute cluster metrics for a spike sorting output Parameters ---------- firings : INPUT Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer cluster labels. timeseries : INPUT Optional path of timeseries mda file (MxN) which could be raw or preprocessed metrics_out : OUTPUT Path of output json file containing the metrics. clip_size : int (Optional) clip size, aka snippet size (used when computing the templates, or average waveforms) samplerate : float Optional sample rate in Hz refrac_msec : float (Optional) Define interval (in ms) as refractory period. If 0 then don't compute the refractory period metric. """ print('Reading firings...') F=mdaio.readmda(firings) print('Initializing...') R=F.shape[0] L=F.shape[1] assert(R>=3) times=F[1,:] labels=F[2,:].astype(np.int) K=np.max(labels) N=0 if timeseries: X=mdaio.DiskReadMda(timeseries) N=X.N2() if (samplerate>0) and (N>0): duration_sec=N/samplerate else: duration_sec=0 clusters=[] for k in range(1,K+1): inds_k=np.where(labels==k)[0] metrics_k={ "num_events":len(inds_k) } if duration_sec: metrics_k['firing_rate']=len(inds_k)/duration_sec cluster={ "label":k, "metrics":metrics_k } clusters.append(cluster) if timeseries: print('Computing templates...') templates=compute_templates_helper(timeseries=timeseries,firings=firings,clip_size=clip_size) for k in range(1,K+1): template_k=templates[:,:,k-1] # subtract mean on each channel (todo: vectorize this operation) # Alex: like this? # template_k[m,:] -= np.mean(template_k[m,:]) for m in range(templates.shape[0]): template_k[m,:]=template_k[m,:]-np.mean(template_k[m,:]) peak_amplitude=np.max(np.abs(template_k)) clusters[k-1]['metrics']['peak_amplitude']=peak_amplitude ## todo: subtract template means, compute peak amplitudes if refrac_msec > 0: print('Computing Refractory Period Violations') msec_string = '{0:.5f}'.format(refrac_msec).rstrip('0').rstrip('.') rr_string = 'refractory_violation_{}msec'.format(msec_string) max_dt_msec=50 bin_size_msec=refrac_msec auto_cors = compute_cross_correlograms_helper(firings=firings,mode='autocorrelograms',samplerate=samplerate,max_dt_msec=max_dt_msec,bin_size_msec=bin_size_msec) mid_ind = np.floor(max_dt_msec/bin_size_msec).astype(int) mid_inds = np.arange(mid_ind-2,mid_ind+2) for k0,auto_cor_obj in enumerate(auto_cors['correlograms']): auto = auto_cor_obj['bin_counts'] k = auto_cor_obj['k'] peak = np.percentile(auto, 75) mid = np.mean(auto[mid_inds]) rr = safe_div(mid,peak) clusters[k-1]['metrics'][rr_string] = rr ret={ "clusters":clusters } print('Writing output...') str=json.dumps(ret,indent=4) with open(metrics_out, 'w') as out: out.write(str) print('Done.') return True
bandpass_filter.version = '0.1' if __name__ == "__main__": samplerate = int(3e4) freq_min = 250 freq_max = 6000 data_dir = '../ephys_preprocessing/' raw_data_ch1 = np.asarray( sio.loadmat(os.path.join(data_dir, 'raw_data_ch1.mat'))['data']) mdaio.writemda32(raw_data_ch1, os.path.join(data_dir, 'raw_data_ch1.mda')) timeseries = os.path.join(data_dir, 'raw_data_ch1.mda') timeseries_out = os.path.join(data_dir, 'filtered_raw_data_ch1.mda') bandpass_filter(timeseries, timeseries_out, samplerate, freq_min, freq_max) filtered_data = mdaio.readmda( os.path.join(data_dir, 'filtered_raw_data_ch1.mda')) detrended_data_ch1 = np.asarray( sio.loadmat(os.path.join(data_dir, 'detrended_data_ch1.mat'))['copy']) mdaio.writemda32(detrended_data_ch1, os.path.join(data_dir, 'detrended_data_ch1.mda')) timeseries = os.path.join(data_dir, 'detrended_data_ch1.mda') timeseries_out = os.path.join(data_dir, 'filtered_detrended_data_ch1.mda') bandpass_filter(timeseries, timeseries_out, samplerate, freq_min, freq_max) filtered_data_detrended = mdaio.readmda( os.path.join(data_dir, 'filtered_detrended_data_ch1.mda')) plt.plot(raw_data_ch1[0, :], label='raw') plt.plot(filtered_data_detrended[0, :], label='filtered_detrended') # plt.plot(detrended_data_ch1[0, :], label='raw_detrended')
def read_mda_timestamps(file): return readmda(file)
def loadClusteredData(data_location=None, firings_file='firings.curated.mda', helper_file='hand_curated.mv2', time_limits=None): """ Load up clustered data and pool it. :data_location: Directory which has clustered data from all the tetrodes. :time_limits: (Floating Point) Real time limits within which the data should be extracted. :returns: Spike data, separated into containers for individual units. """ if data_location is None: # Get the location using a file dialog data_location = QtHelperUtils.get_directory("Select data location.") clustered_spikes = [] tt_cl_to_unique_cluster_id = {} unique_cluster_id = 0 tetrode_list = os.listdir(data_location) for tt_dir in tetrode_list: tt_dir_path = os.path.join(data_location, tt_dir) if not os.path.isdir(tt_dir_path): continue if firings_file in os.listdir(tt_dir_path): tt_idx = tt_dir.split('nt')[1] firings_file_path = os.path.join(tt_dir_path, firings_file) curation_file_path = os.path.join(tt_dir_path, helper_file) try: # Read the firings file firing_data = mdaio.readmda(firings_file_path) except Exception as err: print('Tetrode ' + tt_dir + 'Unable to read firings file!') print(err) continue try: # Read the curation file for info on spike clusters with open(curation_file_path, 'r') as f: curation_file = json.load(f) # Get all cluster IDs. This includes noise, mua, everything! cluster_ids = [ int(cl) for cl in curation_file['cluster_attributes'].keys() ] except (FileNotFoundError, IOError) as err: print('Tetrode ' + tt_dir + 'Unable to read curation file.') continue # Read off spikes for individual clusters and assign unique cluster IDs to them if time_limits is not None: firing_times = (firing_data[1] - firing_data[1][0]) / SPIKE_SAMPLING_RATE time_limit_start_idx = np.searchsorted(firing_times, time_limits[0], side='left') time_limit_finish_idx = np.searchsorted(firing_times, time_limits[1], side='right') firing_data = firing_data[:, time_limit_start_idx: time_limit_finish_idx] firing_clusters = firing_data[2] n_clusters = 0 for unit_id in cluster_ids: unit_spikes = firing_data[1][firing_clusters == unit_id] if len(unit_spikes > 0): tt_cl_to_unique_cluster_id[(tt_idx, unit_id)] = unique_cluster_id n_clusters += 1 unique_cluster_id += 1 clustered_spikes.append(unit_spikes) else: clustered_spikes.append(None) n_spikes = len(firing_data[1]) print('Tetrode %s loaded %d spikes from %d clusters.' % (tt_dir, n_spikes, n_clusters)) else: print('Tetrode ' + tt_dir + ': Firings file not found!') return clustered_spikes