コード例 #1
0
def test_autoreject():
    """Some basic tests for autoreject."""

    event_id = {'Visual/Left': 3}
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    include = [u'EEG %03d' % i for i in range(1, 15)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=False, include=include, exclude=[])
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=8,
                        reject=None, add_eeg_ref=False, preload=True)

    X = epochs.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_channels=n_channels, n_times=n_times,
                          thresh=40e-6)
    ar.fit(X)

    reject = get_rejection_threshold(epochs)
    assert_true(reject, isinstance(reject, dict))

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, ar, X, None,
                  param_name, param_range)

    ar = LocalAutoReject()
    assert_raises(NotImplementedError, validation_curve, ar, epochs, None,
                  param_name, param_range)

    ar = LocalAutoRejectCV()
    assert_raises(ValueError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    for method in ['random_search', 'bayesian_optimization']:
        compute_thresholds(epochs, method=method)
コード例 #2
0
thresh_func = partial(compute_thresholds, picks=picks, method='random_search',
                      random_state=42)

###############################################################################
# :class:`autoreject.LocalAutoRejectCV` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

###############################################################################
# Note that:class:`autoreject.LocalAutoRejectCV` by design supports
# multiple channels.
# If no picks are passed separate solutions will be computed for each channel
# type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

ar = LocalAutoRejectCV(n_interpolates, consensus_percs, picks=picks,
                       thresh_func=thresh_func)

# Not that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
evoked_clean = epochs_clean.average()
evoked = epochs['Auditory/Left'].average()

###############################################################################
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

###############################################################################
コード例 #3
0
# because we do not want epochs to be dropped when instantiating
# :class:`mne.Epochs`.

###############################################################################

epochs = Epochs(raw, events, event_id, tmin, tmax,
                picks=picks, baseline=(None, 0), reject=None,
                verbose=False, detrend=0, preload=True)

###############################################################################
# :class:`autoreject.LocalAutoRejectCV` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

###############################################################################

ar = LocalAutoRejectCV(n_interpolates, consensus_percs, compute_thresholds)
epochs_clean = ar.fit_transform(epochs)

evoked = epochs.average()
evoked_clean = epochs_clean.average()

###############################################################################
# Now, we will manually mark the bad channels just for plotting.

###############################################################################

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

###############################################################################
# Let us plot the results.
コード例 #4
0
###############################################################################
# Note that :class:`autoreject.LocalAutoRejectCV` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

###############################################################################
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.LocalAutoRejectCV` on a
# representative subsample of the data.


ar = LocalAutoRejectCV(thresh_func=thresh_func, verbose='tqdm', picks=picks)

ar.fit(this_epoch)
epochs_ar = ar.transform(this_epoch)

###############################################################################
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.
コード例 #5
0
    #######################################################################
    from functools import partial
    thresh_func = partial(compute_thresholds,
                          method='random_search',
                          random_state=42)

    ######################################################################
    # :class:`autoreject.LocalAutoRejectCV` internally does cross-validation to
    # determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

    #####################################################################
    epochs_grad = epochs.copy().pick_types(meg="grad")
    epochs_mag = epochs.copy().pick_types(meg="mag")

    ar_grad = LocalAutoRejectCV(n_interpolates,
                                consensus_percs,
                                thresh_func=thresh_func,
                                verbose="progressbar")

    ar_mag = LocalAutoRejectCV(n_interpolates,
                               consensus_percs,
                               thresh_func=thresh_func,
                               verbose="progressbar")

    epochs_grad_clean = ar_grad.fit_transform(epochs_grad)
    epochs_mag_clean = ar_mag.fit_transform(epochs_mag)

    epochs_clean = epochs.copy()

    bads_grads = ar_grad.bad_epochs_idx
    bads_mags = ar_mag.bad_epochs_idx
    bads_comb = list(set(list(bads_mags) + list(bads_grads)))
コード例 #6
0
def test_autoreject():
    """Test basic LocalAutoReject functionality."""

    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    ##########################################################################
    # picking epochs
    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info,
                           meg=False,
                           eeg=False,
                           stim=False,
                           eog=True,
                           include=include,
                           exclude=[])
    epochs = mne.Epochs(raw,
                        events,
                        event_id,
                        tmin,
                        tmax,
                        picks=picks,
                        baseline=(None, 0),
                        decim=10,
                        reject=None,
                        preload=False)[:10]

    ar = LocalAutoReject()
    assert_raises(ValueError, ar.fit, epochs)
    epochs.load_data()

    ar.fit(epochs)
    assert_true(len(ar.picks) == len(picks) - 1)

    # epochs with no picks.
    epochs = mne.Epochs(raw,
                        events,
                        event_id,
                        tmin,
                        tmax,
                        baseline=(None, 0),
                        decim=10,
                        reject=None,
                        preload=True)[:20]
    # let's drop some channels to speed up
    pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True)
    pre_picks = np.r_[
        mne.pick_types(epochs.info, meg='mag', eeg=False)[::15],
        mne.pick_types(epochs.info, meg='grad', eeg=False)[::60],
        mne.pick_types(epochs.info, meg=False, eeg=True)[::16],
        mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)]
    pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks]
    epochs.pick_channels(pick_ch_names)
    epochs_fit = epochs[:10]
    epochs_new = epochs[10:]

    X = epochs_fit.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar_global = GlobalAutoReject(n_channels=n_channels,
                                 n_times=n_times,
                                 thresh=40e-6)
    ar_global.fit(X)

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, ar_global, X, None, param_name,
                  param_range)

    ##########################################################################
    # picking AutoReject
    picks = mne.pick_types(epochs.info,
                           meg='mag',
                           eeg=True,
                           stim=False,
                           eog=False,
                           include=[],
                           exclude=[])
    ch_types = ['mag', 'eeg']

    ar = LocalAutoReject(picks=picks)
    assert_raises(NotImplementedError, validation_curve, ar, epochs, None,
                  param_name, param_range)

    thresh_func = partial(compute_thresholds,
                          method='bayesian_optimization',
                          random_state=42)
    ar = LocalAutoRejectCV(cv=3,
                           picks=picks,
                           thresh_func=thresh_func,
                           n_interpolates=[1, 2],
                           consensus_percs=[0.5, 1])
    assert_raises(AttributeError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    ar.fit(epochs_fit)
    fix_log = ar.fix_log
    bad_epochs_idx = ar.local_reject_.bad_epochs_idx_
    good_epochs_idx = ar.local_reject_.good_epochs_idx_
    for ch_type in ch_types:
        # test that kappa & rho are selected
        assert_true(ar.n_interpolate_[ch_type] in ar.n_interpolates)
        assert_true(ar.consensus_perc_[ch_type] in ar.consensus_percs)
        # test that local autoreject is synced with AR-CV instance
        assert_equal(ar.n_interpolate_[ch_type],
                     ar.local_reject_.n_interpolate[ch_type])
        assert_equal(ar.consensus_perc_[ch_type],
                     ar.local_reject_.consensus_perc[ch_type])

    # test complementarity of goods and bads
    assert_array_equal(np.sort(np.r_[bad_epochs_idx, good_epochs_idx]),
                       np.arange(len(epochs_fit)))

    # test that transform does not change state of ar
    epochs_fit.fit_ = True
    epochs_clean = ar.transform(epochs_fit)  # apply same data
    assert_array_equal(fix_log, ar.fix_log)
    assert_array_equal(bad_epochs_idx, ar.local_reject_.bad_epochs_idx_)
    assert_array_equal(good_epochs_idx, ar.local_reject_.good_epochs_idx_)

    epochs_new_clean = ar.transform(epochs_new)  # apply to new data
    assert_array_equal(fix_log, ar.fix_log)
    assert_array_equal(bad_epochs_idx, ar.local_reject_.bad_epochs_idx_)
    assert_array_equal(good_epochs_idx, ar.local_reject_.good_epochs_idx_)

    is_same = epochs_new_clean.get_data() == epochs_new.get_data()
    if not np.isscalar(is_same):
        is_same = np.isscalar(is_same)
    assert_true(not is_same)

    assert_equal(epochs_clean.ch_names, epochs_fit.ch_names)
    # Now we test that the .bad_segments has the shape
    # of n_trials, n_sensors, such that n_sensors is the
    # the full number sensors, before picking. We, hence,
    # expect nothing to be rejected outside of our picks
    # but rejections can occur inside our picks.
    assert_equal(ar.bad_segments.shape[1], len(epochs_fit.ch_names))
    assert_true(np.any(ar.bad_segments[:, picks]))
    non_picks = np.ones(len(epochs_fit.ch_names), dtype=bool)
    non_picks[picks] = False
    assert_true(not np.any(ar.bad_segments[:, non_picks]))

    assert_true(isinstance(ar.threshes_, dict))
    assert_true(len(ar.picks) == len(picks))
    assert_true(len(ar.threshes_.keys()) == len(ar.picks))
    pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)
    assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys())
    assert_raises(
        IndexError, ar.transform,
        epochs.copy().pick_channels([epochs.ch_names[pp] for pp in picks[:3]]))

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp])
                            for ii, pp in enumerate(picks)])
    threshes_a = compute_thresholds(epochs_fit,
                                    picks=picks,
                                    method='random_search')
    assert_equal(set(threshes_a.keys()), set(ch_names))
    threshes_b = compute_thresholds(epochs_fit,
                                    picks=picks,
                                    method='bayesian_optimization')
    assert_equal(set(threshes_b.keys()), set(ch_names))
コード例 #7
0
    #######################################################################
    from functools import partial
    thresh_func = partial(
        compute_thresholds, method='random_search', random_state=42)

    ######################################################################
    # :class:`autoreject.LocalAutoRejectCV` internally does cross-validation to
    # determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

    #####################################################################
    epochs_grad = epochs.copy().pick_types(meg="grad")
    epochs_mag = epochs.copy().pick_types(meg="mag")

    ar_grad = LocalAutoRejectCV(
        n_interpolates,
        consensus_percs,
        thresh_func=thresh_func,
        verbose="progressbar")

    ar_mag = LocalAutoRejectCV(
        n_interpolates,
        consensus_percs,
        thresh_func=thresh_func,
        verbose="progressbar")

    epochs_grad_clean = ar_grad.fit_transform(epochs_grad)
    epochs_mag_clean = ar_mag.fit_transform(epochs_mag)

    epochs_clean = epochs.copy()

    bads_grads = ar_grad.bad_epochs_idx
コード例 #8
0
thresh_func = partial(compute_thresholds, random_state=42, n_jobs=1)

###############################################################################
# Note that :class:`autoreject.LocalAutoRejectCV` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

###############################################################################
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.LocalAutoRejectCV` on a
# representative subsample of the data.

ar = LocalAutoRejectCV(thresh_func=thresh_func, picks=picks, verbose='tqdm')

epochs_ar, reject_log = ar.fit_transform(this_epoch, return_log=True)

###############################################################################
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.

plt.matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))
コード例 #9
0
                    picks=picks,
                    preload=True)

epochs.decimate(decim=4)  # decim from 1024 to 256

# set up function to compute sensor-level threshold
thresh_func = partial(compute_thresholds,
                      picks=picks,
                      method='bayesian_optimization')

#---------------------------------------
# Run autoreject
#---------------------------------------

epochs.ch_names
ar = LocalAutoRejectCV(picks=picks, thresh_func=thresh_func)
epochs_clean = ar.fit_transform(epochs)

from autoreject import get_rejection_threshold
reject = get_rejection_threshold(epochs)
reject

#---------------------------------------
# Check autocorrect
#---------------------------------------

# plot epochs

plot_epochs(epochs,
            bad_epochs_idx=ar.bad_epochs_idx,
            n_channels=64,
コード例 #10
0
ファイル: test_autoreject.py プロジェクト: cdla/autoreject
def test_autoreject():
    """Test basic LocalAutoReject functionality."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    ##########################################################################
    # picking epochs
    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info,
                           meg=False,
                           eeg=False,
                           stim=False,
                           eog=True,
                           include=include,
                           exclude=[])
    epochs = mne.Epochs(raw,
                        events,
                        event_id,
                        tmin,
                        tmax,
                        picks=picks,
                        baseline=(None, 0),
                        decim=10,
                        reject=None,
                        preload=False)[:10]

    ar = LocalAutoReject()
    assert_raises(ValueError, ar.fit, epochs)
    epochs.load_data()

    ar.fit(epochs)
    assert_true(len(ar.picks_) == len(picks) - 1)

    # epochs with no picks.
    epochs = mne.Epochs(raw,
                        events,
                        event_id,
                        tmin,
                        tmax,
                        baseline=(None, 0),
                        decim=10,
                        reject=None,
                        preload=True)[:20]
    # let's drop some channels to speed up
    pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True)
    pre_picks = np.r_[
        mne.pick_types(epochs.info, meg='mag', eeg=False)[::15],
        mne.pick_types(epochs.info, meg='grad', eeg=False)[::60],
        mne.pick_types(epochs.info, meg=False, eeg=True)[::16],
        mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)]
    pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks]
    epochs.pick_channels(pick_ch_names)
    epochs_fit = epochs[:12]  # make sure to use different size of epochs
    epochs_new = epochs[12:]

    X = epochs_fit.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar_global = GlobalAutoReject(n_channels=n_channels,
                                 n_times=n_times,
                                 thresh=40e-6)
    ar_global.fit(X)

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, ar_global, X, None, param_name,
                  param_range)

    ##########################################################################
    # picking AutoReject

    picks = mne.pick_types(epochs.info,
                           meg='mag',
                           eeg=True,
                           stim=False,
                           eog=False,
                           include=[],
                           exclude=[])
    non_picks = mne.pick_types(epochs.info,
                               meg='grad',
                               eeg=False,
                               stim=False,
                               eog=False,
                               include=[],
                               exclude=[])
    ch_types = ['mag', 'eeg']

    ar = LocalAutoReject(picks=picks)  # XXX : why do we need this??
    assert_raises(NotImplementedError, validation_curve, ar, epochs, None,
                  param_name, param_range)

    thresh_func = partial(compute_thresholds,
                          method='bayesian_optimization',
                          random_state=42)
    ar = LocalAutoRejectCV(cv=3,
                           picks=picks,
                           thresh_func=thresh_func,
                           n_interpolate=[1, 2],
                           consensus=[0.5, 1])
    assert_raises(AttributeError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    ar.fit(epochs_fit)
    reject_log = ar.get_reject_log(epochs_fit)
    for ch_type in ch_types:
        # test that kappa & rho are selected
        assert_true(ar.n_interpolate_[ch_type] in ar.n_interpolate)
        assert_true(ar.consensus_[ch_type] in ar.consensus)

        assert_true(ar.n_interpolate_[ch_type] ==
                    ar.local_reject_[ch_type].n_interpolate_[ch_type])
        assert_true(ar.consensus_[ch_type] ==
                    ar.local_reject_[ch_type].consensus_[ch_type])

    # test complementarity of goods and bads
    assert_array_equal(len(reject_log.bad_epochs), len(epochs_fit))

    # test that transform does not change state of ar
    epochs_clean = ar.transform(epochs_fit)  # apply same data
    reject_log2 = ar.get_reject_log(epochs_fit)
    assert_array_equal(reject_log.labels, reject_log2.labels)
    assert_array_equal(reject_log.bad_epochs, reject_log2.bad_epochs)
    assert_array_equal(reject_log.ch_names, reject_log2.ch_names)

    epochs_new_clean = ar.transform(epochs_new)  # apply to new data

    reject_log_new = ar.get_reject_log(epochs_new)
    assert_array_equal(len(reject_log_new.bad_epochs), len(epochs_new))

    assert_true(len(reject_log_new.bad_epochs) != len(reject_log.bad_epochs))

    picks_by_type = _get_picks_by_type(epochs.info, ar.picks)
    # test correct entries in fix log
    assert_true(np.isnan(reject_log_new.labels[:, non_picks]).sum() > 0)
    assert_true(np.isnan(reject_log_new.labels[:, picks]).sum() == 0)
    assert_equal(reject_log_new.labels.shape,
                 (len(epochs_new), len(epochs_new.ch_names)))

    # test correct interpolations by type
    for ch_type, this_picks in picks_by_type:
        interp_counts = np.sum(reject_log_new.labels[:, this_picks] == 2,
                               axis=1)
        labels = reject_log_new.labels.copy()
        not_this_picks = np.setdiff1d(np.arange(labels.shape[1]), this_picks)
        labels[:, not_this_picks] = np.nan
        interp_channels = _get_interp_chs(labels, reject_log.ch_names,
                                          this_picks)
        assert_array_equal(interp_counts, [len(cc) for cc in interp_channels])

    is_same = epochs_new_clean.get_data() == epochs_new.get_data()
    if not np.isscalar(is_same):
        is_same = np.isscalar(is_same)
    assert_true(not is_same)
    assert_equal(epochs_clean.ch_names, epochs_fit.ch_names)

    assert_true(isinstance(ar.threshes_, dict))
    assert_true(len(ar.picks) == len(picks))
    assert_true(len(ar.threshes_.keys()) == len(ar.picks))
    pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)[0]
    assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys())
    assert_raises(
        IndexError, ar.transform,
        epochs.copy().pick_channels([epochs.ch_names[pp] for pp in picks[:3]]))

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp])
                            for ii, pp in enumerate(picks)])
    threshes_a = compute_thresholds(epochs_fit,
                                    picks=picks,
                                    method='random_search')
    assert_equal(set(threshes_a.keys()), set(ch_names))
    threshes_b = compute_thresholds(epochs_fit,
                                    picks=picks,
                                    method='bayesian_optimization')
    assert_equal(set(threshes_b.keys()), set(ch_names))
コード例 #11
0
ファイル: plot_auto_repair.py プロジェクト: cdla/autoreject
                      random_state=42)

###############################################################################
# :class:`autoreject.LocalAutoRejectCV` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

###############################################################################
# Note that:class:`autoreject.LocalAutoRejectCV` by design supports
# multiple channels.
# If no picks are passed separate solutions will be computed for each channel
# type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

ar = LocalAutoRejectCV(n_interpolates,
                       consensus_percs,
                       picks=picks,
                       thresh_func=thresh_func)

# Not that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
evoked_clean = epochs_clean.average()
evoked = epochs['Auditory/Left'].average()

###############################################################################
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']
コード例 #12
0
                    preload=True)

###############################################################################
# First, we set up the function to compute the sensor-level thresholds.

from functools import partial  # noqa
thresh_func = partial(compute_thresholds,
                      method='random_search',
                      random_state=42)

###############################################################################
# :class:`autoreject.LocalAutoRejectCV` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

ar = LocalAutoRejectCV(n_interpolates,
                       consensus_percs,
                       thresh_func=thresh_func)
epochs_clean = ar.fit_transform(epochs)

evoked = epochs.average()
evoked_clean = epochs_clean.average()

###############################################################################
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

###############################################################################
# Let us plot the results.
コード例 #13
0
    # Same `dev_head_t` for all runs so that we can concatenate them.
    epoch.info['dev_head_t'] = epochs[0].info['dev_head_t']

epochs = mne.epochs.concatenate_epochs(epochs)

###############################################################################
# Now, we apply autoreject

from autoreject import LocalAutoRejectCV, compute_thresholds  # noqa
from functools import partial  # noqa

this_epoch = epochs['famous']
thresh_func = partial(compute_thresholds, random_state=42)

ar = LocalAutoRejectCV(thresh_func=thresh_func, verbose='tqdm')
epochs_ar = ar.fit_transform(this_epoch.copy())

###############################################################################
# ... and visualize the bad epochs and sensors. Bad sensors which have been
# interpolated are in blue. Bad sensors which are not interpolated are in red.
# Bad trials are also in red.

from autoreject import plot_epochs  # noqa
plot_epochs(this_epoch, bad_epochs_idx=ar.bad_epochs_idx,
            fix_log=ar.fix_log, scalings=dict(eeg=40e-6),
            title='')

###############################################################################
# ... and the epochs after cleaning with autoreject
コード例 #14
0
thresh_func = partial(compute_thresholds, random_state=42, n_jobs=1)

###############################################################################
# Note that :class:`autoreject.LocalAutoRejectCV` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

###############################################################################
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.LocalAutoRejectCV` on a
# representative subsample of the data.

ar = LocalAutoRejectCV(thresh_func=thresh_func, verbose='tqdm', picks=picks)

ar.fit(this_epoch)
epochs_ar = ar.transform(this_epoch)

###############################################################################
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.