Exemplo n.º 1
0
def run_autoreject(subject):
    """Interpolate bad epochs/sensors using Autoreject.

    Parameters
    ----------
    *subject: string
        The participant reference

    Save the resulting *-epo.fif file in the '4_autoreject' directory.
    Save .png of ERP difference and heatmap plots.

    References
    ----------
    [1] Mainak Jas, Denis Engemann, Federico Raimondo, Yousra Bekhti, and
    Alexandre Gramfort, “Automated rejection and repair of bad trials in
    MEG/EEG.” In 6th International Workshop on Pattern Recognition in
    Neuroimaging (PRNI), 2016.

    [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and
    Alexandre Gramfort. 2017. “Autoreject: Automated artifact rejection for
    MEG and EEG data”. NeuroImage, 159, 417-429.

    """
    # Import data
    input_path = root + '/4_ICA/' + subject + '-epo.fif'
    epochs = mne.read_epochs(input_path)

    # Autoreject
    ar = AutoReject(random_state=42,
                    n_jobs=4)

    ar.fit_transform(epochs)
    epochs_clean = ar.transform(epochs)

    # Plot difference
    evoked = epochs.average()
    evoked_clean = epochs_clean.average()

    fig, axes = plt.subplots(2, 1, figsize=(6, 6))

    for ax in axes:
        ax.tick_params(axis='x', which='both', bottom='off', top='off')
        ax.tick_params(axis='y', which='both', left='off', right='off')

    evoked.plot(exclude=[], axes=axes[0], ylim=[-30, 30], show=False)
    axes[0].set_title('Before autoreject')
    evoked_clean.plot(exclude=[], axes=axes[1], ylim=[-30, 30])
    axes[1].set_title('After autoreject')
    plt.tight_layout()
    plt.savefig(root + '/5_autoreject/' +
                subject + '-autoreject.png')
    plt.close()

    # Plot heatmap
    ar.get_reject_log(epochs).plot()
    plt.savefig(root + '/5_autoreject/' + subject + '-heatmap.png')
    plt.close()

    # Save epoch data
    out_epoch = root + '/5_autoreject/' + subject + '-epo.fif'
    epochs_clean.save(out_epoch)
Exemplo n.º 2
0
def epoch_and_clean_trials(subject,
                           diagdir,
                           bidsdir,
                           datadir,
                           derivdir,
                           epochlength=3,
                           eventid={'visualfix/fixCross': 10}):
    """
    Chunk the data into epochs starting at the eventid specified per trial,
    lasting 7 seconds (which should include all trial elements).
    Do automatic artifact detection, rejection and fixing for eyeblinks,
    heartbeat, and high- and low-amplitude artifacts.
    :param subject: str, subject identifier. takes the form '001'
    :param diagdir: str, path to a directory where diagnostic plots can be saved
    :param bidsdir: str, path to a directory with BIDS data. Needed to load
    event logs from the experiment
    :param datadir: str, path to a directory with SSS-processed data
    :param derivdir: str, path to a directory where cleaned epochs can be saved
    :param epochlength: int, length of epoch
    :param eventid: dict, the event to start an Epoch from
    """
    # construct name of the first split
    raw_fname = Path(datadir) / f'sub-{subject}/meg' / \
                f'sub-{subject}_task-memento_proc-sss_meg.fif'
    logging.info(f"Reading in SSS-processed data from subject sub-{subject}. "
                 f"Attempting the following path: {raw_fname}")
    raw = mne.io.read_raw_fif(raw_fname)
    events, event_dict = get_events(raw)
    # filter the data to remove high-frequency noise. Minimal high-pass filter
    # based on
    # https://www.sciencedirect.com/science/article/pii/S0165027021000157
    # ensure the data is loaded prior to filtering
    raw.load_data()
    if subject == '017':
        logging.info('Setting additional bad channels for subject 17')
        raw.info['bads'] = ['MEG0313', 'MEG0513', 'MEG0523']
        raw.interpolate_bads()
    # high-pass doesn't make sense, raw data has 0.1Hz high-pass filter already!
    _filter_data(raw, h_freq=100)
    # ICA to detect and repair artifacts
    logging.info('Removing eyeblink and heartbeat artifacts')
    rng = np.random.RandomState(28)
    remove_eyeblinks_and_heartbeat(
        raw=raw,
        subject=subject,
        figdir=diagdir,
        events=events,
        eventid=eventid,
        rng=rng,
    )
    # get the actual epochs: chunk the trial into epochs starting from the
    # event ID. Do not baseline correct the data.
    logging.info(f'Creating epochs of length {epochlength}')
    if eventid == {'press/left': 1, 'press/right': 4}:
        # when centered on the response, move back in time
        epochs = mne.Epochs(raw,
                            events,
                            event_id=eventid,
                            tmin=-epochlength,
                            tmax=0,
                            picks='meg',
                            baseline=None)
    else:
        epochs = mne.Epochs(raw,
                            events,
                            event_id=eventid,
                            tmin=0,
                            tmax=epochlength,
                            picks='meg',
                            baseline=None)
    # ADD SUBJECT SPECIFIC TRIAL NUMBER TO THE EPOCH! ONLY THIS WAY WE CAN
    # LATER RECOVER WHICH TRIAL PARAMETERS WE'RE LOOKING AT BASED ON THE LOGS AS
    # THE EPOCH REJECTION WILL REMOVE TRIALS
    logging.info("Retrieving trial metadata.")
    from pymento_meg.proc.epoch import get_trial_features
    metadata = get_trial_features(bids_path=bidsdir,
                                  subject=subject,
                                  column='trial_no')
    # transform to integers
    metadata = metadata.astype(int)
    # this does not work if we start at fixation cross for subject 5, because 1
    # fixation cross trigger is missing from the data, and it becomes impossible
    # to associate the trial metadata to the correct trials in the data
    epochs.metadata = metadata
    epochs.load_data()
    ## downsample the data to 200Hz
    #logging.info('Resampling epoched data down to 200 Hz')
    #epochs.resample(sfreq=200, verbose=True)
    # use autoreject to repair bad epochs
    ar = AutoReject(
        random_state=rng,
        n_interpolate=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
    epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True)
    # save the cleaned, epoched data to disk.
    outpath = _construct_path([
        Path(derivdir),
        f"sub-{subject}",
        "meg",
        f"sub-{subject}_task-memento_cleaned_epo.fif",
    ])
    logging.info(f"Saving cleaned, epoched data to {outpath}")
    epochs_clean.save(outpath, overwrite=True)
    # visualize the bad sensors for each trial
    fig = ar.get_reject_log(epochs).plot()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"epoch-rejectlog_sub-{subject}.png",
    ])
    fig.savefig(fname)
    # plot the average of all cleaned epochs
    fig = epochs_clean.average().plot()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"clean-epoch_average_sub-{subject}.png",
    ])
    fig.savefig(fname)
    # plot psd of cleaned epochs
    psd = epochs_clean.plot_psd()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"psd_cleaned-epochs-{subject}.png",
    ])
    psd.savefig(fname)
Exemplo n.º 3
0
def AR_local(cleaned_epochs_ICA, verbose=False):
    """
    Applies local Autoreject to correct or reject bad epochs.

    Arguments:
        clean_epochs_ICA: list of Epochs after global Autoreject and ICA
        verbose: to plot data before and after AR, boolean set to False
          by default.

    Returns:
        cleaned_epochs_AR: list of Epochs after local Autoreject.
    """
    bad_epochs_AR = []

    # defaults values for n_interpolates and consensus_percs
    n_interpolates = np.array([1, 4, 32])
    consensus_percs = np.linspace(0, 1.0, 11)

    for clean_epochs in cleaned_epochs_ICA:  # per subj

        picks = mne.pick_types(clean_epochs[0].info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude=[])

        if verbose:
            ar_verbose = 'progressbar'
        else:
            ar_verbose = False

        ar = AutoReject(n_interpolates,
                        consensus_percs,
                        picks=picks,
                        thresh_method='random_search',
                        random_state=42,
                        verbose=ar_verbose)

        # fitting AR to get bad epochs
        ar.fit(clean_epochs)
        reject_log = ar.get_reject_log(clean_epochs, picks=picks)
        bad_epochs_AR.append(reject_log)

    # taking bad epochs for min 1 subj (dyad)
    log1 = bad_epochs_AR[0]
    log2 = bad_epochs_AR[1]

    bad1 = np.where(log1.bad_epochs == True)
    bad2 = np.where(log2.bad_epochs == True)

    bad = list(set(bad1[0].tolist()).intersection(bad2[0].tolist()))
    if verbose:
        print('%s percent of bad epochs' %
              int(len(bad) / len(list(log1.bad_epochs)) * 100))

    # picking good epochs for the two subj
    cleaned_epochs_AR = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        clean_epochs_ep = clean_epochs.drop(indices=bad)
        # interpolating bads or removing epochs
        clean_epochs_AR = ar.transform(clean_epochs_ep)
        cleaned_epochs_AR.append(clean_epochs_AR)
    # equalizing epochs length between two subjects
    mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

    # Vizualisation before after AR
    evoked_before = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        evoked_before.append(clean_epochs.average())

    evoked_after_AR = []
    for clean in cleaned_epochs_AR:
        evoked_after_AR.append(clean.average())

    if verbose:
        for i, j in zip(evoked_before, evoked_after_AR):
            fig, axes = plt.subplots(2, 1, figsize=(6, 6))
            for ax in axes:
                ax.tick_params(axis='x', which='both', bottom='off', top='off')
                ax.tick_params(axis='y', which='both', left='off', right='off')

            ylim = dict(grad=(-170, 200))
            i.pick_types(eeg=True, exclude=[])
            i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
            axes[0].set_title('Before autoreject')
            j.pick_types(eeg=True, exclude=[])
            j.plot(exclude=[], axes=axes[1], ylim=ylim)
            # Problème titre ne s'affiche pas pour le deuxieme axe !!!
            axes[1].set_title('After autoreject')
            plt.tight_layout()

    return cleaned_epochs_AR
Exemplo n.º 4
0
                    preload=True)

ar = AutoReject(n_interpolates,
                consensus_percs,
                picks=picks,
                thresh_method='random_search',
                random_state=42)

# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs)

# epochs_ar, reject_log = ar.fit_transform(epochs, return_log=True)

epochs_clean = ar.transform(epochs)
reject_log = ar.get_reject_log(epochs)
evoked_clean = epochs_clean.average()
evoked = epochs.average()

# visualize rejected epochs
scalings = dict(eeg=40e-6)
reject_log.plot_epochs(epochs, scalings=scalings)

# epochs after cleaning
epochs_clean.plot(scalings=scalings)

# notify when done
os.system('say "... I am ready for you Neil."')

# Visualize repaired data
ylim = dict(eeg=(-2, 2))
Exemplo n.º 5
0
def test_autoreject():
    """Test basic _AutoReject functionality."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    ##########################################################################
    # picking epochs
    include = [u'EEG %03d' % i for i in range(1, 45, 3)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=True, include=include, exclude=[])
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=10,
                        reject=None, preload=False)[:10]

    ar = _AutoReject()
    assert_raises(ValueError, ar.fit, epochs)
    epochs.load_data()

    ar.fit(epochs)
    assert_true(len(ar.picks_) == len(picks) - 1)

    # epochs with no picks.
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        baseline=(None, 0), decim=10,
                        reject=None, preload=True)[:20]
    # let's drop some channels to speed up
    pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True)
    pre_picks = np.r_[
        mne.pick_types(epochs.info, meg='mag', eeg=False)[::15],
        mne.pick_types(epochs.info, meg='grad', eeg=False)[::60],
        mne.pick_types(epochs.info, meg=False, eeg=True)[::16],
        mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)]
    pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks]
    bad_ch_names = [epochs.ch_names[ix] for ix in range(len(epochs.ch_names))
                    if ix not in pre_picks]
    epochs_with_bads = epochs.copy()
    epochs_with_bads.info['bads'] = bad_ch_names
    epochs.pick_channels(pick_ch_names)

    epochs_fit = epochs[:12]  # make sure to use different size of epochs
    epochs_new = epochs[12:]
    epochs_with_bads_fit = epochs_with_bads[:12]

    X = epochs_fit.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = _GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = _GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = _GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar_global = _GlobalAutoReject(
        n_channels=n_channels, n_times=n_times, thresh=40e-6)
    ar_global.fit(X)

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, X, None,
                  param_name, param_range)

    ##########################################################################
    # picking AutoReject

    picks = mne.pick_types(
        epochs.info, meg='mag', eeg=True, stim=False, eog=False,
        include=[], exclude=[])
    non_picks = mne.pick_types(
        epochs.info, meg='grad', eeg=False, stim=False, eog=False,
        include=[], exclude=[])
    ch_types = ['mag', 'eeg']

    ar = _AutoReject(picks=picks)  # XXX : why do we need this??

    ar = AutoReject(cv=3, picks=picks, random_state=42,
                    n_interpolate=[1, 2], consensus=[0.5, 1])
    assert_raises(AttributeError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    ar.fit(epochs_fit)
    reject_log = ar.get_reject_log(epochs_fit)
    for ch_type in ch_types:
        # test that kappa & rho are selected
        assert_true(
            ar.n_interpolate_[ch_type] in ar.n_interpolate)
        assert_true(
            ar.consensus_[ch_type] in ar.consensus)

        assert_true(
            ar.n_interpolate_[ch_type] ==
            ar.local_reject_[ch_type].n_interpolate_[ch_type])
        assert_true(
            ar.consensus_[ch_type] ==
            ar.local_reject_[ch_type].consensus_[ch_type])

    # test complementarity of goods and bads
    assert_array_equal(len(reject_log.bad_epochs), len(epochs_fit))

    # test that transform does not change state of ar
    epochs_clean = ar.transform(epochs_fit)  # apply same data
    assert_true(repr(ar))
    assert_true(repr(ar.local_reject_))
    reject_log2 = ar.get_reject_log(epochs_fit)
    assert_array_equal(reject_log.labels, reject_log2.labels)
    assert_array_equal(reject_log.bad_epochs, reject_log2.bad_epochs)
    assert_array_equal(reject_log.ch_names, reject_log2.ch_names)

    epochs_new_clean = ar.transform(epochs_new)  # apply to new data

    reject_log_new = ar.get_reject_log(epochs_new)
    assert_array_equal(len(reject_log_new.bad_epochs), len(epochs_new))

    assert_true(
        len(reject_log_new.bad_epochs) != len(reject_log.bad_epochs))

    picks_by_type = _get_picks_by_type(epochs.info, ar.picks)
    # test correct entries in fix log
    assert_true(
        np.isnan(reject_log_new.labels[:, non_picks]).sum() > 0)
    assert_true(
        np.isnan(reject_log_new.labels[:, picks]).sum() == 0)
    assert_equal(reject_log_new.labels.shape,
                 (len(epochs_new), len(epochs_new.ch_names)))

    # test correct interpolations by type
    for ch_type, this_picks in picks_by_type:
        interp_counts = np.sum(
            reject_log_new.labels[:, this_picks] == 2, axis=1)
        labels = reject_log_new.labels.copy()
        not_this_picks = np.setdiff1d(np.arange(labels.shape[1]), this_picks)
        labels[:, not_this_picks] = np.nan
        interp_channels = _get_interp_chs(
            labels, reject_log.ch_names, this_picks)
        assert_array_equal(
            interp_counts, [len(cc) for cc in interp_channels])

    is_same = epochs_new_clean.get_data() == epochs_new.get_data()
    if not np.isscalar(is_same):
        is_same = np.isscalar(is_same)
    assert_true(not is_same)

    # test that transform ignores bad channels
    epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[])
    ar_bads = AutoReject(cv=3, random_state=42,
                         n_interpolate=[1, 2], consensus=[0.5, 1])
    ar_bads.fit(epochs_with_bads_fit)
    epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit)

    good_w_bads_ix = mne.pick_types(epochs_with_bads_clean.info,
                                    meg='mag', eeg=True, eog=True,
                                    exclude='bads')
    good_wo_bads_ix = mne.pick_types(epochs_clean.info,
                                     meg='mag', eeg=True, eog=True,
                                     exclude='bads')
    assert_array_equal(epochs_with_bads_clean.get_data()[:, good_w_bads_ix, :],
                       epochs_clean.get_data()[:, good_wo_bads_ix, :])

    bad_ix = [epochs_with_bads_clean.ch_names.index(ch)
              for ch in epochs_with_bads_clean.info['bads']]
    epo_ix = ~ar_bads.get_reject_log(epochs_with_bads_fit).bad_epochs
    assert_array_equal(
        epochs_with_bads_clean.get_data()[:, bad_ix, :],
        epochs_with_bads_fit.get_data()[epo_ix, :, :][:, bad_ix, :])

    assert_equal(epochs_clean.ch_names, epochs_fit.ch_names)

    assert_true(isinstance(ar.threshes_, dict))
    assert_true(len(ar.picks) == len(picks))
    assert_true(len(ar.threshes_.keys()) == len(ar.picks))
    pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)[0]
    assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys())
    assert_raises(
        IndexError, ar.transform,
        epochs.copy().pick_channels(
            [epochs.ch_names[pp] for pp in picks[:3]]))

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp])
                          for ii, pp in enumerate(picks)])
    threshes_a = compute_thresholds(
        epochs_fit, picks=picks, method='random_search')
    assert_equal(set(threshes_a.keys()), set(ch_names))
    threshes_b = compute_thresholds(
        epochs_fit, picks=picks, method='bayesian_optimization')
    assert_equal(set(threshes_b.keys()), set(ch_names))
Exemplo n.º 6
0
def run_preproc(datadir='/data'):

    print('data directory: {}'.format(datadir))
    conf_file_path = join(datadir, 'eegprep.conf')
    config = Configuration()
    config.setDefaults(defaults)
    if os.path.isfile(conf_file_path):
        with open(conf_file_path) as fh:
            conf_string = fh.read()
        config.updateFromString(conf_string)
    print('configuration:')
    print(config)

    bidsdir = join(datadir, 'BIDS')
    eegprepdir = join(bidsdir, 'derivatives', 'eegprep')

    
    subjectdirs = sorted(glob.glob(join(bidsdir, 'sub-*')))
    for subjectdir in subjectdirs:
        assert os.path.isdir(subjectdir)
        
        sub = basename(subjectdir)[4:]

        # prepare derivatives directory
        derivdir = join(eegprepdir, 'sub-' + sub)
        os.makedirs(derivdir, exist_ok=True)
        reportsdir = join(eegprepdir, 'reports', 'sub-' + sub)
        os.makedirs(reportsdir, exist_ok=True)


        subject_epochs = {}
        rawtypes = {'.set': mne.io.read_raw_eeglab, '.bdf': mne.io.read_raw_edf}
        for fname in sorted(glob.glob(join(subjectdir, 'eeg', '*'))):
            _, ext = splitext(fname)
            if ext not in rawtypes.keys():
                continue
            sub, ses, task, run = filename2tuple(basename(fname))

            print('\nProcessing raw file: ' + basename(fname))

            # read data
            raw = rawtypes[ext](fname, preload=True, verbose=False)
            events = mne.find_events(raw)  #raw, consecutive=False, min_duration=0.005)

            # Set channel types and select reference channels
            channelFile = fname.replace('eeg' + ext, 'channels.tsv')
            channels = pandas.read_csv(channelFile, index_col='name', sep='\t')
            bids2mne = {
                'MISC': 'misc',
                'EEG': 'eeg',
                'VEOG': 'eog',
                'TRIG': 'stim',
                'REF': 'eeg',
            }
            channels['mne'] = channels.type.replace(bids2mne)
            
            # the below fails if the specified channels are not in the data
            raw.set_channel_types(channels.mne.to_dict())

            # set bad channels
            raw.info['bads'] = channels[channels.status=='bad'].index.tolist()

            # pick channels to use for epoching
            epoching_picks = mne.pick_types(raw.info, eeg=True, eog=False, stim=False, exclude='bads')


            # Filtering
            #raw.filter(l_freq=0.05, h_freq=40, fir_design='firwin')

            montage = mne.channels.read_montage(guess_montage(raw.ch_names))
            print(montage)
            raw.set_montage(montage)

            # plot raw data
            nchans = len(raw.ch_names)
            pick_channels = numpy.arange(0, nchans, numpy.floor(nchans/20)).astype(int)
            start = numpy.round(raw.times.max()/2)
            fig = raw.plot(start=start, order=pick_channels)
            fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_raw.png'.format(sub, ses, task, run)
            fig.savefig(join(reportsdir, fname_plot))

            # Set reference
            refChannels = channels[channels.type=='REF'].index.tolist()
            raw.set_eeg_reference(ref_channels=refChannels)

            ##  epoching
            epochs_params = dict(
                events=events,
                tmin=-0.1,
                tmax=0.8,
                reject=None,  # dict(eeg=250e-6, eog=150e-6)
                picks=epoching_picks,
                detrend=0,
            )
            file_epochs = mne.Epochs(raw, preload=True, **epochs_params)
            file_epochs.drop_channels(refChannels)

            # autoreject (under development)
            ar = AutoReject(n_jobs=4)
            clean_epochs = ar.fit_transform(file_epochs)

            rejectlog = ar.get_reject_log(clean_epochs)
            fname_log = 'sub-{}_ses-{}_task-{}_run-{}_reject-log.npz'.format(sub, ses, task, run)
            save_rejectlog(join(reportsdir, fname_log), rejectlog)
            fig = plot_rejectlog(rejectlog)
            fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_bad-epochs.png'.format(sub, ses, task, run)
            fig.savefig(join(reportsdir, fname_plot))


            # store for now
            subject_epochs[(ses, task, run)] = clean_epochs

            # create evoked plots
            conds = clean_epochs.event_id.keys()
            selected_conds = random.sample(conds, min(len(conds), 6))
            picks = mne.pick_types(clean_epochs.info, eeg=True)
            for cond in selected_conds:
                evoked = clean_epochs[cond].average()
                fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_evoked-{}.png'.format(sub, ses, task, run, cond)
                fig = evoked.plot_joint(picks=picks)
                fig.savefig(join(reportsdir, fname_plot))



        sessSeg = 0
        sessions = sorted(list(set([k[sessSeg] for k in subject_epochs.keys()])))
        for session in sessions:
            taskSeg = 1
            tasks = list(set([k[taskSeg] for k in subject_epochs.keys() if k[sessSeg]==session]))
            for task in tasks:
                print('\nGathering epochs for session {} task {}'.format(session, task))
                epochs_selection = [v for (k, v) in subject_epochs.items() if k[:2]==(session, task)]

                task_epochs = mne.epochs.concatenate_epochs(epochs_selection)
                
                # downsample if configured to do so
                # important to do this after concatenation because 
                # downsampling may cause rejection for 'TOOSHORT'
                if config['downsample'] < task_epochs.info['sfreq']:
                    task_epochs = task_epochs.copy().resample(config['downsample'], npad='auto')

                ext = config['out_file_format']
                fname = join(derivdir, 'sub-{}_ses-{}_task-{}_epo.{}'.format(sub, session, task, ext))
                variables = {
                    'epochs': task_epochs.get_data(),
                    'events': task_epochs.events,
                    'timepoints': task_epochs.times
                }
                if ext == 'fif':
                    task_epochs.save(fname)
                elif ext == 'mat':
                    scipy.io.savemat(fname, mdict=variables)
                elif ext == 'npy':
                    numpy.savez(fname, **variables)
Exemplo n.º 7
0
evoked.info['bads'] = ['MEG 2443']
evoked_clean.info['bads'] = ['MEG 2443']

# %%
# Let us plot the results.

import matplotlib.pyplot as plt  # noqa
set_matplotlib_defaults(plt)

fig, axes = plt.subplots(2, 1, figsize=(6, 6))

for ax in axes:
    ax.tick_params(axis='x', which='both', bottom='off', top='off')
    ax.tick_params(axis='y', which='both', left='off', right='off')

ylim = dict(grad=(-170, 200))
evoked.pick_types(meg='grad', exclude=[])
evoked.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
axes[0].set_title('Before autoreject')
evoked_clean.pick_types(meg='grad', exclude=[])
evoked_clean.plot(exclude=[], axes=axes[1], ylim=ylim)
axes[1].set_title('After autoreject')
plt.tight_layout()

# %%
# To top things up, we can also visualize the bad sensors for each trial using
# a heatmap.

ar.get_reject_log(epochs['Auditory/Left']).plot()
Exemplo n.º 8
0
def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float = 50.0, verbose: bool = False) -> list:
    """
    Applies local Autoreject to repair or reject bad epochs.

    Arguments:
        clean_epochs_ICA: list of Epochs after global Autoreject and ICA.
        strategy: more or less generous strategy to reject bad epochs: 'union'
          or 'intersection'. 'union' rejects bad epochs from subject 1 and
          subject 2 immediatly, whereas 'intersection' rejects shared bad epochs
          between subjects, tries to repare remaining bad epochs per subject,
          reject the non-reparable per subject and finally equalize epochs number
          between subjects. Set to 'union' by default.
        threshold: percentage of epochs removed that is accepted. Above
          this threshold, data are considered as a too shortened sample
          for further analyses. Set to 50.0 by default.
        verbose: option to plot data before and after AR, boolean, set to
          False by default. # use verbose = false until next Autoreject update

    Note:
        To reject or repair epochs, parameters are more or less conservative,
        see http://autoreject.github.io/generated/autoreject.AutoReject.

    Returns:
        cleaned_epochs_AR: list of Epochs after local Autoreject.
        dic_AR: dictionnary with the percentage of epochs rejection
          for each subject and for the intersection of the them.
    """
    bad_epochs_AR = []
    AR = []
    dic_AR = {}
    dic_AR['strategy'] = strategy
    dic_AR['threshold'] = threshold

    # defaults values for n_interpolates and consensus_percs
    n_interpolates = np.array([1, 4, 32])
    consensus_percs = np.linspace(0, 1.0, 11)
    # more generous values
    # n_interpolates = np.array([16, 32, 64])
    # n_interpolates = np.array([1, 4, 8, 16, 32, 64])
    # consensus_percs = np.linspace(0.5, 1.0, 11)

    for clean_epochs in cleaned_epochs_ICA:  # per subj

        picks = mne.pick_types(
            clean_epochs[0].info,
            meg=False,
            eeg=True,
            stim=False,
            eog=False,
            exclude=[])

        ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
                        thresh_method='random_search', random_state=42,
                        verbose='tqdm_notebook')
        AR.append(ar)

        # fitting AR to get bad epochs
        ar.fit(clean_epochs)
        reject_log = ar.get_reject_log(clean_epochs, picks=picks)
        bad_epochs_AR.append(reject_log)

    # taking bad epochs for min 1 subj (dyad)
    log1 = bad_epochs_AR[0]
    log2 = bad_epochs_AR[1]

    bad1 = np.where(log1.bad_epochs == True)
    bad2 = np.where(log2.bad_epochs == True)

    if strategy == 'union':
        bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist())))
    elif strategy == 'intersection':
        bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist())))
    else:
        TypeError('not good strategy input!')

    # storing the percentage of epochs rejection
    dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100)
    dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100)

    # picking good epochs for the two subj
    cleaned_epochs_AR = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        # keep a copy of the original data
        clean_epochs_ep = copy.deepcopy(clean_epochs)
        clean_epochs_ep = clean_epochs_ep.drop(indices=bad)
        # interpolating bads or removing epochs
        ar = AR[cleaned_epochs_ICA.index(clean_epochs)]
        clean_epochs_AR = ar.transform(clean_epochs_ep)
        cleaned_epochs_AR.append(clean_epochs_AR)

    if strategy == 'intersection':
        # equalizing epochs length between two participants
        mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

    dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100)
    if dic_AR['dyad'] >= threshold:
        TypeError('percentage of rejected epochs above threshold!')
    if verbose:
        print('%s percent of bad epochs' % dic_AR['dyad'])

    # Vizualisation before after AR
    evoked_before = []
    for clean_epochs in cleaned_epochs_ICA:  # per subj
        evoked_before.append(clean_epochs.average())

    evoked_after_AR = []
    for clean in cleaned_epochs_AR:
        evoked_after_AR.append(clean.average())

    if verbose:
        for i, j in zip(evoked_before, evoked_after_AR):
            fig, axes = plt.subplots(2, 1, figsize=(6, 6))
            for ax in axes:
                ax.tick_params(axis='x', which='both', bottom='off', top='off')
                ax.tick_params(axis='y', which='both', left='off', right='off')

            ylim = dict(grad=(-170, 200))
            i.pick_types(eeg=True, exclude=[])
            i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
            axes[0].set_title('Before autoreject')
            j.pick_types(eeg=True, exclude=[])
            j.plot(exclude=[], axes=axes[1], ylim=ylim)
            # Problème titre ne s'affiche pas pour le deuxieme axe !!!
            axes[1].set_title('After autoreject')
            plt.tight_layout()

    return cleaned_epochs_AR, dic_AR