Ejemplo n.º 1
0
def test_io():
    """Test IO functionality."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)
    savedir = _TempDir()
    fname = op.join(savedir, 'autoreject.hdf5')

    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=True, include=include, exclude=[])

    # raise error if preload is false
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=4,
                        reject=None, preload=True)[:10]
    ar = AutoReject(cv=2, random_state=42, n_interpolate=[1],
                    consensus=[0.5], verbose=False)
    ar.save(fname)  # save without fitting

    # check that fit after saving is the same as fit
    # without saving
    ar2 = read_auto_reject(fname)
    ar.fit(epochs)
    ar2.fit(epochs)
    assert np.sum([ar.threshes_[k] - ar2.threshes_[k]
                   for k in ar.threshes_.keys()]) == 0.

    pytest.raises(ValueError, ar.save, fname)
    ar.save(fname, overwrite=True)
    ar3 = read_auto_reject(fname)
    epochs_clean1, reject_log1 = ar.transform(epochs, return_log=True)
    epochs_clean2, reject_log2 = ar3.transform(epochs, return_log=True)
    assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data())
    assert_array_equal(reject_log1.labels, reject_log2.labels)
Ejemplo n.º 2
0
def autore(epo_eeg_cust):
    """
       This function is used for artifact correction/rejection
       ----------
       epo_eeg_cust: MNE.Epochs
            Epochs data

       Returns
       -------
       clean: MNE.Epochs
           Artifact-free epochs data

       """
    ar = AutoReject(n_jobs=4)
    ar.fit(epo_eeg_cust)
    epo_ar, reject_log = ar.transform(epo_eeg_cust, return_log=True)
    clean = epo_ar.copy()
    # Used for plotting
    #scalings = dict(eeg=50)
    # reject_log.plot_epochs(epo_eeg_cust, scalings=scalings)
    # epo_ar.average().plot()
    return clean
Ejemplo n.º 3
0
def main():

    #################################################
    ## SETUP

    ## Get list of subject files
    subj_files = listdir(DAT_PATH)
    subj_files = [file for file in subj_files if EXT.lower() in file.lower()]

    ## Set up FOOOF Objects
    # Initialize FOOOF settings & objects objects
    fooof_settings = FOOOFSettings(peak_width_limits=PEAK_WIDTH_LIMITS, max_n_peaks=MAX_N_PEAKS,
                                   min_peak_amplitude=MIN_PEAK_AMP, peak_threshold=PEAK_THRESHOLD,
                                   aperiodic_mode=APERIODIC_MODE)
    fm = FOOOF(*fooof_settings, verbose=False)
    fg = FOOOFGroup(*fooof_settings, verbose=False)

    # Save out a settings file
    fg.save('0-FOOOF_Settings', pjoin(RES_PATH, 'FOOOF'), save_settings=True)

    # Set up the dictionary to store all the FOOOF results
    fg_dict = dict()
    for load_label in LOAD_LABELS:
        fg_dict[load_label] = dict()
        for side_label in SIDE_LABELS:
            fg_dict[load_label][side_label] = dict()
            for seg_label in SEG_LABELS:
                fg_dict[load_label][side_label][seg_label] = []

    ## Initialize group level data stores
    n_subjs, n_conds, n_times = len(subj_files), 3, N_TIMES
    group_fooofed_alpha_freqs = np.zeros(shape=[n_subjs])
    dropped_components = np.ones(shape=[n_subjs, 50]) * 999
    dropped_trials = np.ones(shape=[n_subjs, 1500]) * 999
    canonical_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])
    fooofed_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])

    # Set channel types
    ch_types = {'LHor' : 'eog', 'RHor' : 'eog', 'IVer' : 'eog', 'SVer' : 'eog',
                'LMas' : 'misc', 'RMas' : 'misc', 'Nose' : 'misc', 'EXG8' : 'misc'}

    #################################################
    ## RUN ACROSS ALL SUBJECTS

    # Run analysis across each subject
    for s_ind, subj_file in enumerate(subj_files):

        # Get subject label and print status
        subj_label = subj_file.split('.')[0]
        print('\nCURRENTLY RUNNING SUBJECT: ', subj_label, '\n')

        #################################################
        ## LOAD / ORGANIZE / SET-UP DATA

        # Load subject of data, apply apply fixes for channels, etc
        eeg_dat = mne.io.read_raw_edf(pjoin(DAT_PATH, subj_file),
                                      preload=True, verbose=False)

        # Fix channel name labels
        eeg_dat.info['ch_names'] = [chl[2:] for chl in \
            eeg_dat.ch_names[:-1]] + [eeg_dat.ch_names[-1]]
        for ind, chi in enumerate(eeg_dat.info['chs']):
            eeg_dat.info['chs'][ind]['ch_name'] = eeg_dat.info['ch_names'][ind]

        # Update channel types
        eeg_dat.set_channel_types(ch_types)

        # Set reference - average reference
        eeg_dat = eeg_dat.set_eeg_reference(ref_channels='average',
                                            projection=False, verbose=False)

        # Set channel montage
        chs = mne.channels.read_montage('standard_1020', eeg_dat.ch_names)
        eeg_dat.set_montage(chs)

        # Get event information & check all used event codes
        evs = mne.find_events(eeg_dat, shortest_event=1, verbose=False)

        # Pull out sampling rate
        srate = eeg_dat.info['sfreq']

        #################################################
        ## Pre-Processing: ICA

        # High-pass filter data for running ICA
        eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

        if RUN_ICA:

            print("\nICA: CALCULATING SOLUTION\n")

            # ICA settings
            method = 'fastica'
            n_components = 0.99
            random_state = 47
            reject = {'eeg': 20e-4}

            # Initialize ICA object
            ica = ICA(n_components=n_components, method=method,
                      random_state=random_state)

            # Fit ICA
            ica.fit(eeg_dat, reject=reject)

            # Save out ICA solution
            ica.save(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Otherwise: load previously saved ICA to apply
        else:
            print("\nICA: USING PRECOMPUTED\n")
            ica = read_ica(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Find components to drop, based on correlation with EOG channels
        drop_inds = []
        for chi in EOG_CHS:
            inds, _ = ica.find_bads_eog(eeg_dat, ch_name=chi, threshold=2.5,
                                             l_freq=1, h_freq=10, verbose=False)
            drop_inds.extend(inds)
        drop_inds = list(set(drop_inds))

        # Set which components to drop, and collect record of this
        ica.exclude = drop_inds
        dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

        # Apply ICA to data
        eeg_dat = ica.apply(eeg_dat)

        #################################################
        ## SORT OUT EVENT CODES

        # Extract a list of all the event labels
        all_trials = [it for it2 in EV_DICT.values() for it in it2]

        # Create list of new event codes to be used to label correct trials (300s)
        all_trials_new = [it + 100 for it in all_trials]
        # This is an annoying way to collapse across the doubled event markers from above
        all_trials_new = [it - 1 if not ind%2 == 0 else it for ind, it in enumerate(all_trials_new)]
        # Get labelled dictionary of new event names
        ev_dict2 = {k:v for k, v in zip(EV_DICT.keys(), set(all_trials_new))}

        # Initialize variables to store new event definitions
        evs2 = np.empty(shape=[0, 3], dtype='int64')
        lags = np.array([])

        # Loop through, creating new events for all correct trials
        t_min, t_max = -0.4, 3.0
        for ref_id, targ_id, new_id in zip(all_trials, CORR_CODES * 6, all_trials_new):

            t_evs, t_lags = mne.event.define_target_events(evs, ref_id, targ_id, srate,
                                                           t_min, t_max, new_id)

            if len(t_evs) > 0:
                evs2 = np.vstack([evs2, t_evs])
                lags = np.concatenate([lags, t_lags])

        #################################################
        ## FOOOF

        # Set channel of interest
        ch_ind = eeg_dat.ch_names.index(CHL)

        # Calculate PSDs over ~ first 2 minutes of data, for specified channel
        fmin, fmax = 1, 50
        tmin, tmax = 5, 125
        psds, freqs = mne.time_frequency.psd_welch(eeg_dat, fmin=fmin, fmax=fmax,
                                                   tmin=tmin, tmax=tmax,
                                                   n_fft=int(2*srate), n_overlap=int(srate),
                                                   n_per_seg=int(2*srate),
                                                   verbose=False)

        # Fit FOOOF across all channels
        fg.fit(freqs, psds, FREQ_RANGE, n_jobs=-1)

        # Save out FOOOF results
        fg.save(subj_label + '_fooof', pjoin(RES_PATH, 'FOOOF'), save_results=True)

        # Extract individualized CF from specified channel, add to group collection
        fm = fg.get_fooof(ch_ind, False)
        fooof_freq, _, _ = get_band_peak(fm.peak_params_, [7, 14])
        group_fooofed_alpha_freqs[s_ind] = fooof_freq

        # If not FOOOF alpha extracted, reset to 10
        if np.isnan(fooof_freq):
            fooof_freq = 10

        #################################################
        ## ALPHA FILTERING

        # CANONICAL: Filter data to canonical alpha band: 8-12 Hz
        alpha_dat = eeg_dat.copy()
        alpha_dat.filter(8, 12, fir_design='firwin', verbose=False)
        alpha_dat.apply_hilbert(envelope=True, verbose=False)

        # FOOOF: Filter data to FOOOF derived alpha band
        fooof_dat = eeg_dat.copy()
        fooof_dat.filter(fooof_freq-2, fooof_freq+2, fir_design='firwin')
        fooof_dat.apply_hilbert(envelope=True)

        #################################################
        ## EPOCH TRIALS

        # Set epoch timings
        tmin, tmax = -0.85, 1.1

        # Epoch trials - raw data for trial rejection
        epochs = mne.Epochs(eeg_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                            baseline=None, preload=True, verbose=False)

        # Epoch trials - filtered version
        epochs_alpha = mne.Epochs(alpha_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)
        epochs_fooof = mne.Epochs(fooof_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)

        #################################################
        ## PRE-PROCESSING: AUTO-REJECT
        if RUN_AUTOREJECT:

            print('\nAUTOREJECT: CALCULATING SOLUTION\n')

            # Initialize and run autoreject across epochs
            ar = AutoReject(n_jobs=4, verbose=False)
            ar.fit(epochs)

            # Save out AR solution
            ar.save(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'), overwrite=True)

        # Otherwise: load & apply previously saved AR solution
        else:
            print('\nAUTOREJECT: USING PRECOMPUTED\n')
            ar = read_auto_reject(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'))
            ar.verbose = 'tqdm'

        # Apply autoreject to the original epochs object it was learnt on
        epochs, rej_log = ar.transform(epochs, return_log=True)

        # Apply autoreject to the copies of the data - apply interpolation, then drop same epochs
        _apply_interp(rej_log, epochs_alpha, ar.threshes_, ar.picks_, ar.verbose)
        epochs_alpha.drop(rej_log.bad_epochs)
        _apply_interp(rej_log, epochs_fooof, ar.threshes_, ar.picks_, ar.verbose)
        epochs_fooof.drop(rej_log.bad_epochs)

        # Collect which epochs were dropped
        dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(rej_log.bad_epochs)[0]

        #################################################
        ## SET UP CHANNEL CLUSTERS

        # Set channel clusters - take channels contralateral to stimulus presentation
        #  Note: channels will be used to extract data contralateral to stimulus presentation
        le_chs = ['P3', 'P5', 'P7', 'P9', 'O1', 'PO3', 'PO7']       # Left Side Channels
        le_inds = [epochs.ch_names.index(chn) for chn in le_chs]
        ri_chs = ['P4', 'P6', 'P8', 'P10', 'O2', 'PO4', 'PO8']      # Right Side Channels
        ri_inds = [epochs.ch_names.index(chn) for chn in ri_chs]

        #################################################
        ## TRIAL-RELATED ANALYSIS: CANONICAL vs. FOOOF

        ## Pull out channels of interest for each load level
        #  Channels extracted are those contralateral to stimulus presentation

        # Canonical Data
        lo1_a = np.concatenate([epochs_alpha['LeLo1']._data[:, ri_inds, :],
                                epochs_alpha['RiLo1']._data[:, le_inds, :]], 0)
        lo2_a = np.concatenate([epochs_alpha['LeLo2']._data[:, ri_inds, :],
                                epochs_alpha['RiLo2']._data[:, le_inds, :]], 0)
        lo3_a = np.concatenate([epochs_alpha['LeLo3']._data[:, ri_inds, :],
                                epochs_alpha['RiLo3']._data[:, le_inds, :]], 0)

        # FOOOFed data
        lo1_f = np.concatenate([epochs_fooof['LeLo1']._data[:, ri_inds, :],
                                epochs_fooof['RiLo1']._data[:, le_inds, :]], 0)
        lo2_f = np.concatenate([epochs_fooof['LeLo2']._data[:, ri_inds, :],
                                epochs_fooof['RiLo2']._data[:, le_inds, :]], 0)
        lo3_f = np.concatenate([epochs_fooof['LeLo3']._data[:, ri_inds, :],
                                epochs_fooof['RiLo3']._data[:, le_inds, :]], 0)

        ## Calculate average across trials and channels - add to group data collection

        # Canonical data
        canonical_group_avg_dat[s_ind, 0, :] = np.mean(lo1_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 1, :] = np.mean(lo2_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 2, :] = np.mean(lo3_a, 1).mean(0)

        # FOOOFed data
        fooofed_group_avg_dat[s_ind, 0, :] = np.mean(lo1_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 1, :] = np.mean(lo2_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 2, :] = np.mean(lo3_f, 1).mean(0)

        #################################################
        ## FOOOFING TRIAL AVERAGED DATA

        # Loop loop loads & trials segments
        for seg_label, seg_time in zip(SEG_LABELS, SEG_TIMES):
            tmin, tmax = seg_time[0], seg_time[1]

            # Calculate PSDs across trials, fit FOOOF models to averages
            for le_label, ri_label, load_label in zip(['LeLo1', 'LeLo2', 'LeLo3'],
                                                      ['RiLo1', 'RiLo2', 'RiLo3'],
                                                      LOAD_LABELS):

                ## Calculate trial wise PSDs for left & right side trials
                trial_freqs, le_trial_psds = periodogram(
                    epochs[le_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)
                trial_freqs, ri_trial_psds = periodogram(
                    epochs[ri_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)

                ## FIT ALL CHANNELS VERSION
                if FIT_ALL_CHANNELS:

                    ## Average spectra across trials within a given load & side
                    le_avg_psd_contra = avg_func(le_trial_psds[:, ri_inds, :], 0)
                    le_avg_psd_ipsi = avg_func(le_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_contra = avg_func(ri_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_ipsi = avg_func(ri_trial_psds[:, ri_inds, :], 0)

                    ## Combine spectra across left & right trials for given load
                    ch_psd_contra = np.vstack([le_avg_psd_contra, ri_avg_psd_contra])
                    ch_psd_ipsi = np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi])

                    ## Fit FOOOFGroup to all channels, average & and collect results
                    fg.fit(trial_freqs, ch_psd_contra, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fg.fit(trial_freqs, ch_psd_ipsi, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

                ## COLLAPSE ACROSS CHANNELS VERSION
                else:

                    ## Average spectra across trials and channels within a given load & side
                    le_avg_psd_contra = avg_func(avg_func(le_trial_psds[:, ri_inds, :], 0), 0)
                    le_avg_psd_ipsi = avg_func(avg_func(le_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_contra = avg_func(avg_func(ri_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_ipsi = avg_func(avg_func(ri_trial_psds[:, ri_inds, :], 0), 0)

                    ## Collapse spectra across left & right trials for given load
                    avg_psd_contra = avg_func(np.vstack([le_avg_psd_contra, ri_avg_psd_contra]), 0)
                    avg_psd_ipsi = avg_func(np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]), 0)

                    ## Fit FOOOF, and collect results
                    fm.fit(trial_freqs, avg_psd_contra, FREQ_RANGE)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fm.fit(trial_freqs, avg_psd_ipsi, FREQ_RANGE)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

    #################################################
    ## SAVE OUT RESULTS

    # Save out group data
    np.save(pjoin(RES_PATH, 'Group', 'alpha_freqs_group'), group_fooofed_alpha_freqs)
    np.save(pjoin(RES_PATH, 'Group', 'canonical_group'), canonical_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'fooofed_group'), fooofed_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_trials'), dropped_trials)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_components'), dropped_components)

    # Save out second round of FOOOFing
    for load_label in LOAD_LABELS:
        for side_label in SIDE_LABELS:
            for seg_label in SEG_LABELS:
                fg = combine_fooofs(fg_dict[load_label][side_label][seg_label])
                fg.save('Group_' + load_label + '_' + side_label + '_' + seg_label,
                        pjoin(RES_PATH, 'FOOOF'), save_results=True)
Ejemplo n.º 4
0
def run_autoreject(subject):
    """Interpolate bad epochs/sensors using Autoreject.

    Parameters
    ----------
    *subject: string
        The participant reference

    Save the resulting *-epo.fif file in the '4_autoreject' directory.
    Save .png of ERP difference and heatmap plots.

    References
    ----------
    [1] Mainak Jas, Denis Engemann, Federico Raimondo, Yousra Bekhti, and
    Alexandre Gramfort, “Automated rejection and repair of bad trials in
    MEG/EEG.” In 6th International Workshop on Pattern Recognition in
    Neuroimaging (PRNI), 2016.

    [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and
    Alexandre Gramfort. 2017. “Autoreject: Automated artifact rejection for
    MEG and EEG data”. NeuroImage, 159, 417-429.

    """
    # Import data
    input_path = root + '/4_ICA/' + subject + '-epo.fif'
    epochs = mne.read_epochs(input_path)

    # Autoreject
    ar = AutoReject(random_state=42,
                    n_jobs=4)

    ar.fit_transform(epochs)
    epochs_clean = ar.transform(epochs)

    # Plot difference
    evoked = epochs.average()
    evoked_clean = epochs_clean.average()

    fig, axes = plt.subplots(2, 1, figsize=(6, 6))

    for ax in axes:
        ax.tick_params(axis='x', which='both', bottom='off', top='off')
        ax.tick_params(axis='y', which='both', left='off', right='off')

    evoked.plot(exclude=[], axes=axes[0], ylim=[-30, 30], show=False)
    axes[0].set_title('Before autoreject')
    evoked_clean.plot(exclude=[], axes=axes[1], ylim=[-30, 30])
    axes[1].set_title('After autoreject')
    plt.tight_layout()
    plt.savefig(root + '/5_autoreject/' +
                subject + '-autoreject.png')
    plt.close()

    # Plot heatmap
    ar.get_reject_log(epochs).plot()
    plt.savefig(root + '/5_autoreject/' + subject + '-heatmap.png')
    plt.close()

    # Save epoch data
    out_epoch = root + '/5_autoreject/' + subject + '-epo.fif'
    epochs_clean.save(out_epoch)
Ejemplo n.º 5
0
def AR_local(cleaned_epochs_ICA, verbose=False):
    """
    Applies local Autoreject to correct or reject bad epochs.

    Arguments:
        clean_epochs_ICA: list of Epochs after global Autoreject and ICA
        verbose: to plot data before and after AR, boolean set to False
          by default.

    Returns:
        cleaned_epochs_AR: list of Epochs after local Autoreject.
    """
    bad_epochs_AR = []

    # defaults values for n_interpolates and consensus_percs
    n_interpolates = np.array([1, 4, 32])
    consensus_percs = np.linspace(0, 1.0, 11)

    for clean_epochs in cleaned_epochs_ICA:  # per subj

        picks = mne.pick_types(clean_epochs[0].info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude=[])

        if verbose:
            ar_verbose = 'progressbar'
        else:
            ar_verbose = False

        ar = AutoReject(n_interpolates,
                        consensus_percs,
                        picks=picks,
                        thresh_method='random_search',
                        random_state=42,
                        verbose=ar_verbose)

        # fitting AR to get bad epochs
        ar.fit(clean_epochs)
        reject_log = ar.get_reject_log(clean_epochs, picks=picks)
        bad_epochs_AR.append(reject_log)

    # taking bad epochs for min 1 subj (dyad)
    log1 = bad_epochs_AR[0]
    log2 = bad_epochs_AR[1]

    bad1 = np.where(log1.bad_epochs == True)
    bad2 = np.where(log2.bad_epochs == True)

    bad = list(set(bad1[0].tolist()).intersection(bad2[0].tolist()))
    if verbose:
        print('%s percent of bad epochs' %
              int(len(bad) / len(list(log1.bad_epochs)) * 100))

    # picking good epochs for the two subj
    cleaned_epochs_AR = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        clean_epochs_ep = clean_epochs.drop(indices=bad)
        # interpolating bads or removing epochs
        clean_epochs_AR = ar.transform(clean_epochs_ep)
        cleaned_epochs_AR.append(clean_epochs_AR)
    # equalizing epochs length between two subjects
    mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

    # Vizualisation before after AR
    evoked_before = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        evoked_before.append(clean_epochs.average())

    evoked_after_AR = []
    for clean in cleaned_epochs_AR:
        evoked_after_AR.append(clean.average())

    if verbose:
        for i, j in zip(evoked_before, evoked_after_AR):
            fig, axes = plt.subplots(2, 1, figsize=(6, 6))
            for ax in axes:
                ax.tick_params(axis='x', which='both', bottom='off', top='off')
                ax.tick_params(axis='y', which='both', left='off', right='off')

            ylim = dict(grad=(-170, 200))
            i.pick_types(eeg=True, exclude=[])
            i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
            axes[0].set_title('Before autoreject')
            j.pick_types(eeg=True, exclude=[])
            j.plot(exclude=[], axes=axes[1], ylim=ylim)
            # Problème titre ne s'affiche pas pour le deuxieme axe !!!
            axes[1].set_title('After autoreject')
            plt.tight_layout()

    return cleaned_epochs_AR
Ejemplo n.º 6
0
                    detrend=0,
                    preload=True)

ar = AutoReject(n_interpolates,
                consensus_percs,
                picks=picks,
                thresh_method='random_search',
                random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs)

# epochs_ar, reject_log = ar.fit_transform(epochs, return_log=True)

epochs_clean = ar.transform(epochs)
reject_log = ar.get_reject_log(epochs)
evoked_clean = epochs_clean.average()
evoked = epochs.average()

# visualize rejected epochs
scalings = dict(eeg=40e-6)
reject_log.plot_epochs(epochs, scalings=scalings)

# epochs after cleaning
epochs_clean.plot(scalings=scalings)

# notify when done
os.system('say "... I am ready for you Neil."')

# Visualize repaired data
    epochs.load_data()
    epochs = epochs.drop(epochs_2_drop, reason="bad behaviour")
    epochs.save(op.join(sub_path, "clean-" + epo.split(sep)[-1]),
                overwrite=True)
    print("AMOUNT OF EPOCHS AFTER MATCHING WITH BEH:", len(epochs))
    print("DOES IT MATCH?", len(beh_ixs) == len(epochs))
    print("\n")

    if len(beh_ixs) == len(epochs):
        ar = AutoReject(consensus=np.linspace(0, 1.0, 27),
                        n_interpolate=np.array([1, 4, 32]),
                        thresh_method="bayesian_optimization",
                        cv=10,
                        n_jobs=-1,
                        random_state=42,
                        verbose="progressbar")
        ar.fit(epochs)

        epo_type = epo.split(sep)[-1].split("-")[3]
        name = "{}-{}-{}".format(subject_id, numero, epo_type)
        ar_fname = op.join(qc_folder, "{}-autoreject.h5".format(name))
        ar.save(ar_fname, overwrite=True)
        epochs_ar, rej_log = ar.transform(epochs, return_log=True)
        rej_log.plot(show=False)
        plt.savefig(op.join(qc_folder, "{}-autoreject-log.png".format(name)))
        plt.close("all")
        epo.split(sep)[-1]
        cleaned = op.join(sub_path, "autoreject-" + epo.split(sep)[-1])
        epochs.save(op.join(sub_path, "autoreject-" + epo.split(sep)[-1]),
                    overwrite=True)
        print("CLEANED EPOCHS SAVED:", cleaned)
Ejemplo n.º 8
0
def test_autoreject():
    """Test basic _AutoReject functionality."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    ##########################################################################
    # picking epochs
    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=True, include=include, exclude=[])
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=10,
                        reject=None, preload=False)[:10]

    ar = _AutoReject()
    assert_raises(ValueError, ar.fit, epochs)
    epochs.load_data()

    ar.fit(epochs)
    assert_true(len(ar.picks_) == len(picks) - 1)

    # epochs with no picks.
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        baseline=(None, 0), decim=10,
                        reject=None, preload=True)[:20]
    # let's drop some channels to speed up
    pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True)
    pre_picks = np.r_[
        mne.pick_types(epochs.info, meg='mag', eeg=False)[::15],
        mne.pick_types(epochs.info, meg='grad', eeg=False)[::60],
        mne.pick_types(epochs.info, meg=False, eeg=True)[::16],
        mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)]
    pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks]
    bad_ch_names = [epochs.ch_names[ix] for ix in range(len(epochs.ch_names))
                    if ix not in pre_picks]
    epochs_with_bads = epochs.copy()
    epochs_with_bads.info['bads'] = bad_ch_names
    epochs.pick_channels(pick_ch_names)

    epochs_fit = epochs[:12]  # make sure to use different size of epochs
    epochs_new = epochs[12:]
    epochs_with_bads_fit = epochs_with_bads[:12]

    X = epochs_fit.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = _GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = _GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = _GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar_global = _GlobalAutoReject(
        n_channels=n_channels, n_times=n_times, thresh=40e-6)
    ar_global.fit(X)

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, X, None,
                  param_name, param_range)

    ##########################################################################
    # picking AutoReject

    picks = mne.pick_types(
        epochs.info, meg='mag', eeg=True, stim=False, eog=False,
        include=[], exclude=[])
    non_picks = mne.pick_types(
        epochs.info, meg='grad', eeg=False, stim=False, eog=False,
        include=[], exclude=[])
    ch_types = ['mag', 'eeg']

    ar = _AutoReject(picks=picks)  # XXX : why do we need this??

    ar = AutoReject(cv=3, picks=picks, random_state=42,
                    n_interpolate=[1, 2], consensus=[0.5, 1])
    assert_raises(AttributeError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    ar.fit(epochs_fit)
    reject_log = ar.get_reject_log(epochs_fit)
    for ch_type in ch_types:
        # test that kappa & rho are selected
        assert_true(
            ar.n_interpolate_[ch_type] in ar.n_interpolate)
        assert_true(
            ar.consensus_[ch_type] in ar.consensus)

        assert_true(
            ar.n_interpolate_[ch_type] ==
            ar.local_reject_[ch_type].n_interpolate_[ch_type])
        assert_true(
            ar.consensus_[ch_type] ==
            ar.local_reject_[ch_type].consensus_[ch_type])

    # test complementarity of goods and bads
    assert_array_equal(len(reject_log.bad_epochs), len(epochs_fit))

    # test that transform does not change state of ar
    epochs_clean = ar.transform(epochs_fit)  # apply same data
    assert_true(repr(ar))
    assert_true(repr(ar.local_reject_))
    reject_log2 = ar.get_reject_log(epochs_fit)
    assert_array_equal(reject_log.labels, reject_log2.labels)
    assert_array_equal(reject_log.bad_epochs, reject_log2.bad_epochs)
    assert_array_equal(reject_log.ch_names, reject_log2.ch_names)

    epochs_new_clean = ar.transform(epochs_new)  # apply to new data

    reject_log_new = ar.get_reject_log(epochs_new)
    assert_array_equal(len(reject_log_new.bad_epochs), len(epochs_new))

    assert_true(
        len(reject_log_new.bad_epochs) != len(reject_log.bad_epochs))

    picks_by_type = _get_picks_by_type(epochs.info, ar.picks)
    # test correct entries in fix log
    assert_true(
        np.isnan(reject_log_new.labels[:, non_picks]).sum() > 0)
    assert_true(
        np.isnan(reject_log_new.labels[:, picks]).sum() == 0)
    assert_equal(reject_log_new.labels.shape,
                 (len(epochs_new), len(epochs_new.ch_names)))

    # test correct interpolations by type
    for ch_type, this_picks in picks_by_type:
        interp_counts = np.sum(
            reject_log_new.labels[:, this_picks] == 2, axis=1)
        labels = reject_log_new.labels.copy()
        not_this_picks = np.setdiff1d(np.arange(labels.shape[1]), this_picks)
        labels[:, not_this_picks] = np.nan
        interp_channels = _get_interp_chs(
            labels, reject_log.ch_names, this_picks)
        assert_array_equal(
            interp_counts, [len(cc) for cc in interp_channels])

    is_same = epochs_new_clean.get_data() == epochs_new.get_data()
    if not np.isscalar(is_same):
        is_same = np.isscalar(is_same)
    assert_true(not is_same)

    # test that transform ignores bad channels
    epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[])
    ar_bads = AutoReject(cv=3, random_state=42,
                         n_interpolate=[1, 2], consensus=[0.5, 1])
    ar_bads.fit(epochs_with_bads_fit)
    epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit)

    good_w_bads_ix = mne.pick_types(epochs_with_bads_clean.info,
                                    meg='mag', eeg=True, eog=True,
                                    exclude='bads')
    good_wo_bads_ix = mne.pick_types(epochs_clean.info,
                                     meg='mag', eeg=True, eog=True,
                                     exclude='bads')
    assert_array_equal(epochs_with_bads_clean.get_data()[:, good_w_bads_ix, :],
                       epochs_clean.get_data()[:, good_wo_bads_ix, :])

    bad_ix = [epochs_with_bads_clean.ch_names.index(ch)
              for ch in epochs_with_bads_clean.info['bads']]
    epo_ix = ~ar_bads.get_reject_log(epochs_with_bads_fit).bad_epochs
    assert_array_equal(
        epochs_with_bads_clean.get_data()[:, bad_ix, :],
        epochs_with_bads_fit.get_data()[epo_ix, :, :][:, bad_ix, :])

    assert_equal(epochs_clean.ch_names, epochs_fit.ch_names)

    assert_true(isinstance(ar.threshes_, dict))
    assert_true(len(ar.picks) == len(picks))
    assert_true(len(ar.threshes_.keys()) == len(ar.picks))
    pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)[0]
    assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys())
    assert_raises(
        IndexError, ar.transform,
        epochs.copy().pick_channels(
            [epochs.ch_names[pp] for pp in picks[:3]]))

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp])
                          for ii, pp in enumerate(picks)])
    threshes_a = compute_thresholds(
        epochs_fit, picks=picks, method='random_search')
    assert_equal(set(threshes_a.keys()), set(ch_names))
    threshes_b = compute_thresholds(
        epochs_fit, picks=picks, method='bayesian_optimization')
    assert_equal(set(threshes_b.keys()), set(ch_names))
picks = mne.pick_types(raw.info,
                       meg=False,
                       eeg=True,
                       stim=False,
                       eog=False,
                       include=[],
                       exclude=[])

# Make epochs from the raw data
epochs = mne.Epochs(raw,
                    picks=picks,
                    events=events,
                    event_id=event_id,
                    tmin=tmin,
                    tmax=tmax,
                    preload=True,
                    reject=None)

# Setup AutoReject
ar = AutoReject(n_interpolates,
                consensus_percs,
                thresh_method='random_search',
                random_state=seed)

# Fit, i.e. calculate AutoReject
ar.fit(epochs)

epochs_clean = ar.transform(epochs)  # Clean the epochs
epochs_clean.save(data_path + '%s-epo.fif' % subject)  # Save the epochs
Ejemplo n.º 10
0
# %%
# Note that :class:`autoreject.AutoReject` by design supports
# multiple channels.
# If no picks are passed, separate solutions will be computed for each channel
# type and internally combined. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
                thresh_method='random_search', random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
evoked_clean = epochs_clean.average()
evoked = epochs['Auditory/Left'].average()

# %%
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

# %%
# Let us plot the results.

import matplotlib.pyplot as plt  # noqa
set_matplotlib_defaults(plt)
Ejemplo n.º 11
0
def run_epochs(subject, autoreject=True):
    raw_fname = op.join(meg_dir, subject, f'{subject}_audvis-filt_raw_sss.fif')
    annot_fname = op.join(meg_dir, subject, f'{subject}_audvis-annot.fif')
    raw = mne.io.read_raw_fif(raw_fname, preload=False)
    annot = mne.read_annotations(annot_fname)
    raw.set_annotations(annot)
    if autoreject:
        epo_fname = op.join(meg_dir, subject,
                            f'{subject}_audvis-filt-sss-ar-epo.fif')
    else:
        epo_fname = op.join(meg_dir, subject,
                            f'{subject}_audvis-filt-sss-epo.fif')
    # ICA
    ica_fname = op.join(meg_dir, subject, f'{subject}_audvis-ica.fif')
    ica = mne.preprocessing.read_ica(ica_fname)

    # ICA
    ica = mne.preprocessing.read_ica(ica_fname)
    try:
        # ECG
        ecg_epochs = mne.preprocessing.create_ecg_epochs(raw,
                                                         l_freq=10,
                                                         h_freq=20,
                                                         baseline=(None, None),
                                                         preload=True)
        ecg_inds, scores_ecg = ica.find_bads_ecg(ecg_epochs,
                                                 method='ctps',
                                                 threshold='auto',
                                                 verbose='INFO')
    except ValueError:
        # not found
        pass
    else:
        print(f'Found {len(ecg_inds)} ({ecg_inds}) ECG indices for {subject}')
        if len(ecg_inds) != 0:
            ica.exclude.extend(ecg_inds[:n_max_ecg])
            # for future inspection
            ecg_epochs.average().save(
                op.join(meg_dir, subject, f'{subject}_audvis-ecg-ave.fif'))
        # release memory
        del ecg_epochs, ecg_inds, scores_ecg

    try:
        # EOG
        eog_epochs = mne.preprocessing.create_eog_epochs(raw,
                                                         baseline=(None, None),
                                                         preload=True)
        eog_inds, scores_eog = ica.find_bads_eog(eog_epochs)
    except ValueError:
        # not found
        pass
    else:
        print(f'Found {len(eog_inds)} ({eog_inds}) EOG indices for {subject}')
        if len(eog_inds) != 0:
            ica.exclude.extend(eog_inds[:n_max_eog])
            # for future inspection
            eog_epochs.average().save(
                op.join(meg_dir, subject, f'{subject}_audvis-eog-ave.fif'))
            del eog_epochs, eog_inds, scores_eog  # release memory

    # applying ICA on Raw
    raw.load_data()
    ica.apply(raw)

    # extract events for epoching
    # modify stim_channel for your need
    events = mne.find_events(raw, stim_channel="STI 014")
    picks = mne.pick_types(raw.info, meg=True)
    epochs = mne.Epochs(
        raw,
        events=events,
        picks=picks,
        event_id=event_id,
        tmin=tmin,
        tmax=tmax,
        baseline=baseline,
        decim=4,  # raw sampling rate is 600 Hz, subsample to 150 Hz
        preload=True,  # for autoreject
        reject_tmax=reject_tmax,
        reject_by_annotation=True)
    del raw, annot

    # autoreject (local)
    if autoreject:
        # local reject
        # keep the bad sensors/channels because autoreject can repair it via
        # interpolation
        picks = mne.pick_types(epochs.info, meg=True, exclude=[])
        ar = AutoReject(picks=picks, n_jobs=n_jobs, verbose=False)
        print(f'Run autoreject (local) for {subject} (it takes a long time)')
        ar.fit(epochs)
        print(f'Drop bad epochs and interpolate bad sensors for {subject}')
        epochs = ar.transform(epochs)

    print(f'Dropped {round(epochs.drop_log_stats(), 2)}% epochs for {subject}')
    epochs.save(epo_fname, overwrite=True)
Ejemplo n.º 12
0
def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float = 50.0, verbose: bool = False) -> list:
    """
    Applies local Autoreject to repair or reject bad epochs.

    Arguments:
        clean_epochs_ICA: list of Epochs after global Autoreject and ICA.
        strategy: more or less generous strategy to reject bad epochs: 'union'
          or 'intersection'. 'union' rejects bad epochs from subject 1 and
          subject 2 immediatly, whereas 'intersection' rejects shared bad epochs
          between subjects, tries to repare remaining bad epochs per subject,
          reject the non-reparable per subject and finally equalize epochs number
          between subjects. Set to 'union' by default.
        threshold: percentage of epochs removed that is accepted. Above
          this threshold, data are considered as a too shortened sample
          for further analyses. Set to 50.0 by default.
        verbose: option to plot data before and after AR, boolean, set to
          False by default. # use verbose = false until next Autoreject update

    Note:
        To reject or repair epochs, parameters are more or less conservative,
        see http://autoreject.github.io/generated/autoreject.AutoReject.

    Returns:
        cleaned_epochs_AR: list of Epochs after local Autoreject.
        dic_AR: dictionnary with the percentage of epochs rejection
          for each subject and for the intersection of the them.
    """
    bad_epochs_AR = []
    AR = []
    dic_AR = {}
    dic_AR['strategy'] = strategy
    dic_AR['threshold'] = threshold

    # defaults values for n_interpolates and consensus_percs
    n_interpolates = np.array([1, 4, 32])
    consensus_percs = np.linspace(0, 1.0, 11)
    # more generous values
    # n_interpolates = np.array([16, 32, 64])
    # n_interpolates = np.array([1, 4, 8, 16, 32, 64])
    # consensus_percs = np.linspace(0.5, 1.0, 11)

    for clean_epochs in cleaned_epochs_ICA:  # per subj

        picks = mne.pick_types(
            clean_epochs[0].info,
            meg=False,
            eeg=True,
            stim=False,
            eog=False,
            exclude=[])

        ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
                        thresh_method='random_search', random_state=42,
                        verbose='tqdm_notebook')
        AR.append(ar)

        # fitting AR to get bad epochs
        ar.fit(clean_epochs)
        reject_log = ar.get_reject_log(clean_epochs, picks=picks)
        bad_epochs_AR.append(reject_log)

    # taking bad epochs for min 1 subj (dyad)
    log1 = bad_epochs_AR[0]
    log2 = bad_epochs_AR[1]

    bad1 = np.where(log1.bad_epochs == True)
    bad2 = np.where(log2.bad_epochs == True)

    if strategy == 'union':
        bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist())))
    elif strategy == 'intersection':
        bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist())))
    else:
        TypeError('not good strategy input!')

    # storing the percentage of epochs rejection
    dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100)
    dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100)

    # picking good epochs for the two subj
    cleaned_epochs_AR = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        # keep a copy of the original data
        clean_epochs_ep = copy.deepcopy(clean_epochs)
        clean_epochs_ep = clean_epochs_ep.drop(indices=bad)
        # interpolating bads or removing epochs
        ar = AR[cleaned_epochs_ICA.index(clean_epochs)]
        clean_epochs_AR = ar.transform(clean_epochs_ep)
        cleaned_epochs_AR.append(clean_epochs_AR)

    if strategy == 'intersection':
        # equalizing epochs length between two participants
        mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

    dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100)
    if dic_AR['dyad'] >= threshold:
        TypeError('percentage of rejected epochs above threshold!')
    if verbose:
        print('%s percent of bad epochs' % dic_AR['dyad'])

    # Vizualisation before after AR
    evoked_before = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        evoked_before.append(clean_epochs.average())

    evoked_after_AR = []
    for clean in cleaned_epochs_AR:
        evoked_after_AR.append(clean.average())

    if verbose:
        for i, j in zip(evoked_before, evoked_after_AR):
            fig, axes = plt.subplots(2, 1, figsize=(6, 6))
            for ax in axes:
                ax.tick_params(axis='x', which='both', bottom='off', top='off')
                ax.tick_params(axis='y', which='both', left='off', right='off')

            ylim = dict(grad=(-170, 200))
            i.pick_types(eeg=True, exclude=[])
            i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
            axes[0].set_title('Before autoreject')
            j.pick_types(eeg=True, exclude=[])
            j.plot(exclude=[], axes=axes[1], ylim=ylim)
            # Problème titre ne s'affiche pas pour le deuxieme axe !!!
            axes[1].set_title('After autoreject')
            plt.tight_layout()

    return cleaned_epochs_AR, dic_AR
###############################################################################
# Note that:class:`autoreject.AutoReject` by design supports
# multiple channels.
# If no picks are passed separate solutions will be computed for each channel
# type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
                thresh_method='random_search', random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
evoked_clean = epochs_clean.average()
evoked = epochs['Auditory/Left'].average()

###############################################################################
# Now, we will manually mark the bad channels just for plotting.

evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

###############################################################################
# Let us plot the results.

import matplotlib.pyplot as plt  # noqa
set_matplotlib_defaults(plt)