def setup_comparison_study(study_folder, gt_dict): """ Based on a dict of (recordnig, sorting) create the study folder. Parameters ---------- study_folder: str The study folder. gt_dict : a dict of tuple (recording, sorting_gt) Dict of tuple that contain recording and sorting ground truth """ study_folder = Path(study_folder) assert not os.path.exists(study_folder), 'study_folder already exists' os.makedirs(str(study_folder)) os.makedirs(str(study_folder / 'raw_files')) os.makedirs(str(study_folder / 'ground_truth')) for rec_name, (recording, sorting_gt) in gt_dict.items(): # write recording as binary format + json + prb raw_filename = study_folder / 'raw_files' / (rec_name + '.dat') prb_filename = study_folder / 'raw_files' / (rec_name + '.prb') json_filename = study_folder / 'raw_files' / (rec_name + '.json') num_chan = recording.get_num_channels() chunksize = 2**24 // num_chan sr = recording.get_sampling_frequency() se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) se.save_probe_file(recording, prb_filename, format='spyking_circus') with open(json_filename, 'w', encoding='utf8') as f: info = dict(sample_rate=sr, num_chan=num_chan, dtype='float32', frames_first=True) json.dump(info, f, indent=4) # write recording sorting_gt as with npz format se.NpzSortingExtractor.write_sorting( sorting_gt, study_folder / 'ground_truth' / (rec_name + '.npz')) # make an index of recording names with open(study_folder / 'names.txt', mode='w', encoding='utf8') as f: for rec_name in gt_dict: f.write(rec_name + '\n')
def _setup_recording(self, recording, output_folder): source_dir = Path(__file__).parent # alias to params p = self.params experiment_name = output_folder / 'recording' # save prb file: if p['probe_file'] is None: p['probe_file'] = output_folder / 'probe.prb' se.save_probe_file(recording, p['probe_file'], format='klusta', radius=p['adjacency_radius']) # source file if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first and\ recording._timeseries.offset==0: # no need to copy raw_filename = str(Path(recording._datfile).resolve()) dtype = recording._timeseries.dtype.str nb_chan = len(recording._channels) else: # save binary file (chunk by hcunk) into a new file raw_filename = output_folder / 'recording.dat' n_chan = recording.get_num_channels() chunksize = 2**24// n_chan dtype='int16' se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype=dtype, chunksize=chunksize) if p['detect_sign'] < 0: detect_sign = 'negative' elif p['detect_sign'] > 0: detect_sign = 'positive' else: detect_sign = 'both' # set up klusta config file with (source_dir / 'config_default.prm').open('r') as f: klusta_config = f.readlines() # Note: should use format with dict approach here klusta_config = ''.join(klusta_config).format(experiment_name, p['probe_file'], raw_filename, float(recording.get_sampling_frequency()), recording.get_num_channels(), "'{}'".format(dtype), p['threshold_strong_std_factor'], p['threshold_weak_std_factor'], "'" + detect_sign + "'", p['extract_s_before'], p['extract_s_after'], p['n_features_per_channel'], p['pca_n_waveforms_max'], p['num_starting_clusters'] ) with (output_folder /'config.prm').open('w') as f: f.writelines(klusta_config)
def _setup_recording(self, recording, output_folder): # reset the output folder if output_folder.is_dir(): shutil.rmtree(str(output_folder)) os.makedirs(str(output_folder)) # save prb file: probe_file = output_folder / 'probe.prb' se.save_probe_file(recording, probe_file, format='spyking_circus') # source file if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first: # no need to copy raw_filename = recording._datfile dtype = recording._timeseries.dtype.str nb_chan = len(recording._channels) offset = recording._timeseries.offset else: if self.debug: print('Local copy of recording') # save binary file (chunk by hcunk) into a new file raw_filename = output_folder / 'raw_signals.raw' n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) dtype = 'float32' offset = 0 # initialize source and probe file tdc_dataio = tdc.DataIO(dirname=str(output_folder)) nb_chan = recording.get_num_channels() tdc_dataio.set_data_source( type='RawData', filenames=[str(raw_filename)], dtype=dtype, sample_rate=recording.get_sampling_frequency(), total_channel=nb_chan, offset=offset) tdc_dataio.set_probe_file(str(probe_file)) if self.debug: print(tdc_dataio)
def test_write_dat_file(self): nb_sample = self.RX.get_num_frames() nb_chan = self.RX.get_num_channels() # time_axis=0 chunksize=None se.write_binary_dat_format(self.RX, self.test_dir + 'rec.dat', time_axis=0, dtype='float32', chunksize=None) data = np.memmap(open(self.test_dir + 'rec.dat'), dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T assert np.allclose(data, self.RX.get_traces()) del (data) # this close the file # time_axis=1 chunksize=None se.write_binary_dat_format(self.RX, self.test_dir + 'rec.dat', time_axis=1, dtype='float32', chunksize=None) data = np.memmap(open(self.test_dir + 'rec.dat'), dtype='float32', mode='r', shape=(nb_chan, nb_sample)) assert np.allclose(data, self.RX.get_traces()) del (data) # this close the file # time_axis=0 chunksize=99 se.write_binary_dat_format(self.RX, self.test_dir + 'rec.dat', time_axis=0, dtype='float32', chunksize=99) data = np.memmap(open(self.test_dir + 'rec.dat'), dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T assert np.allclose(data, self.RX.get_traces()) del (data) # this close the file # time_axis=1 chunksize=99 do not work with self.assertRaises(Exception) as context: se.write_binary_dat_format(self.RX, self.test_dir + 'rec.dat', time_axis=1, dtype='float32', chunksize=99)
def run(self): import tridesclous as tdc tmpdir = Path(_get_tmpdir('tdc')) recording = SFMdaRecordingExtractor(self.recording_dir) params = { 'fullchain_kargs': { 'duration': 300., 'preprocessor': { 'highpass_freq': self.freq_min, 'lowpass_freq': self.freq_max, 'smooth_size': 0, 'chunksize': 1024, 'lostfront_chunksize': 128, 'signalpreprocessor_engine': 'numpy', 'common_ref_removal': self.common_ref_removal, }, 'peak_detector': { 'peakdetector_engine': 'numpy', 'peak_sign': '-', 'relative_threshold': self.detection_threshold, 'peak_span': self.peak_span, }, 'noise_snippet': { 'nb_snippet': 300, }, 'extract_waveforms': { 'n_left': self.waveforms_n_left, 'n_right': self.waveforms_n_right, 'mode': 'rand', 'nb_max': 20000, 'align_waveform': self.align_waveform, }, 'clean_waveforms': { 'alien_value_threshold': self.alien_value_threshold, }, }, 'feat_method': 'peak_max', 'feat_kargs': {}, 'clust_method': 'sawchaincut', 'clust_kargs': { 'kde_bandwith': 1. }, } # save prb file: probe_file = tmpdir / 'probe.prb' se.save_probe_file(recording, probe_file, format='spyking_circus') # source file if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first: # no need to copy raw_filename = recording._datfile dtype = recording._timeseries.dtype.str nb_chan = len(recording._channels) offset = recording._timeseries.offset else: # save binary file (chunk by hcunk) into a new file raw_filename = tmpdir / 'raw_signals.raw' n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) dtype = 'float32' offset = 0 # initialize source and probe file tdc_dataio = tdc.DataIO(dirname=str(tmpdir)) nb_chan = recording.get_num_channels() tdc_dataio.set_data_source( type='RawData', filenames=[str(raw_filename)], dtype=dtype, sample_rate=recording.get_sampling_frequency(), total_channel=nb_chan, offset=offset) tdc_dataio.set_probe_file(str(probe_file)) try: sorting = tdc_helper(tmpdir=tmpdir, params=params, recording=recording) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def import_from_spike_interface(recording, sorting, tdc_dirname, highpass_freq=300., relative_threshold=5.): output_folder = Path(tdc_dirname) # save prb file: probe_file = output_folder / 'probe.prb' se.save_probe_file(recording, probe_file, format='spyking_circus') # save binary file (chunk by hcunk) into a new file raw_filename = output_folder / 'raw_signals.raw' n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) dtype = 'float32' offset = 0 sr = recording.get_sampling_frequency() # initialize source and probe file dataio = DataIO(dirname=str(output_folder)) nb_chan = recording.get_num_channels() dataio.set_data_source(type='RawData', filenames=[str(raw_filename)], dtype=dtype, sample_rate=sr, total_channel=nb_chan, offset=offset) dataio.set_probe_file(str(probe_file)) cc = CatalogueConstructor(dataio=dataio) cc.set_preprocessor_params( chunksize=1024, memory_mode='memmap', #signal preprocessor signalpreprocessor_engine='numpy', highpass_freq=highpass_freq, lowpass_freq=None, common_ref_removal=False, lostfront_chunksize=None, #peak detector peakdetector_engine='numpy', peak_sign='-', relative_threshold=relative_threshold, peak_span=1. / sr, ) t1 = time.perf_counter() cc.estimate_signals_noise(seg_num=0, duration=30.) t2 = time.perf_counter() print('estimate_signals_noise', t2 - t1) duration = dataio.get_segment_length(0) / dataio.sample_rate t1 = time.perf_counter() cc.run_signalprocessor(duration=duration, detect_peak=False) t2 = time.perf_counter() print('run_signalprocessor', t2 - t1) n_right = 60 n_left = -45 sig_size = dataio.get_segment_length(seg_num=0) all_peaks = [] for label in sorting.get_unit_ids(): indexes = sorting.get_unit_spike_train(label) indexes = indexes[indexes < (sig_size - n_right - 1)] indexes = indexes[indexes > (-n_left + 1)] peaks = np.zeros(indexes.shape, dtype=_dtype_peak) peaks['index'][:] = indexes peaks['cluster_label'][:] = label peaks['segment'][:] = 0 all_peaks.append(peaks) all_peaks = np.concatenate(all_peaks) order = np.argsort(all_peaks['index']) all_peaks = all_peaks[order] nb_peak = all_peaks.size cc.arrays.create_array('all_peaks', _dtype_peak, (nb_peak, ), 'memmap') cc.all_peaks[:] = all_peaks cc.on_new_cluster() #~ print(cc.clusters) t1 = time.perf_counter() cc.extract_some_waveforms(n_left=n_left, n_right=n_right, mode='rand', nb_max=10000) cc.clean_waveforms(alien_value_threshold=100.) t2 = time.perf_counter() print('extract_some_waveforms', t2 - t1) cc.project(method='peak_max') # put back label cc.all_peaks['cluster_label'][cc.some_peaks_index] = all_peaks[ cc.some_peaks_index]['cluster_label'] cc.on_new_cluster() cc.compute_all_centroid() cc.refresh_colors() print(cc) return cc
def run_sorters(sorter_list, recording_dict_or_list, working_folder, grouping_property=None, shared_binary_copy=False, engine=None, engine_kargs={}, debug=False, write_log=True): """ This run several sorter on several recording. Simple implementation will nested loops. Need to be done with multiprocessing. sorter_list: list of str (sorter names) recording_dict_or_list: a dict (or a list) of recording working_folder : str engine = None ( = 'loop') or 'multiprocessing' processes = only if 'multiprocessing' if None then processes=os.cpu_count() debug=True/False to control sorter verbosity Note: engine='multiprocessing' use the python multiprocessing module. This do not allow to have subprocess in subprocess. So sorter that already use internally multiprocessing, this will fail. Parameters ---------- sorter_list: list of str List of sorter name. recording_dict_or_list: dict or list A dict of recording. The key will be the name of the recording. In a list is given then the name will be recording_0, recording_1, ... working_folder: str The working directory. This must not exists before calling this function. grouping_property: str The property of grouping given to sorters. shared_binary_copy: False default Before running each sorter, all recording are copied inside the working_folder with the raw binary format (BinDatRecordingExtractor) and new recording are instantiated as BinDatRecordingExtractor. This avoids multiple copy inside each sorter of the same file but imply a global of all files. engine: str 'loop' or 'multiprocessing' engine_kargs: dict This contains kargs specific to the launcher engine: * 'loop' : no kargs * 'multiprocessing' : {'processes' : } number of processes debug: bool default True write_log: bool default True Output ---------- results : dict The output is nested dict[rec_name][sorter_name] of SortingExtrator. """ assert not os.path.exists( working_folder), 'working_folder already exists, please remove it' working_folder = Path(working_folder) for sorter_name in sorter_list: assert sorter_name in sorter_dict, '{} is not in sorter list'.format( sorter_name) if isinstance(recording_dict_or_list, list): # in case of list recording_dict = { 'recording_{}'.format(i): rec for i, rec in enumerate(recording_dict_or_list) } elif isinstance(recording_dict_or_list, dict): recording_dict = recording_dict_or_list else: raise (ValueError('bad recording dict')) if shared_binary_copy: os.makedirs(working_folder / 'raw_files') old_rec_dict = dict(recording_dict) recording_dict = {} for rec_name, recording in old_rec_dict.items(): if grouping_property is not None: recording_list = se.get_sub_extractors_by_property( recording, grouping_property) n_group = len(recording_list) assert n_group == 1, 'shared_binary_copy work only when one group' recording = recording_list[0] grouping_property = None raw_filename = working_folder / 'raw_files' / (rec_name + '.raw') prb_filename = working_folder / 'raw_files' / (rec_name + '.prb') n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan sr = recording.get_sampling_frequency() # save binary se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) # save location (with PRB format) se.save_probe_file(recording, prb_filename, format='spyking_circus') # make new recording new_rec = se.BinDatRecordingExtractor(raw_filename, sr, n_chan, 'float32', frames_first=True) se.load_probe_file(new_rec, prb_filename) recording_dict[rec_name] = new_rec task_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / 'output_folders' / rec_name / sorter_name task_list.append((rec_name, recording, sorter_name, output_folder, grouping_property, debug, write_log)) if engine is None or engine == 'loop': # simple loop in main process for arg_list in task_list: # print(arg_list) _run_one(arg_list) elif engine == 'multiprocessing': # use mp.Pool processes = engine_kargs.get('processes', None) pool = multiprocessing.Pool(processes) pool.map(_run_one, task_list) if write_log: # collect run time and write to cvs with open(working_folder / 'run_time.csv', mode='w') as f: for task in task_list: rec_name = task[0] sorter_name = task[2] output_folder = task[3] if os.path.exists(output_folder / 'run_log.txt'): with open(output_folder / 'run_log.txt', mode='r') as logfile: run_time = float(logfile.readline().replace( 'run_time:', '')) txt = '{}\t{}\t{}\n'.format(rec_name, sorter_name, run_time) f.write(txt) results = collect_results(working_folder) return results
def _setup_recording(self, recording, output_folder): source_dir = Path(__file__).parent p = self.params if not check_if_installed(Kilosort2Sorter.kilosort2_path, Kilosort2Sorter.npy_matlab_path): raise Exception(Kilosort2Sorter.installation_mesg) # save binary file if p['file_name'] is None: self.file_name = Path('recording') elif p['file_name'].suffix == '.dat': self.file_name = p['file_name'].stem p['file_name'] = self.file_name se.write_binary_dat_format(recording, output_folder / self.file_name, dtype='int16') # set up kilosort2 config files and run kilosort2 on data with (source_dir / 'kilosort2_master.txt').open('r') as f: kilosort2_master = f.readlines() with (source_dir / 'kilosort2_config.txt').open('r') as f: kilosort2_config = f.readlines() with (source_dir / 'kilosort2_channelmap.txt').open('r') as f: kilosort2_channelmap = f.readlines() nchan = recording.get_num_channels() dat_file = (output_folder / (self.file_name.name + '.dat')).absolute() kilo_thresh = p['detect_threshold'] sample_rate = recording.get_sampling_frequency() if not Kilosort2Sorter.installed: raise ImportError('Kilosort2 is not installed', Kilosort2Sorter.installation_mesg) abs_channel = (output_folder / 'kilosort2_channelmap.m').absolute() abs_config = (output_folder / 'kilosort2_config.m').absolute() kilosort2_path = Path(Kilosort2Sorter.kilosort2_path).absolute() npy_matlab_path = Path( Kilosort2Sorter.npy_matlab_path).absolute() / 'npy-matlab' if p['car']: use_car = 1 else: use_car = 0 kilosort2_master = ''.join(kilosort2_master).format( kilosort2_path, npy_matlab_path, output_folder, abs_channel, abs_config) kilosort2_config = ''.join(kilosort2_config).format( nchan, nchan, sample_rate, dat_file, p['minFR'], kilo_thresh, use_car) electrode_dimensions = p['electrode_dimensions'] if 'group' in recording.get_channel_property_names(): groups = [ recording.get_channel_property(ch, 'group') for ch in recording.get_channel_ids() ] else: groups = 'ones(1, Nchannels)' if 'location' not in recording.get_channel_property_names(): print( "'location' information is not found. Using linear configuration" ) for i_ch, ch in enumerate(recording.get_channel_ids()): recording.set_channel_property(ch, 'location', [0, i_ch]) positions = np.array([ recording.get_channel_property(chan, 'location') for chan in recording.get_channel_ids() ]) if electrode_dimensions is None: kilosort2_channelmap = ''.join(kilosort2_channelmap).format( nchan, list(positions[:, 0]), list(positions[:, 1]), groups, sample_rate) elif len(electrode_dimensions) == 2: kilosort2_channelmap = ''.join(kilosort2_channelmap).format( nchan, list(positions[:, electrode_dimensions[0]]), list(positions[:, electrode_dimensions[1]]), groups, recording.get_sampling_frequency()) else: raise Exception("Electrode dimension should be a list of len 2") for fname, value in zip([ 'kilosort2_master.m', 'kilosort2_config.m', 'kilosort2_channelmap.m' ], [kilosort2_master, kilosort2_config, kilosort2_channelmap]): with (output_folder / fname).open('w') as f: f.writelines(value)
def export_to_phy(recording, sorting, output_folder, nPC=3, electrode_dimensions=None, grouping_property=None, start_frame=None, end_frame=None, ms_before=1., ms_after=2., dtype=None, max_num_waveforms=np.inf, max_num_pca_waveforms=np.inf, save_waveforms=False, verbose=False): ''' Exports paired recording and sorting extractors to phy template-gui format. Parameters ---------- recording: RecordingExtractor The recording extractor sorting: SortingExtractor The sorting extractor output_folder: str The output folder where the phy template-gui files are saved nPC: int nPCFeatures in template-gui format electrode_dimensions: list If electrode locations are 3D, it indicates the 2D dimensions to use as channel location grouping_property: str Property to group channels. E.g. if the recording extractor has the 'group' property and 'grouping_property' is 'group', then waveforms are computed group-wise. start_frame: int The start frame for computing waveforms end_frame: int The end frame for computing waveforms ms_before: float Time period in ms to cut waveforms before the spike events ms_after: float Time period in ms to cut waveforms after the spike events dtype: dtype The numpy dtype of the waveforms max_num_waveforms: int The maximum number of waveforms to extract (default is np.inf) max_num_pca_waveforms: int The maximum number of waveforms to use to compute PCA (default is np.inf) save_waveforms: bool If True, waveforms are saved as waveforms.npy verbose: bool If True output is verbose ''' if not isinstance(recording, se.RecordingExtractor) or not isinstance( sorting, se.SortingExtractor): raise AttributeError() output_folder = Path(output_folder).absolute() if output_folder.is_dir(): shutil.rmtree(output_folder) output_folder.mkdir() # save dat file se.write_binary_dat_format(recording, output_folder / 'recording.dat', dtype='int16') # write params.py with (output_folder / 'params.py').open('w') as f: f.write("dat_path =" + "'" + str(output_folder / 'recording.dat') + "'" + '\n') f.write('n_channels_dat = ' + str(recording.get_num_channels()) + '\n') f.write("dtype = 'int16'\n") f.write('offset = 0\n') f.write('sample_rate = ' + str(recording.get_sampling_frequency()) + '\n') f.write('hp_filtered = False') # pc_features.npy - [nSpikes, nFeaturesPerChannel, nPCFeatures] single if grouping_property in recording.get_channel_property_names(): groups, num_chans_in_group = np.unique([ recording.get_channel_property(ch, grouping_property) for ch in recording.get_channel_ids() ], return_counts=True) max_num_chans_in_group = np.max(num_chans_in_group) channel_groups = np.array([ recording.get_channel_property(ch, grouping_property) for ch in recording.get_channel_ids() ]) else: max_num_chans_in_group = recording.get_num_channels() channel_groups = np.array([0] * recording.get_num_channels()) if nPC > max_num_chans_in_group: nPC = max_num_chans_in_group if verbose: print("Changed number of PC to number of channels: ", nPC) if 'waveforms' not in sorting.get_unit_spike_feature_names(): waveforms = get_unit_waveforms(recording, sorting, start_frame=start_frame, end_frame=end_frame, max_num_waveforms=max_num_waveforms, ms_before=ms_before, ms_after=ms_after, dtype=dtype, verbose=verbose) pc_scores = compute_unit_pca_scores( recording, sorting, n_comp=nPC, by_electrode=True, start_frame=start_frame, end_frame=end_frame, max_num_waveforms=max_num_waveforms, ms_before=ms_before, ms_after=ms_after, dtype=dtype, max_num_pca_waveforms=max_num_pca_waveforms, verbose=verbose) # spike times.npy and spike clusters.npy spike_times = np.array([]) spike_clusters = np.array([]) pc_features = np.array([]) for i_u, id in enumerate(sorting.get_unit_ids()): st = sorting.get_unit_spike_train(id) cl = [i_u] * len(sorting.get_unit_spike_train(id)) pc = pc_scores[i_u] spike_times = np.concatenate((spike_times, np.array(st))) spike_clusters = np.concatenate((spike_clusters, np.array(cl))) if i_u == 0: pc_features = np.array(pc) else: pc_features = np.vstack((pc_features, np.array(pc))) sorting_idxs = np.argsort(spike_times) spike_times = spike_times[sorting_idxs, np.newaxis] spike_clusters = spike_clusters[sorting_idxs, np.newaxis] # pc_features (nSpikes, nPC, nPCFeatures) pc_features = pc_features[sorting_idxs].swapaxes(1, 2) # amplitudes.npy amplitudes = np.ones((len(spike_times), 1), dtype='int16') # channel_map.npy channel_map = np.arange(recording.get_num_channels()) channel_map_si = np.array(recording.get_channel_ids()) # channel_positions.npy if 'location' in recording.get_channel_property_names(): positions = np.array([ recording.get_channel_property(chan, 'location') for chan in recording.get_channel_ids() ]) if electrode_dimensions is not None: positions = positions[:, electrode_dimensions] else: if verbose: print( "'location' property is not available and it will be linear.") positions = np.zeros((recording.get_num_channels(), 2)) positions[:, 1] = np.arange(recording.get_num_channels()) # similar_templates.npy - [nTemplates, nTemplates] single templates = get_unit_templates(recording, sorting) if not isinstance(templates, list): if len(templates.shape) == 2: # single unit templates = templates.reshape(1, templates.shape[0], templates.shape[1]) similar_templates = _compute_templates_similarity(templates) # templates.npy templates = np.array(templates, dtype='float32').swapaxes(1, 2) if grouping_property in recording.get_channel_property_names(): if grouping_property not in sorting.get_unit_property_names(): set_unit_properties_by_max_channel_properties( recording, sorting, grouping_property) # pc_feature_ind = np.zeros((len(sorting.get_unit_ids()), int(max_num_chans_in_group)), dtype=int) templates_ind = np.zeros( (len(sorting.get_unit_ids()), int(max_num_chans_in_group)), dtype=int) templates_red = np.zeros((templates.shape[0], templates.shape[1], int(max_num_chans_in_group))) for u_i, u in enumerate(sorting.get_unit_ids()): group = sorting.get_unit_property(u, 'group') unit_chans = [] for ch in recording.get_channel_ids(): if recording.get_channel_property(ch, 'group') == group: unit_chans.append(list(channel_map_si).index(ch)) if len(unit_chans) == 0: raise Exception( "Sorting extractor has different property than recording extractor. " "They should correspond.") if len(unit_chans) != max_num_chans_in_group: # append closest channel if list(channel_map).index(int( np.max(unit_chans))) + 1 < np.max(channel_map): unit_chans.append( list(channel_map).index(int(np.max(unit_chans)) + 1)) else: unit_chans.append( list(channel_map).index(int(np.min(unit_chans)) - 1)) unit_chans = np.array(unit_chans) templates_ind[u_i] = unit_chans templates_red[u_i, :] = templates[u_i, :, unit_chans].T templates = templates_red else: templates_ind = np.tile(np.arange(recording.get_num_channels()), (len(sorting.get_unit_ids()), 1)) pc_feature_ind = np.tile(np.arange(recording.get_num_channels()), (len(sorting.get_unit_ids()), 1)) # spike_templates.npy - [nSpikes, ] uint32 spike_templates = spike_clusters # Save channel_group and second_max_channel to .tsv metadata second_max_channel = [] for t in templates: second_max_channel.append( np.argsort(np.abs(np.min(t, axis=0)))[::-1][1]) with (output_folder / 'cluster_second_max_chans.tsv').open('w') as tsvfile: writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n') writer.writerow(['cluster_id', 'sec_channel']) for i, (u, ch) in enumerate( zip(sorting.get_unit_ids(), second_max_channel)): writer.writerow([i, ch]) with (output_folder / 'cluster_group.tsv').open('w') as tsvfile: writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n') writer.writerow(['cluster_id', 'group']) for i, u in enumerate(sorting.get_unit_ids()): writer.writerow([i, 'unsorted']) if 'group' in sorting.get_unit_property_names(): with (output_folder / 'cluster_chan_grp.tsv').open('w') as tsvfile: writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n') writer.writerow(['cluster_id', 'chan_grp']) for i, u in enumerate(sorting.get_unit_ids()): writer.writerow([i, sorting.get_unit_property(u, 'group')]) else: with (output_folder / 'cluster_channel_group.tsv').open('w') as tsvfile: writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n') writer.writerow(['cluster_id', 'ch_group']) for i, u in enumerate(sorting.get_unit_ids()): writer.writerow([i, 0]) # Save .tsv metadata max_amplitudes = [np.min(t) for t in templates] second_max_channel = [] for t in templates: second_max_channel.append( np.argsort(np.abs(np.min(t, axis=0)))[::-1][1]) with (output_folder / 'cluster_amps.tsv').open('w') as tsvfile: writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n') writer.writerow(['cluster_id', 'max_amp']) for i, (u, amp) in enumerate(zip(sorting.get_unit_ids(), max_amplitudes)): writer.writerow([i, amp]) with (output_folder / 'cluster_second_max_chan.tsv').open('w') as tsvfile: writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n') writer.writerow(['cluster_id', 'sec_channel']) for i, (u, ch) in enumerate( zip(sorting.get_unit_ids(), second_max_channel)): writer.writerow([i, ch]) np.save(str(output_folder / 'amplitudes.npy'), amplitudes) np.save(str(output_folder / 'spike_times.npy'), spike_times.astype('int64')) np.save(str(output_folder / 'spike_templates.npy'), spike_templates.astype('int64')) np.save(str(output_folder / 'spike_clusters.npy'), spike_clusters.astype('int64')) np.save(str(output_folder / 'pc_features.npy'), pc_features) np.save(str(output_folder / 'pc_feature_ind.npy'), pc_feature_ind.astype('int64')) np.save(str(output_folder / 'templates.npy'), templates) np.save(str(output_folder / 'template_ind.npy'), templates_ind.astype('int64')) np.save(str(output_folder / 'similar_templates.npy'), similar_templates) np.save(str(output_folder / 'channel_map.npy'), channel_map.astype('int64')) np.save(str(output_folder / 'channel_map_si.npy'), channel_map_si.astype('int64')) np.save(str(output_folder / 'channel_positions.npy'), positions) np.save(str(output_folder / 'channel_groups.npy'), channel_groups) if save_waveforms: np.save(str(output_folder / 'waveforms.npy'), np.array(waveforms)) if verbose: print('Saved phy format to: ', output_folder) print('Run:\n\nphy template-gui ', str(output_folder / 'params.py'))