def run(self): code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code num_workers=os.environ.get('NUM_WORKERS',2) try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code #num_workers = os.environ.get('NUM_WORKERS', 1) #print('num_workers: {}'.format(num_workers)) try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting, yaml_file = yass_helper( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, template_width_ms=self.template_width_ms, filter=self.filter) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) #shutil.copyfile(yaml_file, self.paramfile_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def test_remove_bad_channels(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_rm = remove_bad_channels(rec, bad_channel_ids=[0]) assert 0 not in rec_rm.get_channel_ids() rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2]) assert 1 not in rec_rm.get_channel_ids() and 2 not in rec_rm.get_channel_ids() check_dumping(rec_rm) shutil.rmtree('test') timeseries = np.random.randn(4, 60000) timeseries[1] = 10 * timeseries[1] rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000) rec_np.set_channel_locations(np.ones((rec_np.get_num_channels(), 2))) se.MdaRecordingExtractor.write_recording(rec_np, 'test') rec = se.MdaRecordingExtractor('test') rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=0.1) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=10) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) shutil.rmtree('test')
def run(self): ironclust_src=os.environ.get('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('Environment variable not set: IRONCLUST_SRC') code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.ironclust( recording=recording, tmpdir=tmpdir, ## TODO detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_src=ironclust_src ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot() fname=save_plot(self.plot_out)
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) ef=int(1e6) recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef) recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub) sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) print('computing templates...') templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids) print('.') ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=compute_template_snr(template,channel_noise_levels) peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def real(name='franklab_tetrode', download=True): if name == 'franklab_tetrode': dsdir = 'kbucket://b5ecdf1474c5/datasets/neuron_paper/franklab_tetrode' IX = se.MdaRecordingExtractor(dir_path=dsdir) return (IX, None) else: raise Exception('Unrecognized name for real dataset: ' + name)
def run(self): code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code try: recording = si.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = si.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = sorters.kilosort(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan) si.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
def run(self): ret={} recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) ret['samplerate']=recording.getSamplingFrequency() ret['num_channels']=len(recording.getChannelIds()) ret['duration_sec']=recording.getNumFrames()/ret['samplerate'] write_json_file(self.json_out,ret)
def spikeforest_sort( recording_dirname, # The recording extractor sorter, sorting_params, _force_run=False, _force_save=False ): recording_signature=kb.computeDirHash(recording_dirname) signature_obj=dict( sorter_name=sorter.name, sorter_version=sorter.version, recording=recording_signature, sorting_params=sorting_params ) if not _force_run: print('Looking up in cache...') firings=kb.realizeFile(key=signature_obj) if firings: print('Found') if _force_save: print('Saving') kb.saveFile(fname=firings,key=signature_obj) return si.MdaSortingExtractor(firings_file=firings) recording=si.MdaRecordingExtractor(recording_dirname) sorting=sorter(recording=recording,**sorting_params) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path='tmp_firings.mda') kb.saveFile(fname='tmp_firings.mda',key=signature_obj) return sorting
def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot() fname=save_plot(self.plot_out)
def recordingExtractor(self, download=False): X = si.MdaRecordingExtractor(dataset_directory=self.directory(), download=download) if 'channels' in self._obj: if self._obj['channels']: X = si.SubRecordingExtractor(parent_recording=X, channel_ids=self._obj['channels']) return X
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) R=st.filters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) N=R.getNumFrames() N2=int(N/2) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] sw.TimeseriesWidget(recording=R,trange=[N2-4000,N2+0],channels=channels,width=12,height=5).plot() save_plot(self.jpg_out)
def createSession(self): recording = se.MdaRecordingExtractor( dataset_directory=self._recording_directory, download=False) recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=10000) recording = se.NumpyRecordingExtractor( timeseries=recording.getTraces(), samplerate=recording.getSamplingFrequency()) W = SFW.TimeseriesWidget(recording=recording) return W
def test_mda_extractor(self): path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) self._check_recording_return_types(RX_mda) self._check_recordings_equal(self.RX, RX_mda) self._check_sorting_return_types(SX_mda) self._check_sortings_equal(self.SX, SX_mda)
def create_dumpable_extractors_from_existing(folder, RX, SX): folder = Path(folder) if 'location' not in RX.get_shared_channel_property_names(): RX.set_channel_locations(np.random.randn(RX.get_num_channels(), 2)) se.MdaRecordingExtractor.write_recording(RX, folder) RX_mda = se.MdaRecordingExtractor(folder) se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz') SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz') return RX_mda, SX_npz
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) R=st.filters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) S=si.MdaSortingExtractor(firings_file=self.firings) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] units=S.getUnitIds() if len(units)>20: units=units[::int(len(units)/20)] sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot() save_plot(self.jpg_out)
def yass_example(download=True, set_id=1): if set_id in range(1, 7): dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format( set_id) IX = se.MdaRecordingExtractor(dataset_directory=dsdir, download=download) path1 = os.path.join(dsdir, 'firings_true.mda') print(path1) OX = se.MdaSortingExtractor(path1) return (IX, OX) else: raise Exception( 'Invalid ID for yass_example {} is not betewen 1..6'.format( set_id))
def toy_example(duration=10, num_channels=4, sampling_frequency=30000.0, K=10, dumpable=False, dump_folder=None, seed=None): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX
def run(self): recording = si.MdaRecordingExtractor(self.dataset_dir) num_workers = int(os.environ.get('NUM_WORKERS', -1)) if num_workers <= 0: num_workers = None sorting = sf.sorters.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, freq_min=self.freq_min, freq_max=self.freq_max, whiten=self.whiten, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, noise_overlap_threshold=self.noise_overlap_threshold, num_workers=num_workers) si.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out)
def run(self): import spikeextractors as se import spiketoolkit as st import ml_ms4alg print('MountainSort4......') recording = se.MdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) # Bandpass filter if self.freq_min or self.freq_max: recording = st.preprocessing.bandpass_filter( recording=recording, freq_min=self.freq_min, freq_max=self.freq_max) # Whiten if self.whiten: recording = st.preprocessing.whiten(recording=recording) # Sort sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers) # Curate # if self.noise_overlap_threshold is not None: # sorting=ml_ms4alg.mountainsort4_curation( # recording=recording, # sorting=sorting, # noise_overlap_threshold=self.noise_overlap_threshold # ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out)
def create_dumpable_extractors(folder, duration=10, num_channels=4, sampling_frequency=30000.0, K=10, seed=None): RX, SX = toy_example(duration=duration, num_channels=num_channels, K=K, sampling_frequency=sampling_frequency, seed=seed) folder = Path(folder) se.MdaRecordingExtractor.write_recording(RX, folder) RX_mda = se.MdaRecordingExtractor(folder) se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz') SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz') return RX_mda, SX_npz
def test_dump_load_multi_sub_extractor(self): # generate dumpable formats path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda]) check_dumping(RX_multi_chan) RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], ) check_dumping(RX_multi_time) RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1]) check_dumping(RX_multi_chan) SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2]) check_dumping(SX_sub) SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda]) check_dumping(SX_multi)
def run(self): import spikewidgets as sw R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording = R0 # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100) ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1) peak_channel_index=np.argmax(max_p2p_amps_on_channels) peak_channel=recording.getChannelIds()[peak_channel_index] R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index]) R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000) templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100) template2=templates2[0] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index] #info0['snr']=compute_template_snr(template,channel_noise_levels) #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def run(self): recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) num_workers=os.environ.get('NUM_WORKERS',None) if num_workers: num_workers=int(num_workers) sorting=sf.sorters.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, freq_min=self.freq_min, freq_max=self.freq_max, whiten=self.whiten, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, noise_overlap_threshold=self.noise_overlap_threshold, num_workers=num_workers ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
def run(self): ironclust_path = os.environ.get('IRONCLUST_PATH', None) if not ironclust_path: raise Exception('Environment variable not set: IRONCLUST_PATH') code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-tmp-' + code try: recording = se.MdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = ironclust_helper( recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_path=ironclust_path, params=params, ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def getRecording(self,download=True): ret=si.MdaRecordingExtractor(dataset_directory=self._kbucket_path,download=download) return ret
def toy_example(duration: float = 10., num_channels: int = 4, sampling_frequency: float = 30000., K: int = 10, dumpable: bool = False, dump_folder: Optional[Union[str, Path]] = None, seed: Optional[int] = None): """ Create toy recording and sorting extractors. Parameters ---------- duration: float Duration in s (default 10) num_channels: int Number of channels (default 4) sampling_frequency: float Sampling frequency (default 30000) K: int Number of units (default 10) dumpable: bool If True, objects are dumped to file and become 'dumpable' dump_folder: str or Path Path to dump folder (if None, 'test' is used seed: int Seed for random initialization Returns ------- recording: RecordingExtractor The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an MdaRecordingExtractor sorting: SortingExtractor The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an NpzSortingExtractor """ upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX