예제 #1
0
 def from_memory(recording: se.RecordingExtractor,
                 serialize=False,
                 serialize_dtype=None):
     if serialize:
         if serialize_dtype is None:
             raise Exception(
                 'You must specify the serialize_dtype when serializing recording extractor in from_memory()'
             )
         with hi.TemporaryDirectory() as tmpdir:
             fname = tmpdir + '/' + _random_string(10) + '_recording.mda'
             se.BinDatRecordingExtractor.write_recording(
                 recording=recording,
                 save_path=fname,
                 time_axis=0,
                 dtype=serialize_dtype)
             with ka.config(use_hard_links=True):
                 uri = ka.store_file(fname, basename='raw.mda')
             num_channels = recording.get_num_channels()
             channel_ids = [int(a) for a in recording.get_channel_ids()]
             xcoords = [
                 recording.get_channel_property(a, 'location')[0]
                 for a in channel_ids
             ]
             ycoords = [
                 recording.get_channel_property(a, 'location')[1]
                 for a in channel_ids
             ]
             recording = LabboxEphysRecordingExtractor({
                 'recording_format': 'bin1',
                 'data': {
                     'raw':
                     uri,
                     'raw_num_channels':
                     num_channels,
                     'num_frames':
                     int(recording.get_num_frames()),
                     'samplerate':
                     float(recording.get_sampling_frequency()),
                     'channel_ids':
                     channel_ids,
                     'channel_map':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [int(i) for i in range(num_channels)])),
                     'channel_positions':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [[float(xcoords[i]),
                               float(ycoords[i])]
                              for i in range(num_channels)]))
                 }
             })
             return recording
     obj = {
         'recording_format': 'in_memory',
         'data': register_in_memory_object(recording)
     }
     return LabboxEphysRecordingExtractor(obj)
예제 #2
0
def _get_geom_from_recording(recording: se.RecordingExtractor):
    channel_ids = recording.get_channel_ids()
    M = len(channel_ids)
    location0 = recording.get_channel_property(channel_ids[0], 'location')
    nd = len(location0)
    geom = np.zeros((M, nd))
    for i in range(M):
        location_i = recording.get_channel_property(channel_ids[i], 'location')
    geom[i, :] = location_i
    return geom
예제 #3
0
def estimate_noise_level(recording: se.RecordingExtractor):
    N = recording.get_num_frames()
    samplerate = recording.get_sampling_frequency()
    start_frame = 0
    end_frame = int(np.minimum(samplerate * 1, N))
    X = recording.get_traces(
        channel_ids=[int(id) for id in recording.get_channel_ids()],
        start_frame=start_frame,
        end_frame=end_frame)
    est_noise_level = np.median(np.abs(X.squeeze(
    ))) / 0.6745  # median absolute deviation (MAD) estimate of stdev
    if (est_noise_level == 0): est_noise_level = 1
    return est_noise_level
예제 #4
0
def _get_geom_from_recording(recording: se.RecordingExtractor) -> np.ndarray:
    """Retrieve the electrode locations for a recording

    Parameters
    ----------
    recording : se.RecordingExtractor
        Recording extractor (SpikeInterface)

    Returns
    -------
    np.ndarray
        Geom array (M x D) where the dimension D is either 2 or 3
    """
    channel_ids = recording.get_channel_ids()
    M = len(channel_ids)
    location0 = recording.get_channel_property(channel_ids[0], 'location')
    nd = len(location0)
    geom = np.zeros((M, nd))
    for ch_id, ii in enumerate(channel_ids):
        location_ii = recording.get_channel_property(
            ch_id, 'location')
        geom[ii, :] = list(location_ii)
    return geom
예제 #5
0
def _get_neighborhoods(*, recording: se.RecordingExtractor, opts: EphysNlmV1Opts) -> List[Dict]:
    """Get a list of neighborhoods from a recording extractor based on the ephys_nlm options

    Parameters
    ----------
    recording : se.RecordingExtractor
        Recording extractor (SpikeInterface)
    opts : EphysNlmV1Opts
        Denoising options

    Returns
    -------
    List[Dict]
        List of dictionaries representing neighborhoods. Each dictionary contains information about the neighborhood, such as number of channels.
    """
    M = len(recording.get_channel_ids())
    if opts.multi_neighborhood is False:
        # A single neighborhood
        return [
            dict(
                channel_indices=np.arange(M),
                target_indices=np.arange(M)
            )
        ]
    geom: np.ndarray = _get_geom_from_recording(recording=recording)
    adjacency_radius = opts.neighborhood_adjacency_radius
    assert adjacency_radius is not None, 'You need to provide neighborhood_adjacency_radius when multi_neighborhood is True'
    ret = []
    for m in range(M):
        channel_indices = _get_channel_neighborhood(
            m=m, geom=geom, adjacency_radius=adjacency_radius)
        ret.append(dict(
            channel_indices=channel_indices,
            target_indices=[m]
        ))
    return ret
예제 #6
0
def prepare_snippets_h5_from_extractors(recording: se.RecordingExtractor,
                                        sorting: se.SortingExtractor,
                                        output_h5_path: str,
                                        start_frame,
                                        end_frame,
                                        max_neighborhood_size: int,
                                        max_events_per_unit: Union[None,
                                                                   int] = None,
                                        snippet_len=(50, 80)):
    import h5py
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )

    save_path = output_h5_path
    with h5py.File(save_path, 'w') as f:
        f.create_dataset('unit_ids', data=np.array(unit_ids).astype(np.int32))
        f.create_dataset('sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64))
        f.create_dataset('channel_ids',
                         data=np.array(recording.get_channel_ids()))
        f.create_dataset('num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32))
        channel_locations = recording.get_channel_locations()
        f.create_dataset(f'channel_locations',
                         data=np.array(channel_locations))
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            f.create_dataset(f'unit_spike_trains/{unit_id}',
                             data=np.array(x).astype(np.float64))
            f.create_dataset(f'unit_waveforms/{unit_id}/waveforms',
                             data=unit_waveforms[ii].astype(np.float32))
            f.create_dataset(
                f'unit_waveforms/{unit_id}/channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int))
            f.create_dataset(f'unit_waveforms/{unit_id}/spike_train',
                             data=np.array(
                                 sorting_subsampled.get_unit_spike_train(
                                     unit_id=unit_id)).astype(np.float64))
예제 #7
0
    def check_metadata_write(self, metadata: dict, nwbfile_path: Path, recording: se.RecordingExtractor):
        standard_metadata = get_nwb_metadata(recording=recording)
        device_defaults = dict(  # from the individual add_devices function
            name="Device",
            description="no description"
        )
        electrode_group_defaults = dict(  # from the individual add_electrode_groups function
            name="Electrode Group",
            description="no description",
            location="unknown",
            device="Device"
        )

        with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io:
            nwbfile = io.read()

            device_source = metadata["Ecephys"].get("Device", standard_metadata["Ecephys"]["Device"])
            self.assertEqual(len(device_source), len(nwbfile.devices))
            for device in device_source:
                device_name = device.get("name", device_defaults["name"])
                self.assertIn(device_name, nwbfile.devices)
                self.assertEqual(
                    device.get("description", device_defaults["description"]), nwbfile.devices[device_name].description
                )
                self.assertEqual(device.get("manufacturer"), nwbfile.devices[device["name"]].manufacturer)

            electrode_group_source = metadata["Ecephys"].get(
                "ElectrodeGroup",
                standard_metadata["Ecephys"]["ElectrodeGroup"]
            )
            self.assertEqual(len(electrode_group_source), len(nwbfile.electrode_groups))
            for group in electrode_group_source:
                group_name = group.get("name", electrode_group_defaults["name"])
                self.assertIn(group_name, nwbfile.electrode_groups)
                self.assertEqual(
                    group.get("description", electrode_group_defaults["description"]),
                    nwbfile.electrode_groups[group_name].description
                )
                self.assertEqual(
                    group.get("location", electrode_group_defaults["location"]),
                    nwbfile.electrode_groups[group_name].location
                )
                device_name = group.get("device", electrode_group_defaults["device"])
                self.assertIn(device_name, nwbfile.devices)
                self.assertEqual(nwbfile.electrode_groups[group_name].device, nwbfile.devices[device_name])

            n_channels = len(recording.get_channel_ids())
            electrode_source = metadata["Ecephys"].get("Electrodes", [])
            self.assertEqual(n_channels, len(nwbfile.electrodes))
            for column in electrode_source:
                column_name = column["name"]
                self.assertIn(column_name, nwbfile.electrodes)
                self.assertEqual(column["description"], getattr(nwbfile.electrodes, column_name).description)
                if column_name in ["x", "y", "z", "rel_x", "rel_y", "rel_z"]:
                    for j in n_channels:
                        self.assertEqual(column["data"][j], getattr(nwbfile.electrodes[j], column_name).values[0])
                else:
                    for j in n_channels:
                        self.assertTrue(
                            column["data"][j] == getattr(nwbfile.electrodes[j], column_name).values[0]
                            or (
                                    np.isnan(column["data"][j])
                                    and np.isnan(getattr(nwbfile.electrodes[j], column_name).values[0])
                            )
                        )
예제 #8
0
def ephys_nlm_v1(recording: se.RecordingExtractor, *, opts: EphysNlmV1Opts, device: Union[None, str], verbose: int = 1) -> Tuple[OutputRecordingExtractor, EphysNlmV1Info]:
    """Denoise an ephys recording using non-local means.
    
    The input and output recordings are RecordingExtractors from SpikeInterface.

    Parameters
    ----------
    recording : se.RecordingExtractor
        The ephys recording to denoise (see SpikeInterface)
    opts : EphysNlmV1Opts
        Options created using EphysNlmV1Opts(...)
    device : Union[str, None]
        Either cuda or cpu (cuda is highly recommended, but you need to have
        CUDA/PyTorch working on your system). If None, then the EPHYS_NLM_DEVICE
        environment variable will be used.
    verbose : int, optional
        Verbosity level, by default 1

    Returns
    -------
    Tuple[OutputRecordingExtractor, EphysNlmV1Info]
        The output recording extractor see SpikeInterface
        and info about the run
    """
    channel_ids = recording.get_channel_ids()
    M = len(channel_ids)
    N = recording.get_num_frames()
    T = opts.clip_size
    assert T % 2 == 0, 'clip size must be divisible by 2.'
    info = EphysNlmV1Info()
    info.recording = recording
    info.opts = opts
    info.start_time = time.time()
    if opts.block_size is None:
        if opts.block_size_sec is None:
            raise Exception('block_size and block_size_sec are both None')
        opts.block_size = int(
            recording.get_sampling_frequency() * opts.block_size_sec)
    block_size = opts.block_size
    N = recording.get_num_frames()
    num_blocks = max(1, math.floor(N / block_size))
    assert opts.sigma == 'auto', 'Only sigma=auto allowed at this time'
    assert opts.whitening == 'auto', 'Only whitening=auto allowed at this time'

    if device is None:
        device = os.getenv('EPHYS_NLM_DEVICE', None)
        if device is None or device == '':
            print('Warning: EPHYS_NLM_DEVICE not set -- defaulting to cpu. To use GPU, set EPHYS_NLM_DEVICE=cuda')
            device = 'cpu'
    elif device == 'cpu':
        print('Using device=cpu. Warning: GPU is much faster. To use GPU, set device=cuda')
    if device == 'cuda':
        assert torch.cuda.is_available(), f'Cannot use device=cuda. PyTorch/CUDA is not configured properly -- torch.cuda.is_available() is returning False.'
        print('Using device=cuda')
    elif device == 'cpu':
        print('Using device=cpu')
    else:
        raise Exception(f'Invalid device: {device}') # pragma: no cover
    info.device = device
    opts._device = device  # for convenience

    neighborhoods = _get_neighborhoods(recording=recording, opts=opts)

    if verbose >= 1:
        print(f'Denoising recording of size {M} x {N} using {len(neighborhoods)} neighborhoods and {num_blocks} time blocks')

    initial_traces = recording.get_traces(
        start_frame=0, end_frame=min(N, block_size))
    _estimate_sigma_and_whitening(
        traces=initial_traces, neighborhoods=neighborhoods, opts=opts, verbose=verbose)

    recording_out = OutputRecordingExtractor(
        base_recording=recording, block_size=opts.block_size)

    for ii in range(num_blocks):
        if verbose >= 1:
            print(f'Denoising block {ii} of {num_blocks}')

        t1 = ii * block_size
        t2 = t1 + block_size
        # The last block is potentially larger
        if ii == num_blocks - 1:
            t2 = N

        block_traces = recording.get_traces(start_frame=t1, end_frame=t2)
        block_traces_denoised, block_info = _denoise_block(
            traces=block_traces,
            opts=opts,
            neighborhoods=neighborhoods,
            verbose=verbose
        )
        info.blocks.append(block_info)
        recording_out.add_block(block_traces_denoised)

    info.end_time = time.time()
    info.elapsed_time = info.end_time - info.start_time

    return recording_out, info
예제 #9
0
def prepare_snippets_nwb_from_extractors(
        recording: se.RecordingExtractor,
        sorting: se.SortingExtractor,
        nwb_file_path: str,
        nwb_object_prefix: str,
        start_frame,
        end_frame,
        max_neighborhood_size: int,
        max_events_per_unit: Union[None, int] = None,
        snippet_len=(50, 80),
):
    import pynwb
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )
    with pynwb.NWBHDF5IO(path=nwb_file_path, mode='a') as io:
        nwbf = io.read()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_unit_ids',
                         data=np.array(unit_ids).astype(np.int32),
                         notes='sorted waveform unit ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64),
                         notes='sorted waveform sampling frequency')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_ids',
                         data=np.array(recording.get_channel_ids()),
                         notes='sorted waveform channel ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32),
                         notes='sorted waveform number of frames')
        channel_locations = recording.get_channel_locations()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_locations',
                         data=np.array(channel_locations),
                         notes='sorted waveform channel locations')
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_spike_trains',
                data=np.array(x).astype(np.float64),
                notes=f'sorted spike trains for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_waveforms',
                data=unit_waveforms[ii].astype(np.float32),
                notes=f'sorted waveforms for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int),
                notes=f'sorted channel ids for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_sub_spike_train',
                data=np.array(
                    sorting_subsampled.get_unit_spike_train(
                        unit_id=unit_id)).astype(np.float64),
                notes=f'sorted subsampled spike train for unit {unit_id}')
        io.write(nwbf)