def write_sorting(sorting: SortingExtractor, save_path: PathType):
        assert save_path.suffixes == [
            ".spikes",
            ".cellinfo",
            ".mat",
        ], "The save_path must correspond to the CellExplorer format of sorting_id.spikes.cellinfo.mat!"

        base_path = save_path.parent
        sorting_id = save_path.name.split(".")[0]
        session_info_save_path = base_path / f"{sorting_id}.sessionInfo.mat"
        spikes_save_path = save_path
        base_path.mkdir(parents=True, exist_ok=True)

        sampling_frequency = sorting.get_sampling_frequency()
        session_info_mat_dict = dict(sessionInfo=dict(rates=dict(
            wideband=sampling_frequency)))

        scipy.io.savemat(file_name=session_info_save_path,
                         mdict=session_info_mat_dict)

        spikes_mat_dict = dict(spikes=dict(
            UID=sorting.get_unit_ids(),
            times=[[[y / sampling_frequency] for y in x]
                   for x in sorting.get_units_spike_train()],
        ))
        # If, in the future, it is ever desired to allow this to write unit properties, they must conform
        # to the format here: https://cellexplorer.org/datastructure/data-structure-and-format/
        scipy.io.savemat(file_name=spikes_save_path, mdict=spikes_mat_dict)
예제 #2
0
def prepare_snippets_h5_from_extractors(recording: se.RecordingExtractor,
                                        sorting: se.SortingExtractor,
                                        output_h5_path: str,
                                        start_frame,
                                        end_frame,
                                        max_neighborhood_size: int,
                                        max_events_per_unit: Union[None,
                                                                   int] = None,
                                        snippet_len=(50, 80)):
    import h5py
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )

    save_path = output_h5_path
    with h5py.File(save_path, 'w') as f:
        f.create_dataset('unit_ids', data=np.array(unit_ids).astype(np.int32))
        f.create_dataset('sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64))
        f.create_dataset('channel_ids',
                         data=np.array(recording.get_channel_ids()))
        f.create_dataset('num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32))
        channel_locations = recording.get_channel_locations()
        f.create_dataset(f'channel_locations',
                         data=np.array(channel_locations))
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            f.create_dataset(f'unit_spike_trains/{unit_id}',
                             data=np.array(x).astype(np.float64))
            f.create_dataset(f'unit_waveforms/{unit_id}/waveforms',
                             data=unit_waveforms[ii].astype(np.float32))
            f.create_dataset(
                f'unit_waveforms/{unit_id}/channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int))
            f.create_dataset(f'unit_waveforms/{unit_id}/spike_train',
                             data=np.array(
                                 sorting_subsampled.get_unit_spike_train(
                                     unit_id=unit_id)).astype(np.float64))
예제 #3
0
def prepare_snippets_nwb_from_extractors(
        recording: se.RecordingExtractor,
        sorting: se.SortingExtractor,
        nwb_file_path: str,
        nwb_object_prefix: str,
        start_frame,
        end_frame,
        max_neighborhood_size: int,
        max_events_per_unit: Union[None, int] = None,
        snippet_len=(50, 80),
):
    import pynwb
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )
    with pynwb.NWBHDF5IO(path=nwb_file_path, mode='a') as io:
        nwbf = io.read()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_unit_ids',
                         data=np.array(unit_ids).astype(np.int32),
                         notes='sorted waveform unit ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64),
                         notes='sorted waveform sampling frequency')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_ids',
                         data=np.array(recording.get_channel_ids()),
                         notes='sorted waveform channel ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32),
                         notes='sorted waveform number of frames')
        channel_locations = recording.get_channel_locations()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_locations',
                         data=np.array(channel_locations),
                         notes='sorted waveform channel locations')
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_spike_trains',
                data=np.array(x).astype(np.float64),
                notes=f'sorted spike trains for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_waveforms',
                data=unit_waveforms[ii].astype(np.float32),
                notes=f'sorted waveforms for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int),
                notes=f'sorted channel ids for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_sub_spike_train',
                data=np.array(
                    sorting_subsampled.get_unit_spike_train(
                        unit_id=unit_id)).astype(np.float64),
                notes=f'sorted subsampled spike train for unit {unit_id}')
        io.write(nwbf)