def run(self): sorting=si.MdaSortingExtractor(firings_file=self.firings) sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true) if len(self.units_true)>0: sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true) SC=st.comparison.SortingComparison(sorting_true,sorting) df=sw.SortingComparisonTable(comparison=SC).getDataframe() json=df.transpose().to_dict() html=df.to_html(index=False) _write_json_file(json,self.json_out) _write_json_file(html,self.html_out)
def spikeforest_sort( recording_dirname, # The recording extractor sorter, sorting_params, _force_run=False, _force_save=False ): recording_signature=kb.computeDirHash(recording_dirname) signature_obj=dict( sorter_name=sorter.name, sorter_version=sorter.version, recording=recording_signature, sorting_params=sorting_params ) if not _force_run: print('Looking up in cache...') firings=kb.realizeFile(key=signature_obj) if firings: print('Found') if _force_save: print('Saving') kb.saveFile(fname=firings,key=signature_obj) return si.MdaSortingExtractor(firings_file=firings) recording=si.MdaRecordingExtractor(recording_dirname) sorting=sorter(recording=recording,**sorting_params) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path='tmp_firings.mda') kb.saveFile(fname='tmp_firings.mda',key=signature_obj) return sorting
def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot() fname=save_plot(self.plot_out)
def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot() fname=save_plot(self.plot_out)
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) ef=int(1e6) recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef) recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub) sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) print('computing templates...') templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids) print('.') ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=compute_template_snr(template,channel_noise_levels) peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def getSortingTrue(self): D2=mlp.readDir(self._kbucket_path) if 'firings_true.mda' in D2['files']: ret=si.MdaSortingExtractor(firings_file=self._kbucket_path+'/firings_true.mda') else: ret=None return ret
def test_mda_extractor(self): path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) self._check_recording_return_types(RX_mda) self._check_recordings_equal(self.RX, RX_mda) self._check_sorting_return_types(SX_mda) self._check_sortings_equal(self.SX, SX_mda)
def get_result_from_folder(output_folder: Union[str, Path]): output_folder = Path(output_folder) tmpdir = output_folder / 'tmp' result_fname = str(tmpdir / 'firings.mda') samplerate_fname = str(tmpdir / 'samplerate.txt') with open(samplerate_fname, 'r') as f: samplerate = float(f.read()) sorting = se.MdaSortingExtractor(file_path=result_fname, sampling_frequency=samplerate) return sorting
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) R=st.filters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) S=si.MdaSortingExtractor(firings_file=self.firings) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] units=S.getUnitIds() if len(units)>20: units=units[::int(len(units)/20)] sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot() save_plot(self.jpg_out)
def yass_example(download=True, set_id=1): if set_id in range(1, 7): dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format( set_id) IX = se.MdaRecordingExtractor(dataset_directory=dsdir, download=download) path1 = os.path.join(dsdir, 'firings_true.mda') print(path1) OX = se.MdaSortingExtractor(path1) return (IX, OX) else: raise Exception( 'Invalid ID for yass_example {} is not betewen 1..6'.format( set_id))
def test_dump_load_multi_sub_extractor(self): # generate dumpable formats path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda]) check_dumping(RX_multi_chan) RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], ) check_dumping(RX_multi_time) RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1]) check_dumping(RX_multi_chan) SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2]) check_dumping(SX_sub) SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda]) check_dumping(SX_multi)
def run(self): import spikewidgets as sw R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording = R0 # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100) ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1) peak_channel_index=np.argmax(max_p2p_amps_on_channels) peak_channel=recording.getChannelIds()[peak_channel_index] R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index]) R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000) templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100) template2=templates2[0] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index] #info0['snr']=compute_template_snr(template,channel_noise_levels) #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def get_result_from_folder(output_folder): sorting = se.MdaSortingExtractor(str(output_folder / 'firings.mda')) return sorting
def sorting(self): return si.MdaSortingExtractor(firings_file=self._obj['firings'])
def sortingTrue(self): return si.MdaSortingExtractor(firings_file=self.directory() + '/firings_true.mda')