def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) ef=int(1e6) recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef) recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub) sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) print('computing templates...') templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids) print('.') ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=compute_template_snr(template,channel_noise_levels) peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def run(self, payload, next_elem): spif_params_dict = {} probe_file = None for param in self.param_list: if param["name"] == "probe_path": probe_file = param["value"] elif param["name"] == "channel_map": channel_map = param["value"] elif param["name"] == "channel_groups": channel_groups = param["value"] else: spif_params_dict[param["name"]] = param["value"] recording = self._spif_class(**spif_params_dict) if probe_file: recording = recording.load_probe_file(probe_file, channel_map, channel_groups) else: if channel_map: assert np.all([ chan in channel_map for chan in recording.get_channel_ids() ]), ("all channel_ids in " "'channel_map' must be in recording channel ids") recording = se.SubRecordingExtractor(recording, channel_ids=channel_map) if channel_groups: recording.set_channel_groups(recording.get_channel_ids(), channel_groups) return recording
def run(self): code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/yass-tmp-' + code # num_workers = os.environ.get('NUM_WORKERS', 1) # print('num_workers: {}'.format(num_workers)) try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting, _ = yass_helper(recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, template_width_ms=self.template_width_ms, filter=self.filter) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) # shutil.copyfile(yaml_file, self.paramfile_out) except: if os.path.exists(tmpdir): # shutil.rmtree(tmpdir) print('not deleted tmpdir1') raise if not getattr(self, '_keep_temp_files', False): # shutil.rmtree(tmpdir) print('not deleted tmpdir2')
def old_fetch_average_waveform_plot_data(recording_object, sorting_object, unit_id): import labbox_ephys as le R = le.LabboxEphysRecordingExtractor(recording_object) S = le.LabboxEphysSortingExtractor(sorting_object) start_frame = 0 end_frame = R.get_sampling_frequency() * 30 R0 = se.SubRecordingExtractor(parent_recording=R, start_frame=start_frame, end_frame=end_frame) S0 = se.SubSortingExtractor(parent_sorting=S, start_frame=start_frame, end_frame=end_frame) times0 = S0.get_unit_spike_train(unit_id=unit_id) if len(times0) == 0: # no waveforms found return dict(channel_id=None, average_waveform=None) try: average_waveform = st.postprocessing.get_unit_templates( recording=R0, sorting=S0, unit_ids=[unit_id])[0] except: raise Exception(f'Error getting unit templates for unit {unit_id}') channel_maximums = np.max(np.abs(average_waveform), axis=1) maxchan_index = np.argmax(channel_maximums) maxchan_id = R0.get_channel_ids()[maxchan_index] return dict(channel_id=maxchan_id, average_waveform=average_waveform[maxchan_index, :].tolist())
def run(self): code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code num_workers=os.environ.get('NUM_WORKERS',2) try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
def run(self): _keep_temp_files = True code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort2_helper(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, minFR=self.minFR) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not _keep_temp_files: print('removing tmpdir1') shutil.rmtree(tmpdir) raise if not _keep_temp_files: print('removing tmpdir2') shutil.rmtree(tmpdir)
def run(self): ironclust_src=os.environ.get('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('Environment variable not set: IRONCLUST_SRC') code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.ironclust( recording=recording, tmpdir=tmpdir, ## TODO detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_src=ironclust_src ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot() fname=save_plot(self.plot_out)
def run(self): code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code try: recording = si.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = si.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = sorters.kilosort(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan) si.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
def run(self): tmpdir = _get_tmpdir('jrclust') try: recording = SFMdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) all_params = dict() for param0 in self.PARAMETERS: all_params[param0.name] = getattr(self, param0.name) sorting = jrclust_helper( recording=recording, tmpdir=tmpdir, params=params, **all_params, ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot() fname=save_plot(self.plot_out)
def _on_view_timeseries(self): rx = self._recording.recordingExtractor() sf = rx.getSamplingFrequency() if self._recording.recordingFileIsLocal(): rx = se.SubRecordingExtractor(parent_recording=rx, start_frame=int(sf * 0), end_frame=int(sf * 10)) else: rx = se.SubRecordingExtractor(parent_recording=rx, start_frame=int(sf * 0), end_frame=int(sf * 1)) rx = st.preprocessing.bandpass_filter(recording=rx, freq_min=300, freq_max=6000) self._view = SFW.TimeseriesWidget(recording=rx) self.refresh()
def test_ttl_frames_in_sub_multi(self): # sub recording start_frame = self.example_info['num_frames'] // 3 end_frame = 2 * self.example_info['num_frames'] // 3 RX_sub = se.SubRecordingExtractor(self.RX, start_frame=start_frame, end_frame=end_frame) original_ttls = self.RX.get_ttl_events()[0] ttls_in_sub = original_ttls[np.where((original_ttls >= start_frame) & (original_ttls < end_frame))[0]] self.assertTrue( np.array_equal(RX_sub.get_ttl_events()[0], ttls_in_sub - start_frame)) # multirecording RX_multi = se.MultiRecordingTimeExtractor( recordings=[self.RX, self.RX, self.RX]) ttls_originals = self.RX.get_ttl_events()[0] num_ttls = len(ttls_originals) self.assertEqual(len(RX_multi.get_ttl_events()[0]), 3 * num_ttls) self.assertTrue( np.array_equal(RX_multi.get_ttl_events()[0][:num_ttls], ttls_originals)) self.assertTrue( np.array_equal(RX_multi.get_ttl_events()[0][num_ttls:2 * num_ttls], ttls_originals + self.RX.get_num_frames())) self.assertTrue( np.array_equal(RX_multi.get_ttl_events()[0][2 * num_ttls:], ttls_originals + 2 * self.RX.get_num_frames()))
def recordingExtractor(self, download=False): X = si.MdaRecordingExtractor(dataset_directory=self.directory(), download=download) if 'channels' in self._obj: if self._obj['channels']: X = si.SubRecordingExtractor(parent_recording=X, channel_ids=self._obj['channels']) return X
def __init__(self, context): vd.Component.__init__(self) self._context = context rx = self._context.recording.recordingExtractor() sf = rx.getSamplingFrequency() print(self._context.recording.recordingFileIsLocal()) if self._context.recording.recordingFileIsLocal(): rx = se.SubRecordingExtractor(parent_recording=rx, start_frame=int(sf * 0), end_frame=int(sf * 10)) else: rx = se.SubRecordingExtractor(parent_recording=rx, start_frame=int(sf * 0), end_frame=int(sf * 1)) rx = st.preprocessing.bandpass_filter(recording=rx, freq_min=300, freq_max=6000) self._timeseries_widget = TimeseriesWidget(recording=rx)
def run(self): ret={} recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) ret['samplerate']=recording.getSamplingFrequency() ret['num_channels']=len(recording.getChannelIds()) ret['duration_sec']=recording.getNumFrames()/ret['samplerate'] write_json_file(self.json_out,ret)
def createSession(self): recording = SFMdaRecordingExtractor( dataset_directory=self._recording_directory, download=False) recording = se.SubRecordingExtractor( parent_recording=recording, start_frame=0, end_frame=10000) recording = se.NumpyRecordingExtractor( timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency()) W = SFW.TimeseriesWidget(recording=recording) _make_full_browser(W) return W
def compute_units_info(*, recording, sorting, channel_ids=[], unit_ids=[]): if (channel_ids) and (len(channel_ids) > 0): recording = si.SubRecordingExtractor(parent_recording=recording, channel_ids=channel_ids) # load into memory print('Loading recording into RAM...') recording = si.NumpyRecordingExtractor( timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency()) # do filtering print('Filtering...') recording = bandpass_filter(recording=recording, freq_min=300, freq_max=6000) recording = si.NumpyRecordingExtractor( timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency()) if (not unit_ids) or (len(unit_ids) == 0): unit_ids = sorting.get_unit_ids() print('Computing channel noise levels...') channel_noise_levels = compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates print('Computing unit templates...') templates = compute_unit_templates(recording=recording, sorting=sorting, unit_ids=unit_ids, max_num=100) print(recording.get_channel_ids()) ret = [] for i, unit_id in enumerate(unit_ids): print('Unit {} of {} (id={})'.format(i + 1, len(unit_ids), unit_id)) template = templates[i] max_p2p_amps_on_channels = np.max(template, axis=1) - np.min(template, axis=1) peak_channel_index = np.argmax(max_p2p_amps_on_channels) peak_channel = recording.get_channel_ids()[peak_channel_index] peak_signal = np.max(np.abs(template[peak_channel_index, :])) info0 = dict() info0['unit_id'] = int(unit_id) info0['snr'] = peak_signal / channel_noise_levels[peak_channel_index] info0['peak_channel'] = int(recording.get_channel_ids()[peak_channel]) train = sorting.get_unit_spike_train(unit_id=unit_id) info0['num_events'] = int(len(train)) info0['firing_rate'] = float( len(train) / (recording.get_num_frames() / recording.get_sampling_frequency())) ret.append(info0) return ret
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels) R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) N=R.getNumFrames() N2=int(N/2) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] sw.TimeseriesWidget(recording=R,trange=[N2-4000,N2+0],channels=channels,width=12,height=5).plot() save_plot(self.jpg_out)
def run(self): import spikewidgets as sw R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording = R0 # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100) ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1) peak_channel_index=np.argmax(max_p2p_amps_on_channels) peak_channel=recording.getChannelIds()[peak_channel_index] R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index]) R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000) templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100) template2=templates2[0] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index] #info0['snr']=compute_template_snr(template,channel_noise_levels) #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def __init__(self, arg: Union[str, dict], download: bool=False): super().__init__() obj = _create_object_for_arg(arg) assert obj is not None self._object: dict = obj recording_format = self._object['recording_format'] data: dict = self._object['data'] if recording_format == 'mda': self._recording: se.RecordingExtractor = MdaRecordingExtractor(timeseries_path=data['raw'], samplerate=data['params']['samplerate'], geom=np.array(data['geom']), download=download) elif recording_format == 'nrs': self._recording: se.RecordingExtractor = NrsRecordingExtractor(**data) # elif recording_format == 'nwb': # path0 = kp.load_file(data['path']) # self._recording: se.RecordingExtractor = NwbRecordingExtractor(path0, electrical_series_name='e-series') elif recording_format == 'bin1': self._recording: se.RecordingExtractor = Bin1RecordingExtractor(**data, p2p=True) elif recording_format == 'snippets1': self._recording: se.RecordingExtractor = Snippets1RecordingExtractor(snippets_h5_uri=data['snippets_h5_uri'], p2p=True) elif recording_format == 'subrecording': R = LabboxEphysRecordingExtractor(data['recording'], download=download) if 'channel_ids' in data: channel_ids = np.array(data['channel_ids']) elif 'group' in data: channel_ids = np.array(R.get_channel_ids()) groups = R.get_channel_groups(channel_ids=R.get_channel_ids()) group = int(data['group']) inds = np.where(np.array(groups) == group)[0] channel_ids = channel_ids[inds] elif 'groups' in data: raise Exception('This case not yet handled.') else: channel_ids = None if 'start_frame' in data: start_frame = data['start_frame'] end_frame = data['end_frame'] else: start_frame = None end_frame = None self._recording: se.RecordingExtractor = se.SubRecordingExtractor( parent_recording=R, channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame ) elif recording_format == 'filtered': R = LabboxEphysRecordingExtractor(data['recording'], download=download) self._recording: se.RecordingExtractor = _apply_filters(recording=R, filters=data['filters']) else: raise Exception(f'Unexpected recording format: {recording_format}') self.copy_channel_properties(recording=self._recording)
def get_unit_waveforms(recording, sorting, unit_ids, channel_ids_by_unit, snippet_len): if not isinstance(snippet_len, list) and not isinstance( snippet_len, tuple): b = int(snippet_len / 2) a = int(snippet_len) - b snippet_len = [a, b] num_channels = recording.get_num_channels() num_frames = recording.get_num_frames() num_bytes_per_chunk = 1000 * 1000 * 1000 # ? how to choose this num_bytes_per_frame = num_channels * 2 chunk_size = num_bytes_per_chunk / num_bytes_per_frame padding_size = 100 + snippet_len[0] + snippet_len[ 1] # a bit excess padding chunks = _divide_recording_into_time_chunks(num_frames=num_frames, chunk_size=chunk_size, padding_size=padding_size) all_unit_waveforms = [[] for ii in range(len(unit_ids))] for ii, chunk in enumerate(chunks): # chunk: {istart, iend, istart_with_padding, iend_with_padding} # include padding print( f'Processing chunk {ii + 1} of {len(chunks)}; chunk-range: {chunk["istart_with_padding"]} {chunk["iend_with_padding"]}; num-frames: {num_frames}' ) recording_chunk = se.SubRecordingExtractor( parent_recording=recording, start_frame=chunk['istart_with_padding'], end_frame=chunk['iend_with_padding']) # note that the efficiency of this operation may need improvement (really depends on sorting extractor implementation) sorting_chunk = se.SubSortingExtractor(parent_sorting=sorting, start_frame=chunk['istart'], end_frame=chunk['iend']) print(f'Getting unit waveforms for chunk {ii + 1} of {len(chunks)}') # num_events_in_chunk x num_channels_in_nbhd[unit_id] x len_of_one_snippet unit_waveforms = _get_unit_waveforms_for_chunk( recording=recording_chunk, sorting=sorting_chunk, frame_offset=chunk['istart'] - chunk[ 'istart_with_padding'], # just the padding size (except 0 for first chunk) unit_ids=unit_ids, snippet_len=snippet_len, channel_ids_by_unit=channel_ids_by_unit) for i_unit, x in enumerate(unit_waveforms): all_unit_waveforms[i_unit].append(x) # concatenate the results over the chunks unit_waveforms = [ # tot_num_events_for_unit x num_channels_in_nbhd[unit_id] x len_of_one_snippet np.concatenate(all_unit_waveforms[i_unit], axis=0) for i_unit in range(len(unit_ids)) ] return unit_waveforms
def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if len(self.channels)>0: R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels) R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) S=si.MdaSortingExtractor(firings_file=self.firings) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] if len(self.units)>0: units=self.units else: units=S.getUnitIds() if len(units)>20: units=units[::int(len(units)/20)] sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot() save_plot(self.jpg_out)
def test_multi_sub_recording_extractor(self): RX_multi = se.MultiRecordingTimeExtractor( recordings=[self.RX, self.RX, self.RX], epoch_names=['A', 'B', 'C']) RX_sub = RX_multi.get_epoch('C') self._check_recordings_equal(self.RX, RX_sub) self.assertEqual(4, len(RX_sub.get_channel_ids())) RX_multi = se.MultiRecordingChannelExtractor( recordings=[self.RX, self.RX2, self.RX3], groups=[1, 2, 3]) print(RX_multi.get_channel_groups()) RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3]) self._check_recordings_equal(self.RX2, RX_sub) self.assertEqual([2, 2, 2, 2], RX_sub.get_channel_groups()) self.assertEqual(12, len(RX_multi.get_channel_ids()))
def get_max_channels_per_waveforms(recording, grouping_property, channel_ids, max_channels_per_waveforms): if grouping_property is None: if max_channels_per_waveforms is None: n_channels = len(channel_ids) elif max_channels_per_waveforms >= len(channel_ids): n_channels = len(channel_ids) else: n_channels = max_channels_per_waveforms else: rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids) rec_groups = np.array(rec.get_channel_groups()) groups, count = np.unique(rec_groups, return_counts=True) if max_channels_per_waveforms is None: n_channels = np.max(count) elif max_channels_per_waveforms >= np.max(count): n_channels = np.max(count) else: n_channels = max_channels_per_waveforms return n_channels
def test_dump_load_multi_sub_extractor(self): # generate dumpable formats path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda]) check_dumping(RX_multi_chan) RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], ) check_dumping(RX_multi_time) RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1]) check_dumping(RX_multi_chan) SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2]) check_dumping(SX_sub) SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda]) check_dumping(SX_multi)
def run(self): recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) num_workers=os.environ.get('NUM_WORKERS',None) if num_workers: num_workers=int(num_workers) sorting=sf.sorters.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, freq_min=self.freq_min, freq_max=self.freq_max, whiten=self.whiten, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, noise_overlap_threshold=self.noise_overlap_threshold, num_workers=num_workers ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
def run(self): ironclust_path = os.environ.get('IRONCLUST_PATH', None) if not ironclust_path: raise Exception('Environment variable not set: IRONCLUST_PATH') code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-tmp-' + code try: recording = se.MdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = ironclust_helper( recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_path=ironclust_path, params=params, ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
def subset_recording(self, stub_test: bool = False): """ Subset a recording extractor according to stub and channel subset options. Parameters ---------- stub_test : bool, optional (default False) """ kwargs = dict() if stub_test: num_frames = 100 end_frame = min( [num_frames, self.recording_extractor.get_num_frames()]) kwargs.update(end_frame=end_frame) if self.subset_channels is not None: kwargs.update(channel_ids=self.subset_channels) recording_extractor = se.SubRecordingExtractor( self.recording_extractor, **kwargs) return recording_extractor
def find_unit_peak_channels(recording, sorting, unit_ids): # Use the first part of the recording to estimate the peak channels sorting_shortened = SubsampledSortingExtractor(parent_sorting=sorting, max_events_per_unit=20, method='truncate') max_time = 0 for unit_id in sorting_shortened.get_unit_ids(): st = sorting_shortened.get_unit_spike_train(unit_id=unit_id) if len(st) > 0: max_time = max(max_time, np.max(st)) recording_shortened = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=max_time + 1) unit_waveforms = get_unit_waveforms( recording=recording_shortened, sorting=sorting_shortened, unit_ids=unit_ids, channel_ids_by_unit=None, snippet_len=(10, 10) ) channel_ids = recording.get_channel_ids() peak_channels = {} for ii, unit_id in enumerate(unit_ids): average_waveform = np.median(unit_waveforms[ii], axis=0) peak_channel_index = int(np.argmax(np.max(average_waveform, axis=1) - np.min(average_waveform, axis=1))) peak_channels[unit_id] = int(channel_ids[peak_channel_index]) return peak_channels