Ejemplo n.º 1
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   if (self.channel_ids) and (len(self.channel_ids)>0):
     R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
   recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   sorting=si.MdaSortingExtractor(firings_file=self.firings)
   ef=int(1e6)
   recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef)
   recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub)
   sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef)
   unit_ids=self.unit_ids
   if (not unit_ids) or (len(unit_ids)==0):
     unit_ids=sorting.getUnitIds()
 
   channel_noise_levels=compute_channel_noise_levels(recording=recording)
   print('computing templates...')
   templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids)
   print('.')
   ret=[]
   for i,unit_id in enumerate(unit_ids):
     template=templates[i]
     info0=dict()
     info0['unit_id']=int(unit_id)
     info0['snr']=compute_template_snr(template,channel_noise_levels)
     peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
     info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index])
     train=sorting.getUnitSpikeTrain(unit_id=unit_id)
     info0['num_events']=int(len(train))
     info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
     ret.append(info0)
   write_json_file(self.json_out,ret)
    def run(self):
        print(
            'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}'
            .format(self.firings, self.firings_true, self.units_true))
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        if (self.units_true is not None) and (len(self.units_true) > 0):
            sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true,
                                                  unit_ids=self.units_true)

        SC = st.comparison.compare_sorter_to_ground_truth(
            gt_sorting=sorting_true,
            tested_sorting=sorting,
            delta_time=0.3,
            min_accuracy=0,
            compute_misclassification=False,
            exhaustive_gt=False  # Fix this in future
        )
        df = pd.concat([SC.count, SC.get_performance()], axis=1).reset_index()

        df = df.rename(columns=dict(gt_unit_id='unit_id',
                                    fp='num_false_positives',
                                    fn='num_false_negatives',
                                    tested_id='best_unit',
                                    tp='num_matches'))
        df['matched_unit'] = df['best_unit']
        df['f_p'] = 1 - df['precision']
        df['f_n'] = 1 - df['recall']

        # sw.SortingComparisonTable(comparison=SC).getDataframe()
        json = df.transpose().to_dict()
        html = df.to_html(index=False)
        _write_json_file(json, self.json_out)
        _write_json_file(html, self.html_out)
Ejemplo n.º 3
0
def old_fetch_average_waveform_plot_data(recording_object, sorting_object,
                                         unit_id):
    import labbox_ephys as le
    R = le.LabboxEphysRecordingExtractor(recording_object)
    S = le.LabboxEphysSortingExtractor(sorting_object)

    start_frame = 0
    end_frame = R.get_sampling_frequency() * 30
    R0 = se.SubRecordingExtractor(parent_recording=R,
                                  start_frame=start_frame,
                                  end_frame=end_frame)
    S0 = se.SubSortingExtractor(parent_sorting=S,
                                start_frame=start_frame,
                                end_frame=end_frame)

    times0 = S0.get_unit_spike_train(unit_id=unit_id)
    if len(times0) == 0:
        # no waveforms found
        return dict(channel_id=None, average_waveform=None)
    try:
        average_waveform = st.postprocessing.get_unit_templates(
            recording=R0, sorting=S0, unit_ids=[unit_id])[0]
    except:
        raise Exception(f'Error getting unit templates for unit {unit_id}')

    channel_maximums = np.max(np.abs(average_waveform), axis=1)
    maxchan_index = np.argmax(channel_maximums)
    maxchan_id = R0.get_channel_ids()[maxchan_index]

    return dict(channel_id=maxchan_id,
                average_waveform=average_waveform[maxchan_index, :].tolist())
Ejemplo n.º 4
0
    def test_multi_sub_sorting_extractor(self):
        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], )
        SX_multi.set_unit_property(unit_id=1, property_name='dummy', value=5)
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0)
        self._check_sortings_equal(SX_multi, SX_sub)
        self.assertEqual(SX_multi.get_unit_property(1, 'dummy'),
                         SX_sub.get_unit_property(1, 'dummy'))

        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(sortings=[self.SX, self.SX2], )
        SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi,
                                         start_frame=0,
                                         end_frame=N)
        self._check_sortings_equal(SX_multi, SX_sub1)
Ejemplo n.º 5
0
def _keep_good_units(sorting_obj, cluster_groups_csv_uri):
    sorting = LabboxEphysSortingExtractor(sorting_obj)
    df = pd.read_csv(kp.load_file(cluster_groups_csv_uri), delimiter='\t')
    df_good = df.loc[df['group'] == 'good']
    good_unit_ids = df_good['cluster_id'].to_numpy().tolist()
    sorting_good = se.SubSortingExtractor(parent_sorting=sorting,
                                          unit_ids=good_unit_ids)
    return _create_npy1_sorting_object(sorting=sorting_good)
Ejemplo n.º 6
0
 def run(self):
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true)
     if len(self.units_true)>0:
         sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true)
     SC=st.comparison.SortingComparison(sorting_true,sorting)
     df=sw.SortingComparisonTable(comparison=SC).getDataframe()
     json=df.transpose().to_dict()
     html=df.to_html(index=False)
     _write_json_file(json,self.json_out)
     _write_json_file(html,self.html_out)
Ejemplo n.º 7
0
    def test_multi_sub_sorting_extractor(self):
        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi,
                                        start_frame=N,
                                        end_frame=2 * N)
        self._check_sortings_equal(self.SX, SX_sub)
        self.assertEqual(SX_multi.get_sampling_frequency(),
                         self.SX.get_sampling_frequency())
        self.assertEqual(SX_sub.get_sampling_frequency(),
                         self.SX.get_sampling_frequency())

        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0)
        self._check_sortings_equal(SX_multi, SX_sub)

        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[2 * N, 0, N])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi,
                                        start_frame=N,
                                        end_frame=2 * N)
        self._check_sortings_equal(self.SX, SX_sub)

        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[0, 0, 0])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0)
        self._check_sortings_equal(SX_multi, SX_sub)

        N = self.RX.get_num_frames()
        SX_multi = se.MultiSortingExtractor(sortings=[self.SX, self.SX2],
                                            start_frames=[0, 0])
        SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi,
                                         start_frame=0,
                                         end_frame=N)
        self._check_sortings_equal(SX_multi, SX_sub1)
def get_unit_waveforms(recording, sorting, unit_ids, channel_ids_by_unit,
                       snippet_len):
    if not isinstance(snippet_len, list) and not isinstance(
            snippet_len, tuple):
        b = int(snippet_len / 2)
        a = int(snippet_len) - b
        snippet_len = [a, b]

    num_channels = recording.get_num_channels()
    num_frames = recording.get_num_frames()
    num_bytes_per_chunk = 1000 * 1000 * 1000  # ? how to choose this
    num_bytes_per_frame = num_channels * 2
    chunk_size = num_bytes_per_chunk / num_bytes_per_frame
    padding_size = 100 + snippet_len[0] + snippet_len[
        1]  # a bit excess padding
    chunks = _divide_recording_into_time_chunks(num_frames=num_frames,
                                                chunk_size=chunk_size,
                                                padding_size=padding_size)
    all_unit_waveforms = [[] for ii in range(len(unit_ids))]
    for ii, chunk in enumerate(chunks):
        # chunk: {istart, iend, istart_with_padding, iend_with_padding} # include padding
        print(
            f'Processing chunk {ii + 1} of {len(chunks)}; chunk-range: {chunk["istart_with_padding"]} {chunk["iend_with_padding"]}; num-frames: {num_frames}'
        )
        recording_chunk = se.SubRecordingExtractor(
            parent_recording=recording,
            start_frame=chunk['istart_with_padding'],
            end_frame=chunk['iend_with_padding'])
        # note that the efficiency of this operation may need improvement (really depends on sorting extractor implementation)
        sorting_chunk = se.SubSortingExtractor(parent_sorting=sorting,
                                               start_frame=chunk['istart'],
                                               end_frame=chunk['iend'])
        print(f'Getting unit waveforms for chunk {ii + 1} of {len(chunks)}')
        # num_events_in_chunk x num_channels_in_nbhd[unit_id] x len_of_one_snippet
        unit_waveforms = _get_unit_waveforms_for_chunk(
            recording=recording_chunk,
            sorting=sorting_chunk,
            frame_offset=chunk['istart'] - chunk[
                'istart_with_padding'],  # just the padding size (except 0 for first chunk)
            unit_ids=unit_ids,
            snippet_len=snippet_len,
            channel_ids_by_unit=channel_ids_by_unit)
        for i_unit, x in enumerate(unit_waveforms):
            all_unit_waveforms[i_unit].append(x)

    # concatenate the results over the chunks
    unit_waveforms = [
        # tot_num_events_for_unit x num_channels_in_nbhd[unit_id] x len_of_one_snippet
        np.concatenate(all_unit_waveforms[i_unit], axis=0)
        for i_unit in range(len(unit_ids))
    ]
    return unit_waveforms
Ejemplo n.º 9
0
    def test_example(self):
        self.assertEqual(self.RX.get_channel_ids(), self.example_info['channel_ids'])
        self.assertEqual(self.RX.get_num_channels(), self.example_info['num_channels'])
        self.assertEqual(self.RX.get_num_frames(), self.example_info['num_frames'])
        self.assertEqual(self.RX.get_sampling_frequency(), self.example_info['sampling_frequency'])
        self.assertEqual(self.SX.get_unit_ids(), self.example_info['unit_ids'])
        self.assertEqual(self.RX.get_channel_locations(0)[0][0], self.example_info['channel_prop'][0])
        self.assertEqual(self.RX.get_channel_locations(0)[0][1], self.example_info['channel_prop'][1])
        self.assertEqual(self.SX.get_unit_property(unit_id=1, property_name='stability'),
                         self.example_info['unit_prop'])
        self.assertTrue(np.array_equal(self.SX.get_unit_spike_train(1), self.example_info['train1']))
        self.assertTrue(issubclass(self.SX.get_unit_spike_train(1).dtype.type, np.integer))
        self.assertTrue(self.RX.get_shared_channel_property_names(), ['group', 'location', 'shared_channel_prop'])
        self.assertTrue(self.RX.get_channel_property_names(0), ['group', 'location', 'shared_channel_prop'])
        self.assertTrue(self.SX2.get_shared_unit_property_names(), ['shared_unit_prop'])
        self.assertTrue(self.SX2.get_unit_property_names(4), ['shared_unit_prop', 'stability'])
        self.assertTrue(self.SX2.get_shared_unit_spike_feature_names(), ['shared_unit_feature'])
        self.assertTrue(self.SX2.get_unit_spike_feature_names(3), ['shared_channel_prop', 'widths'])

        print(self.SX3.get_unit_spike_features(0, 'dummy'))
        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy'), self.example_info['features3']))
        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=4),
                                       self.example_info['features3'][1:]))
        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', end_frame=4),
                                       self.example_info['features3'][:1]))
        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46),
                                       self.example_info['features3'][1:6]))
        self.assertTrue('dummy2' in self.SX3.get_unit_spike_feature_names(0))
        self.assertTrue('dummy2_idxs' in self.SX3.get_unit_spike_feature_names(0))

        sub_extractor_full = se.SubSortingExtractor(self.SX3)
        sub_extractor_partial = se.SubSortingExtractor(self.SX3, start_frame=20, end_frame=46)

        self.assertTrue(np.array_equal(sub_extractor_full.get_unit_spike_features(0, 'dummy'),
                                       self.SX3.get_unit_spike_features(0, 'dummy')))
        self.assertTrue(np.array_equal(sub_extractor_partial.get_unit_spike_features(0, 'dummy'),
                                       self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46)))

        check_recording_return_types(self.RX)
Ejemplo n.º 10
0
    def test_multi_sub_sorting_extractor(self):
        N = self.RX.getNumFrames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi,
                                        start_frame=N,
                                        end_frame=2 * N)
        self._check_sortings_equal(self.SX, SX_sub)

        N = self.RX.getNumFrames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[0, N, 2 * N])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0)
        self._check_sortings_equal(SX_multi, SX_sub)

        N = self.RX.getNumFrames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[2 * N, 0, N])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi,
                                        start_frame=N,
                                        end_frame=2 * N)
        self._check_sortings_equal(self.SX, SX_sub)

        N = self.RX.getNumFrames()
        SX_multi = se.MultiSortingExtractor(
            sortings=[self.SX, self.SX, self.SX], start_frames=[0, 0, 0])
        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0)
        self._check_sortings_equal(SX_multi, SX_sub)

        N = self.RX.getNumFrames()
        SX_multi = se.MultiSortingExtractor(sortings=[self.SX, self.SX2],
                                            start_frames=[0, 0])
        SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi,
                                         start_frame=0,
                                         end_frame=N)
        self._check_sortings_equal(SX_multi, SX_sub1)
Ejemplo n.º 11
0
def mountainsort4_curation(*,
                           recording,
                           sorting,
                           noise_overlap_threshold=None):
    if noise_overlap_threshold is not None:
        units = sorting.get_unit_ids()
        noise_overlap_scores = compute_noise_overlap(recording=recording,
                                                     sorting=sorting,
                                                     unit_ids=units)
        inds = np.where(
            np.array(noise_overlap_scores) <= noise_overlap_threshold)[0]
        new_units = list(np.array(units)[inds])
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         unit_ids=new_units)
    return sorting
    def run(self):
        print(
            'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}'
            .format(self.firings, self.firings_true, self.units_true))
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        if (self.units_true is not None) and (len(self.units_true) > 0):
            sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true,
                                                  unit_ids=self.units_true)

        SC = SortingComparison(sorting_true, sorting, delta_tp=30)
        df = get_comparison_data_frame(comparison=SC)
        # sw.SortingComparisonTable(comparison=SC).getDataframe()
        json = df.transpose().to_dict()
        html = df.to_html(index=False)
        _write_json_file(json, self.json_out)
        _write_json_file(html, self.html_out)
Ejemplo n.º 13
0
    def test_dump_load_multi_sub_extractor(self):
        # generate dumpable formats
        path1 = self.test_dir + '/mda'
        path2 = path1 + '/firings_true.mda'
        se.MdaRecordingExtractor.write_recording(self.RX, path1)
        se.MdaSortingExtractor.write_sorting(self.SX, path2)
        RX_mda = se.MdaRecordingExtractor(path1)
        SX_mda = se.MdaSortingExtractor(path2)

        RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda])
        check_dumping(RX_multi_chan)
        RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], )
        check_dumping(RX_multi_time)
        RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1])
        check_dumping(RX_multi_chan)

        SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2])
        check_dumping(SX_sub)
        SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda])
        check_dumping(SX_multi)
Ejemplo n.º 14
0
def prepare_snippets_nwb_from_extractors(
        recording: se.RecordingExtractor,
        sorting: se.SortingExtractor,
        nwb_file_path: str,
        nwb_object_prefix: str,
        start_frame,
        end_frame,
        max_neighborhood_size: int,
        max_events_per_unit: Union[None, int] = None,
        snippet_len=(50, 80),
):
    import pynwb
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )
    with pynwb.NWBHDF5IO(path=nwb_file_path, mode='a') as io:
        nwbf = io.read()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_unit_ids',
                         data=np.array(unit_ids).astype(np.int32),
                         notes='sorted waveform unit ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64),
                         notes='sorted waveform sampling frequency')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_ids',
                         data=np.array(recording.get_channel_ids()),
                         notes='sorted waveform channel ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32),
                         notes='sorted waveform number of frames')
        channel_locations = recording.get_channel_locations()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_locations',
                         data=np.array(channel_locations),
                         notes='sorted waveform channel locations')
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_spike_trains',
                data=np.array(x).astype(np.float64),
                notes=f'sorted spike trains for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_waveforms',
                data=unit_waveforms[ii].astype(np.float32),
                notes=f'sorted waveforms for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int),
                notes=f'sorted channel ids for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_sub_spike_train',
                data=np.array(
                    sorting_subsampled.get_unit_spike_train(
                        unit_id=unit_id)).astype(np.float64),
                notes=f'sorted subsampled spike train for unit {unit_id}')
        io.write(nwbf)
Ejemplo n.º 15
0
    def test_example(self):
        self.assertEqual(self.RX.get_channel_ids(),
                         self.example_info['channel_ids'])
        self.assertEqual(self.RX.get_num_channels(),
                         self.example_info['num_channels'])
        self.assertEqual(self.RX.get_num_frames(),
                         self.example_info['num_frames'])
        self.assertEqual(self.RX.get_sampling_frequency(),
                         self.example_info['sampling_frequency'])
        self.assertEqual(self.SX.get_unit_ids(), self.example_info['unit_ids'])
        self.assertEqual(
            self.RX.get_channel_locations(0)[0][0],
            self.example_info['channel_prop'][0])
        self.assertEqual(
            self.RX.get_channel_locations(0)[0][1],
            self.example_info['channel_prop'][1])
        self.assertTrue(
            np.array_equal(self.RX.get_ttl_events()[0],
                           self.example_info['ttls']))
        self.assertEqual(
            self.SX.get_unit_property(unit_id=1, property_name='stability'),
            self.example_info['unit_prop'])
        self.assertTrue(
            np.array_equal(self.SX.get_unit_spike_train(1),
                           self.example_info['train1']))

        self.assertTrue(
            issubclass(self.SX.get_unit_spike_train(1).dtype.type, np.integer))
        self.assertTrue(self.RX.get_shared_channel_property_names(),
                        ['group', 'location', 'shared_channel_prop'])
        self.assertTrue(self.RX.get_channel_property_names(0),
                        ['group', 'location', 'shared_channel_prop'])
        self.assertTrue(self.SX2.get_shared_unit_property_names(),
                        ['shared_unit_prop'])
        self.assertTrue(self.SX2.get_unit_property_names(4),
                        ['shared_unit_prop', 'stability'])
        self.assertTrue(self.SX2.get_shared_unit_spike_feature_names(),
                        ['shared_unit_feature'])
        self.assertTrue(self.SX2.get_unit_spike_feature_names(3),
                        ['shared_channel_prop', 'widths'])

        self.assertTrue(
            np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy'),
                           self.example_info['features3']))
        self.assertTrue(
            np.array_equal(
                self.SX3.get_unit_spike_features(0, 'dummy', start_frame=4),
                self.example_info['features3'][1:]))
        self.assertTrue(
            np.array_equal(
                self.SX3.get_unit_spike_features(0, 'dummy', end_frame=4),
                self.example_info['features3'][:1]))
        self.assertTrue(
            np.array_equal(
                self.SX3.get_unit_spike_features(0,
                                                 'dummy',
                                                 start_frame=20,
                                                 end_frame=46),
                self.example_info['features3'][1:6]))
        self.assertTrue('dummy2' in self.SX3.get_unit_spike_feature_names(0))
        self.assertTrue(
            'dummy2_idxs' in self.SX3.get_unit_spike_feature_names(0))

        sub_extractor_full = se.SubSortingExtractor(self.SX3)
        sub_extractor_partial = se.SubSortingExtractor(self.SX3,
                                                       start_frame=20,
                                                       end_frame=46)

        self.assertTrue(
            np.array_equal(
                sub_extractor_full.get_unit_spike_features(0, 'dummy'),
                self.SX3.get_unit_spike_features(0, 'dummy')))
        self.assertTrue(
            np.array_equal(
                sub_extractor_partial.get_unit_spike_features(0, 'dummy'),
                self.SX3.get_unit_spike_features(0,
                                                 'dummy',
                                                 start_frame=20,
                                                 end_frame=46)))

        self.assertEqual(tuple(self.RX.get_epoch_info("epoch1").values()),
                         self.example_info['epochs_info'][0])
        self.assertEqual(tuple(self.RX.get_epoch_info("epoch2").values()),
                         self.example_info['epochs_info'][1])
        self.assertEqual(tuple(self.SX.get_epoch_info("epoch1").values()),
                         self.example_info['epochs_info'][0])
        self.assertEqual(tuple(self.SX.get_epoch_info("epoch2").values()),
                         self.example_info['epochs_info'][1])

        self.assertEqual(tuple(self.RX.get_epoch_info("epoch1").values()),
                         tuple(self.RX2.get_epoch_info("epoch1").values()))
        self.assertEqual(tuple(self.RX.get_epoch_info("epoch2").values()),
                         tuple(self.RX2.get_epoch_info("epoch2").values()))
        self.assertEqual(tuple(self.SX.get_epoch_info("epoch1").values()),
                         tuple(self.SX2.get_epoch_info("epoch1").values()))
        self.assertEqual(tuple(self.SX.get_epoch_info("epoch2").values()),
                         tuple(self.SX2.get_epoch_info("epoch2").values()))

        self.assertTrue(
            np.array_equal(
                self.RX2.frame_to_time(np.arange(self.RX2.get_num_frames())),
                self.example_info['times']))
        self.assertTrue(
            np.array_equal(
                self.SX2.get_unit_spike_train(3) /
                self.SX2.get_sampling_frequency() + 5,
                self.SX2.frame_to_time(self.SX2.get_unit_spike_train(3))))

        self.RX3.clear_channel_locations()
        self.assertTrue(
            'location' not in self.RX3.get_shared_channel_property_names())
        self.RX3.set_channel_locations(self.example_info['geom'])
        self.assertTrue(
            np.array_equal(self.RX3.get_channel_locations(),
                           self.RX2.get_channel_locations()))
        self.RX3.set_channel_groups(groups=[1], channel_ids=[1])
        self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 1)
        self.RX3.clear_channel_groups()
        self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 0)
        self.RX3.set_channel_locations(locations=[[np.nan, np.nan, np.nan]],
                                       channel_ids=[1])
        self.assertTrue(
            'location' not in self.RX3.get_shared_channel_property_names())
        self.RX3.set_channel_locations(locations=[[0, 0, 0]], channel_ids=[1])
        self.assertTrue(
            'location' in self.RX3.get_shared_channel_property_names())
        check_recording_return_types(self.RX)
    def run_conversion(self,
                       nwbfile: NWBFile,
                       metadata: dict,
                       stub_test: bool = False,
                       write_ecephys_metadata: bool = False):
        if 'UnitProperties' not in metadata:
            metadata['UnitProperties'] = []
        if write_ecephys_metadata and 'Ecephys' in metadata:
            n_channels = max(
                [len(x['data']) for x in metadata['Ecephys']['Electrodes']])
            recording = se.NumpyRecordingExtractor(timeseries=np.array(
                range(n_channels)),
                                                   sampling_frequency=1)
            se.NwbRecordingExtractor.add_devices(recording=recording,
                                                 nwbfile=nwbfile,
                                                 metadata=metadata)
            se.NwbRecordingExtractor.add_electrode_groups(recording=recording,
                                                          nwbfile=nwbfile,
                                                          metadata=metadata)
            se.NwbRecordingExtractor.add_electrodes(recording=recording,
                                                    nwbfile=nwbfile,
                                                    metadata=metadata)

        property_descriptions = dict()
        if stub_test:
            max_min_spike_time = max([
                min(x) for y in self.sorting_extractor.get_unit_ids()
                for x in [self.sorting_extractor.get_unit_spike_train(y)]
                if any(x)
            ])
            stub_sorting_extractor = se.SubSortingExtractor(
                self.sorting_extractor,
                unit_ids=self.sorting_extractor.get_unit_ids(),
                start_frame=0,
                end_frame=1.1 * max_min_spike_time)
            sorting_extractor = stub_sorting_extractor
        else:
            sorting_extractor = self.sorting_extractor

        for metadata_column in metadata['UnitProperties']:
            assert len(metadata_column['data']) == len(sorting_extractor.get_unit_ids()), \
                f"The metadata_column '{metadata_column['name']}' data must have the same dimension as the sorting IDs!"

            property_descriptions.update(
                {metadata_column['name']: metadata_column['description']})
            for unit_idx, unit_id in enumerate(
                    sorting_extractor.get_unit_ids()):
                if metadata_column['name'] == 'electrode_group':
                    if nwbfile.electrode_groups:
                        data = nwbfile.electrode_groups[metadata_column['data']
                                                        [unit_idx]]
                        sorting_extractor.set_unit_property(
                            unit_id, metadata_column['name'], data)
                else:
                    data = metadata_column['data'][unit_idx]
                    sorting_extractor.set_unit_property(
                        unit_id, metadata_column['name'], data)

        se.NwbSortingExtractor.write_sorting(
            sorting_extractor,
            property_descriptions=property_descriptions,
            nwbfile=nwbfile)
Ejemplo n.º 17
0
def prepare_snippets_h5_from_extractors(recording: se.RecordingExtractor,
                                        sorting: se.SortingExtractor,
                                        output_h5_path: str,
                                        start_frame,
                                        end_frame,
                                        max_neighborhood_size: int,
                                        max_events_per_unit: Union[None,
                                                                   int] = None,
                                        snippet_len=(50, 80)):
    import h5py
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )

    save_path = output_h5_path
    with h5py.File(save_path, 'w') as f:
        f.create_dataset('unit_ids', data=np.array(unit_ids).astype(np.int32))
        f.create_dataset('sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64))
        f.create_dataset('channel_ids',
                         data=np.array(recording.get_channel_ids()))
        f.create_dataset('num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32))
        channel_locations = recording.get_channel_locations()
        f.create_dataset(f'channel_locations',
                         data=np.array(channel_locations))
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            f.create_dataset(f'unit_spike_trains/{unit_id}',
                             data=np.array(x).astype(np.float64))
            f.create_dataset(f'unit_waveforms/{unit_id}/waveforms',
                             data=unit_waveforms[ii].astype(np.float32))
            f.create_dataset(
                f'unit_waveforms/{unit_id}/channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int))
            f.create_dataset(f'unit_waveforms/{unit_id}/spike_train',
                             data=np.array(
                                 sorting_subsampled.get_unit_spike_train(
                                     unit_id=unit_id)).astype(np.float64))