Пример #1
0
def intervals_from_traces(recording: RecordingExtractor):
    """Extract interval times from TTL pulses."""
    traces = recording.get_traces(channel_ids=[1, 2])
    sf = recording.get_sampling_frequency()

    ttls = []
    states = []
    for tr in traces:
        threshold = np.ptp(tr) / 2 + np.min(tr)
        crossings = np.array(tr > threshold).astype("int8")

        rising = np.nonzero(np.diff(crossings, 1) > 0)[0]
        falling = np.nonzero(np.diff(crossings, 1) < 0)[0]

        ttl = np.concatenate((rising, falling))
        sort_order = np.argsort(ttl)
        ttl = np.sort(ttl)
        state = [1] * len(rising) + [-1] * len(falling)
        state = np.array(state)[sort_order]

        ttls.append(ttl)
        states.append(state)

    conditions = []
    for ttl, state in zip(ttls, states):
        assert len(ttl[state == 1]) == len(
            ttl[state == -1]), "Different number of rising/falling edges!"
        condition = np.zeros((len(ttl[state == 1]), 2), dtype="int")

        condition[:, 0] = ttl[state == 1] / sf
        condition[:, 1] = ttl[state == -1] / sf

        conditions.append(condition)

    return conditions
Пример #2
0
 def from_memory(recording: se.RecordingExtractor,
                 serialize=False,
                 serialize_dtype=None):
     if serialize:
         if serialize_dtype is None:
             raise Exception(
                 'You must specify the serialize_dtype when serializing recording extractor in from_memory()'
             )
         with hi.TemporaryDirectory() as tmpdir:
             fname = tmpdir + '/' + _random_string(10) + '_recording.mda'
             se.BinDatRecordingExtractor.write_recording(
                 recording=recording,
                 save_path=fname,
                 time_axis=0,
                 dtype=serialize_dtype)
             with ka.config(use_hard_links=True):
                 uri = ka.store_file(fname, basename='raw.mda')
             num_channels = recording.get_num_channels()
             channel_ids = [int(a) for a in recording.get_channel_ids()]
             xcoords = [
                 recording.get_channel_property(a, 'location')[0]
                 for a in channel_ids
             ]
             ycoords = [
                 recording.get_channel_property(a, 'location')[1]
                 for a in channel_ids
             ]
             recording = LabboxEphysRecordingExtractor({
                 'recording_format': 'bin1',
                 'data': {
                     'raw':
                     uri,
                     'raw_num_channels':
                     num_channels,
                     'num_frames':
                     int(recording.get_num_frames()),
                     'samplerate':
                     float(recording.get_sampling_frequency()),
                     'channel_ids':
                     channel_ids,
                     'channel_map':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [int(i) for i in range(num_channels)])),
                     'channel_positions':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [[float(xcoords[i]),
                               float(ycoords[i])]
                              for i in range(num_channels)]))
                 }
             })
             return recording
     obj = {
         'recording_format': 'in_memory',
         'data': register_in_memory_object(recording)
     }
     return LabboxEphysRecordingExtractor(obj)
Пример #3
0
def estimate_noise_level(recording: se.RecordingExtractor):
    N = recording.get_num_frames()
    samplerate = recording.get_sampling_frequency()
    start_frame = 0
    end_frame = int(np.minimum(samplerate * 1, N))
    X = recording.get_traces(
        channel_ids=[int(id) for id in recording.get_channel_ids()],
        start_frame=start_frame,
        end_frame=end_frame)
    est_noise_level = np.median(np.abs(X.squeeze(
    ))) / 0.6745  # median absolute deviation (MAD) estimate of stdev
    if (est_noise_level == 0): est_noise_level = 1
    return est_noise_level
Пример #4
0
    def _run(self, recording: se.RecordingExtractor, output_folder: Path):
        dataset_dir = output_folder / 'ironclust_dataset'
        source_dir = Path(__file__).parent

        samplerate = recording.get_sampling_frequency()

        num_channels = recording.get_num_channels()
        num_timepoints = recording.get_num_frames()
        duration_minutes = num_timepoints / samplerate / 60
        if self.verbose:
            print(
                'Num. channels = {}, Num. timepoints = {}, duration = {} minutes'
                .format(num_channels, num_timepoints, duration_minutes))

        if self.verbose:
            print('Creating argfile.txt...')
        txt = ''
        for key0, val0 in self.params.items():
            txt += '{}={}\n'.format(key0, val0)
        txt += 'samplerate={}\n'.format(samplerate)
        with (dataset_dir / 'argfile.txt').open('w') as f:
            f.write(txt)

        tmpdir = output_folder / 'tmp'
        os.makedirs(str(tmpdir), exist_ok=True)
        if self.verbose:
            print(
                'Running ironclust in {tmpdir}...'.format(tmpdir=str(tmpdir)))

        shell_cmd = '''
            #!/bin/bash
            cd {tmpdir}
            /run_irc {dataset_dir} {tmpdir} {dataset_dir}/argfile.txt
        '''.format(tmpdir=str(tmpdir), dataset_dir=str(dataset_dir))

        shell_script = ShellScript(shell_cmd)
        shell_script.start()

        retcode = shell_script.wait()

        if retcode != 0:
            raise Exception('ironclust returned a non-zero exit code')

        result_fname = str(tmpdir / 'firings.mda')
        if not os.path.exists(result_fname):
            raise Exception('Result file does not exist: ' + result_fname)

        samplerate_fname = str(tmpdir / 'samplerate.txt')
        with open(samplerate_fname, 'w') as f:
            f.write('{}'.format(samplerate))
Пример #5
0
 def __init__(self, *, recording: se.RecordingExtractor, freq_min, freq_max,
              freq_wid):
     self._padding = 3000
     target_ram = 100 * 1000 * 1000
     target_chunk_size = math.ceil(
         min(recording.get_sampling_frequency() * 30,
             target_ram / (recording.get_num_channels() * 4)))
     # It's important that the fft's have size 2^x. So we prepare the chunk sizes to have size 2^x - 2*padding
     chunk_size = int(2**math.ceil(np.log2(target_chunk_size)) -
                      self._padding * 2)
     FilterRecording.__init__(self,
                              recording=recording,
                              chunk_size=chunk_size)
     self._params = dict(name='bandpass_filter',
                         freq_min=freq_min,
                         freq_max=freq_max,
                         freq_wid=freq_wid)
     self._recording = recording
Пример #6
0
    def write_recording(recording: RecordingExtractor,
                        save_path: PathType,
                        dtype: DtypeType = None,
                        **write_binary_kwargs):
        """
        Convert and save the recording extractor to Neuroscope format.

        Parameters
        ----------
        recording: RecordingExtractor
            The recording extractor to be converted and saved.
        save_path: str
            Path to desired target folder. The name of the files will be the same as the final directory.
        dtype: dtype
            Optional. Data type to be used in writing; must be int16 or int32 (default).
                      Will throw a warning if stored recording type from get_traces() does not match.
        **write_binary_kwargs: keyword arguments for write_to_binary_dat_format function
            - chunk_size
            - chunk_mb
        """
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)

        if save_path.suffix == "":
            recording_name = save_path.name
        else:
            recording_name = save_path.stem
        xml_name = recording_name

        save_xml_filepath = save_path / f"{xml_name}.xml"
        recording_filepath = save_path / recording_name

        # create parameters file if none exists
        if save_xml_filepath.is_file():
            raise FileExistsError(f"{save_xml_filepath} already exists!")

        xml_root = et.Element('xml')
        et.SubElement(xml_root, 'acquisitionSystem')
        et.SubElement(xml_root.find('acquisitionSystem'), 'nBits')
        et.SubElement(xml_root.find('acquisitionSystem'), 'nChannels')
        et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate')

        recording_dtype = str(recording.get_dtype())
        int_loc = recording_dtype.find('int')
        recording_n_bits = recording_dtype[(int_loc + 3):(int_loc + 5)]

        valid_dtype = ["16", "32"]
        if dtype is None:
            if int_loc != -1 and recording_n_bits in valid_dtype:
                n_bits = recording_n_bits
            else:
                print(
                    "Warning: Recording data type must be int16 or int32! Defaulting to int32."
                )
                n_bits = "32"
            dtype = f"int{n_bits}"  # update dtype in pass to BinDatRecordingExtractor.write_recording
        else:
            dtype = str(dtype)  # if user passed numpy data type
            int_loc = dtype.find('int')
            assert int_loc != -1, "Data type must be int16 or int32! Non-integer received."
            n_bits = dtype[(int_loc + 3):(int_loc + 5)]
            assert n_bits in valid_dtype, "Data type must be int16 or int32!"

        xml_root.find('acquisitionSystem').find('nBits').text = n_bits
        xml_root.find('acquisitionSystem').find('nChannels').text = str(
            recording.get_num_channels())
        xml_root.find('acquisitionSystem').find('samplingRate').text = str(
            recording.get_sampling_frequency())

        et.ElementTree(xml_root).write(str(save_xml_filepath),
                                       pretty_print=True)

        recording.write_to_binary_dat_format(recording_filepath,
                                             dtype=dtype,
                                             **write_binary_kwargs)
Пример #7
0
    def _run(self, recording: se.RecordingExtractor, output_folder: Path):
        recording = recover_recording(recording)
        dataset_dir = output_folder / 'ironclust_dataset'
        source_dir = Path(__file__).parent

        samplerate = recording.get_sampling_frequency()

        if recording.is_filtered and self.params['filter']:
            print("Warning! The recording is already filtered, but Ironclust filter is enabled. You can disable "
                  "filters by setting 'filter' parameter to False")

        num_channels = recording.get_num_channels()
        num_timepoints = recording.get_num_frames()
        duration_minutes = num_timepoints / samplerate / 60
        if self.verbose:
            print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(
            num_channels, num_timepoints, duration_minutes))

        if self.verbose:
            print('Creating argfile.txt...')
        txt = ''
        for key0, val0 in self.params.items():
            txt += '{}={}\n'.format(key0, val0)
        txt += 'samplerate={}\n'.format(samplerate)
        with (dataset_dir / 'argfile.txt').open('w') as f:
            f.write(txt)

        tmpdir = output_folder / 'tmp'
        os.makedirs(str(tmpdir), exist_ok=True)
        if self.verbose:
            print('Running ironclust in {tmpdir}...'.format(tmpdir=str(tmpdir)))
        cmd = '''
            addpath('{source_dir}');
            addpath('{ironclust_path}', '{ironclust_path}/matlab', '{ironclust_path}/matlab/mdaio');
            try
                p_ironclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '', '', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt');
            catch
                fprintf('----------------------------------------');
                fprintf(lasterr());
                quit(1);
            end
            quit(0);
        '''
        cmd = cmd.format(ironclust_path=IronClustSorter.ironclust_path, tmpdir=str(tmpdir),
                         dataset_dir=str(dataset_dir), source_dir=str(source_dir))

        matlab_cmd = ShellScript(cmd, script_path=str(tmpdir / 'run_ironclust.m'))
        matlab_cmd.write()

        if 'win' in sys.platform and sys.platform != 'darwin':
            shell_cmd = '''
                cd {tmpdir}
                matlab -nosplash -wait -log -r run_ironclust
            '''.format(tmpdir=tmpdir)
        else:
            shell_cmd = '''
                #!/bin/bash
                cd "{tmpdir}"
                matlab -nosplash -nodisplay -log -r run_ironclust
            '''.format(tmpdir=tmpdir)

        shell_script = ShellScript(shell_cmd, script_path=output_folder / f'run_{self.sorter_name}',
                                   log_path=output_folder / f'{self.sorter_name}.log', verbose=self.verbose)
        shell_script.start()

        retcode = shell_script.wait()

        if retcode != 0:
            raise Exception('ironclust returned a non-zero exit code')

        result_fname = str(tmpdir / 'firings.mda')
        if not os.path.exists(result_fname):
            raise Exception('Result file does not exist: ' + result_fname)

        samplerate_fname = str(tmpdir / 'samplerate.txt')
        with open(samplerate_fname, 'w') as f:
            f.write('{}'.format(samplerate))
Пример #8
0
    def _run(self, recording: se.RecordingExtractor, output_folder: Path):
        dataset_dir = output_folder / 'ironclust_dataset'
        source_dir = Path(__file__).parent

        samplerate = recording.get_sampling_frequency()

        num_channels = recording.get_num_channels()
        num_timepoints = recording.get_num_frames()
        duration_minutes = num_timepoints / samplerate / 60
        if self.verbose:
            print(
                'Num. channels = {}, Num. timepoints = {}, duration = {} minutes'
                .format(num_channels, num_timepoints, duration_minutes))

        if self.verbose:
            print('Creating argfile.txt...')
        txt = ''
        for key0, val0 in self.params.items():
            txt += '{}={}\n'.format(key0, val0)
        txt += 'samplerate={}\n'.format(samplerate)
        with (dataset_dir / 'argfile.txt').open('w') as f:
            f.write(txt)

        tmpdir = output_folder / 'tmp'
        os.makedirs(str(tmpdir), exist_ok=True)
        if self.verbose:
            print(
                'Running ironclust in {tmpdir}...'.format(tmpdir=str(tmpdir)))

        if os.getenv('IRONCLUST_BINARY_PATH', None):
            shell_cmd = f'''
            #!/bin/bash
            cd {tmpdir}
            exec $IRONCLUST_BINARY_PATH {dataset_dir} {tmpdir} {dataset_dir}/argfile.txt
            '''
        else:
            matlab_script = f'''
            try
                addpath(genpath('{self.ironclust_path}'));
                irc2('{dataset_dir}', '{str(tmpdir)}', '{dataset_dir}/argfile.txt')
            catch
                fprintf('----------------------------------------');
                fprintf(lasterr());
                quit(1);
            end
            quit(0);
            '''
            ShellScript(matlab_script).write(
                str(output_folder / 'ironclust_script.m'))

            if "win" in sys.platform:
                shell_cmd = f'''
                            cd {str(output_folder)}
                            matlab -nosplash -wait -batch ironclust_script
                        '''
            else:
                shell_cmd = f'''
                            #!/bin/bash
                            cd "{str(output_folder)}"
                            matlab -nosplash -nodisplay -r ironclust_script
                        '''
        shell_script = ShellScript(shell_cmd, redirect_output_to_stdout=True)
        shell_script.start()

        retcode = shell_script.wait()

        if retcode != 0:
            raise Exception('ironclust returned a non-zero exit code')

        result_fname = str(tmpdir / 'firings.mda')
        if not os.path.exists(result_fname):
            raise Exception('Result file does not exist: ' + result_fname)

        samplerate_fname = str(tmpdir / 'samplerate.txt')
        with open(samplerate_fname, 'w') as f:
            f.write('{}'.format(samplerate))
Пример #9
0
def prepare_snippets_h5_from_extractors(recording: se.RecordingExtractor,
                                        sorting: se.SortingExtractor,
                                        output_h5_path: str,
                                        start_frame,
                                        end_frame,
                                        max_neighborhood_size: int,
                                        max_events_per_unit: Union[None,
                                                                   int] = None,
                                        snippet_len=(50, 80)):
    import h5py
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )

    save_path = output_h5_path
    with h5py.File(save_path, 'w') as f:
        f.create_dataset('unit_ids', data=np.array(unit_ids).astype(np.int32))
        f.create_dataset('sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64))
        f.create_dataset('channel_ids',
                         data=np.array(recording.get_channel_ids()))
        f.create_dataset('num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32))
        channel_locations = recording.get_channel_locations()
        f.create_dataset(f'channel_locations',
                         data=np.array(channel_locations))
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            f.create_dataset(f'unit_spike_trains/{unit_id}',
                             data=np.array(x).astype(np.float64))
            f.create_dataset(f'unit_waveforms/{unit_id}/waveforms',
                             data=unit_waveforms[ii].astype(np.float32))
            f.create_dataset(
                f'unit_waveforms/{unit_id}/channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int))
            f.create_dataset(f'unit_waveforms/{unit_id}/spike_train',
                             data=np.array(
                                 sorting_subsampled.get_unit_spike_train(
                                     unit_id=unit_id)).astype(np.float64))
    def write_recording(recording: RecordingExtractor, save_path: PathType, dtype: DtypeType = None,
                        **write_binary_kwargs):
        """
        Convert and save the recording extractor to Neuroscope format.

        Parameters
        ----------
        recording: RecordingExtractor
            The recording extractor to be converted and saved.
        save_path: str
            Path to desired target folder. The name of the files will be the same as the final directory.
        dtype: dtype
            Optional. Data type to be used in writing; must be int16 or int32 (default).
                      Will throw a warning if stored recording type from get_traces() does not match.
        **write_binary_kwargs: keyword arguments for write_to_binary_dat_format function
            - chunk_size
            - chunk_mb
        """
        save_path = Path(save_path)

        if not save_path.is_dir():
            os.makedirs(save_path)

        if save_path.suffix == '':
            recording_name = save_path.name
        else:
            recording_name = save_path.stem
        xml_name = recording_name

        save_xml_filepath = save_path / (str(xml_name) + '.xml')
        recording_filepath = save_path / recording_name

        # create parameters file if none exists
        if save_xml_filepath.is_file():
            raise FileExistsError(f'{save_xml_filepath} already exists!')

        soup = BeautifulSoup("", 'xml')
        new_tag = soup.new_tag('nbits')
        recording_dtype = str(recording.get_dtype())
        int_loc = recording_dtype.find('int')
        recording_n_bits = recording_dtype[(int_loc + 3):(int_loc + 5)]

        if dtype is None:  # user did not specify data type
            if int_loc != -1 and recording_n_bits in ['16', '32']:
                n_bits = recording_n_bits
            else:
                print('Warning: Recording data type must be int16 or int32! Defaulting to int32.')
                n_bits = '32'
            dtype = 'int' + n_bits  # update dtype in pass to BinDatRecordingExtractor.write_recording
        else:
            dtype = str(dtype)  # if user passed numpy data type
            int_loc = dtype.find('int')
            assert int_loc != -1, 'Data type must be int16 or int32! Non-integer received.'
            n_bits = dtype[(int_loc + 3):(int_loc + 5)]
            assert n_bits in ['16', '32'], 'Data type must be int16 or int32!'


        new_tag.string = n_bits
        soup.append(new_tag)

        new_tag = soup.new_tag('nchannels')
        new_tag.string = str(recording.get_num_channels())
        soup.append(new_tag)

        new_tag = soup.new_tag('samplingrate')
        new_tag.string = str(recording.get_sampling_frequency())
        soup.append(new_tag)

        # write parameters file
        # create parameters file if none exists
        with save_xml_filepath.open("w") as f:
            f.write(str(soup))

        recording.write_to_binary_dat_format(recording_filepath, dtype=dtype, **write_binary_kwargs)
Пример #11
0
def ephys_nlm_v1(recording: se.RecordingExtractor, *, opts: EphysNlmV1Opts, device: Union[None, str], verbose: int = 1) -> Tuple[OutputRecordingExtractor, EphysNlmV1Info]:
    """Denoise an ephys recording using non-local means.
    
    The input and output recordings are RecordingExtractors from SpikeInterface.

    Parameters
    ----------
    recording : se.RecordingExtractor
        The ephys recording to denoise (see SpikeInterface)
    opts : EphysNlmV1Opts
        Options created using EphysNlmV1Opts(...)
    device : Union[str, None]
        Either cuda or cpu (cuda is highly recommended, but you need to have
        CUDA/PyTorch working on your system). If None, then the EPHYS_NLM_DEVICE
        environment variable will be used.
    verbose : int, optional
        Verbosity level, by default 1

    Returns
    -------
    Tuple[OutputRecordingExtractor, EphysNlmV1Info]
        The output recording extractor see SpikeInterface
        and info about the run
    """
    channel_ids = recording.get_channel_ids()
    M = len(channel_ids)
    N = recording.get_num_frames()
    T = opts.clip_size
    assert T % 2 == 0, 'clip size must be divisible by 2.'
    info = EphysNlmV1Info()
    info.recording = recording
    info.opts = opts
    info.start_time = time.time()
    if opts.block_size is None:
        if opts.block_size_sec is None:
            raise Exception('block_size and block_size_sec are both None')
        opts.block_size = int(
            recording.get_sampling_frequency() * opts.block_size_sec)
    block_size = opts.block_size
    N = recording.get_num_frames()
    num_blocks = max(1, math.floor(N / block_size))
    assert opts.sigma == 'auto', 'Only sigma=auto allowed at this time'
    assert opts.whitening == 'auto', 'Only whitening=auto allowed at this time'

    if device is None:
        device = os.getenv('EPHYS_NLM_DEVICE', None)
        if device is None or device == '':
            print('Warning: EPHYS_NLM_DEVICE not set -- defaulting to cpu. To use GPU, set EPHYS_NLM_DEVICE=cuda')
            device = 'cpu'
    elif device == 'cpu':
        print('Using device=cpu. Warning: GPU is much faster. To use GPU, set device=cuda')
    if device == 'cuda':
        assert torch.cuda.is_available(), f'Cannot use device=cuda. PyTorch/CUDA is not configured properly -- torch.cuda.is_available() is returning False.'
        print('Using device=cuda')
    elif device == 'cpu':
        print('Using device=cpu')
    else:
        raise Exception(f'Invalid device: {device}') # pragma: no cover
    info.device = device
    opts._device = device  # for convenience

    neighborhoods = _get_neighborhoods(recording=recording, opts=opts)

    if verbose >= 1:
        print(f'Denoising recording of size {M} x {N} using {len(neighborhoods)} neighborhoods and {num_blocks} time blocks')

    initial_traces = recording.get_traces(
        start_frame=0, end_frame=min(N, block_size))
    _estimate_sigma_and_whitening(
        traces=initial_traces, neighborhoods=neighborhoods, opts=opts, verbose=verbose)

    recording_out = OutputRecordingExtractor(
        base_recording=recording, block_size=opts.block_size)

    for ii in range(num_blocks):
        if verbose >= 1:
            print(f'Denoising block {ii} of {num_blocks}')

        t1 = ii * block_size
        t2 = t1 + block_size
        # The last block is potentially larger
        if ii == num_blocks - 1:
            t2 = N

        block_traces = recording.get_traces(start_frame=t1, end_frame=t2)
        block_traces_denoised, block_info = _denoise_block(
            traces=block_traces,
            opts=opts,
            neighborhoods=neighborhoods,
            verbose=verbose
        )
        info.blocks.append(block_info)
        recording_out.add_block(block_traces_denoised)

    info.end_time = time.time()
    info.elapsed_time = info.end_time - info.start_time

    return recording_out, info
Пример #12
0
def prepare_snippets_nwb_from_extractors(
        recording: se.RecordingExtractor,
        sorting: se.SortingExtractor,
        nwb_file_path: str,
        nwb_object_prefix: str,
        start_frame,
        end_frame,
        max_neighborhood_size: int,
        max_events_per_unit: Union[None, int] = None,
        snippet_len=(50, 80),
):
    import pynwb
    from labbox_ephys import (SubsampledSortingExtractor,
                              find_unit_neighborhoods, find_unit_peak_channels,
                              get_unit_waveforms)
    if start_frame is not None:
        recording = se.SubRecordingExtractor(parent_recording=recording,
                                             start_frame=start_frame,
                                             end_frame=end_frame)
        sorting = se.SubSortingExtractor(parent_sorting=sorting,
                                         start_frame=start_frame,
                                         end_frame=end_frame)

    unit_ids = sorting.get_unit_ids()
    samplerate = recording.get_sampling_frequency()

    # Use this optimized function rather than spiketoolkit's version
    # for efficiency with long recordings and/or many channels, units or spikes
    # we should submit this to the spiketoolkit project as a PR
    print('Subsampling sorting')
    if max_events_per_unit is not None:
        sorting_subsampled = SubsampledSortingExtractor(
            parent_sorting=sorting,
            max_events_per_unit=max_events_per_unit,
            method='random')
    else:
        sorting_subsampled = sorting
    print('Finding unit peak channels')
    peak_channels_by_unit = find_unit_peak_channels(recording=recording,
                                                    sorting=sorting,
                                                    unit_ids=unit_ids)
    print('Finding unit neighborhoods')
    channel_ids_by_unit = find_unit_neighborhoods(
        recording=recording,
        peak_channels_by_unit=peak_channels_by_unit,
        max_neighborhood_size=max_neighborhood_size)
    print(f'Getting unit waveforms for {len(unit_ids)} units')
    unit_waveforms = get_unit_waveforms(
        recording=recording,
        sorting=sorting_subsampled,
        unit_ids=unit_ids,
        channel_ids_by_unit=channel_ids_by_unit,
        snippet_len=snippet_len)
    # unit_waveforms = st.postprocessing.get_unit_waveforms(
    #     recording=recording,
    #     sorting=sorting,
    #     unit_ids=unit_ids,
    #     ms_before=1,
    #     ms_after=1.5,
    #     max_spikes_per_unit=500
    # )
    with pynwb.NWBHDF5IO(path=nwb_file_path, mode='a') as io:
        nwbf = io.read()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_unit_ids',
                         data=np.array(unit_ids).astype(np.int32),
                         notes='sorted waveform unit ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_sampling_frequency',
                         data=np.array([samplerate]).astype(np.float64),
                         notes='sorted waveform sampling frequency')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_ids',
                         data=np.array(recording.get_channel_ids()),
                         notes='sorted waveform channel ids')
        nwbf.add_scratch(name=f'{nwb_object_prefix}_num_frames',
                         data=np.array([recording.get_num_frames()
                                        ]).astype(np.int32),
                         notes='sorted waveform number of frames')
        channel_locations = recording.get_channel_locations()
        nwbf.add_scratch(name=f'{nwb_object_prefix}_channel_locations',
                         data=np.array(channel_locations),
                         notes='sorted waveform channel locations')
        for ii, unit_id in enumerate(unit_ids):
            x = sorting.get_unit_spike_train(unit_id=unit_id)
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_spike_trains',
                data=np.array(x).astype(np.float64),
                notes=f'sorted spike trains for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_waveforms',
                data=unit_waveforms[ii].astype(np.float32),
                notes=f'sorted waveforms for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_channel_ids',
                data=np.array(channel_ids_by_unit[int(unit_id)]).astype(int),
                notes=f'sorted channel ids for unit {unit_id}')
            nwbf.add_scratch(
                name=f'{nwb_object_prefix}_unit_{unit_id}_sub_spike_train',
                data=np.array(
                    sorting_subsampled.get_unit_spike_train(
                        unit_id=unit_id)).astype(np.float64),
                notes=f'sorted subsampled spike train for unit {unit_id}')
        io.write(nwbf)