Exemplo n.º 1
0
def test_ransac():
    """Some basic tests for ransac."""

    event_id = {'Visual/Left': 3}
    tmin, tmax = -0.2, 0.5

    events = mne.find_events(raw)
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        baseline=(None, 0), decim=8,
                        reject=None, preload=True)
    # normal case
    picks = mne.pick_types(epochs.info, meg='mag', eeg=False, stim=False,
                           eog=False, exclude=[])

    ransac = Ransac(picks=picks)
    epochs_clean = ransac.fit_transform(epochs)
    assert_true(len(epochs_clean) == len(epochs))
    # Pass numpy instead of epochs
    X = epochs.get_data()
    assert_raises(AttributeError, ransac.fit, X)
    #
    # should not contain both channel types
    picks = mne.pick_types(epochs.info, meg=True, eeg=False, stim=False,
                           eog=False, exclude=[])
    ransac = Ransac(picks=picks)
    assert_raises(ValueError, ransac.fit, epochs)
    #
    # should not contain other channel types.
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=True,
                           eog=False, exclude=[])
    ransac = Ransac(picks=picks)
    assert_raises(ValueError, ransac.fit, epochs)
Exemplo n.º 2
0
def test_ransac():
    """Some basic tests for ransac."""
    ransac = Ransac()

    event_id = {'Visual/Left': 3}
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)
    include = [u'EEG %03d' % i for i in range(1, 15)]
    picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False,
                           eog=False, include=include, exclude=[])
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=8,
                        reject=None, add_eeg_ref=False)

    X = epochs.get_data()
    assert_raises(ValueError, ransac.fit, X)
    # should not contain both channel types
    assert_raises(ValueError, ransac.fit, epochs)
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=True,
                           eog=False, include=include, exclude=[])
    # should not contain other channel types
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=8,
                        reject=None, add_eeg_ref=False)
    assert_raises(ValueError, ransac.fit, epochs)
    # now with only one channel type
    picks = mne.pick_types(raw.info, meg=False, eeg=True)
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=8,
                        reject=None, add_eeg_ref=False, preload=True)
    epochs_clean = ransac.fit_transform(epochs)
    assert_true(len(epochs_clean) == len(epochs))
Exemplo n.º 3
0
def mne_ransac_bad_channels(raw, overwrite=False):
    bids_chan_file = tb.fileparts(raw.filenames[0], '_channels.tsv', -4)
    ransacfile = tb.fileparts(raw.filenames[0], '_channels_ransac.tsv')
    if not pathlib.Path.is_file(pathlib.Path(ransacfile)) or overwrite:
        epochs = mne_epoch(raw).drop_bad()
        epochs.load_data()
        ransac = Ransac(random_state=999)
        ransac.fit(epochs)
        raw.info['bads'] = ransac.bad_chs_
        chans = pd.read_csv(bids_chan_file, delimiter='\t')
        chans.loc[chans.name.isin(ransac.bad_chs_), 'status'] = 'bad'
        pd.DataFrame.to_csv(chans, ransacfile, sep='\t')
    else:
        chans = pd.read_csv(ransacfile, delimiter='\t')
        raw.info['bads'] = list(chans.loc[chans['status'] == 'bad', 'name'])
    return raw
Exemplo n.º 4
0
def robust_avg_ref(epochs, ransac_parameters, apply=True):
    """
    Create a robust average reference by first interpolating the bad channels
    to exclude outliers. The reference is applied as a projection. Return
    epochs with reference projection applied if apply=True
    """
    ransac = Ransac(**ransac_parameters, verbose="tqdm")
    epochs_tmp = epochs.copy()
    epochs_tmp = ransac.fit_transform(epochs)
    set_eeg_reference(epochs_tmp, ref_channels="average", projection=True)
    robust_avg_proj = epochs_tmp.info["projs"][0]
    del epochs_tmp
    epochs.info["projs"].append(robust_avg_proj)
    if apply:
        epochs.apply_proj()
    return epochs
Exemplo n.º 5
0
def interpolate_bads(epochs, ransac_parameters):
    ransac = Ransac(**ransac_parameters, verbose="tqdm")
    evoked = epochs.average()  # for plotting
    epochs = ransac.fit_transform(epochs)
    evoked.info["bads"] = ransac.bad_chs_
    # plot evoked response with and without interpolated bads:
    fig, ax = plt.subplots(2)
    evoked.plot(exclude=[], axes=ax[0], show=False)
    ax[0].set_title('Before RANSAC')
    evoked = epochs.average()  # for plotting
    evoked.info["bads"] = ransac.bad_chs_
    evoked.plot(exclude=[], axes=ax[1], show=False)
    ax[1].set_title('After RANSAC')
    fig.tight_layout()
    fig.savefig(_out_folder / Path("interpolate_bad_channels.pdf"), dpi=800)
    plt.close()
    return epochs
Exemplo n.º 6
0
def test_ransac():
    """Some basic tests for ransac."""
    event_id = {'Visual/Left': 3}
    tmin, tmax = -0.2, 0.5

    events = mne.find_events(raw)
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        baseline=(None, 0), decim=8,
                        reject=None, preload=True)
    # normal case
    picks = mne.pick_types(epochs.info, meg='mag', eeg=False, stim=False,
                           eog=False, exclude=[])

    ransac = Ransac(picks=picks)
    epochs_clean = ransac.fit_transform(epochs)
    assert len(epochs_clean) == len(epochs)
    # Pass numpy instead of epochs
    X = epochs.get_data()
    pytest.raises(AttributeError, ransac.fit, X)
    #
    # should not contain both channel types
    picks = mne.pick_types(epochs.info, meg=True, eeg=False, stim=False,
                           eog=False, exclude=[])
    ransac = Ransac(picks=picks)
    pytest.raises(ValueError, ransac.fit, epochs)
    #
    # should not contain other channel types.
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=True,
                           eog=False, exclude=[])
    ransac = Ransac(picks=picks)
    pytest.raises(ValueError, ransac.fit, epochs)
Exemplo n.º 7
0
    def applyRansac(self):
        '''
        Implements RAndom SAmple Consensus (RANSAC) method to detect bad channels.

        Returns
        - - - -
        self.info['bads']: list with all bad channels detected by the RANSAC algorithm

        '''

        # select channels to display
        picks = mne.pick_types(self.info, eeg=True, exclude='bads')

        # use Ransac, interpolating bads and append bad channels to self.info['bads']
        ransac = Ransac(verbose=False, picks=picks, n_jobs=1)
        epochs_clean = ransac.fit_transform(self)
        print('The following electrodes are selected as bad by Ransac:')
        print('\n'.join(ransac.bad_chs_))
        self.info['bads'] = ransac.bad_chs_
Exemplo n.º 8
0
                detrend=0,
                preload=True)
picks = mne.pick_types(epochs.info,
                       meg='grad',
                       eeg=False,
                       stim=False,
                       eog=False,
                       include=[],
                       exclude=[])

###############################################################################
# We import ``Ransac`` and run the familiar ``fit_transform`` method.
from autoreject import Ransac  # noqa
from autoreject.utils import interpolate_bads  # noqa

ransac = Ransac(verbose='progressbar', picks=picks, n_jobs=1)
epochs_clean = ransac.fit_transform(epochs)

###############################################################################
# We can also get the list of bad channels computed by ``Ransac``.

print('\n'.join(ransac.bad_chs_))

###############################################################################
# Then we compute the ``evoked`` before and after interpolation.

evoked = epochs.average()
evoked_clean = epochs_clean.average()

###############################################################################
# We will manually mark the bad channels just for plotting.
    # Step 2: Estimate ICA
    # Apply HPF to all channels and a 60Hz Notch filter to eogs
    raw.filter(preprocess_options['ica_highpass'], None,
               skip_by_annotation=['boundary'])
    raw.filter(None, 40, picks=['eog'])
    raw.notch_filter([60, 120], picks=['eog'])

    # Make ICA Epochs
    epochs = mne.Epochs(raw, events, event_id=event_id,
                        tmin=preprocess_options['tmin'],
                        tmax=preprocess_options['tmax'],
                        baseline=(None, None), reject=None, preload=True)
    epochs.set_montage(bv_montage)

    # Autodetect bad channels
    ransac = Ransac(verbose=False, n_jobs=4,
                    min_corr=.60, unbroken_time=.6)
    ransac.fit(epochs.copy().filter(None, 40))
    if len(ransac.bad_chs_):
        for chan in ransac.bad_chs_:
            epochs.info['bads'].append(chan)
    print(f'RANSAC Bad Channels: {ransac.bad_chs_}')

    # Save RANSAC
    ransac_file = deriv_path / f'{sub}_task-{task}_ref_FCz_ransac.pkl'
    with open(ransac_file, 'wb') as f:
        pickle.dump(ransac, f)

    # Make RANSAC json
    json_info = {
        'Description': 'RANSAC object computed from epoched data',
        'parameteres': {
Exemplo n.º 10
0
def preprocessing(file):
    print("setting default objects")
    # avoid MNE being too verbose
    mne.set_log_level('ERROR')
    use_autoreject = True
    electrodes_snr = []
    ssveps = []
    snrs_faces = []
    snrs_ob = []
    Report_pp = []

    fooof_object = fooof.FOOOF(background_mode='fixed', )

    #Start file by file analysis
    print("reading raw EDF")
    raw = mne.io.read_raw_edf(file,
                              montage=mne.channels.read_montage('biosemi64'),
                              eog=[f'EXG{n}' for n in range(1, 9)])
    raw.info['subject_info'] = {
        'pid': file.name[7:11],
        'group': file.name[3:6],
        'filename': file.name
    }

    # make raw plots
    file_id = raw.info['subject_info']['filename']
    pid = raw.info['subject_info']['pid']
    path = ('FOOOF/%s' % (pid))

    if not os.path.exists(path):
        os.makedirs(path)

    plot = raw.plot_psd()
    plot.savefig('%s/rawplot_psd_%s.png' % (path, file_id))
    plot.clf()
    plot = raw.plot(duration=30., start=60., block=True)
    plot.savefig('%s/rawplot_%s.png' % (path, file_id))
    plot.clf()

    #Define Events
    events = mne.find_events(raw)
    events = events[events[:, 2] < 255, :]

    raw.info['events'] = mne.find_events(raw, stim_channel='STI 014')

    epochs = (mne.Epochs(
        raw,
        events,
        tmin=-1,
        tmax=15,
        preload=True,
    ).set_eeg_reference().load_data().resample(256).apply_proj())

    #Evoked Plots
    evoked = epochs.average()
    fig = mne.viz.plot_evoked(evoked, spatial_colors=True, selectable=True)
    fig.savefig('%s/evoked_plot_%s.png' % (path, file_id))
    fig.clf()

    if use_autoreject:
        #assess sensors
        print("Running ransac autoreject")
        picks = mne.pick_types(
            epochs.info,
            eeg=True,
        )
        ransac = Ransac(verbose=False, picks=picks, n_jobs=1)
        epochs = ransac.fit_transform(epochs)

    ssvep = ssvepy.Ssvep(
        epochs,
        [1.2, 6],
        compute_tfr=False,
        fmin=0.5,
        fmax=45,
        noisebandwidth=3,
    )

    ssvep.psd = (ssvep.psd.groupby(
        ssvep.psd.coords["epoch"] < 200).mean("epoch").rename(
            {"epoch": "faces"}))
    ssvep.snr = ssvep._get_snr(ssvep.psd.coords["frequency"].data)
    ssvep.original_psd = ssvep.psd.copy()

    import scipy.linalg

    fooof_data = xr.Dataset({
        key: xr.full_like(ssvep.psd.interp(frequency=[1.2, 6]), np.nan)
        for key in ['peak_amp', 'r_squared', 'offset', 'slope']
    })

    for faces_present, cond_data in ssvep.psd.groupby("faces"):

        for channel, data in cond_data.groupby("channel"):

            try:
                print("Running fooof on channel: {}".format(channel))
                fooof_object.fit(data.coords["frequency"].data,
                                 data.data.squeeze())
            except scipy.linalg.LinAlgError:
                print(" Fooof_object failed for %s" % (file_id))
                continue

            ssvep.psd.loc[faces_present,
                          channel, :] -= (10**fooof_object._bg_fit)

            fooof_data.peak_amp.loc[faces_present, channel,
                                    1.2] = match_peaks(1.2,
                                                       fooof_object,
                                                       max_harmonics=3)
            fooof_data.peak_amp.loc[faces_present, channel,
                                    6] = match_peaks(6,
                                                     fooof_object,
                                                     max_harmonics=3)
            fooof_data.r_squared.loc[faces_present,
                                     channel, :] = fooof_object.r_squared_
            fooof_data.offset.loc[
                faces_present, channel, :] = fooof_object.background_params_[0]
            fooof_data.slope.loc[
                faces_present, channel, :] = fooof_object.background_params_[1]

    # average the SNRs:
    snr_data = xr.concat(
        (ssvep.snr.interp(
            frequency=[1.2, 2.4, 3.6]).mean("frequency").expand_dims(dim={
                "frequency": [1.2]
            }).transpose("faces", "channel", "frequency"),
         ssvep.snr.interp(frequency=[6, 12, 18]).mean("frequency").expand_dims(
             dim={
                 "frequency": [6]
             }).transpose("faces", "channel", "frequency")), "frequency")

    # convert to df
    participant_df = (pd.merge(snr_data.to_dataframe(name='snr').reset_index(),
                               fooof_data.to_dataframe().reset_index(),
                               on=['frequency', 'faces', 'channel']).assign(
                                   participant=raw.info['subject_info']['pid'],
                                   group=raw.info['subject_info']['group'],
                               ))

    participant_df.to_csv('%s/output_%s.csv' % (path, file_id), header=True)
Exemplo n.º 11
0
                detrend=0,
                preload=True)
picks = mne.pick_types(epochs.info,
                       meg='grad',
                       eeg=False,
                       stim=False,
                       eog=False,
                       include=[],
                       exclude=[])

###############################################################################
# We import ``Ransac`` and run the familiar ``fit_transform`` method.
from autoreject import Ransac  # noqa
from autoreject.utils import interpolate_bads  # noqa

ransac = Ransac(verbose=True, picks=picks, n_jobs=1)
epochs_clean = ransac.fit_transform(epochs)

###############################################################################
# We can also get the list of bad channels computed by ``Ransac``.

print('\n'.join(ransac.bad_chs_))

###############################################################################
# Then we compute the ``evoked`` before and after interpolation.

evoked = epochs.average()
evoked_clean = epochs_clean.average()

###############################################################################
# We will manually mark the bad channels just for plotting.
Exemplo n.º 12
0
raw.info['projs'] = list()  # remove proj, don't proj while interpolating
epochs = Epochs(raw, events, event_id, tmin, tmax,
                baseline=(None, 0), reject=None,
                verbose=False, detrend=0, preload=True)
picks = mne.pick_types(epochs.info, meg='grad', eeg=False,
                       stim=False, eog=False,
                       include=[], exclude=[])


###############################################################################
# We import ``Ransac`` and run the familiar ``fit_transform`` method.
from autoreject import Ransac  # noqa
from autoreject.utils import interpolate_bads  # noqa

ransac = Ransac(verbose='progressbar', picks=picks, n_jobs=1)
epochs_clean = ransac.fit_transform(epochs)

###############################################################################
# We can also get the list of bad channels computed by ``Ransac``.

print('\n'.join(ransac.bad_chs_))

###############################################################################
# Then we compute the ``evoked`` before and after interpolation.

evoked = epochs.average()
evoked_clean = epochs_clean.average()

###############################################################################
# We will manually mark the bad channels just for plotting.
Exemplo n.º 13
0
def run_epochs(subject):
    """Run epochs.

    Transform raw data into epochs and match with experimental conditions.
    Reject bad epochs using a (high) threshold (preliminary rejection).

    Parameters
    ----------
    *subject: string
        The participant reference

    Save the resulting *-epo.fif file in the '3_epochs' directory.

    """
    # Load filtered data
    input_path = root + '/2_rawfilter/' + subject + '-raw.fif'
    raw = mne.io.read_raw_fif(input_path)

    # Load e-prime df
    eprime_df = data_path + subject + '/' + subject + fname['eprime']
    eprime = pd.read_csv(eprime_df, skiprows=1, sep='\t')
    eprime = eprime[behav_var]

    # Revome training rows after pause for TNT
    eprime = eprime.drop(eprime.index[[97, 195, 293]])
    eprime.reset_index(inplace=True)

    # Find stim presentation in raw data
    events = mne.find_events(raw, stim_channel='STI 014')

    # Compensate for delay (as measured manually with photodiod)
    events[:, 0] += int(.015 * raw.info['sfreq'])

    # Keep only Obj Pres triggers
    events = events[events[:, 2] == 7, :]

    # Match stim presentation with conditions from eprime df
    for i in range(len(events)):
        if eprime['Cond1'][i] == 'Think':
            if eprime['Cond2'][i] == 'Emotion':
                events[i, 2] = 1
            else:
                events[i, 2] = 2
        elif eprime['Cond1'][i] == 'No-Think':
            if eprime['Cond2'][i] == 'Emotion':
                events[i, 2] = 3
            else:
                events[i, 2] = 4
        else:
            events[i, 2] = 5
    # Set event id
    id = {'Think/EMO': 1, 'Think/NEU': 2, 'No-Think/EMO': 3, 'No-Think/NEU': 4}

    # Epoch raw data
    tmin, tmax = -1.5, 4
    epochs = mne.Epochs(raw, events, id, tmin, tmax, preload=True)

    # Save epochs
    epochs.save(root + '/3_epochs/' + subject + '-epo.fif')

    epochs.info['projs'] = list()  # remove proj

    ransac = Ransac(verbose='progressbar', n_jobs=1)
    epochs_clean = ransac.fit_transform(epochs)

    evoked = epochs.average().crop(-0.2, 3.0)\
                   .apply_baseline(baseline=(None, 0))

    evoked_clean = epochs_clean.average()\
        .crop(-0.2, 3.0).apply_baseline(baseline=(None, 0))

    evoked.info['bads'] = ransac.bad_chs_
    evoked_clean.info['bads'] = ransac.bad_chs_

    # Evoked differences
    fig, axes = plt.subplots(2, 1, figsize=(6, 6))

    for ax in axes:
        ax.tick_params(axis='x', which='both', bottom='off', top='off')
        ax.tick_params(axis='y', which='both', left='off', right='off')

    ylim = dict(grad=(-200, 200))
    evoked.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
    axes[0].set_title('Before RANSAC')
    evoked_clean.plot(exclude=[], axes=axes[1], ylim=ylim)
    axes[1].set_title('After RANSAC')
    fig.tight_layout()
    plt.savefig(root + '/3_epochs/' + subject + '-evoked.png')
    plt.clf()
    plt.close()

    # Heatmap
    ch_names = [epochs.ch_names[ii] for ii in ransac.picks][7::10]
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    ax.imshow(ransac.bad_log.T, cmap='Reds', interpolation='nearest')
    ax.grid(False)
    ax.set_xlabel('Trials', size=15)
    ax.set_ylabel('Sensors', size=15)
    plt.setp(ax, yticks=range(7, len(ransac.picks), 10), yticklabels=ch_names)
    ax.tick_params(axis=u'both', which=u'both', length=0)
    fig.tight_layout(rect=[None, None, None, 1.1])
    ax.set_title('Bad sensors', size=25)
    plt.savefig(root + '/3_epochs/' + subject + '-heatmap.png')
    plt.clf()
    plt.close()

    # Save epochs
    epochs_clean.save(root + '/3_epochs/' + subject + 'clean-epo.fif')