Exemple #1
0
def autoreject_marmouset(subject):

    root_path = '/neurospin/unicog/protocols/ABSeq_marmousets/'
    neural_data_path = root_path + 'neural_data/'

    subject = 'Nr'
    epoch_name = '/epoch_items'
    tmin = -0.099

    # ======== rebuild the epoch object and run autoreject ========
    epoch_data = np.load(neural_data_path + subject + epoch_name + '_data.npy')
    info = np.load(neural_data_path + subject + epoch_name + '_info.npy',
                   allow_pickle=True).item()
    metadata = np.load(neural_data_path + subject + epoch_name +
                       '_metadata.pkl',
                       allow_pickle=True)
    epochs = mne.EpochsArray(epoch_data, info=info, tmin=tmin)
    epochs.metadata = metadata
    epochs.load_data()

    # ======== ======== ======== ======== ======== ======== ========
    ar = AutoReject()
    epochs, reject_log = ar.fit_transform(epochs, return_log=True)
    epochs_clean_fname = neural_data_path + subject + epoch_name + '_clean.fif'
    print("Output: ", epochs_clean_fname)
    epochs.save(epochs_clean_fname, overwrite=True)
    # Save autoreject reject_log
    pickle.dump(reject_log,
                open(epochs_clean_fname[:-4] + '_reject_log.obj', 'wb'))
    np.save(neural_data_path + subject + epoch_name + '_data_clean.npy',
            epochs.get_data())
    epochs.metadata.to_pickle(neural_data_path + subject + epoch_name +
                              '_metadata_clean.pkl')
    np.save(neural_data_path + subject + epoch_name + '_info_clean.npy',
            epochs.info)
def autoreject_repair_epochs(epochs, reject_plot=False):
    """Rejects the bad epochs with AutoReject algorithm

    Parameters
    ----------
    epochs : mne epoch object
        Epoched, filtered eeg data.

    Returns
    ----------
    epochs : mne epoch object
        Epoched data after rejection of bad epochs.

    """
    # Cleaning with autoreject
    picks = mne.pick_types(epochs.info, eeg=True)  # Pick EEG channels
    ar = AutoReject(n_interpolate=[1, 2, 3],
                    n_jobs=6,
                    picks=picks,
                    thresh_func='bayesian_optimization',
                    cv=3,
                    random_state=42,
                    verbose=False)

    cleaned_epochs, reject_log = ar.fit_transform(epochs, return_log=True)

    if reject_plot:
        reject_log.plot_epochs(epochs, scalings=dict(eeg=40e-6))

    return cleaned_epochs
Exemple #3
0
def segment_files(bids_filepath, tmin=0, tmax=0.8):
    raw = read_raw_fif(bids_filepath, preload=True)
    picks = mne.pick_types(raw.info,
                           meg=True,
                           ref_meg=True,
                           eeg=False,
                           eog=False,
                           stim=False)
    ### Set some constants for epoching
    baseline = None  #(None, -0.05)
    #reject = {'mag': 4e-12}
    try:
        events = mne.find_events(raw,
                                 min_duration=1 / raw.info['sfreq'],
                                 verbose=False)
    except ValueError:
        events = mne.find_events(raw,
                                 min_duration=2 / raw.info['sfreq'],
                                 verbose=False)
    event_id = {'Freq': 21, 'Rare': 31}
    epochs = mne.Epochs(raw,
                        events=events,
                        event_id=event_id,
                        tmin=tmin,
                        tmax=tmax,
                        baseline=baseline,
                        reject=None,
                        picks=picks,
                        preload=True)
    ar = AutoReject(n_jobs=6)
    epochs_clean, autoreject_log = ar.fit_transform(epochs, return_log=True)
    return epochs_clean, autoreject_log
Exemple #4
0
def segment_files(bids_filepath):
    raw = read_raw_fif(bids_filepath, preload=True)
    picks = mne.pick_types(raw.info,
                           meg=True,
                           ref_meg=False,
                           eeg=False,
                           eog=False,
                           stim=False)
    ### Set some constants for epoching
    baseline = None  #(None, 0.0)
    reject = {'mag': 4e-12}
    tmin, tmax = 0, 0.8
    events = mne.find_events(raw, min_duration=2 / raw.info['sfreq'])
    event_id = {'Freq': 21, 'Rare': 31, 'Resp': 99}
    epochs = mne.Epochs(raw,
                        events=events,
                        event_id=event_id,
                        tmin=tmin,
                        tmax=tmax,
                        baseline=baseline,
                        reject=None,
                        picks=picks,
                        preload=True)
    ar = AutoReject()
    epochs_clean = ar.fit_transform(epochs)
    return epochs_clean
def test_io():
    """Test IO functionality."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)
    savedir = _TempDir()
    fname = op.join(savedir, 'autoreject.hdf5')

    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=True, include=include, exclude=[])

    # raise error if preload is false
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=4,
                        reject=None, preload=True)[:10]
    ar = AutoReject(cv=2, random_state=42, n_interpolate=[1],
                    consensus=[0.5], verbose=False)
    ar.save(fname)  # save without fitting

    # check that fit after saving is the same as fit
    # without saving
    ar2 = read_auto_reject(fname)
    ar.fit(epochs)
    ar2.fit(epochs)
    assert np.sum([ar.threshes_[k] - ar2.threshes_[k]
                   for k in ar.threshes_.keys()]) == 0.

    pytest.raises(ValueError, ar.save, fname)
    ar.save(fname, overwrite=True)
    ar3 = read_auto_reject(fname)
    epochs_clean1, reject_log1 = ar.transform(epochs, return_log=True)
    epochs_clean2, reject_log2 = ar3.transform(epochs, return_log=True)
    assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data())
    assert_array_equal(reject_log1.labels, reject_log2.labels)
def run_autoreject(subject, epoch_on_first_element):
    N_JOBS_ar = 1  # "The number of thresholds to compute in parallel."

    print(
        '#########################################################################################'
    )
    print(
        '########################## Processing subject: %s ##########################'
        % subject)
    print(
        '#########################################################################################'
    )

    if epoch_on_first_element:
        print("  Loading 'full sequences' epochs")
        epochs = epoching_funcs.load_epochs_full_sequence(subject,
                                                          cleaned=False)
    else:
        print("  Loading 'items' epochs")
        epochs = epoching_funcs.load_epochs_items(subject, cleaned=False)

    # Running AutoReject (https://autoreject.github.io)
    epochs.load_data()
    ar = AutoReject(n_jobs=N_JOBS_ar)
    epochs, reject_log = ar.fit_transform(epochs, return_log=True)

    # Save epochs (after AutoReject)
    print('  Writing cleaned epochs to disk')
    meg_subject_dir = op.join(config.meg_dir, subject)
    if epoch_on_first_element:
        extension = subject + '_1st_element_clean_epo'
    else:
        extension = subject + '_clean_epo'
    epochs_fname = op.join(meg_subject_dir,
                           config.base_fname.format(**locals()))
    print("Output: ", epochs_fname)
    epochs.save(epochs_fname)  # , overwrite=True)

    # Save autoreject reject_log
    pickle.dump(reject_log, open(epochs_fname[:-4] + '_reject_log.obj', 'wb'))
Exemple #7
0
def runautoreject(epochs,
                  fiffile,
                  senstype,
                  bads=[],
                  n_interpolates=np.array([1, 4, 32]),
                  consensus_percs=np.linspace(0, 1, 11)):

    check_random_state(42)

    raw = mne.io.read_raw_fif(fiffile, preload=True)
    raw.info['bads'] = list()
    raw.pick_types(meg=True)
    raw.info['projs'] = list()
    epochs.info = raw.info  #required since no channel infos

    del raw

    picks = mne.pick_types(epochs.info,
                           meg=senstype,
                           eeg=False,
                           stim=False,
                           eog=False,
                           include=[],
                           exclude=bads)

    epochs.verbose = False
    epochs.baseline = (None, 0)
    epochs.preload = True
    epochs.detrend = 0

    ar = AutoReject(n_interpolates,
                    consensus_percs,
                    picks=picks,
                    thresh_method='bayesian_optimization',
                    random_state=42,
                    verbose=False)

    epochs, reject_log = ar.fit_transform(epochs, return_log=True)
    return reject_log
def run_autoreject(epochs, show_figs=False, results_dir=None):
    """Run autoreject.
    """
    from autoreject import AutoReject

    ar = AutoReject()
    epochs = ar.fit_transform(epochs)

    if show_figs or results_dir is not None:
        pass
        # ar_log = ar.get_reject_log(epochs_clean)
        # fig_log = ar_log.plot()
        # ar_log.plot_epochs()
        # Similar to bad_segments, but with entries 0, 1, and 2.
        #     0 : good data segment
        #     1 : bad data segment not interpolated
        #     2 : bad data segment interpolated
    if results_dir is not None:
        pass
        # fig_log.savefig(os.path.join(results_dir, '4a_bad_epochs.png'))

    return epochs
Exemple #9
0
def autore(epo_eeg_cust):
    """
       This function is used for artifact correction/rejection
       ----------
       epo_eeg_cust: MNE.Epochs
            Epochs data

       Returns
       -------
       clean: MNE.Epochs
           Artifact-free epochs data

       """
    ar = AutoReject(n_jobs=4)
    ar.fit(epo_eeg_cust)
    epo_ar, reject_log = ar.transform(epo_eeg_cust, return_log=True)
    clean = epo_ar.copy()
    # Used for plotting
    #scalings = dict(eeg=50)
    # reject_log.plot_epochs(epo_eeg_cust, scalings=scalings)
    # epo_ar.average().plot()
    return clean
Exemple #10
0
def reject_epochs(epochs, autoreject_parameters):
    ar = AutoReject(**autoreject_parameters, verbose="tqdm")
    # for event in epochs.event_id.keys():
    #    epochs[event] = ar.fit_transform(epochs[event])
    epochs = ar.fit_transform(epochs)
    fig, ax = plt.subplots(2)
    # plotipyt histogram of rejection thresholds
    ax[0].set_title("Rejection Thresholds")
    ax[0].hist(1e6 * np.array(list(ar.threshes_.values())),
               30,
               color='g',
               alpha=0.4)
    ax[0].set(xlabel='Threshold (μV)', ylabel='Number of sensors')
    # plot cross validation error:
    loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.
    im = ax[1].matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))
    ax[1].set_xticks(range(len(ar.consensus)))
    ax[1].set_xticklabels(['%.1f' % c for c in ar.consensus])
    ax[1].set_yticks(range(len(ar.n_interpolate)))
    ax[1].set_yticklabels(ar.n_interpolate)
    # Draw rectangle at location of best parameters
    idx, jdx = np.unravel_index(loss.argmin(), loss.shape)
    rect = patches.Rectangle((idx - 0.5, jdx - 0.5),
                             1,
                             1,
                             linewidth=2,
                             edgecolor='r',
                             facecolor='none')
    ax[1].add_patch(rect)
    ax[1].xaxis.set_ticks_position('bottom')
    ax[1].set(xlabel=r'Consensus percentage $\kappa$',
              ylabel=r'Max sensors interpolated $\rho$',
              title='Mean cross validation error (x 1e6)')
    fig.colorbar(im)
    fig.tight_layout()
    fig.savefig(_out_folder / Path("reject_epochs.pdf"), dpi=800)
    plt.close()
    return epochs
Exemple #11
0
def test_fnirs():
    """Test that autoreject runs on fNIRS data."""
    raw = mne.io.read_raw_nirx(
        os.path.join(mne.datasets.fnirs_motor.data_path(), 'Participant-1'))
    raw.crop(tmax=1200)
    raw = mne.preprocessing.nirs.optical_density(raw)
    raw = mne.preprocessing.nirs.beer_lambert_law(raw)
    events, _ = mne.events_from_annotations(raw,
                                            event_id={
                                                '1.0': 1,
                                                '2.0': 2,
                                                '3.0': 3
                                            })
    event_dict = {'Control': 1, 'Tapping/Left': 2, 'Tapping/Right': 3}
    epochs = mne.Epochs(raw,
                        events,
                        event_id=event_dict,
                        tmin=-5,
                        tmax=15,
                        proj=True,
                        baseline=(None, 0),
                        preload=True,
                        detrend=None,
                        verbose=True)
    # Test autoreject
    ar = AutoReject()
    assert len(epochs) == 37
    epochs_clean = ar.fit_transform(epochs)
    assert len(epochs_clean) < len(epochs)
    # Test threshold extraction
    reject = get_rejection_threshold(epochs)
    print(reject)
    assert "hbo" in reject.keys()
    assert "hbr" in reject.keys()
    assert reject["hbo"] < 0.001  # This is a very high value as sanity check
    assert reject["hbr"] < 0.001
    assert reject["hbr"] > 0.0
Exemple #12
0
def run_epochs(subject,
               epoch_on_first_element,
               baseline=True,
               tmin=None,
               tmax=None,
               whattoreturn=None):

    # SEt this param to True if you want to run autoreject locally too when config.autorject = True
    from datetime import datetime
    now = datetime.now().time()

    ARlocal = False

    print("Processing subject: %s" % subject)
    meg_subject_dir = op.join(config.meg_dir, subject)
    run_info_subject_dir = op.join(config.run_info_dir, subject)
    raw_list = list()
    events_list = list()

    if config.noEEG:
        output_dir = op.join(meg_subject_dir, 'noEEG')
        utils.create_folder(output_dir)
    else:
        output_dir = meg_subject_dir

    print("  Loading raw data")
    runs = config.runs_dict[subject]
    for run in runs:
        extension = run + '_ica_raw'
        print(extension)
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        # ---------------------------------------------------------------------------------------------------------------- #
        # RESAMPLING EACH RUN BEFORE CONCAT & EPOCHING
        # Resampling the raw data while keeping events from original raw data, to avoid potential loss of
        # events when downsampling: https://www.nmr.mgh.harvard.edu/mne/dev/auto_examples/preprocessing/plot_resample.html
        # Find events
        events = mne.find_events(raw,
                                 stim_channel=config.stim_channel,
                                 consecutive=True,
                                 min_duration=config.min_event_duration,
                                 shortest_event=config.shortest_event)

        print('  Downsampling raw data')
        raw, events = raw.resample(config.resample_sfreq,
                                   npad='auto',
                                   events=events)

        times_between_events_and_end = (raw.last_samp -
                                        events[:, 0]) / raw.info['sfreq']
        if np.sum(times_between_events_and_end < 0.6) > 0:
            print("=== some events are too close to the end ====")

        if len(events) != 46 * 16:
            raise Exception('We expected %i events but we got %i' %
                            (46 * 16, len(events)))

        raw_list.append(raw)
        # ---------------------------------------------------------------------------------------------------------------- #

    if subject == 'sub08-cc_150418':
        # For this participant, we had some problems when concatenating the raws for run08. The error message said that raw08._cals didn't match the other ones.
        # We saw that it is the 'calibration' for the channel EOG061 that was different with respect to run09._cals.
        raw_list[7]._cals = raw_list[8]._cals
        print(
            'Warning: corrected an issue with subject08 run08 ica_raw data file...'
        )

    print('Concatenating runs')
    raw = mne.concatenate_raws(raw_list)
    # raw.set_annotations(None)
    if "eeg" in config.ch_types:
        raw.set_eeg_reference(projection=True)
    del raw_list

    # Save resampled, concatenated runs (in case we need it)
    # print('Saving concatenated runs')
    # fname = op.join(meg_subject_dir, subject + '_allruns_final_raw.fif')
    # raw.save(fname, overwrite=True)

    if config.noEEG:
        picks = mne.pick_types(raw.info,
                               meg=True,
                               eeg=False,
                               stim=True,
                               eog=True,
                               exclude=())
    else:
        picks = mne.pick_types(raw.info,
                               meg=True,
                               eeg=True,
                               stim=True,
                               eog=True,
                               exclude=())

    # Construct metadata from csv events file
    metadata = convert_csv_info_to_metadata(run_info_subject_dir)
    metadata_pandas = pd.DataFrame.from_dict(metadata, orient='index')
    metadata_pandas = pd.DataFrame.transpose(metadata_pandas)

    # ====== Epoching the data
    print('  Epoching')

    # Events
    events = mne.find_events(raw,
                             stim_channel=config.stim_channel,
                             consecutive=True,
                             min_duration=config.min_event_duration,
                             shortest_event=config.shortest_event)

    if epoch_on_first_element:
        # fosca 06012020
        if tmin is None:
            tmin = -0.200
        if tmax is None:
            tmax = 0.25 * 17
        baseline = (tmin, 0)
        if (baseline is None) or (baseline is False):
            baseline = None
        for k in range(len(events)):
            events[k, 2] = k % 16 + 1
        epochs = mne.Epochs(raw,
                            events, {'sequence_starts': 1},
                            tmin,
                            tmax,
                            proj=True,
                            picks=picks,
                            baseline=baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)
        epochs.metadata = metadata_pandas[metadata_pandas['StimPosition'] ==
                                          1.0]
    else:
        if tmin is None:
            tmin = -0.050
        if tmax is None:
            tmax = 0.600
        if (baseline is None) or (baseline is False):
            baseline = None
        else:
            baseline = (tmin, 0)

        epochs = mne.Epochs(raw,
                            events,
                            None,
                            tmin,
                            tmax,
                            proj=True,
                            picks=picks,
                            baseline=baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)

        # Add metadata to epochs
        epochs.metadata = metadata_pandas

    # Save epochs (before AutoReject)

    if whattoreturn is None:
        print('  Writing epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_epo'
        else:
            extension = subject + '_epo'
        epochs_fname = op.join(output_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochs.save(epochs_fname, overwrite=True)
    elif whattoreturn == '':
        epochs.load_data()
        return epochs
    else:
        print("=== we continue on the autoreject part ===")

    if config.autoreject:
        epochs.load_data()
        # Running AutoReject "global" (https://autoreject.github.io) -> just get the thresholds
        from autoreject import get_rejection_threshold
        reject = get_rejection_threshold(epochs, ch_types=config.ch_types)
        epochsARglob = epochs.copy().drop_bad(reject=reject)
        print('  Writing "AR global" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_ARglob_epo'
        else:
            extension = subject + '_ARglob_epo'
        epochs_fname = op.join(output_dir,
                               config.base_fname.format(**locals()))
        if whattoreturn is None:
            print("Output: ", epochs_fname)
            epochsARglob.save(epochs_fname, overwrite=True)
            pickle.dump(
                reject, open(epochs_fname[:-4] + '_ARglob_thresholds.obj',
                             'wb'))
        elif whattoreturn == 'ARglobal':
            return epochsARglob
        else:
            print("==== continue to ARlocal ====")
        # Save autoreject thresholds

        # Running AutoReject "local" (https://autoreject.github.io)
        if ARlocal:
            ar = AutoReject()
            epochsAR, reject_log = ar.fit_transform(epochs, return_log=True)
            print('  Writing "AR local" cleaned epochs to disk')
            if epoch_on_first_element:
                extension = subject + '_1st_element_clean_epo'
            else:
                extension = subject + '_clean_epo'
            epochs_fname = op.join(output_dir,
                                   config.base_fname.format(**locals()))
            if whattoreturn is None:
                print("Output: ", epochs_fname)
                epochsAR.save(epochs_fname, overwrite=True)
                # Save autoreject reject_log
                pickle.dump(
                    reject_log,
                    open(epochs_fname[:-4] + '_reject_local_log.obj', 'wb'))
            else:
                return epochsAR
Exemple #13
0
fig = mne.viz.plot_events(
    events,
    sfreq=raw.info['sfreq'],
    first_samp=raw.first_samp,
    event_id=event_id,
    on_missing='ignore',
)
fig.subplots_adjust(right=0.7)  # make room for legend

for (e, i) in event_id.items():
    a = (events[:, -1] == i).sum()
    print(f"event {e} is present {a} times")

# %% use autoreject local to clean the data from remaining artifacts
if AUTOREJECT:
    ar = AutoReject()
    epochs.load_data()
    epochs_clean = ar.fit_transform(epochs)
else:
    epochs_clean = epochs

# %%
# Is this related with the bonferroni correction ?
reject = get_rejection_threshold(epochs)
print(reject)

# %%
evoked = epochs_clean['audiovis/1200Hz'].average()
evoked.plot()

# %%
Exemple #14
0
def test_autoreject():
    """Test basic _AutoReject functionality."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    ##########################################################################
    # picking epochs
    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=True, include=include, exclude=[])
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=10,
                        reject=None, preload=False)[:10]

    ar = _AutoReject()
    assert_raises(ValueError, ar.fit, epochs)
    epochs.load_data()

    ar.fit(epochs)
    assert_true(len(ar.picks_) == len(picks) - 1)

    # epochs with no picks.
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        baseline=(None, 0), decim=10,
                        reject=None, preload=True)[:20]
    # let's drop some channels to speed up
    pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True)
    pre_picks = np.r_[
        mne.pick_types(epochs.info, meg='mag', eeg=False)[::15],
        mne.pick_types(epochs.info, meg='grad', eeg=False)[::60],
        mne.pick_types(epochs.info, meg=False, eeg=True)[::16],
        mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)]
    pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks]
    bad_ch_names = [epochs.ch_names[ix] for ix in range(len(epochs.ch_names))
                    if ix not in pre_picks]
    epochs_with_bads = epochs.copy()
    epochs_with_bads.info['bads'] = bad_ch_names
    epochs.pick_channels(pick_ch_names)

    epochs_fit = epochs[:12]  # make sure to use different size of epochs
    epochs_new = epochs[12:]
    epochs_with_bads_fit = epochs_with_bads[:12]

    X = epochs_fit.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = _GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = _GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = _GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar_global = _GlobalAutoReject(
        n_channels=n_channels, n_times=n_times, thresh=40e-6)
    ar_global.fit(X)

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, X, None,
                  param_name, param_range)

    ##########################################################################
    # picking AutoReject

    picks = mne.pick_types(
        epochs.info, meg='mag', eeg=True, stim=False, eog=False,
        include=[], exclude=[])
    non_picks = mne.pick_types(
        epochs.info, meg='grad', eeg=False, stim=False, eog=False,
        include=[], exclude=[])
    ch_types = ['mag', 'eeg']

    ar = _AutoReject(picks=picks)  # XXX : why do we need this??

    ar = AutoReject(cv=3, picks=picks, random_state=42,
                    n_interpolate=[1, 2], consensus=[0.5, 1])
    assert_raises(AttributeError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    ar.fit(epochs_fit)
    reject_log = ar.get_reject_log(epochs_fit)
    for ch_type in ch_types:
        # test that kappa & rho are selected
        assert_true(
            ar.n_interpolate_[ch_type] in ar.n_interpolate)
        assert_true(
            ar.consensus_[ch_type] in ar.consensus)

        assert_true(
            ar.n_interpolate_[ch_type] ==
            ar.local_reject_[ch_type].n_interpolate_[ch_type])
        assert_true(
            ar.consensus_[ch_type] ==
            ar.local_reject_[ch_type].consensus_[ch_type])

    # test complementarity of goods and bads
    assert_array_equal(len(reject_log.bad_epochs), len(epochs_fit))

    # test that transform does not change state of ar
    epochs_clean = ar.transform(epochs_fit)  # apply same data
    assert_true(repr(ar))
    assert_true(repr(ar.local_reject_))
    reject_log2 = ar.get_reject_log(epochs_fit)
    assert_array_equal(reject_log.labels, reject_log2.labels)
    assert_array_equal(reject_log.bad_epochs, reject_log2.bad_epochs)
    assert_array_equal(reject_log.ch_names, reject_log2.ch_names)

    epochs_new_clean = ar.transform(epochs_new)  # apply to new data

    reject_log_new = ar.get_reject_log(epochs_new)
    assert_array_equal(len(reject_log_new.bad_epochs), len(epochs_new))

    assert_true(
        len(reject_log_new.bad_epochs) != len(reject_log.bad_epochs))

    picks_by_type = _get_picks_by_type(epochs.info, ar.picks)
    # test correct entries in fix log
    assert_true(
        np.isnan(reject_log_new.labels[:, non_picks]).sum() > 0)
    assert_true(
        np.isnan(reject_log_new.labels[:, picks]).sum() == 0)
    assert_equal(reject_log_new.labels.shape,
                 (len(epochs_new), len(epochs_new.ch_names)))

    # test correct interpolations by type
    for ch_type, this_picks in picks_by_type:
        interp_counts = np.sum(
            reject_log_new.labels[:, this_picks] == 2, axis=1)
        labels = reject_log_new.labels.copy()
        not_this_picks = np.setdiff1d(np.arange(labels.shape[1]), this_picks)
        labels[:, not_this_picks] = np.nan
        interp_channels = _get_interp_chs(
            labels, reject_log.ch_names, this_picks)
        assert_array_equal(
            interp_counts, [len(cc) for cc in interp_channels])

    is_same = epochs_new_clean.get_data() == epochs_new.get_data()
    if not np.isscalar(is_same):
        is_same = np.isscalar(is_same)
    assert_true(not is_same)

    # test that transform ignores bad channels
    epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[])
    ar_bads = AutoReject(cv=3, random_state=42,
                         n_interpolate=[1, 2], consensus=[0.5, 1])
    ar_bads.fit(epochs_with_bads_fit)
    epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit)

    good_w_bads_ix = mne.pick_types(epochs_with_bads_clean.info,
                                    meg='mag', eeg=True, eog=True,
                                    exclude='bads')
    good_wo_bads_ix = mne.pick_types(epochs_clean.info,
                                     meg='mag', eeg=True, eog=True,
                                     exclude='bads')
    assert_array_equal(epochs_with_bads_clean.get_data()[:, good_w_bads_ix, :],
                       epochs_clean.get_data()[:, good_wo_bads_ix, :])

    bad_ix = [epochs_with_bads_clean.ch_names.index(ch)
              for ch in epochs_with_bads_clean.info['bads']]
    epo_ix = ~ar_bads.get_reject_log(epochs_with_bads_fit).bad_epochs
    assert_array_equal(
        epochs_with_bads_clean.get_data()[:, bad_ix, :],
        epochs_with_bads_fit.get_data()[epo_ix, :, :][:, bad_ix, :])

    assert_equal(epochs_clean.ch_names, epochs_fit.ch_names)

    assert_true(isinstance(ar.threshes_, dict))
    assert_true(len(ar.picks) == len(picks))
    assert_true(len(ar.threshes_.keys()) == len(ar.picks))
    pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)[0]
    assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys())
    assert_raises(
        IndexError, ar.transform,
        epochs.copy().pick_channels(
            [epochs.ch_names[pp] for pp in picks[:3]]))

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp])
                          for ii, pp in enumerate(picks)])
    threshes_a = compute_thresholds(
        epochs_fit, picks=picks, method='random_search')
    assert_equal(set(threshes_a.keys()), set(ch_names))
    threshes_b = compute_thresholds(
        epochs_fit, picks=picks, method='bayesian_optimization')
    assert_equal(set(threshes_b.keys()), set(ch_names))
Exemple #15
0
def run_epochs(subject, autoreject=True):
    raw_fname = op.join(meg_dir, subject, f'{subject}_audvis-filt_raw_sss.fif')
    annot_fname = op.join(meg_dir, subject, f'{subject}_audvis-annot.fif')
    raw = mne.io.read_raw_fif(raw_fname, preload=False)
    annot = mne.read_annotations(annot_fname)
    raw.set_annotations(annot)
    if autoreject:
        epo_fname = op.join(meg_dir, subject,
                            f'{subject}_audvis-filt-sss-ar-epo.fif')
    else:
        epo_fname = op.join(meg_dir, subject,
                            f'{subject}_audvis-filt-sss-epo.fif')
    # ICA
    ica_fname = op.join(meg_dir, subject, f'{subject}_audvis-ica.fif')
    ica = mne.preprocessing.read_ica(ica_fname)

    # ICA
    ica = mne.preprocessing.read_ica(ica_fname)
    try:
        # ECG
        ecg_epochs = mne.preprocessing.create_ecg_epochs(raw,
                                                         l_freq=10,
                                                         h_freq=20,
                                                         baseline=(None, None),
                                                         preload=True)
        ecg_inds, scores_ecg = ica.find_bads_ecg(ecg_epochs,
                                                 method='ctps',
                                                 threshold='auto',
                                                 verbose='INFO')
    except ValueError:
        # not found
        pass
    else:
        print(f'Found {len(ecg_inds)} ({ecg_inds}) ECG indices for {subject}')
        if len(ecg_inds) != 0:
            ica.exclude.extend(ecg_inds[:n_max_ecg])
            # for future inspection
            ecg_epochs.average().save(
                op.join(meg_dir, subject, f'{subject}_audvis-ecg-ave.fif'))
        # release memory
        del ecg_epochs, ecg_inds, scores_ecg

    try:
        # EOG
        eog_epochs = mne.preprocessing.create_eog_epochs(raw,
                                                         baseline=(None, None),
                                                         preload=True)
        eog_inds, scores_eog = ica.find_bads_eog(eog_epochs)
    except ValueError:
        # not found
        pass
    else:
        print(f'Found {len(eog_inds)} ({eog_inds}) EOG indices for {subject}')
        if len(eog_inds) != 0:
            ica.exclude.extend(eog_inds[:n_max_eog])
            # for future inspection
            eog_epochs.average().save(
                op.join(meg_dir, subject, f'{subject}_audvis-eog-ave.fif'))
            del eog_epochs, eog_inds, scores_eog  # release memory

    # applying ICA on Raw
    raw.load_data()
    ica.apply(raw)

    # extract events for epoching
    # modify stim_channel for your need
    events = mne.find_events(raw, stim_channel="STI 014")
    picks = mne.pick_types(raw.info, meg=True)
    epochs = mne.Epochs(
        raw,
        events=events,
        picks=picks,
        event_id=event_id,
        tmin=tmin,
        tmax=tmax,
        baseline=baseline,
        decim=4,  # raw sampling rate is 600 Hz, subsample to 150 Hz
        preload=True,  # for autoreject
        reject_tmax=reject_tmax,
        reject_by_annotation=True)
    del raw, annot

    # autoreject (local)
    if autoreject:
        # local reject
        # keep the bad sensors/channels because autoreject can repair it via
        # interpolation
        picks = mne.pick_types(epochs.info, meg=True, exclude=[])
        ar = AutoReject(picks=picks, n_jobs=n_jobs, verbose=False)
        print(f'Run autoreject (local) for {subject} (it takes a long time)')
        ar.fit(epochs)
        print(f'Drop bad epochs and interpolate bad sensors for {subject}')
        epochs = ar.transform(epochs)

    print(f'Dropped {round(epochs.drop_log_stats(), 2)}% epochs for {subject}')
    epochs.save(epo_fname, overwrite=True)
                       eog=False, exclude=exclude)

# %%
# Note that :class:`autoreject.AutoReject` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

# %%
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.AutoReject` on a
# representative subsample of the data.

ar = AutoReject(picks=picks, random_state=42, n_jobs=1, verbose=True)

epochs_ar, reject_log = ar.fit_transform(this_epoch, return_log=True)

# %%
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.

plt.matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))
Exemple #17
0
def epoch_and_clean_trials(subject,
                           diagdir,
                           bidsdir,
                           datadir,
                           derivdir,
                           epochlength=3,
                           eventid={'visualfix/fixCross': 10}):
    """
    Chunk the data into epochs starting at the eventid specified per trial,
    lasting 7 seconds (which should include all trial elements).
    Do automatic artifact detection, rejection and fixing for eyeblinks,
    heartbeat, and high- and low-amplitude artifacts.
    :param subject: str, subject identifier. takes the form '001'
    :param diagdir: str, path to a directory where diagnostic plots can be saved
    :param bidsdir: str, path to a directory with BIDS data. Needed to load
    event logs from the experiment
    :param datadir: str, path to a directory with SSS-processed data
    :param derivdir: str, path to a directory where cleaned epochs can be saved
    :param epochlength: int, length of epoch
    :param eventid: dict, the event to start an Epoch from
    """
    # construct name of the first split
    raw_fname = Path(datadir) / f'sub-{subject}/meg' / \
                f'sub-{subject}_task-memento_proc-sss_meg.fif'
    logging.info(f"Reading in SSS-processed data from subject sub-{subject}. "
                 f"Attempting the following path: {raw_fname}")
    raw = mne.io.read_raw_fif(raw_fname)
    events, event_dict = get_events(raw)
    # filter the data to remove high-frequency noise. Minimal high-pass filter
    # based on
    # https://www.sciencedirect.com/science/article/pii/S0165027021000157
    # ensure the data is loaded prior to filtering
    raw.load_data()
    if subject == '017':
        logging.info('Setting additional bad channels for subject 17')
        raw.info['bads'] = ['MEG0313', 'MEG0513', 'MEG0523']
        raw.interpolate_bads()
    # high-pass doesn't make sense, raw data has 0.1Hz high-pass filter already!
    _filter_data(raw, h_freq=100)
    # ICA to detect and repair artifacts
    logging.info('Removing eyeblink and heartbeat artifacts')
    rng = np.random.RandomState(28)
    remove_eyeblinks_and_heartbeat(
        raw=raw,
        subject=subject,
        figdir=diagdir,
        events=events,
        eventid=eventid,
        rng=rng,
    )
    # get the actual epochs: chunk the trial into epochs starting from the
    # event ID. Do not baseline correct the data.
    logging.info(f'Creating epochs of length {epochlength}')
    if eventid == {'press/left': 1, 'press/right': 4}:
        # when centered on the response, move back in time
        epochs = mne.Epochs(raw,
                            events,
                            event_id=eventid,
                            tmin=-epochlength,
                            tmax=0,
                            picks='meg',
                            baseline=None)
    else:
        epochs = mne.Epochs(raw,
                            events,
                            event_id=eventid,
                            tmin=0,
                            tmax=epochlength,
                            picks='meg',
                            baseline=None)
    # ADD SUBJECT SPECIFIC TRIAL NUMBER TO THE EPOCH! ONLY THIS WAY WE CAN
    # LATER RECOVER WHICH TRIAL PARAMETERS WE'RE LOOKING AT BASED ON THE LOGS AS
    # THE EPOCH REJECTION WILL REMOVE TRIALS
    logging.info("Retrieving trial metadata.")
    from pymento_meg.proc.epoch import get_trial_features
    metadata = get_trial_features(bids_path=bidsdir,
                                  subject=subject,
                                  column='trial_no')
    # transform to integers
    metadata = metadata.astype(int)
    # this does not work if we start at fixation cross for subject 5, because 1
    # fixation cross trigger is missing from the data, and it becomes impossible
    # to associate the trial metadata to the correct trials in the data
    epochs.metadata = metadata
    epochs.load_data()
    ## downsample the data to 200Hz
    #logging.info('Resampling epoched data down to 200 Hz')
    #epochs.resample(sfreq=200, verbose=True)
    # use autoreject to repair bad epochs
    ar = AutoReject(
        random_state=rng,
        n_interpolate=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
    epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True)
    # save the cleaned, epoched data to disk.
    outpath = _construct_path([
        Path(derivdir),
        f"sub-{subject}",
        "meg",
        f"sub-{subject}_task-memento_cleaned_epo.fif",
    ])
    logging.info(f"Saving cleaned, epoched data to {outpath}")
    epochs_clean.save(outpath, overwrite=True)
    # visualize the bad sensors for each trial
    fig = ar.get_reject_log(epochs).plot()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"epoch-rejectlog_sub-{subject}.png",
    ])
    fig.savefig(fname)
    # plot the average of all cleaned epochs
    fig = epochs_clean.average().plot()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"clean-epoch_average_sub-{subject}.png",
    ])
    fig.savefig(fname)
    # plot psd of cleaned epochs
    psd = epochs_clean.plot_psd()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"psd_cleaned-epochs-{subject}.png",
    ])
    psd.savefig(fname)
                    baseline=(None, 0), reject=None,
                    verbose=False, detrend=0, preload=True)

###############################################################################
# :class:`autoreject.AutoReject` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

###############################################################################
# Note that:class:`autoreject.AutoReject` by design supports
# multiple channels.
# If no picks are passed separate solutions will be computed for each channel
# type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
                thresh_method='random_search', random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
evoked_clean = epochs_clean.average()
evoked = epochs['Auditory/Left'].average()

###############################################################################
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

###############################################################################
Exemple #19
0
def run_autoreject(subject):
    """Interpolate bad epochs/sensors using Autoreject.

    Parameters
    ----------
    *subject: string
        The participant reference

    Save the resulting *-epo.fif file in the '4_autoreject' directory.
    Save .png of ERP difference and heatmap plots.

    References
    ----------
    [1] Mainak Jas, Denis Engemann, Federico Raimondo, Yousra Bekhti, and
    Alexandre Gramfort, “Automated rejection and repair of bad trials in
    MEG/EEG.” In 6th International Workshop on Pattern Recognition in
    Neuroimaging (PRNI), 2016.

    [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and
    Alexandre Gramfort. 2017. “Autoreject: Automated artifact rejection for
    MEG and EEG data”. NeuroImage, 159, 417-429.

    """
    # Import data
    input_path = root + '/4_ICA/' + subject + '-epo.fif'
    epochs = mne.read_epochs(input_path)

    # Autoreject
    ar = AutoReject(random_state=42,
                    n_jobs=4)

    ar.fit_transform(epochs)
    epochs_clean = ar.transform(epochs)

    # Plot difference
    evoked = epochs.average()
    evoked_clean = epochs_clean.average()

    fig, axes = plt.subplots(2, 1, figsize=(6, 6))

    for ax in axes:
        ax.tick_params(axis='x', which='both', bottom='off', top='off')
        ax.tick_params(axis='y', which='both', left='off', right='off')

    evoked.plot(exclude=[], axes=axes[0], ylim=[-30, 30], show=False)
    axes[0].set_title('Before autoreject')
    evoked_clean.plot(exclude=[], axes=axes[1], ylim=[-30, 30])
    axes[1].set_title('After autoreject')
    plt.tight_layout()
    plt.savefig(root + '/5_autoreject/' +
                subject + '-autoreject.png')
    plt.close()

    # Plot heatmap
    ar.get_reject_log(epochs).plot()
    plt.savefig(root + '/5_autoreject/' + subject + '-heatmap.png')
    plt.close()

    # Save epoch data
    out_epoch = root + '/5_autoreject/' + subject + '-epo.fif'
    epochs_clean.save(out_epoch)
Exemple #20
0
def run_preproc(datadir='/data'):

    print('data directory: {}'.format(datadir))
    conf_file_path = join(datadir, 'eegprep.conf')
    config = Configuration()
    config.setDefaults(defaults)
    if os.path.isfile(conf_file_path):
        with open(conf_file_path) as fh:
            conf_string = fh.read()
        config.updateFromString(conf_string)
    print('configuration:')
    print(config)

    bidsdir = join(datadir, 'BIDS')
    eegprepdir = join(bidsdir, 'derivatives', 'eegprep')

    
    subjectdirs = sorted(glob.glob(join(bidsdir, 'sub-*')))
    for subjectdir in subjectdirs:
        assert os.path.isdir(subjectdir)
        
        sub = basename(subjectdir)[4:]

        # prepare derivatives directory
        derivdir = join(eegprepdir, 'sub-' + sub)
        os.makedirs(derivdir, exist_ok=True)
        reportsdir = join(eegprepdir, 'reports', 'sub-' + sub)
        os.makedirs(reportsdir, exist_ok=True)


        subject_epochs = {}
        rawtypes = {'.set': mne.io.read_raw_eeglab, '.bdf': mne.io.read_raw_edf}
        for fname in sorted(glob.glob(join(subjectdir, 'eeg', '*'))):
            _, ext = splitext(fname)
            if ext not in rawtypes.keys():
                continue
            sub, ses, task, run = filename2tuple(basename(fname))

            print('\nProcessing raw file: ' + basename(fname))

            # read data
            raw = rawtypes[ext](fname, preload=True, verbose=False)
            events = mne.find_events(raw)  #raw, consecutive=False, min_duration=0.005)

            # Set channel types and select reference channels
            channelFile = fname.replace('eeg' + ext, 'channels.tsv')
            channels = pandas.read_csv(channelFile, index_col='name', sep='\t')
            bids2mne = {
                'MISC': 'misc',
                'EEG': 'eeg',
                'VEOG': 'eog',
                'TRIG': 'stim',
                'REF': 'eeg',
            }
            channels['mne'] = channels.type.replace(bids2mne)
            
            # the below fails if the specified channels are not in the data
            raw.set_channel_types(channels.mne.to_dict())

            # set bad channels
            raw.info['bads'] = channels[channels.status=='bad'].index.tolist()

            # pick channels to use for epoching
            epoching_picks = mne.pick_types(raw.info, eeg=True, eog=False, stim=False, exclude='bads')


            # Filtering
            #raw.filter(l_freq=0.05, h_freq=40, fir_design='firwin')

            montage = mne.channels.read_montage(guess_montage(raw.ch_names))
            print(montage)
            raw.set_montage(montage)

            # plot raw data
            nchans = len(raw.ch_names)
            pick_channels = numpy.arange(0, nchans, numpy.floor(nchans/20)).astype(int)
            start = numpy.round(raw.times.max()/2)
            fig = raw.plot(start=start, order=pick_channels)
            fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_raw.png'.format(sub, ses, task, run)
            fig.savefig(join(reportsdir, fname_plot))

            # Set reference
            refChannels = channels[channels.type=='REF'].index.tolist()
            raw.set_eeg_reference(ref_channels=refChannels)

            ##  epoching
            epochs_params = dict(
                events=events,
                tmin=-0.1,
                tmax=0.8,
                reject=None,  # dict(eeg=250e-6, eog=150e-6)
                picks=epoching_picks,
                detrend=0,
            )
            file_epochs = mne.Epochs(raw, preload=True, **epochs_params)
            file_epochs.drop_channels(refChannels)

            # autoreject (under development)
            ar = AutoReject(n_jobs=4)
            clean_epochs = ar.fit_transform(file_epochs)

            rejectlog = ar.get_reject_log(clean_epochs)
            fname_log = 'sub-{}_ses-{}_task-{}_run-{}_reject-log.npz'.format(sub, ses, task, run)
            save_rejectlog(join(reportsdir, fname_log), rejectlog)
            fig = plot_rejectlog(rejectlog)
            fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_bad-epochs.png'.format(sub, ses, task, run)
            fig.savefig(join(reportsdir, fname_plot))


            # store for now
            subject_epochs[(ses, task, run)] = clean_epochs

            # create evoked plots
            conds = clean_epochs.event_id.keys()
            selected_conds = random.sample(conds, min(len(conds), 6))
            picks = mne.pick_types(clean_epochs.info, eeg=True)
            for cond in selected_conds:
                evoked = clean_epochs[cond].average()
                fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_evoked-{}.png'.format(sub, ses, task, run, cond)
                fig = evoked.plot_joint(picks=picks)
                fig.savefig(join(reportsdir, fname_plot))



        sessSeg = 0
        sessions = sorted(list(set([k[sessSeg] for k in subject_epochs.keys()])))
        for session in sessions:
            taskSeg = 1
            tasks = list(set([k[taskSeg] for k in subject_epochs.keys() if k[sessSeg]==session]))
            for task in tasks:
                print('\nGathering epochs for session {} task {}'.format(session, task))
                epochs_selection = [v for (k, v) in subject_epochs.items() if k[:2]==(session, task)]

                task_epochs = mne.epochs.concatenate_epochs(epochs_selection)
                
                # downsample if configured to do so
                # important to do this after concatenation because 
                # downsampling may cause rejection for 'TOOSHORT'
                if config['downsample'] < task_epochs.info['sfreq']:
                    task_epochs = task_epochs.copy().resample(config['downsample'], npad='auto')

                ext = config['out_file_format']
                fname = join(derivdir, 'sub-{}_ses-{}_task-{}_epo.{}'.format(sub, session, task, ext))
                variables = {
                    'epochs': task_epochs.get_data(),
                    'events': task_epochs.events,
                    'timepoints': task_epochs.times
                }
                if ext == 'fif':
                    task_epochs.save(fname)
                elif ext == 'mat':
                    scipy.io.savemat(fname, mdict=variables)
                elif ext == 'npy':
                    numpy.savez(fname, **variables)
Exemple #21
0
                    baseline=(None, 0), reject=None,
                    verbose=False, detrend=0, preload=True)

# %%
# :class:`autoreject.AutoReject` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`

# %%
# Note that :class:`autoreject.AutoReject` by design supports
# multiple channels.
# If no picks are passed, separate solutions will be computed for each channel
# type and internally combined. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
                thresh_method='random_search', random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
evoked_clean = epochs_clean.average()
evoked = epochs['Auditory/Left'].average()

# %%
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

# %%
raw.resample(100, npad='auto')
raw.set_eeg_reference('average', projection=True)


# Create 30s chunks of data
events = mne.event.make_fixed_length_events(
	raw, id=9999, start=0, stop=None, duration=30.0,
	first_samp=True)
epochs = mne.epochs.Epochs(raw, events, tmin=0, tmax=30.0,
                           baseline=None, preload=True)

# Run autoreject
thresh_func = partial(compute_thresholds, random_state=42, n_jobs=1)
thresh_functhresh_  = partial(compute_thresholds,
	                          random_state=42,
	                          n_jobs=1)
ar = AutoReject(thresh_func=thresh_func, verbose='tqdm')

index = np.random.choice(np.arange(len(epochs)),
                         size=int(np.floor(len(epochs) * 0.1)),
                         replace=False)
ar.fit(epochs[index])
epochs_clean = ar.transform(epochs)

print("{:.2f}% epochs rejected (N={})".format(
      epochs_clean.drop_log_stats(), len(epochs_clean)))


# Save cleaned epochs
epochs_clean.save('../data/derived/cleaned_sleep_scorer_epo.fif')
epochs_clean[:11].save('../data/derived/cleaned_subset_sleep_scorer_epo.fif')
    # print(intersect_order, "ix intersect")
    # print(epoch_order, "ix base")
    # print(epochs_2_drop, "drop")
    epochs.load_data()
    epochs = epochs.drop(epochs_2_drop, reason="bad behaviour")
    epochs.save(op.join(sub_path, "clean-" + epo.split(sep)[-1]),
                overwrite=True)
    print("AMOUNT OF EPOCHS AFTER MATCHING WITH BEH:", len(epochs))
    print("DOES IT MATCH?", len(beh_ixs) == len(epochs))
    print("\n")

    if len(beh_ixs) == len(epochs):
        ar = AutoReject(consensus=np.linspace(0, 1.0, 27),
                        n_interpolate=np.array([1, 4, 32]),
                        thresh_method="bayesian_optimization",
                        cv=10,
                        n_jobs=-1,
                        random_state=42,
                        verbose="progressbar")
        ar.fit(epochs)

        epo_type = epo.split(sep)[-1].split("-")[3]
        name = "{}-{}-{}".format(subject_id, numero, epo_type)
        ar_fname = op.join(qc_folder, "{}-autoreject.h5".format(name))
        ar.save(ar_fname, overwrite=True)
        epochs_ar, rej_log = ar.transform(epochs, return_log=True)
        rej_log.plot(show=False)
        plt.savefig(op.join(qc_folder, "{}-autoreject-log.png".format(name)))
        plt.close("all")
        epo.split(sep)[-1]
        cleaned = op.join(sub_path, "autoreject-" + epo.split(sep)[-1])
picks = mne.pick_types(raw.info,
                       meg=False,
                       eeg=True,
                       stim=False,
                       eog=False,
                       include=[],
                       exclude=[])

# Make epochs from the raw data
epochs = mne.Epochs(raw,
                    picks=picks,
                    events=events,
                    event_id=event_id,
                    tmin=tmin,
                    tmax=tmax,
                    preload=True,
                    reject=None)

# Setup AutoReject
ar = AutoReject(n_interpolates,
                consensus_percs,
                thresh_method='random_search',
                random_state=seed)

# Fit, i.e. calculate AutoReject
ar.fit(epochs)

epochs_clean = ar.transform(epochs)  # Clean the epochs
epochs_clean.save(data_path + '%s-epo.fif' % subject)  # Save the epochs
Exemple #25
0
#%% Fit autoreject

events = mne.make_fixed_length_events(raw, duration=tstep)
epochs = mne.Epochs(raw,
                    events,
                    tmin=0.0,
                    tmax=tstep,
                    baseline=(0, 0),
                    reject=None,
                    verbose=False,
                    detrend=0,
                    preload=True)

ar = AutoReject(n_interpolates,
                consensus_percs,
                picks=picks,
                thresh_method='random_search',
                random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs)

# epochs_ar, reject_log = ar.fit_transform(epochs, return_log=True)

epochs_clean = ar.transform(epochs)
reject_log = ar.get_reject_log(epochs)
evoked_clean = epochs_clean.average()
evoked = epochs.average()

# visualize rejected epochs
def main():

    # Initialize fg
    # TODO: add any settings we want to ue
    fg = FOOOFGroup(peak_width_limits=[1, 6],
                    min_peak_amplitude=0.075,
                    max_n_peaks=6,
                    peak_threshold=1,
                    verbose=False)

    # Save out a settings file
    fg.save(file_name=GROUP + '_fooof_group_settings',
            file_path=SAVE_PATH,
            save_settings=True)

    # START LOOP
    for sub in SUBJ_DAT_NUM:

        print('Current Subject' + str(sub))

        # load subject data
        subj_dat_fname = str(sub) + "_resampled.set"
        full_path = os.path.join(BASE_PATH, subj_dat_fname)
        path_check = Path(full_path)

        if path_check.is_file():

            eeg_dat = mne.io.read_raw_eeglab(full_path,
                                             event_id_func=None,
                                             preload=True)
            evs = mne.io.eeglab.read_events_eeglab(full_path, EV_DICT)

            new_evs = np.empty(shape=(0, 3))

            for ev_label in BLOCK_EVS:
                ev_code = EV_DICT[ev_label]
                temp = evs[evs[:, 2] == ev_code]
                new_evs = np.vstack([new_evs, temp])

            eeg_dat.add_events(new_evs)

            # set EEG average reference
            eeg_dat.set_eeg_reference()

            ## PRE-PROCESSING: ICA
            if RUN_ICA:

                # ICA settings
                method = 'fastica'
                n_components = 0.99
                random_state = 47
                reject = {'eeg': 20e-4}

                # Initialize ICA object
                ica = ICA(n_components=n_components,
                          method=method,
                          random_state=random_state)

                # High-pass filter data for running ICA
                eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

                # Fit ICA
                ica.fit(eeg_dat, reject=reject)

                # Find components to drop, based on correlation with EOG channels
                drop_inds = []
                for chi in EOG_CHS:
                    inds, scores = ica.find_bads_eog(eeg_dat,
                                                     ch_name=chi,
                                                     threshold=2.5,
                                                     l_freq=1,
                                                     h_freq=10,
                                                     verbose=False)
                    drop_inds.extend(inds)
                drop_inds = list(set(drop_inds))

                # Set which components to drop, and collect record of this
                ica.exclude = drop_inds
                #dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

                # Save out ICA solution
                ica.save(pjoin(ICA_PATH, str(sub) + '-ica.fif'))

                # Apply ICA to data
                eeg_dat = ica.apply(eeg_dat)

            ## EPOCH BLOCKS
            events = mne.find_events(eeg_dat)

            #epochs = mne.Epochs(eeg_dat, events=events, tmin=5, tmax=125, baseline=None, preload=True)
            rest_epochs = mne.Epochs(eeg_dat,
                                     events=events,
                                     event_id=REST_EVENT_ID,
                                     tmin=5,
                                     tmax=125,
                                     baseline=None,
                                     preload=True)
            trial_epochs = mne.Epochs(eeg_dat,
                                      events=events,
                                      event_id=TRIAL_EVENT_ID,
                                      tmin=5,
                                      tmax=125,
                                      baseline=None,
                                      preload=True)

            ## PRE-PROCESSING: AUTO-REJECT
            if RUN_AUTOREJECT:

                # Initialize and run autoreject across epochs
                ar = AutoReject(n_jobs=4, verbose=False)
                epochs, rej_log = ar.fit_transform(epochs, True)

                # Drop same trials from filtered data
                rest_epochs.drop(rej_log.bad_epochs)
                trial_epochs.drop(rej_log.bad_epochs)

                # Collect list of dropped trials
                dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(
                    rej_log.bad_epochs)[0]

            # Set montage
            chs = mne.channels.read_montage('standard_1020',
                                            rest_epochs.ch_names[:-1])
            rest_epochs.set_montage(chs)
            trial_epochs.set_montage(chs)

            # Calculate PSDs
            rest_psds, rest_freqs = mne.time_frequency.psd_welch(rest_epochs,
                                                                 fmin=1.,
                                                                 fmax=50.,
                                                                 n_fft=2000,
                                                                 n_overlap=250,
                                                                 n_per_seg=500)
            trial_psds, trial_freqs = mne.time_frequency.psd_welch(
                trial_epochs,
                fmin=1.,
                fmax=50.,
                n_fft=2000,
                n_overlap=250,
                n_per_seg=500)

            # Setting frequency range
            freq_range = [3, 30]

            ## FOOOF the Data

            # Rest Data
            for ind, entry in enumerate(rest_psds):
                rest_fooof_psds = rest_psds[ind, :, :]
                fg.fit(rest_freqs, rest_fooof_psds, freq_range)
                fg.save(file_name=str(sub) + 'fooof_group_results' + str(ind),
                        file_path=REST_SAVE_PATH,
                        save_results=True)

            # Trial Data
            for ind, entry in enumerate(trial_psds):
                trial_fooof_psds = trial_psds[ind, :, :]
                fg.fit(trial_freqs, trial_fooof_psds, freq_range)
                fg.save(file_name=str(sub) + 'fooof_group_results' + str(ind),
                        file_path=TRIAL_SAVE_PATH,
                        save_results=True)

            print('Subject Saved')

        else:

            print('Current Subject' + str(sub) + ' does not exist')
            print(path_check)

    print('Pre-processing Complete')
    def run(self):

        eog = self.info['channel_info']['EOG']
        misc = self.info['channel_info']['Misc']
        stim = self.info['channel_info']['Stim']

        try:
            ext_files = glob.glob(self.info['ext_file_folder'] + '/' +
                                  self.participant + '/*axis0.dat')
        except:
            pass

        tmin = self.t_epoch[0]
        tmax = self.t_epoch[1]

        raw = read_raw_edf(self.file, eog=eog, misc=misc)
        self.raw = cp.deepcopy(raw)
        raw.load_data()

        # marker detection (one marker continous trial)
        if self.info['marker_detection'] == True:
            starts = find_trialstart(raw,
                                     stim_channel=raw.ch_names[stim[0]],
                                     new_samplin_rate=self.sr_new)
            try:
                starts[1] = starts[0] + 30 * 200
            except:
                starts = np.r_[starts, (starts[0] + 30 * 200)]
            events = np.zeros((len(starts), 3))
            events[:, 0] = starts
            events[:, 2] = list(self.info['event_dict'].values())
            events = events.astype(np.int)

        # event detection (one marker regular events)
        if self.info['event_detection'] == True:
            starts = find_trialstart(raw,
                                     stim_channel=raw.ch_names[stim[0]],
                                     new_samplin_rate=self.sr_new)

            events = force_events(ext_files, self.info['event_dict'],
                                  self.sr_new, self.info['trial_length'],
                                  self.info['trials'],
                                  starts[:len(self.info['event_dict'])])

        if self.info['ICA'] == True:
            ica = ICA(method='fastica')

        if self.info['Autoreject'] == True:
            ar = AutoReject()

        ## EEG preprocessing options will applied if parameters are set in object

        #read montage
        try:
            montage = make_standard_montage(self.montage)
            raw.set_montage(montage)
        except:
            pass

        #resampling
        try:
            raw.resample(sfreq=self.sr_new)
        except:
            pass

        #rereferencing
        try:
            raw, _ = mne.set_eeg_reference(raw, ref_channels=['EXG5', 'EXG6'])
        except:
            pass

        #filter
        try:
            low = self.filter_freqs[0]
            high = self.filter_freqs[1]
            raw.filter(low, high, fir_design='firwin')
        except:
            pass

        # occular correction
        try:
            ica.fit(raw)
            ica.exclude = []
            eog_indices, eog_scores = ica.find_bads_eog(raw)
            ica.exclude = eog_indices
            ica.apply(raw)
            self.ica = ica
        except:
            pass

        picks = mne.pick_types(raw.info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude='bads')

        event_id = self.info['event_dict']
        epochs = mne.Epochs(raw,
                            events,
                            event_id,
                            tmin,
                            tmax,
                            proj=True,
                            baseline=None,
                            preload=True,
                            picks=picks)

        #epoch rejection
        try:
            epochs = epochs.drop(indices=self.bads)
        except:
            pass

        try:
            epochs, self.autoreject_log = ar.fit_transform(epochs,
                                                           return_log=True)
        except:
            pass

        bads = np.asarray(
            [l == ['USER'] or l == ['AUTOREJECT'] for l in epochs.drop_log])
        self.bads = np.where(bads == True)
        self.epochs = epochs
        return (self)
Exemple #28
0
def run_epochs(subject,
               epoch_on_first_element,
               baseline=True,
               l_freq=None,
               h_freq=None,
               suffix='_eeg_1Hz'):

    print("Processing subject: %s" % subject)
    meg_subject_dir = op.join(config.meg_dir, subject)
    run_info_subject_dir = op.join(config.run_info_dir, subject)
    raw_list = list()
    events_list = list()

    print("  Loading raw data")
    runs = config.runs_dict[subject]
    for run in runs:
        extension = run + '_ica_raw'
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        # ---------------------------------------------------------------------------------------------------------------- #
        # RESAMPLING EACH RUN BEFORE CONCAT & EPOCHING
        # Resampling the raw data while keeping events from original raw data, to avoid potential loss of
        # events when downsampling: https://www.nmr.mgh.harvard.edu/mne/dev/auto_examples/preprocessing/plot_resample.html
        # Find events
        events = mne.find_events(raw,
                                 stim_channel=config.stim_channel,
                                 consecutive=True,
                                 min_duration=config.min_event_duration,
                                 shortest_event=config.shortest_event)

        print('  Downsampling raw data')
        raw, events = raw.resample(config.resample_sfreq,
                                   npad='auto',
                                   events=events)
        if len(events) != 46 * 16:
            raise Exception('We expected %i events but we got %i' %
                            (46 * 16, len(events)))
        raw.filter(l_freq=1, h_freq=None)
        raw_list.append(raw)
        # ---------------------------------------------------------------------------------------------------------------- #

    if subject == 'sub08-cc_150418':
        # For this participant, we had some problems when concatenating the raws for run08. The error message said that raw08._cals didn't match the other ones.
        # We saw that it is the 'calibration' for the channel EOG061 that was different with respect to run09._cals.
        raw_list[7]._cals = raw_list[8]._cals
        print(
            'Warning: corrected an issue with subject08 run08 ica_raw data file...'
        )

    print('Concatenating runs')
    raw = mne.concatenate_raws(raw_list)
    if "eeg" in config.ch_types:
        raw.set_eeg_reference(projection=True)
    del raw_list

    meg = False
    if 'meg' in config.ch_types:
        meg = True
    elif 'grad' in config.ch_types:
        meg = 'grad'
    elif 'mag' in config.ch_types:
        meg = 'mag'
    eeg = 'eeg' in config.ch_types
    picks = mne.pick_types(raw.info,
                           meg=meg,
                           eeg=eeg,
                           stim=True,
                           eog=True,
                           exclude=())

    # Construct metadata from csv events file
    metadata = epoching_funcs.convert_csv_info_to_metadata(
        run_info_subject_dir)
    metadata_pandas = pd.DataFrame.from_dict(metadata, orient='index')
    metadata_pandas = pd.DataFrame.transpose(metadata_pandas)

    # ====== Epoching the data
    print('  Epoching')

    # Events
    events = mne.find_events(raw,
                             stim_channel=config.stim_channel,
                             consecutive=True,
                             min_duration=config.min_event_duration,
                             shortest_event=config.shortest_event)

    if epoch_on_first_element:
        # fosca 06012020
        config.tmin = -0.200
        config.tmax = 0.25 * 17
        config.baseline = (config.tmin, 0)
        if baseline is None:
            config.baseline = None
        for k in range(len(events)):
            events[k, 2] = k % 16 + 1
        epochs = mne.Epochs(raw,
                            events, {'sequence_starts': 1},
                            config.tmin,
                            config.tmax,
                            proj=True,
                            picks=picks,
                            baseline=config.baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)
        epochs.metadata = metadata_pandas[metadata_pandas['StimPosition'] ==
                                          1.0]
    else:
        config.tmin = -0.050
        config.tmax = 0.600
        config.baseline = (config.tmin, 0)
        if baseline is None:
            config.baseline = None
        epochs = mne.Epochs(raw,
                            events,
                            None,
                            config.tmin,
                            config.tmax,
                            proj=True,
                            picks=picks,
                            baseline=config.baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)

        # Add metadata to epochs
        epochs.metadata = metadata_pandas

    # Save epochs (before AutoReject)
    print('  Writing epochs to disk')
    if epoch_on_first_element:
        extension = subject + '_1st_element_epo' + suffix
    else:
        extension = subject + '_epo' + suffix
    epochs_fname = op.join(meg_subject_dir,
                           config.base_fname.format(**locals()))

    print("Output: ", epochs_fname)
    epochs.save(epochs_fname, overwrite=True)
    # epochs.save(epochs_fname)

    if config.autoreject:
        epochs.load_data()

        # Running AutoReject "global" (https://autoreject.github.io) -> just get the thresholds
        from autoreject import get_rejection_threshold
        reject = get_rejection_threshold(epochs,
                                         ch_types=['mag', 'grad', 'eeg'])
        epochsARglob = epochs.copy().drop_bad(reject=reject)
        print('  Writing "AR global" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_ARglob_epo' + suffix
        else:
            extension = subject + '_ARglob_epo' + suffix
        epochs_fname = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochsARglob.save(epochs_fname, overwrite=True)
        # Save autoreject thresholds
        pickle.dump(reject,
                    open(epochs_fname[:-4] + '_ARglob_thresholds.obj', 'wb'))

        # Running AutoReject "local" (https://autoreject.github.io)
        ar = AutoReject()
        epochsAR, reject_log = ar.fit_transform(epochs, return_log=True)
        print('  Writing "AR local" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_clean_epo' + suffix
        else:
            extension = subject + '_clean_epo' + suffix
        epochs_fname = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochsAR.save(epochs_fname, overwrite=True)
        # Save autoreject reject_log
        pickle.dump(reject_log,
                    open(epochs_fname[:-4] + '_reject_local_log.obj', 'wb'))
Exemple #29
0
def AR_local(cleaned_epochs_ICA, verbose=False):
    """
    Applies local Autoreject to correct or reject bad epochs.

    Arguments:
        clean_epochs_ICA: list of Epochs after global Autoreject and ICA
        verbose: to plot data before and after AR, boolean set to False
          by default.

    Returns:
        cleaned_epochs_AR: list of Epochs after local Autoreject.
    """
    bad_epochs_AR = []

    # defaults values for n_interpolates and consensus_percs
    n_interpolates = np.array([1, 4, 32])
    consensus_percs = np.linspace(0, 1.0, 11)

    for clean_epochs in cleaned_epochs_ICA:  # per subj

        picks = mne.pick_types(clean_epochs[0].info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude=[])

        if verbose:
            ar_verbose = 'progressbar'
        else:
            ar_verbose = False

        ar = AutoReject(n_interpolates,
                        consensus_percs,
                        picks=picks,
                        thresh_method='random_search',
                        random_state=42,
                        verbose=ar_verbose)

        # fitting AR to get bad epochs
        ar.fit(clean_epochs)
        reject_log = ar.get_reject_log(clean_epochs, picks=picks)
        bad_epochs_AR.append(reject_log)

    # taking bad epochs for min 1 subj (dyad)
    log1 = bad_epochs_AR[0]
    log2 = bad_epochs_AR[1]

    bad1 = np.where(log1.bad_epochs == True)
    bad2 = np.where(log2.bad_epochs == True)

    bad = list(set(bad1[0].tolist()).intersection(bad2[0].tolist()))
    if verbose:
        print('%s percent of bad epochs' %
              int(len(bad) / len(list(log1.bad_epochs)) * 100))

    # picking good epochs for the two subj
    cleaned_epochs_AR = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        clean_epochs_ep = clean_epochs.drop(indices=bad)
        # interpolating bads or removing epochs
        clean_epochs_AR = ar.transform(clean_epochs_ep)
        cleaned_epochs_AR.append(clean_epochs_AR)
    # equalizing epochs length between two subjects
    mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

    # Vizualisation before after AR
    evoked_before = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        evoked_before.append(clean_epochs.average())

    evoked_after_AR = []
    for clean in cleaned_epochs_AR:
        evoked_after_AR.append(clean.average())

    if verbose:
        for i, j in zip(evoked_before, evoked_after_AR):
            fig, axes = plt.subplots(2, 1, figsize=(6, 6))
            for ax in axes:
                ax.tick_params(axis='x', which='both', bottom='off', top='off')
                ax.tick_params(axis='y', which='both', left='off', right='off')

            ylim = dict(grad=(-170, 200))
            i.pick_types(eeg=True, exclude=[])
            i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
            axes[0].set_title('Before autoreject')
            j.pick_types(eeg=True, exclude=[])
            j.plot(exclude=[], axes=axes[1], ylim=ylim)
            # Problème titre ne s'affiche pas pour le deuxieme axe !!!
            axes[1].set_title('After autoreject')
            plt.tight_layout()

    return cleaned_epochs_AR
Exemple #30
0
import pickle

# subject = config.subjects_list[11]
subject = 'sub08-cc_150418'
meg_subject_dir = op.join(config.meg_dir, subject)
epochs = epoching_funcs.load_epochs_items(subject, cleaned=False)

# run autoreject "global" -> just get the thresholds
reject = get_rejection_threshold(epochs, ch_types=['mag', 'grad', 'eeg'])
epochs1 = epochs.copy().drop_bad(reject=reject)
fname = op.join(meg_subject_dir, 'epochs_globalAR-epo.fif')
print("Saving: ", fname)
epochs1.save(fname, overwrite=True)

# run autoreject "local"
ar = AutoReject()
epochs2, reject_log = ar.fit_transform(epochs, return_log=True)
fname = op.join(meg_subject_dir, 'epochs_localAR-epo.fif')
print("Saving: ", fname)
epochs2.save(fname, overwrite=True)
# Save autoreject reject_log
pickle.dump(reject_log, open(fname[:-4] + '_reject_log.obj', 'wb'))

######################
fname = op.join(meg_subject_dir, 'epochs_globalAR-epo.fif')
epochs1 = mne.read_epochs(fname, preload=True)
epochs1
epochs1['ViolationOrNot == 1'].copy().average().plot_joint()

fname = op.join(meg_subject_dir, 'epochs_localAR-epo.fif')
epochs2 = mne.read_epochs(fname, preload=True)
Exemple #31
0
def main():

    #################################################
    ## SETUP

    ## Get list of subject files
    subj_files = listdir(DAT_PATH)
    subj_files = [file for file in subj_files if EXT.lower() in file.lower()]

    ## Set up FOOOF Objects
    # Initialize FOOOF settings & objects objects
    fooof_settings = FOOOFSettings(peak_width_limits=PEAK_WIDTH_LIMITS, max_n_peaks=MAX_N_PEAKS,
                                   min_peak_amplitude=MIN_PEAK_AMP, peak_threshold=PEAK_THRESHOLD,
                                   aperiodic_mode=APERIODIC_MODE)
    fm = FOOOF(*fooof_settings, verbose=False)
    fg = FOOOFGroup(*fooof_settings, verbose=False)

    # Save out a settings file
    fg.save('0-FOOOF_Settings', pjoin(RES_PATH, 'FOOOF'), save_settings=True)

    # Set up the dictionary to store all the FOOOF results
    fg_dict = dict()
    for load_label in LOAD_LABELS:
        fg_dict[load_label] = dict()
        for side_label in SIDE_LABELS:
            fg_dict[load_label][side_label] = dict()
            for seg_label in SEG_LABELS:
                fg_dict[load_label][side_label][seg_label] = []

    ## Initialize group level data stores
    n_subjs, n_conds, n_times = len(subj_files), 3, N_TIMES
    group_fooofed_alpha_freqs = np.zeros(shape=[n_subjs])
    dropped_components = np.ones(shape=[n_subjs, 50]) * 999
    dropped_trials = np.ones(shape=[n_subjs, 1500]) * 999
    canonical_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])
    fooofed_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])

    # Set channel types
    ch_types = {'LHor' : 'eog', 'RHor' : 'eog', 'IVer' : 'eog', 'SVer' : 'eog',
                'LMas' : 'misc', 'RMas' : 'misc', 'Nose' : 'misc', 'EXG8' : 'misc'}

    #################################################
    ## RUN ACROSS ALL SUBJECTS

    # Run analysis across each subject
    for s_ind, subj_file in enumerate(subj_files):

        # Get subject label and print status
        subj_label = subj_file.split('.')[0]
        print('\nCURRENTLY RUNNING SUBJECT: ', subj_label, '\n')

        #################################################
        ## LOAD / ORGANIZE / SET-UP DATA

        # Load subject of data, apply apply fixes for channels, etc
        eeg_dat = mne.io.read_raw_edf(pjoin(DAT_PATH, subj_file),
                                      preload=True, verbose=False)

        # Fix channel name labels
        eeg_dat.info['ch_names'] = [chl[2:] for chl in \
            eeg_dat.ch_names[:-1]] + [eeg_dat.ch_names[-1]]
        for ind, chi in enumerate(eeg_dat.info['chs']):
            eeg_dat.info['chs'][ind]['ch_name'] = eeg_dat.info['ch_names'][ind]

        # Update channel types
        eeg_dat.set_channel_types(ch_types)

        # Set reference - average reference
        eeg_dat = eeg_dat.set_eeg_reference(ref_channels='average',
                                            projection=False, verbose=False)

        # Set channel montage
        chs = mne.channels.read_montage('standard_1020', eeg_dat.ch_names)
        eeg_dat.set_montage(chs)

        # Get event information & check all used event codes
        evs = mne.find_events(eeg_dat, shortest_event=1, verbose=False)

        # Pull out sampling rate
        srate = eeg_dat.info['sfreq']

        #################################################
        ## Pre-Processing: ICA

        # High-pass filter data for running ICA
        eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

        if RUN_ICA:

            print("\nICA: CALCULATING SOLUTION\n")

            # ICA settings
            method = 'fastica'
            n_components = 0.99
            random_state = 47
            reject = {'eeg': 20e-4}

            # Initialize ICA object
            ica = ICA(n_components=n_components, method=method,
                      random_state=random_state)

            # Fit ICA
            ica.fit(eeg_dat, reject=reject)

            # Save out ICA solution
            ica.save(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Otherwise: load previously saved ICA to apply
        else:
            print("\nICA: USING PRECOMPUTED\n")
            ica = read_ica(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Find components to drop, based on correlation with EOG channels
        drop_inds = []
        for chi in EOG_CHS:
            inds, _ = ica.find_bads_eog(eeg_dat, ch_name=chi, threshold=2.5,
                                             l_freq=1, h_freq=10, verbose=False)
            drop_inds.extend(inds)
        drop_inds = list(set(drop_inds))

        # Set which components to drop, and collect record of this
        ica.exclude = drop_inds
        dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

        # Apply ICA to data
        eeg_dat = ica.apply(eeg_dat)

        #################################################
        ## SORT OUT EVENT CODES

        # Extract a list of all the event labels
        all_trials = [it for it2 in EV_DICT.values() for it in it2]

        # Create list of new event codes to be used to label correct trials (300s)
        all_trials_new = [it + 100 for it in all_trials]
        # This is an annoying way to collapse across the doubled event markers from above
        all_trials_new = [it - 1 if not ind%2 == 0 else it for ind, it in enumerate(all_trials_new)]
        # Get labelled dictionary of new event names
        ev_dict2 = {k:v for k, v in zip(EV_DICT.keys(), set(all_trials_new))}

        # Initialize variables to store new event definitions
        evs2 = np.empty(shape=[0, 3], dtype='int64')
        lags = np.array([])

        # Loop through, creating new events for all correct trials
        t_min, t_max = -0.4, 3.0
        for ref_id, targ_id, new_id in zip(all_trials, CORR_CODES * 6, all_trials_new):

            t_evs, t_lags = mne.event.define_target_events(evs, ref_id, targ_id, srate,
                                                           t_min, t_max, new_id)

            if len(t_evs) > 0:
                evs2 = np.vstack([evs2, t_evs])
                lags = np.concatenate([lags, t_lags])

        #################################################
        ## FOOOF

        # Set channel of interest
        ch_ind = eeg_dat.ch_names.index(CHL)

        # Calculate PSDs over ~ first 2 minutes of data, for specified channel
        fmin, fmax = 1, 50
        tmin, tmax = 5, 125
        psds, freqs = mne.time_frequency.psd_welch(eeg_dat, fmin=fmin, fmax=fmax,
                                                   tmin=tmin, tmax=tmax,
                                                   n_fft=int(2*srate), n_overlap=int(srate),
                                                   n_per_seg=int(2*srate),
                                                   verbose=False)

        # Fit FOOOF across all channels
        fg.fit(freqs, psds, FREQ_RANGE, n_jobs=-1)

        # Save out FOOOF results
        fg.save(subj_label + '_fooof', pjoin(RES_PATH, 'FOOOF'), save_results=True)

        # Extract individualized CF from specified channel, add to group collection
        fm = fg.get_fooof(ch_ind, False)
        fooof_freq, _, _ = get_band_peak(fm.peak_params_, [7, 14])
        group_fooofed_alpha_freqs[s_ind] = fooof_freq

        # If not FOOOF alpha extracted, reset to 10
        if np.isnan(fooof_freq):
            fooof_freq = 10

        #################################################
        ## ALPHA FILTERING

        # CANONICAL: Filter data to canonical alpha band: 8-12 Hz
        alpha_dat = eeg_dat.copy()
        alpha_dat.filter(8, 12, fir_design='firwin', verbose=False)
        alpha_dat.apply_hilbert(envelope=True, verbose=False)

        # FOOOF: Filter data to FOOOF derived alpha band
        fooof_dat = eeg_dat.copy()
        fooof_dat.filter(fooof_freq-2, fooof_freq+2, fir_design='firwin')
        fooof_dat.apply_hilbert(envelope=True)

        #################################################
        ## EPOCH TRIALS

        # Set epoch timings
        tmin, tmax = -0.85, 1.1

        # Epoch trials - raw data for trial rejection
        epochs = mne.Epochs(eeg_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                            baseline=None, preload=True, verbose=False)

        # Epoch trials - filtered version
        epochs_alpha = mne.Epochs(alpha_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)
        epochs_fooof = mne.Epochs(fooof_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)

        #################################################
        ## PRE-PROCESSING: AUTO-REJECT
        if RUN_AUTOREJECT:

            print('\nAUTOREJECT: CALCULATING SOLUTION\n')

            # Initialize and run autoreject across epochs
            ar = AutoReject(n_jobs=4, verbose=False)
            ar.fit(epochs)

            # Save out AR solution
            ar.save(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'), overwrite=True)

        # Otherwise: load & apply previously saved AR solution
        else:
            print('\nAUTOREJECT: USING PRECOMPUTED\n')
            ar = read_auto_reject(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'))
            ar.verbose = 'tqdm'

        # Apply autoreject to the original epochs object it was learnt on
        epochs, rej_log = ar.transform(epochs, return_log=True)

        # Apply autoreject to the copies of the data - apply interpolation, then drop same epochs
        _apply_interp(rej_log, epochs_alpha, ar.threshes_, ar.picks_, ar.verbose)
        epochs_alpha.drop(rej_log.bad_epochs)
        _apply_interp(rej_log, epochs_fooof, ar.threshes_, ar.picks_, ar.verbose)
        epochs_fooof.drop(rej_log.bad_epochs)

        # Collect which epochs were dropped
        dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(rej_log.bad_epochs)[0]

        #################################################
        ## SET UP CHANNEL CLUSTERS

        # Set channel clusters - take channels contralateral to stimulus presentation
        #  Note: channels will be used to extract data contralateral to stimulus presentation
        le_chs = ['P3', 'P5', 'P7', 'P9', 'O1', 'PO3', 'PO7']       # Left Side Channels
        le_inds = [epochs.ch_names.index(chn) for chn in le_chs]
        ri_chs = ['P4', 'P6', 'P8', 'P10', 'O2', 'PO4', 'PO8']      # Right Side Channels
        ri_inds = [epochs.ch_names.index(chn) for chn in ri_chs]

        #################################################
        ## TRIAL-RELATED ANALYSIS: CANONICAL vs. FOOOF

        ## Pull out channels of interest for each load level
        #  Channels extracted are those contralateral to stimulus presentation

        # Canonical Data
        lo1_a = np.concatenate([epochs_alpha['LeLo1']._data[:, ri_inds, :],
                                epochs_alpha['RiLo1']._data[:, le_inds, :]], 0)
        lo2_a = np.concatenate([epochs_alpha['LeLo2']._data[:, ri_inds, :],
                                epochs_alpha['RiLo2']._data[:, le_inds, :]], 0)
        lo3_a = np.concatenate([epochs_alpha['LeLo3']._data[:, ri_inds, :],
                                epochs_alpha['RiLo3']._data[:, le_inds, :]], 0)

        # FOOOFed data
        lo1_f = np.concatenate([epochs_fooof['LeLo1']._data[:, ri_inds, :],
                                epochs_fooof['RiLo1']._data[:, le_inds, :]], 0)
        lo2_f = np.concatenate([epochs_fooof['LeLo2']._data[:, ri_inds, :],
                                epochs_fooof['RiLo2']._data[:, le_inds, :]], 0)
        lo3_f = np.concatenate([epochs_fooof['LeLo3']._data[:, ri_inds, :],
                                epochs_fooof['RiLo3']._data[:, le_inds, :]], 0)

        ## Calculate average across trials and channels - add to group data collection

        # Canonical data
        canonical_group_avg_dat[s_ind, 0, :] = np.mean(lo1_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 1, :] = np.mean(lo2_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 2, :] = np.mean(lo3_a, 1).mean(0)

        # FOOOFed data
        fooofed_group_avg_dat[s_ind, 0, :] = np.mean(lo1_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 1, :] = np.mean(lo2_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 2, :] = np.mean(lo3_f, 1).mean(0)

        #################################################
        ## FOOOFING TRIAL AVERAGED DATA

        # Loop loop loads & trials segments
        for seg_label, seg_time in zip(SEG_LABELS, SEG_TIMES):
            tmin, tmax = seg_time[0], seg_time[1]

            # Calculate PSDs across trials, fit FOOOF models to averages
            for le_label, ri_label, load_label in zip(['LeLo1', 'LeLo2', 'LeLo3'],
                                                      ['RiLo1', 'RiLo2', 'RiLo3'],
                                                      LOAD_LABELS):

                ## Calculate trial wise PSDs for left & right side trials
                trial_freqs, le_trial_psds = periodogram(
                    epochs[le_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)
                trial_freqs, ri_trial_psds = periodogram(
                    epochs[ri_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)

                ## FIT ALL CHANNELS VERSION
                if FIT_ALL_CHANNELS:

                    ## Average spectra across trials within a given load & side
                    le_avg_psd_contra = avg_func(le_trial_psds[:, ri_inds, :], 0)
                    le_avg_psd_ipsi = avg_func(le_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_contra = avg_func(ri_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_ipsi = avg_func(ri_trial_psds[:, ri_inds, :], 0)

                    ## Combine spectra across left & right trials for given load
                    ch_psd_contra = np.vstack([le_avg_psd_contra, ri_avg_psd_contra])
                    ch_psd_ipsi = np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi])

                    ## Fit FOOOFGroup to all channels, average & and collect results
                    fg.fit(trial_freqs, ch_psd_contra, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fg.fit(trial_freqs, ch_psd_ipsi, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

                ## COLLAPSE ACROSS CHANNELS VERSION
                else:

                    ## Average spectra across trials and channels within a given load & side
                    le_avg_psd_contra = avg_func(avg_func(le_trial_psds[:, ri_inds, :], 0), 0)
                    le_avg_psd_ipsi = avg_func(avg_func(le_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_contra = avg_func(avg_func(ri_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_ipsi = avg_func(avg_func(ri_trial_psds[:, ri_inds, :], 0), 0)

                    ## Collapse spectra across left & right trials for given load
                    avg_psd_contra = avg_func(np.vstack([le_avg_psd_contra, ri_avg_psd_contra]), 0)
                    avg_psd_ipsi = avg_func(np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]), 0)

                    ## Fit FOOOF, and collect results
                    fm.fit(trial_freqs, avg_psd_contra, FREQ_RANGE)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fm.fit(trial_freqs, avg_psd_ipsi, FREQ_RANGE)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

    #################################################
    ## SAVE OUT RESULTS

    # Save out group data
    np.save(pjoin(RES_PATH, 'Group', 'alpha_freqs_group'), group_fooofed_alpha_freqs)
    np.save(pjoin(RES_PATH, 'Group', 'canonical_group'), canonical_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'fooofed_group'), fooofed_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_trials'), dropped_trials)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_components'), dropped_components)

    # Save out second round of FOOOFing
    for load_label in LOAD_LABELS:
        for side_label in SIDE_LABELS:
            for seg_label in SEG_LABELS:
                fg = combine_fooofs(fg_dict[load_label][side_label][seg_label])
                fg.save('Group_' + load_label + '_' + side_label + '_' + seg_label,
                        pjoin(RES_PATH, 'FOOOF'), save_results=True)
                       eog=False, exclude=exclude)

###############################################################################
# Note that :class:`autoreject.AutoReject` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

###############################################################################
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.AutoReject` on a
# representative subsample of the data.

ar = AutoReject(picks=picks, random_state=42, n_jobs=1, verbose='tqdm')

epochs_ar, reject_log = ar.fit_transform(this_epoch, return_log=True)

###############################################################################
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.

plt.matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))