Exemple #1
0
def threshold_drift_metrics(sorting,
                            recording,
                            threshold,
                            threshold_sign,
                            metric_name="max_drift",
                            drift_metrics_interval_s=DriftMetric.
                            params['drift_metrics_interval_s'],
                            drift_metrics_min_spikes_per_interval=DriftMetric.
                            params['drift_metrics_min_spikes_per_interval'],
                            recording_params=get_recording_params(),
                            pca_scores_params=get_pca_scores_params(),
                            feature_params=get_feature_params(),
                            save_as_property=True,
                            seed=DriftMetric.params['seed'],
                            verbose=DriftMetric.params['verbose']):
    """
    Computes and thresholds the specified drift metric for the sorted dataset with the given sign and value.

    Parameters
    ----------
    sorting: SortingExtractor
        The sorting result to be evaluated.

    recording: RecordingExtractor
        The given recording extractor

    threshold: int or float
        The threshold for the given metric.

    threshold_sign: str
        If 'less', will threshold any metric less than the given threshold.
        If 'less_or_equal', will threshold any metric less than or equal to the given threshold.
        If 'greater', will threshold any metric greater than the given threshold.
        If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold.

    metric_name: str
        The name of the nearest neighbor metric to be thresholded (either "max_drift" or "cumulative_drift").

    drift_metrics_interval_s: float
        Time period for evaluating drift.

    drift_metrics_min_spikes_per_interval: int
        Minimum number of spikes for evaluating drift metrics per interval.

    recording_params: dict
        This dictionary should contain any subset of the following parameters:
            apply_filter: bool
                If True, recording is bandpass-filtered.
            freq_min: float
                High-pass frequency for optional filter (default 300 Hz).
            freq_max: float
                Low-pass frequency for optional filter (default 6000 Hz).

    pca_scores_params: dict
        This dictionary should contain any subset of the following parameters:
            ms_before: float
                Time period in ms to cut waveforms before the spike events
            ms_after: float
                Time period in ms to cut waveforms after the spike events
            dtype: dtype
                The numpy dtype of the waveforms
            max_spikes_per_unit: int
                The maximum number of spikes to extract per unit.
            max_spikes_for_pca: int
                The maximum number of spikes to use to compute PCA.

    feature_params: dict
        This dictionary should contain any subset of the following parameters:
            save_features_props: bool
                If true, it will save features in the sorting extractor.
            recompute_info: bool
                    If True, waveforms are recomputed
            max_spikes_per_unit: int
                The maximum number of spikes to extract per unit.

    save_as_property: bool
        If True, the metric is saved as sorting property
    
    seed: int
        Random seed for reproducibility

    verbose: bool
        If True, will be verbose in metric computation.

    Returns
    ----------
    threshold sorting extractor
    """
    rp_dict, ps_dict, fp_dict = update_param_dicts(
        recording_params=recording_params,
        pca_scores_params=pca_scores_params,
        feature_params=feature_params)

    md = MetricData(sorting=sorting,
                    sampling_frequency=recording.get_sampling_frequency(),
                    recording=recording,
                    apply_filter=rp_dict["apply_filter"],
                    freq_min=rp_dict["freq_min"],
                    freq_max=rp_dict["freq_max"],
                    unit_ids=None,
                    epoch_tuples=None,
                    epoch_names=None,
                    verbose=verbose)

    md.compute_pca_scores(
        n_comp=ps_dict["n_comp"],
        ms_before=ps_dict["ms_before"],
        ms_after=ps_dict["ms_after"],
        dtype=ps_dict["dtype"],
        max_spikes_per_unit=fp_dict["max_spikes_per_unit"],
        max_spikes_for_pca=ps_dict["max_spikes_for_pca"],
        save_features_props=fp_dict['save_features_props'],
        recompute_info=fp_dict['recompute_info'],
        seed=seed,
    )

    dm = DriftMetric(metric_data=md)
    threshold_sorting = dm.threshold_metric(
        threshold, threshold_sign, metric_name, drift_metrics_interval_s,
        drift_metrics_min_spikes_per_interval, save_as_property)
    return threshold_sorting
Exemple #2
0
def threshold_nn_metrics(
        sorting,
        recording,
        threshold,
        threshold_sign,
        metric_name="nn_hit_rate",
        num_channels_to_compare=NearestNeighbor.
    params['num_channels_to_compare'],
        max_spikes_per_cluster=NearestNeighbor.
    params['max_spikes_per_cluster'],
        max_spikes_for_nn=NearestNeighbor.params['max_spikes_for_nn'],
        n_neighbors=NearestNeighbor.params['n_neighbors'],
        recording_params=get_recording_params(),
        pca_scores_params=get_pca_scores_params(),
        feature_params=get_feature_params(),
        save_as_property=True,
        seed=NearestNeighbor.params['seed'],
        verbose=NearestNeighbor.params['verbose']):
    """
    Computes and thresholds the specified nearest neighbor metric for the sorted dataset with the given sign and value.

    Parameters
    ----------
    sorting: SortingExtractor
        The sorting result to be evaluated.

    recording: RecordingExtractor
        The given recording extractor

    threshold: int or float
        The threshold for the given metric.

    threshold_sign: str
        If 'less', will threshold any metric less than the given threshold.
        If 'less_or_equal', will threshold any metric less than or equal to the given threshold.
        If 'greater', will threshold any metric greater than the given threshold.
        If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold.

    metric_name: str
        The name of the nearest neighbor metric to be thresholded (either "nn_hit_rate" or "nn_miss_rate").

    num_channels_to_compare: int
        The number of channels to be used for the PC extraction and comparison
        
    max_spikes_per_cluster: int
        Max spikes to be used from each unit

    max_spikes_for_nn: int
        Max spikes to be used for nearest-neighbors calculation.
    
    n_neighbors: int
        Number of neighbors to compare.

    recording_params: dict
        This dictionary should contain any subset of the following parameters:
            apply_filter: bool
                If True, recording is bandpass-filtered.
            freq_min: float
                High-pass frequency for optional filter (default 300 Hz).
            freq_max: float
                Low-pass frequency for optional filter (default 6000 Hz).

    pca_scores_params: dict
        This dictionary should contain any subset of the following parameters:
            ms_before: float
                Time period in ms to cut waveforms before the spike events
            ms_after: float
                Time period in ms to cut waveforms after the spike events
            dtype: dtype
                The numpy dtype of the waveforms
            max_spikes_per_unit: int
                The maximum number of spikes to extract per unit.
            max_spikes_for_pca: int
                The maximum number of spikes to use to compute PCA.

    feature_params: dict
        This dictionary should contain any subset of the following parameters:
            save_features_props: bool
                If true, it will save features in the sorting extractor.
            recompute_info: bool
                    If True, waveforms are recomputed
            max_spikes_per_unit: int
                The maximum number of spikes to extract per unit.

    save_as_property: bool
        If True, the metric is saved as sorting property
    
    save_as_property: bool
        If True, the metric is saved as sorting property
        
    seed: int
        Random seed for reproducibility

    verbose: bool
        If True, will be verbose in metric computation.

    Returns
    ----------
    threshold sorting extractor
    """
    rp_dict, ps_dict, fp_dict = update_param_dicts(
        recording_params=recording_params,
        pca_scores_params=pca_scores_params,
        feature_params=feature_params)

    md = MetricData(sorting=sorting,
                    sampling_frequency=recording.get_sampling_frequency(),
                    recording=recording,
                    apply_filter=rp_dict["apply_filter"],
                    freq_min=rp_dict["freq_min"],
                    freq_max=rp_dict["freq_max"],
                    unit_ids=None,
                    epoch_tuples=None,
                    epoch_names=None,
                    verbose=verbose)
    md.compute_pca_scores(
        n_comp=ps_dict["n_comp"],
        ms_before=ps_dict["ms_before"],
        ms_after=ps_dict["ms_after"],
        dtype=ps_dict["dtype"],
        max_spikes_per_unit=fp_dict["max_spikes_per_unit"],
        max_spikes_for_pca=ps_dict["max_spikes_for_pca"],
        save_features_props=fp_dict['save_features_props'],
        recompute_info=fp_dict['recompute_info'],
        seed=seed,
    )

    nn = NearestNeighbor(metric_data=md)
    threshold_sorting = nn.threshold_metric(threshold, threshold_sign,
                                            metric_name,
                                            num_channels_to_compare,
                                            max_spikes_per_cluster,
                                            max_spikes_for_nn, n_neighbors,
                                            seed, save_as_property)
    return threshold_sorting
def threshold_drift_metrics(
        sorting,
        recording,
        threshold,
        threshold_sign,
        metric_name="max_drift",
        drift_metrics_interval_s=DriftMetric.params['drift_metrics_interval_s'],
        drift_metrics_min_spikes_per_interval=DriftMetric.params['drift_metrics_min_spikes_per_interval'],
        **kwargs
):
    """
    Computes and thresholds the specified drift metric for the sorted dataset with the given sign and value.

    Parameters
    ----------
    sorting: SortingExtractor
        The sorting result to be evaluated.
    recording: RecordingExtractor
        The given recording extractor
    threshold: int or float
        The threshold for the given metric.
    threshold_sign: str
        If 'less', will threshold any metric less than the given threshold.
        If 'less_or_equal', will threshold any metric less than or equal to the given threshold.
        If 'greater', will threshold any metric greater than the given threshold.
        If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold.
    metric_name: str
        The name of the drift metric to be thresholded (either "max_drift" or "cumulative_drift").
    drift_metrics_interval_s: float
        Time period for evaluating drift.
    drift_metrics_min_spikes_per_interval: int
        Minimum number of spikes for evaluating drift metrics per interval.
    **kwargs: keyword arguments
        Keyword arguments among the following:
            method: str
                If 'absolute' (default), amplitudes are absolute amplitudes in uV are returned.
                If 'relative', amplitudes are returned as ratios between waveform amplitudes and template amplitudes
            peak: str
                If maximum channel has to be found among negative peaks ('neg'), positive ('pos') or
                both ('both' - default)
            frames_before: int
                Frames before peak to compute amplitude
            frames_after: int
                Frames after peak to compute amplitude
            apply_filter: bool
                If True, recording is bandpass-filtered
            freq_min: float
                High-pass frequency for optional filter (default 300 Hz)
            freq_max: float
                Low-pass frequency for optional filter (default 6000 Hz)
            grouping_property: str
                Property to group channels. E.g. if the recording extractor has the 'group' property and
                'grouping_property' is 'group', then waveforms are computed group-wise.
            ms_before: float
                Time period in ms to cut waveforms before the spike events
            ms_after: float
                Time period in ms to cut waveforms after the spike events
            dtype: dtype
                The numpy dtype of the waveforms
            compute_property_from_recording: bool
                If True and 'grouping_property' is given, the property of each unit is assigned as the corresponding
                property of the recording extractor channel on which the average waveform is the largest
            max_channels_per_waveforms: int or None
                Maximum channels per waveforms to return. If None, all channels are returned
            n_jobs: int
                Number of parallel jobs (default 1)
            memmap: bool
                If True, waveforms are saved as memmap object (recommended for long recordings with many channels)
            save_property_or_features: bool
                If true, it will save features in the sorting extractor
            recompute_info: bool
                    If True, waveforms are recomputed
            max_spikes_per_unit: int
                The maximum number of spikes to extract per unit
            seed: int
                Random seed for reproducibility
            verbose: bool
                If True, will be verbose in metric computation

    Returns
    ----------
    threshold sorting extractor
    """
    params_dict = update_all_param_dicts_with_kwargs(kwargs)

    md = MetricData(sorting=sorting, sampling_frequency=recording.get_sampling_frequency(), recording=recording,
                    apply_filter=params_dict["apply_filter"], freq_min=params_dict["freq_min"],
                    duration_in_frames=None, freq_max=params_dict["freq_max"], unit_ids=None, verbose=params_dict['verbose'])

    md.compute_pca_scores(**kwargs)

    dm = DriftMetric(metric_data=md)
    threshold_sorting = dm.threshold_metric(threshold, threshold_sign, metric_name, drift_metrics_interval_s,
                                            drift_metrics_min_spikes_per_interval, **kwargs)
    return threshold_sorting