def test_ensure_n_jobs():
    recording = generate_recording()

    n_jobs = ensure_n_jobs(recording)
    assert n_jobs == 1

    n_jobs = ensure_n_jobs(recording, n_jobs=0)
    assert n_jobs == 1

    n_jobs = ensure_n_jobs(recording, n_jobs=1)
    assert n_jobs == 1

    # not dumpable force n_jobs=1
    n_jobs = ensure_n_jobs(recording, n_jobs=-1)
    assert n_jobs == 1

    # dumpable
    n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1)
    assert n_jobs > 1
示例#2
0
    def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg', **job_kwargs):
        """
        This run the PCs on all spike from the sorting.
        This is a long computation because waveform need to be extracted from each spikes.
        
        Used mainly for `export_to_phy()`
        
        PCs are exported to a .npy single file.
        
        """
        p = self._params
        we = self.waveform_extractor
        sorting = we.sorting
        recording = we.recording

        
        assert sorting.get_num_segments() == 1
        assert p['mode'] in ('by_channel_local', 'by_channel_global')
        
        file_path = Path(file_path)

        all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
        spike_times, spike_labels = all_spikes[0]        
        
        max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels())
        
        best_channels_index = get_template_best_channels(we, max_channels_per_template,
                                    peak_sign=peak_sign, outputs='index')
        unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids]
        
        if  p['mode'] == 'by_channel_local':
            all_pca = self._fit_by_channel_local()
        elif p['mode'] == 'by_channel_global':
            one_pca = self._fit_by_channel_global()
            all_pca = [one_pca] * recording.get_num_channels()
        
        # nSpikes, nFeaturesPerChannel, nPCFeatures
        # this come from  phy template-gui
        # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets
        shape = (spike_times.size, p['n_components'], max_channels_per_template)
        all_pcs = np.lib.format.open_memmap(file_path, mode='w+', dtype='float32', shape=shape)
        
        # and run
        func = _all_pc_extractor_chunk
        init_func = _init_work_all_pc_extractor
        n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None))
        if n_jobs == 1:
            init_args = (recording, )
        else:
            init_args = (recording.to_dict(), )
        init_args = init_args + (all_pcs, spike_times, spike_labels, we.nbefore, we.nafter, unit_channels, all_pca)
        processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs)
        processor.run()
示例#3
0
    def compute_amplitudes(self, **job_kwargs):

        we = self.waveform_extractor
        recording = we.recording
        sorting = we.sorting

        all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
        self._all_spikes = all_spikes

        peak_sign = self._params['peak_sign']
        return_scaled = self._params['return_scaled']

        extremum_channels_index = get_template_extremum_channel(
            we, peak_sign=peak_sign, outputs='index')
        peak_shifts = get_template_extremum_channel_peak_shift(
            we, peak_sign=peak_sign)

        # put extremum_channels_index and peak_shifts in vector way
        extremum_channels_index = np.array(
            [extremum_channels_index[unit_id] for unit_id in sorting.unit_ids],
            dtype='int64')
        peak_shifts = np.array(
            [peak_shifts[unit_id] for unit_id in sorting.unit_ids],
            dtype='int64')

        if return_scaled:
            # check if has scaled values:
            if not we.recording.has_scaled_traces():
                print("Setting 'return_scaled' to False")
                return_scaled = False

        # and run
        func = _spike_amplitudes_chunk
        init_func = _init_worker_spike_amplitudes
        n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None))
        if n_jobs == 1:
            init_args = (recording, sorting)
        else:
            init_args = (recording.to_dict(), sorting.to_dict())
        init_args = init_args + (extremum_channels_index, peak_shifts,
                                 return_scaled)
        processor = ChunkRecordingExecutor(recording,
                                           func,
                                           init_func,
                                           init_args,
                                           handle_returns=True,
                                           job_name='extract amplitudes',
                                           **job_kwargs)
        out = processor.run()
        amps, segments = zip(*out)
        amps = np.concatenate(amps)
        segments = np.concatenate(segments)

        self._amplitudes = []
        for segment_index in range(recording.get_num_segments()):
            mask = segments == segment_index
            amps_seg = amps[mask]
            self._amplitudes.append(amps_seg)

            # save to folder
            file_amps = self.extension_folder / f'amplitude_segment_{segment_index}.npy'
            np.save(file_amps, amps_seg)
示例#4
0
    def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg',
                           **job_kwargs):
        """
        Project all spikes from the sorting on the PCA model.
        This is a long computation because waveform need to be extracted from each spikes.
        
        Used mainly for `export_to_phy()`
        
        PCs are exported to a .npy single file.

        Parameters
        ----------
        file_path : str or Path
            Path to npy file that will store the PCA projections
        max_channels_per_template : int, optionl
            Maximum number of best channels to compute PCA projections on
        peak_sign : str, optional
            Peak sign to get best channels ('neg', 'pos', 'both'), by default 'neg'
        {}
        """
        p = self._params
        we = self.waveform_extractor
        sorting = we.sorting
        recording = we.recording

        assert sorting.get_num_segments() == 1
        assert p['mode'] in ('by_channel_local', 'by_channel_global')

        file_path = Path(file_path)

        all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
        spike_times, spike_labels = all_spikes[0]

        max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels())

        best_channels_index = get_template_channel_sparsity(we, outputs="index", peak_sign=peak_sign,
                                                            num_channels=max_channels_per_template)

        unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids]

        pca_model = self.get_pca_model()
        if p['mode'] in ['by_channel_global', 'concatenated']:
            pca_model = [pca_model] * recording.get_num_channels()

        # nSpikes, nFeaturesPerChannel, nPCFeatures
        # this come from  phy template-gui
        # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets
        shape = (spike_times.size, p['n_components'], max_channels_per_template)
        all_pcs = np.lib.format.open_memmap(filename=file_path, mode='w+', dtype='float32', shape=shape)
        all_pcs_args = dict(filename=file_path, mode='r+', dtype='float32', shape=shape)

        # and run
        func = _all_pc_extractor_chunk
        init_func = _init_work_all_pc_extractor
        n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None))
        if n_jobs == 1:
            init_args = (recording,)
        else:
            init_args = (recording.to_dict(),)
        init_args = init_args + (all_pcs_args, spike_times, spike_labels, we.nbefore, we.nafter, 
                                 unit_channels, pca_model)
        processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs)
        processor.run()