Exemplo n.º 1
0
def faster_bad_channels(epochs, picks=None, thres=3, use_metrics=None):
    """Implements the first step of the FASTER algorithm.
    
    This function attempts to automatically mark bad EEG channels by performing
    outlier detection. It operated on epoched data, to make sure only relevant
    data is analyzed.

    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs for which bad channels need to be marked
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. A channel
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'variance', 'correlation', 'hurst', 'kurtosis', 'line_noise'
        Defaults to all of them.

    Returns
    -------
    bads : list of str
        The names of the bad EEG channels.
    """
    metrics = {
        'variance':
        lambda x: np.var(x, axis=1),
        'correlation':
        lambda x: np.nanmean(np.ma.masked_array(
            np.corrcoef(x), np.identity(len(x), dtype=bool)),
                             axis=0),
        'hurst':
        lambda x: hurst(x),
        'kurtosis':
        lambda x: kurtosis(x, axis=1),
        'line_noise':
        lambda x: _freqs_power(x, epochs.info['sfreq'], [50, 60]),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info, meg=False, eeg=True, exclude=[])
    if use_metrics is None:
        use_metrics = metrics.keys()

    # Concatenate epochs in time
    data = epochs.get_data()
    data = data.transpose(1, 0, 2).reshape(data.shape[1], -1)
    data = data[picks]

    # Find bad channels
    bads = []
    for m in use_metrics:
        s = metrics[m](data)
        b = [epochs.ch_names[picks[i]] for i in find_outliers(s, thres)]
        #logger.info('Bad by %s:\n\t%s' % (m, b))
        bads.append(b)

    return np.unique(np.concatenate(bads)).tolist()
Exemplo n.º 2
0
def faster_bad_components(ica, epochs, thres=3, use_metrics=None):
    """Implements the third step of the FASTER algorithm.
    
    This function attempts to automatically mark bad ICA components by
    performing outlier detection.
    Parameters
    ----------
    ica : Instance of ICA
        The ICA operator, already fitted to the supplied Epochs object.
    epochs : Instance of Epochs
        The untransformed epochs to analyze.
    thres : float
        The threshold value, in standard deviations, to apply. A component
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'eog_correlation', 'kurtosis', 'power_gradient', 'hurst',
            'median_gradient'
        Defaults to all of them.
    Returns
    -------
    bads : list of int
        The indices of the bad components.
    See also
    --------
    ICA.find_bads_ecg
    ICA.find_bads_eog
    """
    source_data = ica.get_sources(epochs).get_data().transpose(1, 0, 2)
    source_data = source_data.reshape(source_data.shape[0], -1)

    metrics = {
        'eog_correlation':
        lambda x: x.find_bads_eog(epochs)[1],
        'kurtosis':
        lambda x: kurtosis(np.dot(x.mixing_matrix_.T, x.
                                  pca_components_[:x.n_components_]),
                           axis=1),
        'power_gradient':
        lambda x: _power_gradient(x, source_data),
        'hurst':
        lambda x: hurst(source_data),
        'median_gradient':
        lambda x: np.median(np.abs(np.diff(source_data)), axis=1),
        'line_noise':
        lambda x: _freqs_power(source_data, epochs.info['sfreq'], [50, 60]),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    bads = []
    for m in use_metrics:
        scores = np.atleast_2d(metrics[m](ica))
        for s in scores:
            b = find_outliers(s, thres)
            logger.info('Bad by %s:\n\t%s' % (m, b))
            bads.append(b)

    return np.unique(np.concatenate(bads)).tolist()
Exemplo n.º 3
0
    def _update(self):
        # Have we collected enough samples without the new input?
        enough_collected = self._samples_collected >=\
                self._samples_to_be_collected
        if not enough_collected:
            if self.parent.output is not None and\
                    self.parent.output.shape[TIME_AXIS] > 0:
                self._update_statistics()

        elif not self._enough_collected:  # We just got enough samples
            self._enough_collected = True
            standard_deviations = self._calculate_standard_deviations()
            self._bad_channel_indices = find_outliers(standard_deviations)
            if any(self._bad_channel_indices):
                # message = Message(there_has_been_a_change=True,
                #                   output_history_is_no_longer_valid=True)
                # self._deliver_a_message_to_receivers(message)
                # self.mne_info['bads'].append(self._bad_channel_indices)
                # self.mne_info['bads'] = self._bad_channel_indices

                # TODO: handle emergent bad channels on the go
                pass
        if self._dsamp_freq and self._dsamp_freq < self.mne_info['sfreq']:
            raw = mne.io.RawArray(self.parent.output, self.mne_info)
            raw.resample(self._dsamp_freq)
            self.output = raw.get_data()
            self.mne_info = raw.mne_info

        else:
            self.output = self.parent.output
Exemplo n.º 4
0
def faster_bad_channels_in_epochs(epochs,
                                  picks=None,
                                  thres=3,
                                  use_metrics=None):
    """Implements the fourth step of the FASTER algorithm.
    
    This function attempts to automatically mark bad channels in each epochs by
    performing outlier detection.
    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs to analyze.
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. An epoch
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'amplitude', 'variance', 'deviation', 'median_gradient'
        Defaults to all of them.
    Returns
    -------
    bads : list of lists of int
        For each epoch, the indices of the bad channels.
    """

    metrics = {
        'amplitude': lambda x: np.ptp(x, axis=2),
        'deviation': lambda x: _deviation(x),
        'variance': lambda x: np.var(x, axis=2),
        'median_gradient': lambda x: np.median(np.abs(np.diff(x)), axis=2),
        'line_noise':
        lambda x: _freqs_power(x, epochs.info['sfreq'], [50, 60]),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info,
                               meg=False,
                               eeg=True,
                               exclude='bads')
    if use_metrics is None:
        use_metrics = metrics.keys()

    data = epochs.get_data()[:, picks, :]

    bads = [[] for i in range(len(epochs))]
    for m in use_metrics:
        s_epochs = metrics[m](data)
        for i, s in enumerate(s_epochs):
            b = [epochs.ch_names[picks[j]] for j in find_outliers(s, thres)]
            logger.info('Epoch %d, Bad by %s:\n\t%s' % (i, m, b))
            bads[i].append(b)

    for i, b in enumerate(bads):
        if len(b) > 0:
            bads[i] = np.unique(np.concatenate(b)).tolist()

    return bads
Exemplo n.º 5
0
def faster_bad_epochs(epochs, picks=None, thres=3, use_metrics=None):
    """Implements the second step of the FASTER algorithm.
    
    This function attempts to automatically mark bad epochs by performing
    outlier detection.
    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs to analyze.
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. An epoch
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'amplitude', 'variance', 'deviation'
        Defaults to all of them.
    Returns
    -------
    bads : list of int
        The indices of the bad epochs.
    """

    metrics = {
        'amplitude': lambda x: np.mean(np.ptp(x, axis=2), axis=1),
        'deviation': lambda x: np.mean(_deviation(x), axis=1),
        'variance': lambda x: np.mean(np.var(x, axis=2), axis=1),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info,
                               meg=False,
                               eeg=True,
                               exclude='bads')
    if use_metrics is None:
        use_metrics = metrics.keys()

    data = epochs.get_data()[:, picks, :]

    bads = []
    for m in use_metrics:
        s = metrics[m](data)
        b = find_outliers(s, thres)
        logger.info('Bad by %s:\n\t%s' % (m, b))
        bads.append(b)

    return np.unique(np.concatenate(bads)).tolist()
Exemplo n.º 6
0
    def _update(self):
        # Have we collected enough samples without the new input?
        enough_collected = self._samples_collected >= self._samples_to_be_collected
        if not enough_collected:
            if self.input_node.output is not None and self.input_node.output.shape[
                    TIME_AXIS] > 0:
                self._update_statistics()

        elif not self._enough_collected:  # We just got enough samples
            self._enough_collected = True
            standard_deviations = self._calculate_standard_deviations()
            self._bad_channel_indices = find_outliers(standard_deviations)
            if any(self._bad_channel_indices):
                self._interpolation_matrix = self._calculate_interpolation_matrix(
                )
                message = Message(there_has_been_a_change=True,
                                  output_history_is_no_longer_valid=True)
                self._deliver_a_message_to_receivers(message)

        self.output = self._interpolate(self.input_node.output)
Exemplo n.º 7
0
def find_bad_epochs(epochs, picks=None, thresh=3.29053):
    """Find bad epochs based on amplitude, deviation, and variance.

    Inspired by [1], based on code by Marijn van Vliet [2]. This
    function is working on z-scores. You might want to select the
    thresholds according to how much of the data is expected to
    fall within the absolute bounds:

    95.0% --> 1.95996

    97.0% --> 2.17009

    99.0% --> 2.57583

    99.9% --> 3.29053

    Notes
    -----
    For this function to work, bad channels should have been identified
    and removed or interpolated beforehand. Additionally, baseline
    correction or highpass filtering is recommended to reduce signal
    drifts over time.

    Parameters
    ----------
    epochs : mne epochs object
        The epochs to analyze.

    picks : list of int | None
        Channels to operate on. Defaults to all clean EEG channels. Drops
        EEG channels marked as bad.

    thresh : float
        Epochs that surpass the threshold with their z-score based
        on amplitude, deviation, or variance, will be considered
        bad.

    Returns
    -------
    bads : list of int
        Indices of the bad epochs.

    References
    ----------
    .. [1] Nolan, H., Whelan, R., & Reilly, R. B. (2010). FASTER:
       fully automated statistical thresholding for EEG artifact
       rejection. Journal of neuroscience methods, 192(1), 152-162.

    .. [2] https://gist.github.com/wmvanvliet/d883c3fe1402c7ced6fc

    """
    if picks is None:
        picks = mne.pick_types(epochs.info,
                               meg=False,
                               eeg=True,
                               exclude="bads")

    def calc_deviation(data):
        ch_mean = np.mean(data, axis=2)
        return ch_mean - np.mean(ch_mean, axis=0)

    metrics = {
        "amplitude": lambda x: np.mean(np.ptp(x, axis=2), axis=1),
        "deviation": lambda x: np.mean(calc_deviation(x), axis=1),
        "variance": lambda x: np.mean(np.var(x, axis=2), axis=1),
    }

    data = epochs.get_data()[:, picks, :]

    bads = []
    for m in metrics.keys():
        signal = metrics[m](data)
        bad_idx = find_outliers(signal, thresh)
        bads.append(bad_idx)

    return np.unique(np.concatenate(bads)).tolist()