Exemple #1
0
    def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg', **job_kwargs):
        """
        This run the PCs on all spike from the sorting.
        This is a long computation because waveform need to be extracted from each spikes.
        
        Used mainly for `export_to_phy()`
        
        PCs are exported to a .npy single file.
        
        """
        p = self._params
        we = self.waveform_extractor
        sorting = we.sorting
        recording = we.recording

        
        assert sorting.get_num_segments() == 1
        assert p['mode'] in ('by_channel_local', 'by_channel_global')
        
        file_path = Path(file_path)

        all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
        spike_times, spike_labels = all_spikes[0]        
        
        max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels())
        
        best_channels_index = get_template_best_channels(we, max_channels_per_template,
                                    peak_sign=peak_sign, outputs='index')
        unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids]
        
        if  p['mode'] == 'by_channel_local':
            all_pca = self._fit_by_channel_local()
        elif p['mode'] == 'by_channel_global':
            one_pca = self._fit_by_channel_global()
            all_pca = [one_pca] * recording.get_num_channels()
        
        # nSpikes, nFeaturesPerChannel, nPCFeatures
        # this come from  phy template-gui
        # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets
        shape = (spike_times.size, p['n_components'], max_channels_per_template)
        all_pcs = np.lib.format.open_memmap(file_path, mode='w+', dtype='float32', shape=shape)
        
        # and run
        func = _all_pc_extractor_chunk
        init_func = _init_work_all_pc_extractor
        n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None))
        if n_jobs == 1:
            init_args = (recording, )
        else:
            init_args = (recording.to_dict(), )
        init_args = init_args + (all_pcs, spike_times, spike_labels, we.nbefore, we.nafter, unit_channels, all_pca)
        processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs)
        processor.run()
Exemple #2
0
def get_unit_amplitudes(waveform_extractor,
                        peak_sign='neg',
                        outputs='concatenated',
                        **job_kwargs):
    """
    Computes the spike amplitudes from a WaveformExtractor.
    Amplitudes can be computed in absolute value (uV) or relative to the template amplitude.
    
    1. The waveform extractor is used to determine the max channel per unit.
    2. Then a "peak_shift" is estimated because for some sorter the spike index is not always at the
       extremum.
    3. Extract all epak chunk by chunk (parallel or not)

    """
    we = waveform_extractor
    recording = we.recording
    sorting = we.sorting

    all_spikes = sorting.get_all_spike_trains()

    extremum_channels_index = get_template_extremum_channel(
        waveform_extractor, peak_sign=peak_sign, outputs='index')
    peak_shifts = get_template_extremum_channel_peak_shift(waveform_extractor,
                                                           peak_sign='neg')

    # and run
    func = _unit_amplitudes_chunk
    init_func = _init_worker_unit_amplitudes
    init_args = (recording.to_dict(), sorting.to_dict(),
                 extremum_channels_index, peak_shifts)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       job_name='extract amplitudes',
                                       **job_kwargs)
    out = processor.run()
    amps, segments = zip(*out)
    amps = np.concatenate(amps)
    segments = np.concatenate(segments)

    amplitudes = []
    for segment_index in range(recording.get_num_segments()):
        mask = segments == segment_index
        amplitudes.append(amps[mask])

    if outputs == 'concatenated':
        return amplitudes
    elif outputs == 'by_units':
        amplitudes_by_units = []
        for segment_index in range(recording.get_num_segments()):
            amplitudes_by_units.append({})
            for unit_id in sorting.unit_ids:
                spike_times, spike_labels = all_spikes[segment_index]
                mask = spike_labels == unit_id
                amps = amplitudes[segment_index][mask]
                amplitudes_by_units[segment_index][unit_id] = amps
        return amplitudes_by_units
def localize_peaks(recording, peaks, method='center_of_mass',
                   local_radius_um=150, ms_before=0.3, ms_after=0.6,
                   **job_kwargs):
    """
    Localize peak (spike) in 2D or 3D depending the probe.ndim of the recording.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object
    peaks: numpy
        peak vector given by detect_peaks() in "compact_numpy" way.
    method: str
        Method to be used ('center_of_mass')
    local_radius_um: float
        Radius in micrometer to make neihgborhood for channel
        around the peak
    ms_before: float
        The left window before a peak in millisecond
    ms_after: float
        The left window before a peak in millisecond
    {}

    Returns
    -------
    peak_locations: np.array
        Array with estimated x-y location for each spike
    """
    assert method in ('center_of_mass',)

    # find channel neighbours
    assert local_radius_um is not None
    channel_distance = get_channel_distances(recording)
    neighbours_mask = channel_distance < local_radius_um

    nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.)
    nafter = int(ms_after * recording.get_sampling_frequency() / 1000.)

    contact_locations = recording.get_channel_locations()

    # TODO
    #  make a memmap for peaks to avoid serilisation

    # and run
    func = _localize_peaks_chunk
    init_func = _init_worker_localize_peaks
    init_args = (recording.to_dict(), peaks, method, nbefore, nafter, neighbours_mask, contact_locations)
    processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True,
                                       job_name='localize peaks', **job_kwargs)
    peak_locations = processor.run()

    peak_locations = np.concatenate(peak_locations)

    return peak_locations
Exemple #4
0
def test_ChunkRecordingExecutor():
    recording = generate_recording(num_channels=2)
    # make dumpable
    recording = recording.save()

    def func(segment_index, start_frame, end_frame, worker_ctx):
        import os, time
        # print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid())
        time.sleep(0.010)
        # time.sleep(1.0)
        return os.getpid()

    def init_func(arg1, arg2, arg3):
        worker_ctx = {}
        worker_ctx['arg1'] = arg1
        worker_ctx['arg2'] = arg2
        worker_ctx['arg3'] = arg3
        return worker_ctx

    init_args = 'a', 120, 'yep'

    # no chunk
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_size=None)
    processor.run()

    # chunk + loop
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_memory="500k")
    processor.run()

    # chunk + parralel
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=True,
                                       n_jobs=2,
                                       total_memory="200k",
                                       job_name='job_name')
    processor.run()
def find_spike_from_templates(recording,
                              waveform_extractor,
                              method='simple',
                              method_kwargs={},
                              **job_kwargs):
    """
    Find spike from a recording from known given templates.
    Template are represented as WaveformExtractor so statistics can be extracted.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object
    
    
    method: 'simple' / ...
    
    peak_detect_kwargs: dict
        Params for peak detection
    
    job_kwargs: dict
        Parameters for ChunkRecordingExecutor
    """
    assert method in ('simple', )

    if method == 'simple':
        method_kwargs = check_kwargs_simple_matching(recording,
                                                     waveform_extractor,
                                                     method_kwargs)

    # and run
    func = _find_spike_chunk
    init_func = _init_worker_find_spike
    init_args = (recording.to_dict(), method, method_kwargs)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       job_name='find spikes',
                                       **job_kwargs)
    spikes = processor.run()

    spikes = np.concatenate(spikes)
    return spikes
def test_ChunkRecordingExecutor():
    recording = generate_recording(num_channels=2)
    # make dumpable
    recording = recording.save()

    init_args = 'a', 120, 'yep'

    # no chunk
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_size=None)
    processor.run()

    # chunk + loop
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_memory="500k")
    processor.run()

    # chunk + parralel
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=True,
                                       n_jobs=2,
                                       total_memory="200k",
                                       job_name='job_name')
    processor.run()
Exemple #7
0
def detect_peaks(recording,
                 method='by_channel',
                 peak_sign='neg',
                 detect_threshold=5,
                 n_shifts=2,
                 local_radius_um=100,
                 noise_levels=None,
                 random_chunk_kwargs={},
                 outputs='numpy_compact',
                 **job_kwargs):
    """
    Peak detection ported from tridesclous into spikeinterface.

    Peak detection based on threhold crossing in term of k x MAD
    
    Ifg the MAD is not provide then it is estimated with random snipet
    
    Several methods:

      * 'by_channel' : peak are dettected in each channel independantly
      * 'locally_exclusive' : locally given a radius the best peak only is taken but
          not neirbhoring channels

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object
    
    method: 
    
    peak_sign='neg'/ 'pos' / 'both'
        Signa of the peak.
    detect_threshold: float
        Threshold in median absolute deviations (MAD) to detect peaks
    n_shifts: int
        Number of shifts to find peak. E.g. if n_shift is 2, a peak is detected (if detect_sign is 'negative') if
        a sample is below the threshold, the two samples before are higher than the sample, and the two samples after
        the sample are higher than the sample.
    noise_levels: np.array
        noise_levels can be provide externally if already computed.
    random_chunk_kwargs: dict
        A dict that contain option to randomize chunk for get_noise_levels()
        Only used if noise_levels is None
    numpy_compact: str numpy_compact/numpy_split/sorting
        The type of the output. By default "numpy_compact"
        give a vector with complex dtype.
    """
    assert method in ('by_channel', 'locally_exclusive')
    assert peak_sign in ('both', 'neg', 'pos')
    assert outputs in ('numpy_compact', 'numpy_split', 'sorting')

    if method == 'locally_exclusive' and not HAVE_NUMBA:
        raise ModuleNotFoundError(
            '"locally_exclusive" need numba which is not installed')

    if noise_levels is None:
        noise_levels = get_noise_levels(recording, **random_chunk_kwargs)

    abs_threholds = noise_levels * detect_threshold

    if method == 'locally_exclusive':
        assert local_radius_um is not None
        channel_distance = get_channel_distances(recording)
        neighbours_mask = channel_distance < local_radius_um
    else:
        neighbours_mask = None

    # and run
    func = _detect_peaks_chunk
    init_func = _init_worker_detect_peaks
    init_args = (recording.to_dict(), method, peak_sign, abs_threholds,
                 n_shifts, neighbours_mask)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       **job_kwargs)
    peaks = processor.run()

    peak_sample_inds, peak_chan_inds, peak_amplitudes, peak_segments = zip(
        *peaks)
    peak_sample_inds = np.concatenate(peak_sample_inds)
    peak_chan_inds = np.concatenate(peak_chan_inds)
    peak_amplitudes = np.concatenate(peak_amplitudes)
    peak_segments = np.concatenate(peak_segments)

    if outputs == 'numpy_compact':
        dtype = [('sample_ind', 'int64'), ('channel_ind', 'int64'),
                 ('amplitude', 'float64'), ('segment_ind', 'int64')]
        peaks = np.zeros(peak_sample_inds.size, dtype=dtype)
        peaks['sample_ind'] = peak_sample_inds
        peaks['channel_ind'] = peak_chan_inds
        peaks['amplitude'] = peak_amplitudes
        peaks['segment_ind'] = peak_segments
        return peaks
    elif outputs == 'numpy_split':
        return peak_sample_inds, peak_chan_inds, peak_amplitudes, peak_segments
    elif outputs == 'sorting':
        #@alessio : here we can do what you did in old API
        # the output is a sorting where unit_id is in fact one channel
        raise NotImplementedError
def detect_peaks(recording,
                 method='by_channel',
                 peak_sign='neg',
                 detect_threshold=5,
                 n_shifts=2,
                 local_radius_um=50,
                 noise_levels=None,
                 random_chunk_kwargs={},
                 outputs='numpy_compact',
                 localization_dict=None,
                 **job_kwargs):
    """Peak detection based on threshold crossing in term of k x MAD.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object.
    method: 'by_channel', 'locally_exclusive'
        Method to use. Options:
            * 'by_channel' : peak are detected in each channel independently
            * 'locally_exclusive' : a single best peak is taken from a set of neighboring channels
    peak_sign: 'neg', 'pos', 'both'
        Sign of the peak.
    detect_threshold: float
        Threshold, in median absolute deviations (MAD), to use to detect peaks.
    n_shifts: int
        Number of shifts to find peak.
        For example, if `n_shift` is 2, a peak is detected if a sample crosses the threshold,
        and the two samples before and after are above the sample.
    local_radius_um: float
        The radius to use for detection across local channels.
    noise_levels: array, optional
        Estimated noise levels to use, if already computed.
        If not provide then it is estimated from a random snippet of the data.
    random_chunk_kwargs: dict, optional
        A dict that contain option to randomize chunk for get_noise_levels().
        Only used if noise_levels is None.
    outputs: 'numpy_compact', 'numpy_split', 'sorting'
        The type of the output. By default, "numpy_compact" returns an array with complex dtype.
        In case of 'sorting', each unit corresponds to a recording channel.
    localization_dict : dict, optional
        Can optionally do peak localization at the same time as detection.
        This avoids running `localize_peaks` separately and re-reading the entire dataset.
    {}

    Returns
    -------
    peaks: array
        Detected peaks.

    Notes
    -----
    This peak detection ported from tridesclous into spikeinterface.
    """

    assert method in ('by_channel', 'locally_exclusive')
    assert peak_sign in ('both', 'neg', 'pos')
    assert outputs in ('numpy_compact', 'numpy_split', 'sorting')

    if method == 'locally_exclusive' and not HAVE_NUMBA:
        raise ModuleNotFoundError(
            '"locally_exclusive" need numba which is not installed')

    if noise_levels is None:
        noise_levels = get_noise_levels(recording,
                                        return_scaled=False,
                                        **random_chunk_kwargs)

    abs_threholds = noise_levels * detect_threshold

    if method == 'locally_exclusive':
        assert local_radius_um is not None
        channel_distance = get_channel_distances(recording)
        neighbours_mask = channel_distance < local_radius_um
    else:
        neighbours_mask = None

    # deal with margin
    if localization_dict is None:
        extra_margin = 0
    else:
        assert isinstance(localization_dict, dict)
        assert localization_dict['method'] in dtype_localize_by_method.keys()
        localization_dict = init_kwargs_dict(localization_dict['method'],
                                             localization_dict)

        nbefore = int(localization_dict['ms_before'] *
                      recording.get_sampling_frequency() / 1000.)
        nafter = int(localization_dict['ms_after'] *
                     recording.get_sampling_frequency() / 1000.)
        extra_margin = max(nbefore, nafter)

    # and run
    func = _detect_peaks_chunk
    init_func = _init_worker_detect_peaks
    init_args = (recording.to_dict(), method, peak_sign, abs_threholds,
                 n_shifts, neighbours_mask, extra_margin, localization_dict)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       job_name='detect peaks',
                                       **job_kwargs)
    peaks = processor.run()
    peaks = np.concatenate(peaks)

    if outputs == 'numpy_compact':
        return peaks
    elif outputs == 'sorting':
        return NumpySorting.from_peaks(
            peaks, sampling_frequency=recording.get_sampling_frequency())
Exemple #9
0
    def compute_amplitudes(self, **job_kwargs):

        we = self.waveform_extractor
        recording = we.recording
        sorting = we.sorting

        all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
        self._all_spikes = all_spikes

        peak_sign = self._params['peak_sign']
        return_scaled = self._params['return_scaled']

        extremum_channels_index = get_template_extremum_channel(
            we, peak_sign=peak_sign, outputs='index')
        peak_shifts = get_template_extremum_channel_peak_shift(
            we, peak_sign=peak_sign)

        # put extremum_channels_index and peak_shifts in vector way
        extremum_channels_index = np.array(
            [extremum_channels_index[unit_id] for unit_id in sorting.unit_ids],
            dtype='int64')
        peak_shifts = np.array(
            [peak_shifts[unit_id] for unit_id in sorting.unit_ids],
            dtype='int64')

        if return_scaled:
            # check if has scaled values:
            if not we.recording.has_scaled_traces():
                print("Setting 'return_scaled' to False")
                return_scaled = False

        # and run
        func = _spike_amplitudes_chunk
        init_func = _init_worker_spike_amplitudes
        n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None))
        if n_jobs == 1:
            init_args = (recording, sorting)
        else:
            init_args = (recording.to_dict(), sorting.to_dict())
        init_args = init_args + (extremum_channels_index, peak_shifts,
                                 return_scaled)
        processor = ChunkRecordingExecutor(recording,
                                           func,
                                           init_func,
                                           init_args,
                                           handle_returns=True,
                                           job_name='extract amplitudes',
                                           **job_kwargs)
        out = processor.run()
        amps, segments = zip(*out)
        amps = np.concatenate(amps)
        segments = np.concatenate(segments)

        self._amplitudes = []
        for segment_index in range(recording.get_num_segments()):
            mask = segments == segment_index
            amps_seg = amps[mask]
            self._amplitudes.append(amps_seg)

            # save to folder
            file_amps = self.extension_folder / f'amplitude_segment_{segment_index}.npy'
            np.save(file_amps, amps_seg)
Exemple #10
0
def detect_peaks(recording,
                 method='by_channel',
                 peak_sign='neg',
                 detect_threshold=5,
                 n_shifts=2,
                 local_radius_um=50,
                 noise_levels=None,
                 random_chunk_kwargs={},
                 outputs='numpy_compact',
                 localization_dict=None,
                 **job_kwargs):
    """
    Peak detection ported from tridesclous into spikeinterface.

    Peak detection based on threhold crossing in term of k x MAD
    
    If the MAD is not provide then it is estimated with random snipet
    
    Several methods:

      * 'by_channel' : peak are dettected in each channel independantly
      * 'locally_exclusive' : locally given a radius the best peak only is taken but
        not neighboring channels

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object
    
    method: 
    
    peak_sign='neg'/ 'pos' / 'both'
        Signa of the peak.
    detect_threshold: float
        Threshold in median absolute deviations (MAD) to detect peaks
    n_shifts: int
        Number of shifts to find peak. E.g. if n_shift is 2, a peak is detected (if detect_sign is 'negative') if
        a sample is below the threshold, the two samples before are higher than the sample, and the two samples after
        the sample are higher than the sample.
    noise_levels: np.array
        noise_levels can be provide externally if already computed.
    random_chunk_kwargs: dict
        A dict that contain option to randomize chunk for get_noise_levels()
        Only used if noise_levels is None
    numpy_compact: str numpy_compact/numpy_split/sorting
        The type of the output. By default "numpy_compact"
        give a vector with complex dtype.
    localization_dict : None or dict
        Can optionally do peak localisation at the same time as detection.
        This avoids to run localize_peaks separately and re read the entire dataset.
    
    job_kwargs: dict
        Parameters for ChunkRecordingExecutor
    """
    assert method in ('by_channel', 'locally_exclusive')
    assert peak_sign in ('both', 'neg', 'pos')
    assert outputs in ('numpy_compact', 'numpy_split', 'sorting')

    if method == 'locally_exclusive' and not HAVE_NUMBA:
        raise ModuleNotFoundError(
            '"locally_exclusive" need numba which is not installed')

    if noise_levels is None:
        noise_levels = get_noise_levels(recording, **random_chunk_kwargs)

    abs_threholds = noise_levels * detect_threshold

    if method == 'locally_exclusive':
        assert local_radius_um is not None
        channel_distance = get_channel_distances(recording)
        neighbours_mask = channel_distance < local_radius_um
    else:
        neighbours_mask = None

    # deal with margin
    if localization_dict is None:
        extra_margin = 0
    else:
        assert isinstance(localization_dict, dict)
        assert localization_dict['method'] in dtype_localize_by_method.keys()
        localization_dict = init_kwargs_dict(localization_dict['method'],
                                             localization_dict)

        nbefore = int(localization_dict['ms_before'] *
                      recording.get_sampling_frequency() / 1000.)
        nafter = int(localization_dict['ms_after'] *
                     recording.get_sampling_frequency() / 1000.)
        extra_margin = max(nbefore, nafter)

    # and run
    func = _detect_peaks_chunk
    init_func = _init_worker_detect_peaks
    init_args = (recording.to_dict(), method, peak_sign, abs_threholds,
                 n_shifts, neighbours_mask, extra_margin, localization_dict)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       job_name='detect peaks',
                                       **job_kwargs)
    peaks = processor.run()
    peaks = np.concatenate(peaks)

    if outputs == 'numpy_compact':
        return peaks
    elif outputs == 'sorting':
        # @alessio : here we can do what you did in old API
        # the output is a sorting where unit_id is in fact one channel
        raise NotImplementedError
Exemple #11
0
    def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg',
                           **job_kwargs):
        """
        Project all spikes from the sorting on the PCA model.
        This is a long computation because waveform need to be extracted from each spikes.
        
        Used mainly for `export_to_phy()`
        
        PCs are exported to a .npy single file.

        Parameters
        ----------
        file_path : str or Path
            Path to npy file that will store the PCA projections
        max_channels_per_template : int, optionl
            Maximum number of best channels to compute PCA projections on
        peak_sign : str, optional
            Peak sign to get best channels ('neg', 'pos', 'both'), by default 'neg'
        {}
        """
        p = self._params
        we = self.waveform_extractor
        sorting = we.sorting
        recording = we.recording

        assert sorting.get_num_segments() == 1
        assert p['mode'] in ('by_channel_local', 'by_channel_global')

        file_path = Path(file_path)

        all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
        spike_times, spike_labels = all_spikes[0]

        max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels())

        best_channels_index = get_template_channel_sparsity(we, outputs="index", peak_sign=peak_sign,
                                                            num_channels=max_channels_per_template)

        unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids]

        pca_model = self.get_pca_model()
        if p['mode'] in ['by_channel_global', 'concatenated']:
            pca_model = [pca_model] * recording.get_num_channels()

        # nSpikes, nFeaturesPerChannel, nPCFeatures
        # this come from  phy template-gui
        # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets
        shape = (spike_times.size, p['n_components'], max_channels_per_template)
        all_pcs = np.lib.format.open_memmap(filename=file_path, mode='w+', dtype='float32', shape=shape)
        all_pcs_args = dict(filename=file_path, mode='r+', dtype='float32', shape=shape)

        # and run
        func = _all_pc_extractor_chunk
        init_func = _init_work_all_pc_extractor
        n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None))
        if n_jobs == 1:
            init_args = (recording,)
        else:
            init_args = (recording.to_dict(),)
        init_args = init_args + (all_pcs_args, spike_times, spike_labels, we.nbefore, we.nafter, 
                                 unit_channels, pca_model)
        processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs)
        processor.run()
def localize_peaks(recording, peaks, ms_before=0.1, ms_after=0.3,
    method='center_of_mass', method_kwargs={},
    **job_kwargs):
                   #~ local_radius_um=150, 

    """
    Localize peak (spike) in 2D or 3D depending the method.
    When a probe is 2D then:
       * X is axis 0 of the probe
       * Y is axis 1 of the probe
       * Z is orthogonal to the plane of the probe

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object
    peaks: numpy
        peak vector given by detect_peaks() in "compact_numpy" way.
    ms_before: float
        The left window before a peak in millisecond
    ms_after: float
        The left window before a peak in millisecond
    method: str
        Method to be used ('center_of_mass' or 'monopolar_triangulation')
    method_kwargs: dict of kwargs method
        'center_of_mass':
            * local_radius_um: float
                For channel sparsity
        'monopolar_triangulation' also have:
            * local_radius_um: float
                For channel sparsity
            * max_distance_um: float default 1000
                boundary for distance estimation

    Returns
    -------
    peak_locations: np.array
        Array with estimated location for each spike
        The dtype depend on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha')
    """
    assert method in possible_localization_methods, f"Method {method} is not supported. Choose from {possible_localization_methods}"
    
    # handle default method_kwargs
    method_kwargs = init_kwargs_dict(method, method_kwargs)
    
    nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.)
    nafter = int(ms_after * recording.get_sampling_frequency() / 1000.)

    contact_locations = recording.get_channel_locations()
    
    # margin at border for get_trace
    margin = max(nbefore, nafter)

    # TODO
    #  make a memmap for peaks to avoid serilisation

    # and run
    func = _localize_peaks_chunk
    init_func = _init_worker_localize_peaks
    init_args = (recording.to_dict(), peaks, method, method_kwargs, nbefore, nafter, contact_locations, margin)
    processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True,
                                       job_name='localize peaks', **job_kwargs)
    peak_locations = processor.run()

    peak_locations = np.concatenate(peak_locations)

    return peak_locations
def get_spike_amplitudes(waveform_extractor,
                         peak_sign='neg',
                         outputs='concatenated',
                         return_scaled=True,
                         **job_kwargs):
    """
    Computes the spike amplitudes from a WaveformExtractor.

    1. The waveform extractor is used to determine the max channel per unit.
    2. Then a "peak_shift" is estimated because for some sorters the spike index is not always at the
       peak.
    3. Amplitudes are extracted in chunks (parallel or not)

    Parameters
    ----------
    waveform_extractor: WaveformExtractor
        The waveform extractor object
    peak_sign: str
        The sign to compute maximum channel:
            - 'neg'
            - 'pos'
            - 'both'
    return_scaled: bool
        If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV.
    outputs: str
        How the output should be returned:
            - 'concatenated'
            - 'by_unit'
    {}

    Returns
    -------
    amplitudes: np.array
        The spike amplitudes.
            - If 'concatenated' all amplitudes for all spikes and all units are concatenated
            - If 'by_unit', amplitudes are returned as a list (for segments) of dictionaries (for units)
    """
    we = waveform_extractor
    recording = we.recording
    sorting = we.sorting

    all_spikes = sorting.get_all_spike_trains()

    extremum_channels_index = get_template_extremum_channel(
        waveform_extractor, peak_sign=peak_sign, outputs='index')
    peak_shifts = get_template_extremum_channel_peak_shift(waveform_extractor,
                                                           peak_sign='neg')

    if return_scaled:
        # check if has scaled values:
        if not waveform_extractor.recording.has_scaled_traces():
            print("Setting 'return_scaled' to False")
            return_scaled = False

    # and run
    func = _spike_amplitudes_chunk
    init_func = _init_worker_spike_amplitudes
    init_args = (recording.to_dict(), sorting.to_dict(),
                 extremum_channels_index, peak_shifts, return_scaled)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       job_name='extract amplitudes',
                                       **job_kwargs)
    out = processor.run()
    amps, segments = zip(*out)
    amps = np.concatenate(amps)
    segments = np.concatenate(segments)

    amplitudes = []
    for segment_index in range(recording.get_num_segments()):
        mask = segments == segment_index
        amplitudes.append(amps[mask])

    if outputs == 'concatenated':
        return amplitudes
    elif outputs == 'by_unit':
        amplitudes_by_unit = []
        for segment_index in range(recording.get_num_segments()):
            amplitudes_by_unit.append({})
            for unit_id in sorting.unit_ids:
                spike_times, spike_labels = all_spikes[segment_index]
                mask = spike_labels == unit_id
                amps = amplitudes[segment_index][mask]
                amplitudes_by_unit[segment_index][unit_id] = amps
        return amplitudes_by_unit