def run(self): code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/yass-tmp-' + code # num_workers = os.environ.get('NUM_WORKERS', 1) # print('num_workers: {}'.format(num_workers)) try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting, _ = yass_helper(recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, template_width_ms=self.template_width_ms, filter=self.filter) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) # shutil.copyfile(yaml_file, self.paramfile_out) except: if os.path.exists(tmpdir): # shutil.rmtree(tmpdir) print('not deleted tmpdir1') raise if not getattr(self, '_keep_temp_files', False): # shutil.rmtree(tmpdir) print('not deleted tmpdir2')
def _generate_toy_recordings(): # generate toy recordings if not os.path.exists('toy_recordings'): os.mkdir('toy_recordings') replace_recordings = False ret = [] for K in [5, 10, 15, 20]: recpath = 'toy_recordings/example_K{}'.format(K) if os.path.exists(recpath) and (replace_recordings): print('Generating toy recording: {}'.format(recpath)) shutil.rmtree(recpath) else: print('Recording already exists: {}'.format(recpath)) if not os.path.exists(recpath): rx, sx_true = example_datasets.toy_example1(duration=60, num_channels=4, samplerate=30000, K=K) SFMdaRecordingExtractor.write_recording(recording=rx, save_path=recpath) SFMdaSortingExtractor.write_sorting(sorting=sx_true, save_path=recpath + '/firings_true.mda') ret.append( dict(name='example_K{}'.format(K), study='toy_study', directory=os.path.abspath(recpath), description='A toy recording with K={} units'.format(K))) return ret
def compute_score(self, sorting_extractor): if self.metric != 'spikeforest': comparison = sc.compare_sorter_to_ground_truth(self.gt_se, sorting_extractor, exhaustive_gt=True) d_results = comparison.get_performance(method='pooled_with_average', output='dict') print('results') print(d_results) if self.metric == 'accuracy': score = d_results['accuracy'] if self.metric == 'precision': score = d_results['precision'] if self.metric == 'recall': score = d_results['recall'] if self.metric == 'f1': print('comparison:') print(d_results) if (d_results['precision']+d_results['recall']) > 0: score = 2 * d_results['precision'] * d_results['recall'] / (d_results['precision']+d_results['recall']) else: score = 0 del comparison else: tmp_dir = 'test_outputs_spikeforest' SFMdaSortingExtractor.write_sorting(sorting=sorting_extractor, save_path=os.path.join(tmp_dir,'firings.mda')) print('Compare with ground truth...') sa.GenSortingComparisonTable.execute(firings=os.path.join(tmp_dir,'firings.mda'), firings_true=os.path.join(tmp_dir,'firings_true.mda'), units_true=self.true_units_above, # use all units json_out=os.path.join(tmp_dir,'comparison.json'), html_out=os.path.join(tmp_dir,'comparison.html'), _container=None) comparison = mt.loadObject(path=os.path.join(tmp_dir,'comparison.json')) score = np.mean([float(u['accuracy']) for u in comparison.values()]) return -score
def run(self): tmpdir = _get_tmpdir('jrclust') try: recording = SFMdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) all_params = dict() for param0 in self.PARAMETERS: all_params[param0.name] = getattr(self, param0.name) sorting = jrclust_helper( recording=recording, tmpdir=tmpdir, params=params, **all_params, ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def gen_synth_datasets(datasets, *, outdir, samplerate=32000): if not os.path.exists(outdir): os.mkdir(outdir) for ds in datasets: ds_name = ds['name'] print(ds_name) if 'seed' not in ds.keys(): ds['seed'] = 0 spiketrains = gen_spiketrains( duration=ds['duration'], n_exc=ds['n_exc'], n_inh=ds['n_inh'], f_exc=ds['f_exc'], f_inh=ds['f_inh'], min_rate=ds['min_rate'], st_exc=ds['st_exc'], st_inh=ds['st_inh'], seed=ds['seed'] ) OX = NeoSpikeTrainsOutputExtractor( spiketrains=spiketrains, samplerate=samplerate) X, geom = gen_recording( templates=ds['templates'], output_extractor=OX, noise_level=ds['noise_level'], samplerate=samplerate, duration=ds['duration'] ) IX = si.NumpyRecordingExtractor( timeseries=X, samplerate=samplerate, geom=geom) SFMdaRecordingExtractor.write_recording( IX, outdir+'/{}'.format(ds_name)) SFMdaSortingExtractor.write_sorting( OX, outdir+'/{}/firings_true.mda'.format(ds_name)) print('Done.')
def run(self): from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor print('test1', self.firings_path, self.samplerate) sorting = SFMdaSortingExtractor(firings_file=self.firings_path) samplerate = self.samplerate max_samples = self.max_samples max_dt_msec = self.max_dt_msec bin_size_msec = self.bin_size_msec max_dt_tp = max_dt_msec * samplerate / 1000 bin_size_tp = bin_size_msec * samplerate / 1000 autocorrelograms = [] for unit_id in sorting.get_unit_ids(): print('Unit::g {}'.format(unit_id)) (bin_counts, bin_edges) = compute_autocorrelogram(sorting.get_unit_spike_train(unit_id), max_dt_tp=max_dt_tp, bin_size_tp=bin_size_tp, max_samples=max_samples) autocorrelograms.append(dict( unit_id=unit_id, bin_counts=bin_counts, bin_edges=bin_edges )) ret = dict( autocorrelograms=autocorrelograms ) with open(self.json_out, 'w') as f: json.dump(serialize_np(ret), f)
def run(self): import spikesorters as sorters print('Klusta......') recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/klusta-' + code sorter = sorters.KlustaSorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True ) sorter.set_params( adjacency_radius=self.adjacency_radius, detect_sign=self.detect_sign, threshold_strong_std_factor=self.threshold_strong_std_factor, threshold_weak_std_factor=self.threshold_weak_std_factor, n_features_per_channel=self.n_features_per_channel, num_starting_clusters=self.num_starting_clusters, extract_s_before=self.extract_s_before, extract_s_after=self.extract_s_after ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
def run(self): _keep_temp_files = True code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort2_helper(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, minFR=self.minFR) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not _keep_temp_files: print('removing tmpdir1') shutil.rmtree(tmpdir) raise if not _keep_temp_files: print('removing tmpdir2') shutil.rmtree(tmpdir)
def run(self): timer = time.time() import spikesorters as sorters print('IronClust......') try: ironclust_path = IronClust.install() except: traceback.print_exc() raise Exception('Problem installing ironclust.') sorters.IronClustSorter.set_ironclust_path(ironclust_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code sorter = sorters.IronClustSorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder = False # will be taken care by _keep_temp_files one step above ) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, adjacency_radius_out=self.adjacency_radius_out, detect_threshold=self.detect_threshold, prm_template_name=self.prm_template_name, freq_min=self.freq_min, freq_max=self.freq_max, merge_thresh=self.merge_thresh, pc_per_chan=self.pc_per_chan, whiten=self.whiten, filter_type=self.filter_type, filter_detect_type=self.filter_detect_type, common_ref_type=self.common_ref_type, batch_sec_drift=self.batch_sec_drift, step_sec_drift=self.step_sec_drift, knn=self.knn, min_count=self.min_count, fGpu=self.fGpu, fft_thresh=self.fft_thresh, fft_thresh_low=self.fft_thresh_low, nSites_whiten=self.nSites_whiten, feature_type=self.feature_type, delta_cut=self.delta_cut, post_merge_mode=self.post_merge_mode, sort_mode=self.sort_mode ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
def run(self): from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor from spikeforest_common import autoScaleRecordingToNoiseLevel import spiketoolkit as st clustering_n_jobs = os.environ.get('NUM_WORKERS', None) if clustering_n_jobs is not None: clustering_n_jobs = int(clustering_n_jobs) code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/hs2-tmp-' + code try: if not os.path.exists(tmpdir): os.mkdir(tmpdir) recording = SFMdaRecordingExtractor(self.recording_dir) # print('Auto scaling via normalize_by_quantile...') # recording = st.preprocessing.normalize_by_quantile(recording=recording, scale=200.0) # recording = autoScaleRecordingToNoiseLevel(recording, noise_level=32) print('Running HerdingspikesSorter...') os.environ['HS2_PROBE_PATH'] = tmpdir st_sorter = st.sorters.HerdingspikesSorter(recording=recording, output_folder=tmpdir + '/hs2_sorting_output') print('Using builtin bandpass and normalisation') hs2_par = st_sorter.default_params() hs2_par['filter'] = True hs2_par['pre_scale_value'] = 20 hs2_par['pre_scale'] = True st_sorter.set_params(**hs2_par) if clustering_n_jobs is not None: st_sorter.set_params(clustering_n_jobs=clustering_n_jobs) timer = st_sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = st_sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def run(self): rx = SFMdaRecordingExtractor( dataset_directory=self.recording_directory, download=True, raw_fname=self.filtered_timeseries) sx_true = SFMdaSortingExtractor(firings_file=self.firings_true) sx = SFMdaSortingExtractor(firings_file=self.firings_sorted) ssobj = create_spikesprays(rx=rx, sx_true=sx_true, sx_sorted=sx, neighborhood_size=self.neighborhood_size, num_spikes=self.num_spikes, unit_id_true=self.unit_id_true, unit_id_sorted=self.unit_id_sorted) with open(self.json_out, 'w') as f: json.dump(ssobj, f)
def run(self): if self.throw_error: import time print( 'Intentionally throwing an error in 3 seconds (MountainSort4TestError)...' ) sys.stdout.flush() time.sleep(3) raise Exception('Intentional error.') import ml_ms4alg print('MountainSort4......') recording = SFMdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) # Bandpass filter if self.freq_min or self.freq_max: recording = bandpass_filter(recording=recording, freq_min=self.freq_min, freq_max=self.freq_max) # Whiten if self.whiten: recording = whiten(recording=recording) # Sort sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers) # Curate # if self.noise_overlap_threshold is not None: # sorting=ml_ms4alg.mountainsort4_curation( # recording=recording, # sorting=sorting, # noise_overlap_threshold=self.noise_overlap_threshold # ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out)
def load_sorting_results_info(firings_path, *, recording_path, epoch_name, ntrode_name, curated=False): if not mt.findFile(firings_path): return None sorting = SFMdaSortingExtractor(firings_file=firings_path) total_num_events = 0 for unit_id in sorting.get_unit_ids(): spike_times = sorting.get_unit_spike_train(unit_id=unit_id) total_num_events = total_num_events + len(spike_times) return dict( type='sorting_results', epoch_name=epoch_name, ntrode_name=ntrode_name, curated=curated, firings_path=firings_path, recording_path=recording_path, unit_ids=sorting.get_unit_ids(), num_events=total_num_events )
def run(self): print('Running Tridesclous...') from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor # from spikeforest_common import autoScaleRecordingToNoiseLevel # import spiketoolkit as st import spikesorters code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/tdc-tmp-' + code try: if not os.path.exists(tmpdir): os.mkdir(tmpdir) print('Loading recording...') recording = SFMdaRecordingExtractor(self.recording_dir) # print('Auto scaling via normalize_by_quantile...') # recording = st.preprocessing.normalize_by_quantile(recording=recording, scale=200.0) # recording = autoScaleRecordingToNoiseLevel(recording, noise_level=32) print('Running TridesclousSorter...') os.environ['HS2_PROBE_PATH'] = tmpdir st_sorter = spikesorters.TridesclousSorter(recording=recording, output_folder=tmpdir + '/tdc_sorting_output', verbose=True) # setattr(st_sorter, 'debug', True) st_sorter timer = st_sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = st_sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def run(self): # from spikeinterface/spikesorters import spikesorters as sorters print('MountainSort4......') recording = SFMdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/mountainsort4-' + code sorter = sorters.Mountainsort4Sorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True ) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers, curation=False, whiten=True, filter=True, freq_min=self.freq_min, freq_max=self.freq_max ) # TODO: get elapsed time from the return of this run sorter.run() sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
def run(self): from .bandpass_filter import bandpass_filter from .whiten import whiten import ml_ms4alg print('MountainSort4......') recording = SFMdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) # Bandpass filter if self.freq_min or self.freq_max: recording = bandpass_filter( recording=recording, freq_min=self.freq_min, freq_max=self.freq_max) # Whiten if self.whiten: recording = whiten(recording=recording) # Sort sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers ) # Curate # if self.noise_overlap_threshold is not None: # sorting=ml_ms4alg.mountainsort4_curation( # recording=recording, # sorting=sorting, # noise_overlap_threshold=self.noise_overlap_threshold # ) SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
def run(self): recording = SFMdaRecordingExtractor( dataset_directory=self.recording_directory, download=True) sorting = SFMdaSortingExtractor(firings_file=self.firings) waveforms0 = _get_random_spike_waveforms(recording=recording, sorting=sorting, unit=self.unit_id) channel_ids = recording.get_channel_ids() avg_waveform = np.median(waveforms0, axis=2) ret = dict(channel_ids=channel_ids, average_waveform=avg_waveform.tolist()) with open(self.json_out, 'w') as f: json.dump(ret, f)
def __init__(self, sorter, recording, gt_sorting, params_to_opt, space=None, run_schedule=[50, 50], metric='accuracy', recdir=None, outfile=None, x0=None, y0=None): self.sorter = sorter.lower() self.re = recording self.gt_se = gt_sorting self.params_to_opt = OrderedDict(params_to_opt) self.outfile = outfile self.run_schedule = run_schedule self.space = space self.best_parameters = None self.iteration = 0 self.metric = metric.lower() self.recdir = recdir self.results_obj = None self.SorterClass = ss.sorter_dict[self.sorter] self.true_units_above = None self.x0 = x0 self.y0 = y0 if self.metric == 'spikeforest': tmp_dir = 'test_outputs_spikeforest' if not os.path.exists(tmp_dir): print('Creating folder {} for temporary data - note this is not cleaned up.'.format(tmp_dir)) os.makedirs(tmp_dir) SFMdaSortingExtractor.write_sorting(sorting=self.gt_se, save_path=os.path.join(tmp_dir,'firings_true.mda')) print('Compute units info...') sa.ComputeUnitsInfo.execute(recording_dir=self.recdir, firings=os.path.join(tmp_dir,'firings_true.mda'), json_out=os.path.join(tmp_dir,'true_units_info.json')) true_units_info = mt.loadObject(path=os.path.join(tmp_dir,'true_units_info.json')) true_units_info_by_unit_id = dict() snrthresh = 8 self.true_units_above = [u['unit_id'] for u in true_units_info if u['snr'] > snrthresh] print('Only testing ground truth units with snr > 8: ',self.true_units_above)
def yass_example(download=True, set_id=1): if set_id in range(1, 7): dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format( set_id) IX = SFMdaRecordingExtractor(dataset_directory=dsdir, download=download) path1 = os.path.join(dsdir, 'firings_true.mda') print(path1) OX = SFMdaSortingExtractor(path1) return (IX, OX) else: raise Exception( 'Invalid ID for yass_example {} is not betewen 1..6'.format( set_id))
def run(self): import spikesorters as sorters print('SpyKING CIRCUS......') recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-' + code num_workers = int(os.environ.get('NUM_WORKERS', '1')) sorter = sorters.SpykingcircusSorter(recording=recording, output_folder=tmpdir, verbose=True, delete_output_folder=True) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, auto_merge=0.5, num_workers=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts, ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out)
def run(self): import spikesorters as sorters print('Kilosort2......') try: kilosort2_path = KiloSort2.install() except: traceback.print_exc() raise Exception('Problem installing kilosort.') sorters.Kilosort2Sorter.set_kilosort2_path(kilosort2_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-' + code sorter = sorters.Kilosort2Sorter(recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True) sorter.set_params(detect_threshold=self.detect_threshold, car=self.car, minFR=self.minFR, electrode_dimensions=None, freq_min=self.freq_min, sigmaMask=self.adjacency_radius, nPCs=self.pc_per_chan) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out)
def run(self): import spikesorters as sorters print('KiloSort......') try: kilosort_path = KiloSort.install() except: traceback.print_exc() raise Exception('Problem installing kilosort.') sorters.KilosortSorter.set_kilosort_path(kilosort_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-' + code sorter = sorters.KilosortSorter(recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True) sorter.set_params(detect_threshold=self.detect_threshold, freq_min=self.freq_min, freq_max=self.freq_max, car=True, useGPU=True, electrode_dimensions=None) # TODO: get elapsed time from the return of this run sorter.run() sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out)
def initialize(self): if self._initialized: return self._initialized = True # self._recording_context.initialize() print('******** FORESTVIEW: Initializing sorting result context') if self._sorting_result_object['firings']: self._sorting_extractor = SFMdaSortingExtractor( firings_file=self._sorting_result_object['firings']) else: self._sorting_extractor = None print('******** FORESTVIEW: Done initializing sorting result context')
import os import shutil from spikeforest import example_datasets from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor recording, sorting_true = example_datasets.toy_example1() recdir = 'toy_example1' # remove the toy recording directory if it exists if os.path.exists(recdir): shutil.rmtree(recdir) print('Preparing toy recording...') SFMdaRecordingExtractor.write_recording(recording=recording, save_path=recdir) SFMdaSortingExtractor.write_sorting(sorting=sorting_true, save_path=recdir + '/firings_true.mda')
def run(self): import tridesclous as tdc tmpdir = Path(_get_tmpdir('tdc')) recording = SFMdaRecordingExtractor(self.recording_dir) params = { 'fullchain_kargs': { 'duration': 300., 'preprocessor': { 'highpass_freq': self.freq_min, 'lowpass_freq': self.freq_max, 'smooth_size': 0, 'chunksize': 1024, 'lostfront_chunksize': 128, 'signalpreprocessor_engine': 'numpy', 'common_ref_removal': self.common_ref_removal, }, 'peak_detector': { 'peakdetector_engine': 'numpy', 'peak_sign': '-', 'relative_threshold': self.detection_threshold, 'peak_span': self.peak_span, }, 'noise_snippet': { 'nb_snippet': 300, }, 'extract_waveforms': { 'n_left': self.waveforms_n_left, 'n_right': self.waveforms_n_right, 'mode': 'rand', 'nb_max': 20000, 'align_waveform': self.align_waveform, }, 'clean_waveforms': { 'alien_value_threshold': self.alien_value_threshold, }, }, 'feat_method': 'peak_max', 'feat_kargs': {}, 'clust_method': 'sawchaincut', 'clust_kargs': { 'kde_bandwith': 1. }, } # save prb file: probe_file = tmpdir / 'probe.prb' se.save_probe_file(recording, probe_file, format='spyking_circus') # source file if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first: # no need to copy raw_filename = recording._datfile dtype = recording._timeseries.dtype.str nb_chan = len(recording._channels) offset = recording._timeseries.offset else: # save binary file (chunk by hcunk) into a new file raw_filename = tmpdir / 'raw_signals.raw' n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) dtype = 'float32' offset = 0 # initialize source and probe file tdc_dataio = tdc.DataIO(dirname=str(tmpdir)) nb_chan = recording.get_num_channels() tdc_dataio.set_data_source( type='RawData', filenames=[str(raw_filename)], dtype=dtype, sample_rate=recording.get_sampling_frequency(), total_channel=nb_chan, offset=offset) tdc_dataio.set_probe_file(str(probe_file)) try: sorting = tdc_helper(tmpdir=tmpdir, params=params, recording=recording) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def sorting(self): return SFMdaSortingExtractor(firings_file=self._obj['firings'])
#!/usr/bin/env python from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor from mountaintools import client as mt # Configure to download from the public spikeforest kachery node mt.configDownloadFrom('spikeforest.public') # Load an example tetrode recording with its ground truth # You can also substitute any of the other available recordings recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth' print('loading recording...') recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True) sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda') # import a spike sorter from the spikesorters module of spikeforest from spikeforestsorters import MountainSort4 import os import shutil # In place of MountainSort4 you could use any of the following: # # MountainSort4, SpykingCircus, KiloSort, KiloSort2, YASS # IronClust, HerdingSpikes2, JRClust, Tridesclous, Klusta # although the Matlab sorters require further setup. # clear and create an empty output directory (keep things tidy) if os.path.exists('test_outputs'): shutil.rmtree('test_outputs') os.makedirs('test_outputs', exist_ok=True)
def sortingTrue(self): return SFMdaSortingExtractor(firings_file=self.directory() + '/firings_true.mda')
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor from mountaintools import client as mt # Configure to download from the public spikeforest kachery node mt.configDownloadFrom('spikeforest.public') # Load the recording with its ground truth recdir = 'sha1dir://be6ce9f60fe1963af235862dc8197c9753b4b6f5.hybrid_janelia/drift_siprobe/rec_16c_1200s_11' print('Loading recording...') recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True) sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda') sorting_ms4 = SFMdaSortingExtractor( firings_file= 'sha1://f1c6fdf52a2873d6f746e44dab6bf7ccd2937d97/f1c6fdf52a2873d6f746e44dab6bf7ccd2937d97/firings.mda' ) # import from the spikeforest package import spikeforest_analysis as sa # write the ground truth firings file SFMdaSortingExtractor.write_sorting(sorting=sorting_true, save_path='test_outputs/firings_true.mda') # run the comparison print('Compare with truth...') import time timer = time.time() ## Old method
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor from mountaintools import client as mt # Configure to download from the public spikeforest kachery node mt.configDownloadFrom('spikeforest.public') # Load an example tetrode recording with its ground truth # You can also substitute any of the other available recordings recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth' print('loading recording...') recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True) sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')