def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) ef=int(1e6) recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef) recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub) sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) print('computing templates...') templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids) print('.') ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=compute_template_snr(template,channel_noise_levels) peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
def run(self): print( 'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}' .format(self.firings, self.firings_true, self.units_true)) sorting = SFMdaSortingExtractor(firings_file=self.firings) sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true) if (self.units_true is not None) and (len(self.units_true) > 0): sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true, unit_ids=self.units_true) SC = st.comparison.compare_sorter_to_ground_truth( gt_sorting=sorting_true, tested_sorting=sorting, delta_time=0.3, min_accuracy=0, compute_misclassification=False, exhaustive_gt=False # Fix this in future ) df = pd.concat([SC.count, SC.get_performance()], axis=1).reset_index() df = df.rename(columns=dict(gt_unit_id='unit_id', fp='num_false_positives', fn='num_false_negatives', tested_id='best_unit', tp='num_matches')) df['matched_unit'] = df['best_unit'] df['f_p'] = 1 - df['precision'] df['f_n'] = 1 - df['recall'] # sw.SortingComparisonTable(comparison=SC).getDataframe() json = df.transpose().to_dict() html = df.to_html(index=False) _write_json_file(json, self.json_out) _write_json_file(html, self.html_out)
def old_fetch_average_waveform_plot_data(recording_object, sorting_object, unit_id): import labbox_ephys as le R = le.LabboxEphysRecordingExtractor(recording_object) S = le.LabboxEphysSortingExtractor(sorting_object) start_frame = 0 end_frame = R.get_sampling_frequency() * 30 R0 = se.SubRecordingExtractor(parent_recording=R, start_frame=start_frame, end_frame=end_frame) S0 = se.SubSortingExtractor(parent_sorting=S, start_frame=start_frame, end_frame=end_frame) times0 = S0.get_unit_spike_train(unit_id=unit_id) if len(times0) == 0: # no waveforms found return dict(channel_id=None, average_waveform=None) try: average_waveform = st.postprocessing.get_unit_templates( recording=R0, sorting=S0, unit_ids=[unit_id])[0] except: raise Exception(f'Error getting unit templates for unit {unit_id}') channel_maximums = np.max(np.abs(average_waveform), axis=1) maxchan_index = np.argmax(channel_maximums) maxchan_id = R0.get_channel_ids()[maxchan_index] return dict(channel_id=maxchan_id, average_waveform=average_waveform[maxchan_index, :].tolist())
def test_multi_sub_sorting_extractor(self): N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], ) SX_multi.set_unit_property(unit_id=1, property_name='dummy', value=5) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0) self._check_sortings_equal(SX_multi, SX_sub) self.assertEqual(SX_multi.get_unit_property(1, 'dummy'), SX_sub.get_unit_property(1, 'dummy')) N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor(sortings=[self.SX, self.SX2], ) SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0, end_frame=N) self._check_sortings_equal(SX_multi, SX_sub1)
def _keep_good_units(sorting_obj, cluster_groups_csv_uri): sorting = LabboxEphysSortingExtractor(sorting_obj) df = pd.read_csv(kp.load_file(cluster_groups_csv_uri), delimiter='\t') df_good = df.loc[df['group'] == 'good'] good_unit_ids = df_good['cluster_id'].to_numpy().tolist() sorting_good = se.SubSortingExtractor(parent_sorting=sorting, unit_ids=good_unit_ids) return _create_npy1_sorting_object(sorting=sorting_good)
def run(self): sorting=si.MdaSortingExtractor(firings_file=self.firings) sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true) if len(self.units_true)>0: sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true) SC=st.comparison.SortingComparison(sorting_true,sorting) df=sw.SortingComparisonTable(comparison=SC).getDataframe() json=df.transpose().to_dict() html=df.to_html(index=False) _write_json_file(json,self.json_out) _write_json_file(html,self.html_out)
def test_multi_sub_sorting_extractor(self): N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=N, end_frame=2 * N) self._check_sortings_equal(self.SX, SX_sub) self.assertEqual(SX_multi.get_sampling_frequency(), self.SX.get_sampling_frequency()) self.assertEqual(SX_sub.get_sampling_frequency(), self.SX.get_sampling_frequency()) N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0) self._check_sortings_equal(SX_multi, SX_sub) N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[2 * N, 0, N]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=N, end_frame=2 * N) self._check_sortings_equal(self.SX, SX_sub) N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[0, 0, 0]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0) self._check_sortings_equal(SX_multi, SX_sub) N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor(sortings=[self.SX, self.SX2], start_frames=[0, 0]) SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0, end_frame=N) self._check_sortings_equal(SX_multi, SX_sub1)
def get_unit_waveforms(recording, sorting, unit_ids, channel_ids_by_unit, snippet_len): if not isinstance(snippet_len, list) and not isinstance( snippet_len, tuple): b = int(snippet_len / 2) a = int(snippet_len) - b snippet_len = [a, b] num_channels = recording.get_num_channels() num_frames = recording.get_num_frames() num_bytes_per_chunk = 1000 * 1000 * 1000 # ? how to choose this num_bytes_per_frame = num_channels * 2 chunk_size = num_bytes_per_chunk / num_bytes_per_frame padding_size = 100 + snippet_len[0] + snippet_len[ 1] # a bit excess padding chunks = _divide_recording_into_time_chunks(num_frames=num_frames, chunk_size=chunk_size, padding_size=padding_size) all_unit_waveforms = [[] for ii in range(len(unit_ids))] for ii, chunk in enumerate(chunks): # chunk: {istart, iend, istart_with_padding, iend_with_padding} # include padding print( f'Processing chunk {ii + 1} of {len(chunks)}; chunk-range: {chunk["istart_with_padding"]} {chunk["iend_with_padding"]}; num-frames: {num_frames}' ) recording_chunk = se.SubRecordingExtractor( parent_recording=recording, start_frame=chunk['istart_with_padding'], end_frame=chunk['iend_with_padding']) # note that the efficiency of this operation may need improvement (really depends on sorting extractor implementation) sorting_chunk = se.SubSortingExtractor(parent_sorting=sorting, start_frame=chunk['istart'], end_frame=chunk['iend']) print(f'Getting unit waveforms for chunk {ii + 1} of {len(chunks)}') # num_events_in_chunk x num_channels_in_nbhd[unit_id] x len_of_one_snippet unit_waveforms = _get_unit_waveforms_for_chunk( recording=recording_chunk, sorting=sorting_chunk, frame_offset=chunk['istart'] - chunk[ 'istart_with_padding'], # just the padding size (except 0 for first chunk) unit_ids=unit_ids, snippet_len=snippet_len, channel_ids_by_unit=channel_ids_by_unit) for i_unit, x in enumerate(unit_waveforms): all_unit_waveforms[i_unit].append(x) # concatenate the results over the chunks unit_waveforms = [ # tot_num_events_for_unit x num_channels_in_nbhd[unit_id] x len_of_one_snippet np.concatenate(all_unit_waveforms[i_unit], axis=0) for i_unit in range(len(unit_ids)) ] return unit_waveforms
def test_example(self): self.assertEqual(self.RX.get_channel_ids(), self.example_info['channel_ids']) self.assertEqual(self.RX.get_num_channels(), self.example_info['num_channels']) self.assertEqual(self.RX.get_num_frames(), self.example_info['num_frames']) self.assertEqual(self.RX.get_sampling_frequency(), self.example_info['sampling_frequency']) self.assertEqual(self.SX.get_unit_ids(), self.example_info['unit_ids']) self.assertEqual(self.RX.get_channel_locations(0)[0][0], self.example_info['channel_prop'][0]) self.assertEqual(self.RX.get_channel_locations(0)[0][1], self.example_info['channel_prop'][1]) self.assertEqual(self.SX.get_unit_property(unit_id=1, property_name='stability'), self.example_info['unit_prop']) self.assertTrue(np.array_equal(self.SX.get_unit_spike_train(1), self.example_info['train1'])) self.assertTrue(issubclass(self.SX.get_unit_spike_train(1).dtype.type, np.integer)) self.assertTrue(self.RX.get_shared_channel_property_names(), ['group', 'location', 'shared_channel_prop']) self.assertTrue(self.RX.get_channel_property_names(0), ['group', 'location', 'shared_channel_prop']) self.assertTrue(self.SX2.get_shared_unit_property_names(), ['shared_unit_prop']) self.assertTrue(self.SX2.get_unit_property_names(4), ['shared_unit_prop', 'stability']) self.assertTrue(self.SX2.get_shared_unit_spike_feature_names(), ['shared_unit_feature']) self.assertTrue(self.SX2.get_unit_spike_feature_names(3), ['shared_channel_prop', 'widths']) print(self.SX3.get_unit_spike_features(0, 'dummy')) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy'), self.example_info['features3'])) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=4), self.example_info['features3'][1:])) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', end_frame=4), self.example_info['features3'][:1])) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46), self.example_info['features3'][1:6])) self.assertTrue('dummy2' in self.SX3.get_unit_spike_feature_names(0)) self.assertTrue('dummy2_idxs' in self.SX3.get_unit_spike_feature_names(0)) sub_extractor_full = se.SubSortingExtractor(self.SX3) sub_extractor_partial = se.SubSortingExtractor(self.SX3, start_frame=20, end_frame=46) self.assertTrue(np.array_equal(sub_extractor_full.get_unit_spike_features(0, 'dummy'), self.SX3.get_unit_spike_features(0, 'dummy'))) self.assertTrue(np.array_equal(sub_extractor_partial.get_unit_spike_features(0, 'dummy'), self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46))) check_recording_return_types(self.RX)
def test_multi_sub_sorting_extractor(self): N = self.RX.getNumFrames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=N, end_frame=2 * N) self._check_sortings_equal(self.SX, SX_sub) N = self.RX.getNumFrames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0) self._check_sortings_equal(SX_multi, SX_sub) N = self.RX.getNumFrames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[2 * N, 0, N]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=N, end_frame=2 * N) self._check_sortings_equal(self.SX, SX_sub) N = self.RX.getNumFrames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], start_frames=[0, 0, 0]) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0) self._check_sortings_equal(SX_multi, SX_sub) N = self.RX.getNumFrames() SX_multi = se.MultiSortingExtractor(sortings=[self.SX, self.SX2], start_frames=[0, 0]) SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0, end_frame=N) self._check_sortings_equal(SX_multi, SX_sub1)
def mountainsort4_curation(*, recording, sorting, noise_overlap_threshold=None): if noise_overlap_threshold is not None: units = sorting.get_unit_ids() noise_overlap_scores = compute_noise_overlap(recording=recording, sorting=sorting, unit_ids=units) inds = np.where( np.array(noise_overlap_scores) <= noise_overlap_threshold)[0] new_units = list(np.array(units)[inds]) sorting = se.SubSortingExtractor(parent_sorting=sorting, unit_ids=new_units) return sorting
def run(self): print( 'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}' .format(self.firings, self.firings_true, self.units_true)) sorting = SFMdaSortingExtractor(firings_file=self.firings) sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true) if (self.units_true is not None) and (len(self.units_true) > 0): sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true, unit_ids=self.units_true) SC = SortingComparison(sorting_true, sorting, delta_tp=30) df = get_comparison_data_frame(comparison=SC) # sw.SortingComparisonTable(comparison=SC).getDataframe() json = df.transpose().to_dict() html = df.to_html(index=False) _write_json_file(json, self.json_out) _write_json_file(html, self.html_out)
def test_dump_load_multi_sub_extractor(self): # generate dumpable formats path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda]) check_dumping(RX_multi_chan) RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], ) check_dumping(RX_multi_time) RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1]) check_dumping(RX_multi_chan) SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2]) check_dumping(SX_sub) SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda]) check_dumping(SX_multi)
def prepare_snippets_nwb_from_extractors( recording: se.RecordingExtractor, sorting: se.SortingExtractor, nwb_file_path: str, nwb_object_prefix: str, start_frame, end_frame, max_neighborhood_size: int, max_events_per_unit: Union[None, int] = None, snippet_len=(50, 80), ): import pynwb from labbox_ephys import (SubsampledSortingExtractor, find_unit_neighborhoods, find_unit_peak_channels, get_unit_waveforms) if start_frame is not None: recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=start_frame, end_frame=end_frame) sorting = se.SubSortingExtractor(parent_sorting=sorting, start_frame=start_frame, end_frame=end_frame) unit_ids = sorting.get_unit_ids() samplerate = recording.get_sampling_frequency() # Use this optimized function rather than spiketoolkit's version # for efficiency with long recordings and/or many channels, units or spikes # we should submit this to the spiketoolkit project as a PR print('Subsampling sorting') if max_events_per_unit is not None: sorting_subsampled = SubsampledSortingExtractor( parent_sorting=sorting, max_events_per_unit=max_events_per_unit, method='random') else: sorting_subsampled = sorting print('Finding unit peak channels') peak_channels_by_unit = find_unit_peak_channels(recording=recording, sorting=sorting, unit_ids=unit_ids) print('Finding unit neighborhoods') channel_ids_by_unit = find_unit_neighborhoods( recording=recording, peak_channels_by_unit=peak_channels_by_unit, max_neighborhood_size=max_neighborhood_size) print(f'Getting unit waveforms for {len(unit_ids)} units') unit_waveforms = get_unit_waveforms( recording=recording, sorting=sorting_subsampled, unit_ids=unit_ids, channel_ids_by_unit=channel_ids_by_unit, snippet_len=snippet_len) # unit_waveforms = st.postprocessing.get_unit_waveforms( # recording=recording, # sorting=sorting, # unit_ids=unit_ids, # ms_before=1, # ms_after=1.5, # max_spikes_per_unit=500 # ) with pynwb.NWBHDF5IO(path=nwb_file_path, mode='a') as io: nwbf = io.read() nwbf.add_scratch(name=f'{nwb_object_prefix}_unit_ids', data=np.array(unit_ids).astype(np.int32), notes='sorted waveform unit ids') nwbf.add_scratch(name=f'{nwb_object_prefix}_sampling_frequency', data=np.array([samplerate]).astype(np.float64), notes='sorted waveform sampling frequency') nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_ids', data=np.array(recording.get_channel_ids()), notes='sorted waveform channel ids') nwbf.add_scratch(name=f'{nwb_object_prefix}_num_frames', data=np.array([recording.get_num_frames() ]).astype(np.int32), notes='sorted waveform number of frames') channel_locations = recording.get_channel_locations() nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_locations', data=np.array(channel_locations), notes='sorted waveform channel locations') for ii, unit_id in enumerate(unit_ids): x = sorting.get_unit_spike_train(unit_id=unit_id) nwbf.add_scratch( name=f'{nwb_object_prefix}_unit_{unit_id}_spike_trains', data=np.array(x).astype(np.float64), notes=f'sorted spike trains for unit {unit_id}') nwbf.add_scratch( name=f'{nwb_object_prefix}_unit_{unit_id}_waveforms', data=unit_waveforms[ii].astype(np.float32), notes=f'sorted waveforms for unit {unit_id}') nwbf.add_scratch( name=f'{nwb_object_prefix}_unit_{unit_id}_channel_ids', data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int), notes=f'sorted channel ids for unit {unit_id}') nwbf.add_scratch( name=f'{nwb_object_prefix}_unit_{unit_id}_sub_spike_train', data=np.array( sorting_subsampled.get_unit_spike_train( unit_id=unit_id)).astype(np.float64), notes=f'sorted subsampled spike train for unit {unit_id}') io.write(nwbf)
def test_example(self): self.assertEqual(self.RX.get_channel_ids(), self.example_info['channel_ids']) self.assertEqual(self.RX.get_num_channels(), self.example_info['num_channels']) self.assertEqual(self.RX.get_num_frames(), self.example_info['num_frames']) self.assertEqual(self.RX.get_sampling_frequency(), self.example_info['sampling_frequency']) self.assertEqual(self.SX.get_unit_ids(), self.example_info['unit_ids']) self.assertEqual( self.RX.get_channel_locations(0)[0][0], self.example_info['channel_prop'][0]) self.assertEqual( self.RX.get_channel_locations(0)[0][1], self.example_info['channel_prop'][1]) self.assertTrue( np.array_equal(self.RX.get_ttl_events()[0], self.example_info['ttls'])) self.assertEqual( self.SX.get_unit_property(unit_id=1, property_name='stability'), self.example_info['unit_prop']) self.assertTrue( np.array_equal(self.SX.get_unit_spike_train(1), self.example_info['train1'])) self.assertTrue( issubclass(self.SX.get_unit_spike_train(1).dtype.type, np.integer)) self.assertTrue(self.RX.get_shared_channel_property_names(), ['group', 'location', 'shared_channel_prop']) self.assertTrue(self.RX.get_channel_property_names(0), ['group', 'location', 'shared_channel_prop']) self.assertTrue(self.SX2.get_shared_unit_property_names(), ['shared_unit_prop']) self.assertTrue(self.SX2.get_unit_property_names(4), ['shared_unit_prop', 'stability']) self.assertTrue(self.SX2.get_shared_unit_spike_feature_names(), ['shared_unit_feature']) self.assertTrue(self.SX2.get_unit_spike_feature_names(3), ['shared_channel_prop', 'widths']) self.assertTrue( np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy'), self.example_info['features3'])) self.assertTrue( np.array_equal( self.SX3.get_unit_spike_features(0, 'dummy', start_frame=4), self.example_info['features3'][1:])) self.assertTrue( np.array_equal( self.SX3.get_unit_spike_features(0, 'dummy', end_frame=4), self.example_info['features3'][:1])) self.assertTrue( np.array_equal( self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46), self.example_info['features3'][1:6])) self.assertTrue('dummy2' in self.SX3.get_unit_spike_feature_names(0)) self.assertTrue( 'dummy2_idxs' in self.SX3.get_unit_spike_feature_names(0)) sub_extractor_full = se.SubSortingExtractor(self.SX3) sub_extractor_partial = se.SubSortingExtractor(self.SX3, start_frame=20, end_frame=46) self.assertTrue( np.array_equal( sub_extractor_full.get_unit_spike_features(0, 'dummy'), self.SX3.get_unit_spike_features(0, 'dummy'))) self.assertTrue( np.array_equal( sub_extractor_partial.get_unit_spike_features(0, 'dummy'), self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46))) self.assertEqual(tuple(self.RX.get_epoch_info("epoch1").values()), self.example_info['epochs_info'][0]) self.assertEqual(tuple(self.RX.get_epoch_info("epoch2").values()), self.example_info['epochs_info'][1]) self.assertEqual(tuple(self.SX.get_epoch_info("epoch1").values()), self.example_info['epochs_info'][0]) self.assertEqual(tuple(self.SX.get_epoch_info("epoch2").values()), self.example_info['epochs_info'][1]) self.assertEqual(tuple(self.RX.get_epoch_info("epoch1").values()), tuple(self.RX2.get_epoch_info("epoch1").values())) self.assertEqual(tuple(self.RX.get_epoch_info("epoch2").values()), tuple(self.RX2.get_epoch_info("epoch2").values())) self.assertEqual(tuple(self.SX.get_epoch_info("epoch1").values()), tuple(self.SX2.get_epoch_info("epoch1").values())) self.assertEqual(tuple(self.SX.get_epoch_info("epoch2").values()), tuple(self.SX2.get_epoch_info("epoch2").values())) self.assertTrue( np.array_equal( self.RX2.frame_to_time(np.arange(self.RX2.get_num_frames())), self.example_info['times'])) self.assertTrue( np.array_equal( self.SX2.get_unit_spike_train(3) / self.SX2.get_sampling_frequency() + 5, self.SX2.frame_to_time(self.SX2.get_unit_spike_train(3)))) self.RX3.clear_channel_locations() self.assertTrue( 'location' not in self.RX3.get_shared_channel_property_names()) self.RX3.set_channel_locations(self.example_info['geom']) self.assertTrue( np.array_equal(self.RX3.get_channel_locations(), self.RX2.get_channel_locations())) self.RX3.set_channel_groups(groups=[1], channel_ids=[1]) self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 1) self.RX3.clear_channel_groups() self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 0) self.RX3.set_channel_locations(locations=[[np.nan, np.nan, np.nan]], channel_ids=[1]) self.assertTrue( 'location' not in self.RX3.get_shared_channel_property_names()) self.RX3.set_channel_locations(locations=[[0, 0, 0]], channel_ids=[1]) self.assertTrue( 'location' in self.RX3.get_shared_channel_property_names()) check_recording_return_types(self.RX)
def run_conversion(self, nwbfile: NWBFile, metadata: dict, stub_test: bool = False, write_ecephys_metadata: bool = False): if 'UnitProperties' not in metadata: metadata['UnitProperties'] = [] if write_ecephys_metadata and 'Ecephys' in metadata: n_channels = max( [len(x['data']) for x in metadata['Ecephys']['Electrodes']]) recording = se.NumpyRecordingExtractor(timeseries=np.array( range(n_channels)), sampling_frequency=1) se.NwbRecordingExtractor.add_devices(recording=recording, nwbfile=nwbfile, metadata=metadata) se.NwbRecordingExtractor.add_electrode_groups(recording=recording, nwbfile=nwbfile, metadata=metadata) se.NwbRecordingExtractor.add_electrodes(recording=recording, nwbfile=nwbfile, metadata=metadata) property_descriptions = dict() if stub_test: max_min_spike_time = max([ min(x) for y in self.sorting_extractor.get_unit_ids() for x in [self.sorting_extractor.get_unit_spike_train(y)] if any(x) ]) stub_sorting_extractor = se.SubSortingExtractor( self.sorting_extractor, unit_ids=self.sorting_extractor.get_unit_ids(), start_frame=0, end_frame=1.1 * max_min_spike_time) sorting_extractor = stub_sorting_extractor else: sorting_extractor = self.sorting_extractor for metadata_column in metadata['UnitProperties']: assert len(metadata_column['data']) == len(sorting_extractor.get_unit_ids()), \ f"The metadata_column '{metadata_column['name']}' data must have the same dimension as the sorting IDs!" property_descriptions.update( {metadata_column['name']: metadata_column['description']}) for unit_idx, unit_id in enumerate( sorting_extractor.get_unit_ids()): if metadata_column['name'] == 'electrode_group': if nwbfile.electrode_groups: data = nwbfile.electrode_groups[metadata_column['data'] [unit_idx]] sorting_extractor.set_unit_property( unit_id, metadata_column['name'], data) else: data = metadata_column['data'][unit_idx] sorting_extractor.set_unit_property( unit_id, metadata_column['name'], data) se.NwbSortingExtractor.write_sorting( sorting_extractor, property_descriptions=property_descriptions, nwbfile=nwbfile)
def prepare_snippets_h5_from_extractors(recording: se.RecordingExtractor, sorting: se.SortingExtractor, output_h5_path: str, start_frame, end_frame, max_neighborhood_size: int, max_events_per_unit: Union[None, int] = None, snippet_len=(50, 80)): import h5py from labbox_ephys import (SubsampledSortingExtractor, find_unit_neighborhoods, find_unit_peak_channels, get_unit_waveforms) if start_frame is not None: recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=start_frame, end_frame=end_frame) sorting = se.SubSortingExtractor(parent_sorting=sorting, start_frame=start_frame, end_frame=end_frame) unit_ids = sorting.get_unit_ids() samplerate = recording.get_sampling_frequency() # Use this optimized function rather than spiketoolkit's version # for efficiency with long recordings and/or many channels, units or spikes # we should submit this to the spiketoolkit project as a PR print('Subsampling sorting') if max_events_per_unit is not None: sorting_subsampled = SubsampledSortingExtractor( parent_sorting=sorting, max_events_per_unit=max_events_per_unit, method='random') else: sorting_subsampled = sorting print('Finding unit peak channels') peak_channels_by_unit = find_unit_peak_channels(recording=recording, sorting=sorting, unit_ids=unit_ids) print('Finding unit neighborhoods') channel_ids_by_unit = find_unit_neighborhoods( recording=recording, peak_channels_by_unit=peak_channels_by_unit, max_neighborhood_size=max_neighborhood_size) print(f'Getting unit waveforms for {len(unit_ids)} units') unit_waveforms = get_unit_waveforms( recording=recording, sorting=sorting_subsampled, unit_ids=unit_ids, channel_ids_by_unit=channel_ids_by_unit, snippet_len=snippet_len) # unit_waveforms = st.postprocessing.get_unit_waveforms( # recording=recording, # sorting=sorting, # unit_ids=unit_ids, # ms_before=1, # ms_after=1.5, # max_spikes_per_unit=500 # ) save_path = output_h5_path with h5py.File(save_path, 'w') as f: f.create_dataset('unit_ids', data=np.array(unit_ids).astype(np.int32)) f.create_dataset('sampling_frequency', data=np.array([samplerate]).astype(np.float64)) f.create_dataset('channel_ids', data=np.array(recording.get_channel_ids())) f.create_dataset('num_frames', data=np.array([recording.get_num_frames() ]).astype(np.int32)) channel_locations = recording.get_channel_locations() f.create_dataset(f'channel_locations', data=np.array(channel_locations)) for ii, unit_id in enumerate(unit_ids): x = sorting.get_unit_spike_train(unit_id=unit_id) f.create_dataset(f'unit_spike_trains/{unit_id}', data=np.array(x).astype(np.float64)) f.create_dataset(f'unit_waveforms/{unit_id}/waveforms', data=unit_waveforms[ii].astype(np.float32)) f.create_dataset( f'unit_waveforms/{unit_id}/channel_ids', data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int)) f.create_dataset(f'unit_waveforms/{unit_id}/spike_train', data=np.array( sorting_subsampled.get_unit_spike_train( unit_id=unit_id)).astype(np.float64))