def test_global_autoreject():
    """Test global autoreject."""

    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    picks = mne.pick_types(raw.info,
                           meg=True,
                           eeg=True,
                           stim=False,
                           eog=True,
                           exclude=[])
    # raise error if preload is false
    epochs = mne.Epochs(raw,
                        events,
                        event_id,
                        tmin,
                        tmax,
                        picks=picks,
                        baseline=(None, 0),
                        reject=None,
                        preload=False)

    # Test get_rejection_thresholds.
    reject1 = get_rejection_threshold(epochs, decim=1, random_state=42)
    reject2 = get_rejection_threshold(epochs, decim=1, random_state=42)
    reject3 = get_rejection_threshold(epochs, decim=2, random_state=42)
    tols = dict(eeg=5e-6, eog=5e-6, grad=10e-12, mag=5e-15)
    assert_true(reject1, isinstance(reject1, dict))
    for key, value in list(reject1.items()):
        assert_equal(reject1[key], reject2[key])
        assert_true(abs(reject1[key] - reject3[key]) < tols[key])
Beispiel #2
0
def run_ica(subject):
    raw_fname = op.join(meg_dir, subject, f'{subject}_audvis-filt_raw_sss.fif')
    annot_fname = op.join(meg_dir, subject, f'{subject}_audvis-annot.fif')
    ica_name = op.join(meg_dir, subject, f'{subject}_audvis-ica.fif')

    raw = mne.io.read_raw_fif(raw_fname)
    if op.isfile(annot_fname):
        annot = mne.read_annotations(annot_fname)
        raw.set_annotations(annot)
    # Because the data were Maxwell filtered,
    # higher threshold would be reasonable.
    n_components = 0.999
    ica = mne.preprocessing.ICA(n_components=n_components)
    picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True)
    # use autoreject (global) to find the rejection threshold
    tstep = 1
    events = mne.make_fixed_length_events(raw, duration=tstep)
    # do not use baseline correction because autoreject (global) would be used
    even_epochs = mne.Epochs(raw, events, baseline=None, tmin=0, tmax=tstep)
    print(f'Run autoreject (global) on Raw for {subject}')
    reject = get_rejection_threshold(even_epochs,
                                     ch_types=['mag', 'grad'],
                                     verbose=False)
    ica.fit(raw, picks=picks, reject=reject, tstep=tstep)
    ica.save(ica_name)
    print(f'Finished computing ICA for {subject}')
Beispiel #3
0
def ICA_fit(epochs: list, n_components: int, method: str, fit_params: dict,
            random_state: int) -> list:
    """
    Computes global Autorejection to fit Independent Components Analysis
    on Epochs, for each participant.

    Pre requisite : install autoreject
    https://api.github.com/repos/autoreject/autoreject/zipball/master

    Arguments:
        epochs: list of 2 Epochs objects (for each participant).
          Epochs_S1 and Epochs_S2 correspond to a condition and can result
          from the concatenation of Epochs from different experimental
          realisations of the condition (Epochs are MNE objects).
        n_components: the number of principal components that are passed to the
          ICA algorithm during fitting, int. For a first estimation,
          n_components can be set to 15.
        method: the ICA method used, str 'fastica', 'infomax' or 'picard'.
          'Fastica' is the most frequently used. Use the fit_params argument to set
           additional parameters. Specifically, if you want Extended Infomax, set
           method=’infomax’ and fit_params=dict(extended=True) (this also works
           for method=’picard’). 
        fit_params: Additional parameters passed to the ICA estimator
           as specified by method. None by default.
        random_state: the parameter used to compute random distributions
          for ICA calulation, int or None. It can be useful to fix
          random_state value to have reproducible results. For 15
          components, random_state can be set to 97, for 20 components to 0
          for example.

    Note:
        If Autoreject and ICA take too much time, change the decim value
        (see MNE documentation).
        Please filter the Epochs between 2 and 30 Hz before ICA fit
        (mne.Epochs.filter(epoch, 2, 30, method='fir')).

    Returns:
        icas: list of Independant Components for each participant (IC are MNE
          objects, see MNE documentation for more details).
    """
    icas = []
    for epoch in epochs:
        # per subj
        # applying AR to find global rejection threshold
        reject = get_rejection_threshold(epoch, ch_types='eeg')
        # if very long, can change decim value
        print('The rejection dictionary is %s' % reject)

        # fitting ICA on filt_raw after AR
        ica = ICA(n_components=n_components,
                  method=method,
                  fit_params=fit_params,
                  random_state=random_state).fit(epoch)
        # take bad channels into account in ICA fit
        epoch_all_ch = mne.Epochs.copy(epoch)
        epoch_all_ch.info['bads'] = []
        icas.append(ica.fit(epoch_all_ch, reject=reject, tstep=1))

    return icas
def main_preprocessing(currentDir, sub, winSize, filtOpt, chbFiles):
    myStructPath = os.path.join(currentDir, "myLabels/labels0/labels0.npy")

    labelStruct = np.load(myStructPath,
                          mmap_mode=None,
                          allow_pickle=True,
                          fix_imports=True,
                          encoding='ASCII')

    subTable = labelStruct[sub - 1]
    del labelStruct

    nFiles = len(subTable)
    labelsList = list()
    dataList = list()

    for kFile in range(len(chbFiles)):  # careful here change to nFiles

        newkFile = chbFiles[kFile] - 1
        epochs, labels = makeMyEpochs(currentDir, subTable, sub, newkFile,
                                      winSize, filtOpt)
        nEp = len(labels)

        print(" Performing Thresholding")
        # perform Thresholding
        #        nEpreject = nEp
        #        ind = rnd.randint(0 , nEp , nEpreject) # we get the threshold according to 50 random epochs
        reject = get_rejection_threshold(epochs)
        epochs.drop_bad(reject=reject, verbose=False)
        newLabels = labels[epochs.selection]
        labelsList.extend(newLabels)

        # loop over epochs for normalizing/standardizing and feature extraction
        print("rescaling epochs and extracting features")
        nEpochs = len(epochs)
        if len(newLabels) != nEpochs:
            raise ValueError(
                ' number of labels different from number of epochs')

        for k in range(nEpochs):  # here careful make it nEpochs
            start = time.perf_counter_ns()
            rawData = np.array(epochs[k].get_data()[0, :, :])
            # measure mean and variance before scaling
            myMean = np.mean(rawData, axis=1)
            myVar = np.var(rawData, axis=1)
            normData = scale(rawData, axis=1)
            print("Subject {}: file: {}/{}: extracting features for epoch {} / {}".format(sub , kFile + 1,\
                  len(chbFiles) , k+1 , nEpochs ))
            featVector = myfeature_extraction_functions.myFeaturesExtractor1(
                normData, myMean, myVar)
            #featVector = np.zeros((1 , 21*37)) # for debugging
            dataList.append(featVector)
            end = time.perf_counter_ns()
            print("one feature vector calculation time : {}".format(end -
                                                                    start))

    return np.vstack(dataList), np.asarray(labelsList)
Beispiel #5
0
def _get_global_reject_epochs(raw, events, event_id, epochs_params):
    epochs = mne.Epochs(
        raw, events, event_id=event_id, proj=False,
        **epochs_params)
    epochs.load_data()
    epochs.pick_types(meg=True)
    epochs.apply_proj()
    reject = get_rejection_threshold(epochs, decim=8)
    return reject
Beispiel #6
0
def test_global_autoreject():
    """Test global autoreject."""
    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    picks = mne.pick_types(raw.info,
                           meg=True,
                           eeg=True,
                           stim=False,
                           eog=True,
                           exclude=[])
    # raise error if preload is false
    epochs = mne.Epochs(raw,
                        events,
                        event_id,
                        tmin,
                        tmax,
                        picks=picks,
                        baseline=(None, 0),
                        reject=None,
                        preload=False)

    # Test get_rejection_thresholds.
    reject1 = get_rejection_threshold(epochs, decim=1, random_state=42)
    reject2 = get_rejection_threshold(epochs, decim=1, random_state=42)
    reject3 = get_rejection_threshold(epochs, decim=2, random_state=42)
    tols = dict(eeg=5e-6, eog=5e-6, grad=10e-12, mag=5e-15)
    if platform.system().lower().startswith("win"):  # pragma: no cover
        # XXX: When testing on Windows, the precision seemed to be lower. Why?
        tols = dict(eeg=9e-5, eog=9e-5, grad=10e-12, mag=5e-15)
    assert reject1, isinstance(reject1, dict)
    for key, value in list(reject1.items()):
        assert reject1[key] == reject2[key]
        assert abs(reject1[key] - reject3[key]) < tols[key]

    reject = get_rejection_threshold(epochs, decim=4, ch_types='eeg')
    assert 'eog' not in reject
    assert 'eeg' in reject
    pytest.raises(ValueError,
                  get_rejection_threshold,
                  epochs,
                  decim=4,
                  ch_types=5)
Beispiel #7
0
def _get_global_reject_ssp(raw):
    eog_epochs = mne.preprocessing.create_eog_epochs(raw)
    if len(eog_epochs) >= 5:
        reject_eog = get_rejection_threshold(eog_epochs, decim=8)
        del reject_eog['eog']
    else:
        reject_eog = None

    ecg_epochs = mne.preprocessing.create_ecg_epochs(raw)
    if len(ecg_epochs) >= 5:
        reject_ecg = get_rejection_threshold(ecg_epochs, decim=8)
    else:
        reject_eog = None

    if reject_eog is None:
        reject_eog = reject_ecg
    if reject_ecg is None:
        reject_ecg = reject_eog
    return reject_eog, reject_ecg
Beispiel #8
0
def ICA_fit(epochs, n_components, method, random_state):
    """
    Computes global Autorejection to fit Independant Components Analysis
    on Epochs, for each subject.

    Pre requisite : install autoreject
    https://api.github.com/repos/autoreject/autoreject/zipball/master

    Arguments:
        epochs: list of 2 Epochs objects (for each subject).
          Epochs_S1 and Epochs_S2 correspond to a condition and can result
          from the concatenation of epochs from different occurences of the
          condition across experiments.
          Epochs are MNE objects (data are stored in an array of shape
          (n_epochs, n_channels, n_times) and info is a dictionnary
          sampling parameters).
        n_components: the number of principal components that are passed to the
          ICA algorithm during fitting, int. For a first estimation,
          n_components can be set to 15.
        method: the ICA method used, str 'fastica', 'infomax' or 'picard'.
          Fastica' is the most frequently used.
        random_state: the parameter used to compute random distributions
          for ICA calulation, int or None. It can be useful to fix
          random_state value to have reproducible results. For 15
          components, random_state can be set to 97 for example.

    Note:
        If Autoreject and ICA take too much time, change the decim value
        (see MNE documentation).

    Returns:
        icas: list of independant components for each subject. IC are MNE
          objects, see MNE documentation for more details.
    """
    icas = []
    for epoch in epochs:
        # per subj
        # applying AR to find global rejection threshold
        reject = get_rejection_threshold(epoch, ch_types='eeg')
        # if very long, can change decim value
        print('The rejection dictionary is %s' % reject)

        # fitting ICA on filt_raw after AR
        ica = ICA(n_components=n_components,
                  method=method,
                  random_state=random_state)
        # take bad channels into account in ICA fit
        epoch_all_ch = mne.Epochs.copy(epoch)
        epoch_all_ch.info['bads'] = []
        icas.append(ica.fit(epoch_all_ch, reject=reject, tstep=1))

    return icas
def _get_global_reject_ssp(raw):
    if 'eog' in raw:
        eog_epochs = mne.preprocessing.create_eog_epochs(raw)
    else:
        eog_epochs = []
    if len(eog_epochs) >= 5:
        reject_eog = get_rejection_threshold(eog_epochs, decim=8)
        del reject_eog['eog']  # we don't want to reject eog based on eog
    else:
        reject_eog = None

    ecg_epochs = mne.preprocessing.create_ecg_epochs(raw)
    # we will always have an ECG as long as there are magnetometers
    if len(ecg_epochs) >= 5:
        reject_ecg = get_rejection_threshold(ecg_epochs, decim=8)
        # here we want the eog
    else:
        reject_ecg = None

    if reject_eog is None and reject_ecg is not None:
        reject_eog = {k: v for k, v in reject_ecg.items() if k != 'eog'}
    return reject_eog, reject_ecg
Beispiel #10
0
def _get_global_reject_epochs(raw):
    duration = 3.
    events = mne.make_fixed_length_events(
        raw, id=3000, start=0, duration=duration)

    epochs = mne.Epochs(
        raw, events, event_id=3000, tmin=0, tmax=duration, proj=False,
        baseline=None, reject=None)
    epochs.apply_proj()
    epochs.load_data()
    epochs.pick_types(meg=True)
    reject = get_rejection_threshold(epochs, decim=8)
    return reject
Beispiel #11
0
def test_global_autoreject():
    """Test global autoreject."""

    event_id = None
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False,
                           eog=True, exclude=[])
    # raise error if preload is false
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0),
                        reject=None, preload=False)

    # Test get_rejection_thresholds.
    reject1 = get_rejection_threshold(epochs, decim=1, random_state=42)
    reject2 = get_rejection_threshold(epochs, decim=1, random_state=42)
    reject3 = get_rejection_threshold(epochs, decim=2, random_state=42)
    tols = dict(eeg=5e-6, eog=5e-6, grad=10e-12, mag=5e-15)
    assert_true(reject1, isinstance(reject1, dict))
    for key, value in list(reject1.items()):
        assert_equal(reject1[key], reject2[key])
        assert_true(abs(reject1[key] - reject3[key]) < tols[key])
Beispiel #12
0
def test_autoreject():
    """Some basic tests for autoreject."""

    event_id = {'Visual/Left': 3}
    tmin, tmax = -0.2, 0.5
    events = mne.find_events(raw)

    include = [u'EEG %03d' % i for i in range(1, 15)]
    picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
                           eog=False, include=include, exclude=[])
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                        picks=picks, baseline=(None, 0), decim=8,
                        reject=None, add_eeg_ref=False, preload=True)

    X = epochs.get_data()
    n_epochs, n_channels, n_times = X.shape
    X = X.reshape(n_epochs, -1)

    ar = GlobalAutoReject()
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_channels=n_channels)
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_times=n_times)
    assert_raises(ValueError, ar.fit, X)
    ar = GlobalAutoReject(n_channels=n_channels, n_times=n_times,
                          thresh=40e-6)
    ar.fit(X)

    reject = get_rejection_threshold(epochs)
    assert_true(reject, isinstance(reject, dict))

    param_name = 'thresh'
    param_range = np.linspace(40e-6, 200e-6, 10)
    assert_raises(ValueError, validation_curve, ar, X, None,
                  param_name, param_range)

    ar = LocalAutoReject()
    assert_raises(NotImplementedError, validation_curve, ar, epochs, None,
                  param_name, param_range)

    ar = LocalAutoRejectCV()
    assert_raises(ValueError, ar.fit, X)
    assert_raises(ValueError, ar.transform, X)
    assert_raises(ValueError, ar.transform, epochs)

    epochs.load_data()
    assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf')
    for method in ['random_search', 'bayesian_optimization']:
        compute_thresholds(epochs, method=method)
def clean_with_ica(epochs, subject, hand, control, config, show_ica=False):
    """Clean epochs with ICA.

    Parameters
    ----------
    epochs : mne epoch object
        Epoched, filtered, and autorejected eeg data

    Returns
    ----------
    ica : mne epoch object
        ICA object from mne
    epochs  : mne epoch object
        ica cleaned epochs

    """

    picks = mne.pick_types(epochs.info,
                           meg=False,
                           eeg=True,
                           eog=False,
                           stim=False,
                           exclude='bads')
    ica = mne.preprocessing.ICA(n_components=None,
                                method="picard",
                                verbose=False)
    # Get the rejection threshold using autoreject
    if config['use_previous_ica']:
        read_path = Path(__file__).parents[2] / config['previous_ica']
        data = data = dd.io.load(str(read_path))
        ica_previous = data[subject]['ica'][hand][control]
        ica_previous.apply(epochs)
    else:
        reject_threshold = get_rejection_threshold(epochs)
        ica.fit(epochs, picks=picks, reject=reject_threshold)
        # mne pipeline to detect artifacts
        ica.detect_artifacts(epochs, eog_criterion=range(2))
        ica.apply(epochs)  # Apply the ICA

    if show_ica:
        ica.plot_components(inst=epochs)

    return epochs, ica
Beispiel #14
0
def create_epochs_from_raw(raw, events, metadata=None, meg_channels=True, tmin=-0.1, tmax=0.4, decim=10, reject=None, baseline=(None, 0)):
    """
    Create epochs for decoding

    :param raw:
    :type raw: mne.io.BaseRaw
    :param reject: Either of:
                'auto_global': Automatically compute rejection threshold based on all data
                'auto_channel': Automatically compute rejection threshold for each channel
                'default': Use default values
                None: no rejection
                A dict with the entries 'mag'/'grad'/both: set these rejection parameters (if mag/grad unspecified: no rejection for these channels)

    :param events: The definition of epochs and their event IDs (#epochs x 3 matrix)
    """

    events = np.array(events)

    picks_meg = mne.pick_types(raw.info, meg=meg_channels, eeg=False, eog=False, stim=False, exclude='bads')

    if reject == 'auto_global':
        epochs = mne.Epochs(raw, events=events, tmin=tmin, tmax=tmax, proj=True, picks=picks_meg, baseline=baseline)
        ep_reject = get_rejection_threshold(epochs, decim=2)

    elif reject == 'auto_channel':
        print('Auto-detecting rejection thresholds per channel...')
        epochs = mne.Epochs(raw, events=events, tmin=tmin, tmax=tmax, proj=True, picks=picks_meg, baseline=baseline)
        ep_reject = compute_thresholds(epochs, picks=picks_meg, method='random_search', augment=False, verbose='progressbar')

    else:
        ep_reject = _get_rejection_thresholds(reject, meg_channels)

    epochs = mne.Epochs(raw, events=events, metadata=metadata, tmin=tmin, tmax=tmax, proj=True, decim=decim,
                        picks=picks_meg, reject=ep_reject, preload=True, baseline=baseline)

    # print("\nEvenr IDs:")
    # for cond, eid in epochs.event_id.items():
    #     print("Condition '%s' (event_id = %d): %d events" % (cond, eid, len(epochs[cond])))

    return epochs
def autoreject_threshold(ft_file, out_file, raw_file=0):
    import sys
    import mne
    import autoreject
    import json

    PY3 = sys.version_info[0] == 3

    if PY3:
        string_types = str,
    else:
        string_types = basestring,

    if isinstance(raw_file, string_types):
        info = mne.io.read_info(raw_file)
        epochs = mne.read_epochs_fieldtrip(ft_file, info)
    else:
        epochs = mne.read_epochs_fieldtrip(ft_file, info=None)

    reject = autoreject.get_rejection_threshold(epochs)

    with open(out_file, 'w') as f:
        json.dump(reject, f)
def preprocess_raw(subject):
    raw_file = rawfile_of(subject)
    raw = mne.io.read_raw_edf(raw_file)
    raw.crop(tmin=60, tmax=540)  # 8mn of signal to be comparable with CAM-can
    raw.load_data().pick_channels(list(common_chs))
    raw.resample(250)  # max common sfreq

    # autoreject global (instead of clip at +-800uV proposed by Freiburg)
    duration = 3.
    events = mne.make_fixed_length_events(raw,
                                          id=3,
                                          start=0,
                                          duration=duration)
    epochs = mne.Epochs(raw,
                        events,
                        event_id=3,
                        tmin=0,
                        tmax=duration,
                        proj=False,
                        baseline=None,
                        reject=None)
    reject = get_rejection_threshold(epochs, decim=1)
    return raw, reject
Beispiel #17
0
def test_fnirs():
    """Test that autoreject runs on fNIRS data."""
    raw = mne.io.read_raw_nirx(
        os.path.join(mne.datasets.fnirs_motor.data_path(), 'Participant-1'))
    raw.crop(tmax=1200)
    raw = mne.preprocessing.nirs.optical_density(raw)
    raw = mne.preprocessing.nirs.beer_lambert_law(raw)
    events, _ = mne.events_from_annotations(raw,
                                            event_id={
                                                '1.0': 1,
                                                '2.0': 2,
                                                '3.0': 3
                                            })
    event_dict = {'Control': 1, 'Tapping/Left': 2, 'Tapping/Right': 3}
    epochs = mne.Epochs(raw,
                        events,
                        event_id=event_dict,
                        tmin=-5,
                        tmax=15,
                        proj=True,
                        baseline=(None, 0),
                        preload=True,
                        detrend=None,
                        verbose=True)
    # Test autoreject
    ar = AutoReject()
    assert len(epochs) == 37
    epochs_clean = ar.fit_transform(epochs)
    assert len(epochs_clean) < len(epochs)
    # Test threshold extraction
    reject = get_rejection_threshold(epochs)
    print(reject)
    assert "hbo" in reject.keys()
    assert "hbr" in reject.keys()
    assert reject["hbo"] < 0.001  # This is a very high value as sanity check
    assert reject["hbr"] < 0.001
    assert reject["hbr"] > 0.0
Beispiel #18
0
def save_epochs(p, subjects, in_names, in_numbers, analyses, out_names,
                out_numbers, must_match, decim, run_indices):
    """Generate epochs from raw data based on events

    Can only complete after preprocessing is complete.

    Parameters
    ----------
    p : instance of Parameters
        Analysis parameters.
    subjects : list of str
        Subject names to analyze (e.g., ['Eric_SoP_001', ...]).
    in_names : list of str
        Names of input events.
    in_numbers : list of list of int
        Event numbers (in scored event files) associated with each name.
    analyses : list of str
        Lists of analyses of interest.
    out_names : list of list of str
        Event types to make out of old ones.
    out_numbers : list of list of int
        Event numbers to convert to (e.g., [[1, 1, 2, 3, 3], ...] would create
        three event types, where the first two and last two event types from
        the original list get collapsed over).
    must_match : list of int
        Indices from the original in_names that must match in event counts
        before collapsing. Should eventually be expanded to allow for
        ratio-based collapsing.
    decim : int | list of int
        Amount to decimate.
    run_indices : array-like | None
        Run indices to include.
    """
    in_names = np.asanyarray(in_names)
    old_dict = dict()
    for n, e in zip(in_names, in_numbers):
        old_dict[n] = e

    # let's do some sanity checks
    if len(in_names) != len(in_numbers):
        raise RuntimeError('in_names (%d) must have same length as '
                           'in_numbers (%d)' %
                           (len(in_names), len(in_numbers)))
    if np.any(np.array(in_numbers) <= 0):
        raise ValueError('in_numbers must all be > 0')
    if len(out_names) != len(out_numbers):
        raise RuntimeError('out_names must have same length as out_numbers')
    for name, num in zip(out_names, out_numbers):
        num = np.array(num)
        if len(name) != len(np.unique(num[num > 0])):
            raise RuntimeError('each entry in out_names must have length '
                               'equal to the number of unique elements in the '
                               'corresponding entry in out_numbers:\n%s\n%s' %
                               (name, np.unique(num[num > 0])))
        if len(num) != len(in_names):
            raise RuntimeError('each entry in out_numbers must have the same '
                               'length as in_names')
        if (np.array(num) == 0).any():
            raise ValueError('no element of out_numbers can be zero')

    ch_namess = list()
    drop_logs = list()
    sfreqs = set()
    for si, subj in enumerate(subjects):
        if p.disp_files:
            print('  Loading raw files for subject %s.' % subj)
        epochs_dir = op.join(p.work_dir, subj, p.epochs_dir)
        if not op.isdir(epochs_dir):
            os.mkdir(epochs_dir)
        evoked_dir = op.join(p.work_dir, subj, p.inverse_dir)
        if not op.isdir(evoked_dir):
            os.mkdir(evoked_dir)
        # read in raw files
        raw_names = get_raw_fnames(p, subj, 'pca', False, False,
                                   run_indices[si])
        first_samps = []
        last_samps = []
        for raw_fname in raw_names:
            raw = read_raw_fif(raw_fname, preload=False)
            first_samps.append(raw._first_samps[0])
            last_samps.append(raw._last_samps[-1])
        raw = [read_raw_fif(fname, preload=False) for fname in raw_names]
        _fix_raw_eog_cals(raw)  # EOG epoch scales might be bad!
        raw = concatenate_raws(raw)
        # read in events
        events = _read_events(p, subj, run_indices[si], raw)
        this_decim = _handle_decim(decim[si], raw.info['sfreq'])
        new_sfreq = raw.info['sfreq'] / this_decim
        if p.disp_files:
            print('    Epoching data (decim=%s -> sfreq=%0.1f Hz).' %
                  (this_decim, new_sfreq))
        if new_sfreq not in sfreqs:
            if len(sfreqs) > 0:
                warnings.warn('resulting new sampling frequency %s not equal '
                              'to previous values %s' % (new_sfreq, sfreqs))
            sfreqs.add(new_sfreq)
        epochs_fnames, evoked_fnames = get_epochs_evokeds_fnames(
            p, subj, analyses)
        mat_file, fif_file = epochs_fnames
        if p.autoreject_thresholds:
            assert len(p.autoreject_types) > 0
            assert all(a in ('mag', 'grad', 'eeg', 'ecg', 'eog')
                       for a in p.autoreject_types)
            from autoreject import get_rejection_threshold
            print('    Computing autoreject thresholds', end='')
            rtmin = p.reject_tmin if p.reject_tmin is not None else p.tmin
            rtmax = p.reject_tmax if p.reject_tmax is not None else p.tmax
            temp_epochs = Epochs(raw,
                                 events,
                                 event_id=None,
                                 tmin=rtmin,
                                 tmax=rtmax,
                                 baseline=_get_baseline(p),
                                 proj=True,
                                 reject=None,
                                 flat=None,
                                 preload=True,
                                 decim=this_decim,
                                 reject_by_annotation=p.reject_epochs_by_annot)
            kwargs = dict()
            if 'verbose' in get_args(get_rejection_threshold):
                kwargs['verbose'] = False
            new_dict = get_rejection_threshold(temp_epochs, **kwargs)
            use_reject = dict()
            msgs = list()
            for k in p.autoreject_types:
                msgs.append('%s=%d %s' % (k, DEFAULTS['scalings'][k] *
                                          new_dict[k], DEFAULTS['units'][k]))
                use_reject[k] = new_dict[k]
            print(': ' + ', '.join(msgs))
            hdf5_file = fif_file.replace('-epo.fif', '-reject.h5')
            assert hdf5_file.endswith('.h5')
            write_hdf5(hdf5_file, use_reject, overwrite=True)
        else:
            use_reject = _handle_dict(p.reject, subj)
        # create epochs
        flat = _handle_dict(p.flat, subj)
        use_reject, use_flat = _restrict_reject_flat(use_reject, flat, raw)
        epochs = Epochs(raw,
                        events,
                        event_id=old_dict,
                        tmin=p.tmin,
                        tmax=p.tmax,
                        baseline=_get_baseline(p),
                        reject=use_reject,
                        flat=use_flat,
                        proj=p.epochs_proj,
                        preload=True,
                        decim=this_decim,
                        on_missing=p.on_missing,
                        reject_tmin=p.reject_tmin,
                        reject_tmax=p.reject_tmax,
                        reject_by_annotation=p.reject_epochs_by_annot)
        del raw
        if epochs.events.shape[0] < 1:
            epochs.plot_drop_log()
            raise ValueError('No valid epochs')
        drop_logs.append(epochs.drop_log)
        ch_namess.append(epochs.ch_names)
        # only kept trials that were not dropped
        sfreq = epochs.info['sfreq']
        # now deal with conditions to save evoked
        if p.disp_files:
            print('    Matching trial counts and saving data to disk.')
        for var, name in ((out_names, 'out_names'), (out_numbers,
                                                     'out_numbers'),
                          (must_match, 'must_match'), (evoked_fnames,
                                                       'evoked_fnames')):
            if len(var) != len(analyses):
                raise ValueError('len(%s) (%s) != len(analyses) (%s)' %
                                 (name, len(var), len(analyses)))
        for analysis, names, numbers, match, fn in zip(analyses, out_names,
                                                       out_numbers, must_match,
                                                       evoked_fnames):
            # do matching
            numbers = np.asanyarray(numbers)
            nn = numbers[numbers >= 0]
            new_numbers = []
            for num in numbers:
                if num > 0 and num not in new_numbers:
                    # Eventually we could relax this requirement, but not
                    # having it in place is likely to cause people pain...
                    if any(num < n for n in new_numbers):
                        raise RuntimeError('each list of new_numbers must be '
                                           ' monotonically increasing')
                    new_numbers.append(num)
            new_numbers = np.array(new_numbers)
            in_names_match = in_names[match]
            # use some variables to allow safe name re-use
            offset = max(epochs.events[:, 2].max(), new_numbers.max()) + 1
            safety_str = '__mnefun_copy__'
            assert len(new_numbers) == len(names)  # checked above
            if p.match_fun is None:
                # first, equalize trial counts (this will make a copy)
                e = epochs[list(in_names[numbers > 0])]
                if len(in_names_match) > 1:
                    e.equalize_event_counts(in_names_match)

                # second, collapse relevant types
                for num, name in zip(new_numbers, names):
                    collapse = [
                        x for x in in_names[num == numbers] if x in e.event_id
                    ]
                    combine_event_ids(e,
                                      collapse,
                                      {name + safety_str: num + offset},
                                      copy=False)
                for num, name in zip(new_numbers, names):
                    e.events[e.events[:, 2] == num + offset, 2] -= offset
                    e.event_id[name] = num
                    del e.event_id[name + safety_str]
            else:  # custom matching
                e = p.match_fun(epochs.copy(), analysis, nn, in_names_match,
                                names)

            # now make evoked for each out type
            evokeds = list()
            n_standard = 0
            kinds = ['standard']
            if p.every_other:
                kinds += ['even', 'odd']
            for kind in kinds:
                for name in names:
                    this_e = e[name]
                    if kind == 'even':
                        this_e = this_e[::2]
                    elif kind == 'odd':
                        this_e = this_e[1::2]
                    else:
                        assert kind == 'standard'
                    if len(this_e) > 0:
                        ave = this_e.average(picks='all')
                        stde = this_e.standard_error(picks='all')
                        if kind != 'standard':
                            ave.comment += ' %s' % (kind, )
                            stde.comment += ' %s' % (kind, )
                        evokeds.append(ave)
                        evokeds.append(stde)
                        if kind == 'standard':
                            n_standard += 2
            write_evokeds(fn, evokeds)
            naves = [
                str(n) for n in sorted(
                    set([evoked.nave for evoked in evokeds[:n_standard]]))
            ]
            naves = ', '.join(naves)
            if p.disp_files:
                print('      Analysis "%s": %s epochs / condition' %
                      (analysis, naves))

        if p.disp_files:
            print('    Saving epochs to disk.')
        if 'mat' in p.epochs_type:
            spio.savemat(mat_file,
                         dict(epochs=epochs.get_data(),
                              events=epochs.events,
                              sfreq=sfreq,
                              drop_log=epochs.drop_log),
                         do_compression=True,
                         oned_as='column')
        if 'fif' in p.epochs_type:
            epochs.save(fif_file, **_get_epo_kwargs())

    if p.plot_drop_logs:
        for subj, drop_log in zip(subjects, drop_logs):
            plot_drop_log(drop_log, threshold=p.drop_thresh, subject=subj)
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-'
                           'eve.fif')

raw = io.read_raw_fif(raw_fname, preload=True)
events = mne.read_events(event_fname)

include = []
picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False,
                       eog=True, include=include, exclude='bads')
epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
                    picks=picks, baseline=(None, 0), preload=True,
                    reject=None, verbose=False, detrend=1)

###############################################################################
# Now we get the rejection dictionary

from autoreject import get_rejection_threshold  # noqa
reject = get_rejection_threshold(epochs, decim=1)

###############################################################################
# and print it

print('The rejection dictionary is %s' % reject)

###############################################################################
# Finally, the cleaned epochs
epochs.drop_bad(reject=reject)
epochs.average().plot()
Beispiel #20
0
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    preload=True,
                    reject=None,
                    verbose=False,
                    detrend=1)

###############################################################################
# Now we get the rejection dictionary

from autoreject import get_rejection_threshold  # noqa

# We can use the `decim` parameter to only take every nth time slice.
# This speeds up the computation time. Note however that for low sampling
# rates and high decimation parameters, you might not detect "peaky artifacts"
# (with a fast timecourse) in your data. A low amount of decimation however is
# almost always beneficial at no decrease of accuracy.
reject = get_rejection_threshold(epochs, decim=2)

###############################################################################
# and print it

print('The rejection dictionary is %s' % reject)

###############################################################################
# Finally, the cleaned epochs
epochs.drop_bad(reject=reject)
epochs.average().plot()
Beispiel #21
0
def do_preprocessing_combined(p, subjects, run_indices):
    """Do preprocessing on all raw files together.

    Calculates projection vectors to use to clean data.

    Parameters
    ----------
    p : instance of Parameters
        Analysis parameters.
    subjects : list of str
        Subject names to analyze (e.g., ['Eric_SoP_001', ...]).
    run_indices : array-like | None
        Run indices to include.
    """
    drop_logs = list()
    for si, subj in enumerate(subjects):
        proj_nums = _proj_nums(p, subj)
        ecg_channel = _handle_dict(p.ecg_channel, subj)
        flat = _handle_dict(p.flat, subj)
        if p.disp_files:
            print('  Preprocessing subject %g/%g (%s).' %
                  (si + 1, len(subjects), subj))
        pca_dir = _get_pca_dir(p, subj)
        bad_file = get_bad_fname(p, subj, check_exists=False)

        # Create SSP projection vectors after marking bad channels
        raw_names = get_raw_fnames(p, subj, 'sss', False, False,
                                   run_indices[si])
        empty_names = get_raw_fnames(p, subj, 'sss', 'only')
        for r in raw_names + empty_names:
            if not op.isfile(r):
                raise NameError('File not found (' + r + ')')

        fir_kwargs, old_kwargs = _get_fir_kwargs(p.fir_design)
        if isinstance(p.auto_bad, float):
            print('    Creating post SSS bad channel file:\n'
                  '        %s' % bad_file)
            # do autobad
            raw = _raw_LRFCP(raw_names,
                             p.proj_sfreq,
                             None,
                             None,
                             p.n_jobs_fir,
                             p.n_jobs_resample,
                             list(),
                             None,
                             p.disp_files,
                             method='fir',
                             filter_length=p.filter_length,
                             apply_proj=False,
                             force_bads=False,
                             l_trans=p.hp_trans,
                             h_trans=p.lp_trans,
                             phase=p.phase,
                             fir_window=p.fir_window,
                             pick=True,
                             skip_by_annotation='edge',
                             **fir_kwargs)
            events = fixed_len_events(p, raw)
            rtmin = p.reject_tmin \
                if p.reject_tmin is not None else p.tmin
            rtmax = p.reject_tmax \
                if p.reject_tmax is not None else p.tmax
            # do not mark eog channels bad
            meg, eeg = 'meg' in raw, 'eeg' in raw
            picks = pick_types(raw.info,
                               meg=meg,
                               eeg=eeg,
                               eog=False,
                               exclude=[])
            assert p.auto_bad_flat is None or isinstance(p.auto_bad_flat, dict)
            assert p.auto_bad_reject is None or \
                isinstance(p.auto_bad_reject, dict) or \
                p.auto_bad_reject == 'auto'
            if p.auto_bad_reject == 'auto':
                print('    Auto bad channel selection active. '
                      'Will try using Autoreject module to '
                      'compute rejection criterion.')
                try:
                    from autoreject import get_rejection_threshold
                except ImportError:
                    raise ImportError('     Autoreject module not installed.\n'
                                      '     Noisy channel detection parameter '
                                      '     not defined. To use autobad '
                                      '     channel selection either define '
                                      '     rejection criteria or install '
                                      '     Autoreject module.\n')
                print('    Computing thresholds.\n', end='')
                temp_epochs = Epochs(raw,
                                     events,
                                     event_id=None,
                                     tmin=rtmin,
                                     tmax=rtmax,
                                     baseline=_get_baseline(p),
                                     proj=True,
                                     reject=None,
                                     flat=None,
                                     preload=True,
                                     decim=1)
                kwargs = dict()
                if 'verbose' in get_args(get_rejection_threshold):
                    kwargs['verbose'] = False
                reject = get_rejection_threshold(temp_epochs, **kwargs)
                reject = {kk: vv for kk, vv in reject.items()}
            elif p.auto_bad_reject is None and p.auto_bad_flat is None:
                raise RuntimeError('Auto bad channel detection active. Noisy '
                                   'and flat channel detection '
                                   'parameters not defined. '
                                   'At least one criterion must be defined.')
            else:
                reject = p.auto_bad_reject
            if 'eog' in reject.keys():
                reject.pop('eog', None)
            epochs = Epochs(raw,
                            events,
                            None,
                            tmin=rtmin,
                            tmax=rtmax,
                            baseline=_get_baseline(p),
                            picks=picks,
                            reject=reject,
                            flat=p.auto_bad_flat,
                            proj=True,
                            preload=True,
                            decim=1,
                            reject_tmin=rtmin,
                            reject_tmax=rtmax)
            # channel scores from drop log
            drops = Counter([ch for d in epochs.drop_log for ch in d])
            # get rid of non-channel reasons in drop log
            scores = {
                kk: vv
                for kk, vv in drops.items() if kk in epochs.ch_names
            }
            ch_names = np.array(list(scores.keys()))
            # channel scores expressed as percentile and rank ordered
            counts = (100 * np.array([scores[ch] for ch in ch_names], float) /
                      len(epochs.drop_log))
            order = np.argsort(counts)[::-1]
            # boolean array masking out channels with <= % epochs dropped
            mask = counts[order] > p.auto_bad
            badchs = ch_names[order[mask]]
            if len(badchs) > 0:
                # Make sure we didn't get too many bad MEG or EEG channels
                for m, e, thresh in zip(
                    [True, False], [False, True],
                    [p.auto_bad_meg_thresh, p.auto_bad_eeg_thresh]):
                    picks = pick_types(epochs.info, meg=m, eeg=e, exclude=[])
                    if len(picks) > 0:
                        ch_names = [epochs.ch_names[pp] for pp in picks]
                        n_bad_type = sum(ch in ch_names for ch in badchs)
                        if n_bad_type > thresh:
                            stype = 'meg' if m else 'eeg'
                            raise RuntimeError('Too many bad %s channels '
                                               'found: %s > %s' %
                                               (stype, n_bad_type, thresh))

                print('    The following channels resulted in greater than '
                      '{:.0f}% trials dropped:\n'.format(p.auto_bad * 100))
                print(badchs)
                with open(bad_file, 'w') as f:
                    f.write('\n'.join(badchs))
        if not op.isfile(bad_file):
            print('    Clearing bad channels (no file %s)' %
                  op.sep.join(bad_file.split(op.sep)[-3:]))
            bad_file = None

        ecg_t_lims = _handle_dict(p.ecg_t_lims, subj)
        ecg_f_lims = p.ecg_f_lims

        ecg_eve = op.join(pca_dir, 'preproc_ecg-eve.fif')
        ecg_epo = op.join(pca_dir, 'preproc_ecg-epo.fif')
        ecg_proj = op.join(pca_dir, 'preproc_ecg-proj.fif')
        all_proj = op.join(pca_dir, 'preproc_all-proj.fif')

        get_projs_from = _handle_dict(p.get_projs_from, subj)
        if get_projs_from is None:
            get_projs_from = np.arange(len(raw_names))
        pre_list = [
            r for ri, r in enumerate(raw_names) if ri in get_projs_from
        ]

        projs = list()
        raw_orig = _raw_LRFCP(raw_names=pre_list,
                              sfreq=p.proj_sfreq,
                              l_freq=None,
                              h_freq=None,
                              n_jobs=p.n_jobs_fir,
                              n_jobs_resample=p.n_jobs_resample,
                              projs=projs,
                              bad_file=bad_file,
                              disp_files=p.disp_files,
                              method='fir',
                              filter_length=p.filter_length,
                              force_bads=False,
                              l_trans=p.hp_trans,
                              h_trans=p.lp_trans,
                              phase=p.phase,
                              fir_window=p.fir_window,
                              pick=True,
                              skip_by_annotation='edge',
                              **fir_kwargs)

        # Apply any user-supplied extra projectors
        if p.proj_extra is not None:
            if p.disp_files:
                print('    Adding extra projectors from "%s".' % p.proj_extra)
            projs.extend(read_proj(op.join(pca_dir, p.proj_extra)))

        proj_kwargs, p_sl = _get_proj_kwargs(p)
        #
        # Calculate and apply ERM projectors
        #
        if not p.cont_as_esss:
            if any(proj_nums[2]):
                assert proj_nums[2][2] == 0  # no EEG projectors for ERM
                if len(empty_names) == 0:
                    raise RuntimeError('Cannot compute empty-room projectors '
                                       'from continuous raw data')
                if p.disp_files:
                    print('    Computing continuous projectors using ERM.')
                # Use empty room(s), but processed the same way
                projs.extend(_compute_erm_proj(p, subj, projs, 'sss',
                                               bad_file))
            else:
                cont_proj = op.join(pca_dir, 'preproc_cont-proj.fif')
                _safe_remove(cont_proj)

        #
        # Calculate and apply the ECG projectors
        #
        if any(proj_nums[0]):
            if p.disp_files:
                print('    Computing ECG projectors...', end='')
            raw = raw_orig.copy()

            raw.filter(ecg_f_lims[0],
                       ecg_f_lims[1],
                       n_jobs=p.n_jobs_fir,
                       method='fir',
                       filter_length=p.filter_length,
                       l_trans_bandwidth=0.5,
                       h_trans_bandwidth=0.5,
                       phase='zero-double',
                       fir_window='hann',
                       skip_by_annotation='edge',
                       **old_kwargs)
            raw.add_proj(projs)
            raw.apply_proj()
            find_kwargs = dict()
            if 'reject_by_annotation' in get_args(find_ecg_events):
                find_kwargs['reject_by_annotation'] = True
            elif len(raw.annotations) > 0:
                print('    WARNING: ECG event detection will not make use of '
                      'annotations, please update MNE-Python')
            # We've already filtered the data channels above, but this
            # filters the ECG channel
            ecg_events = find_ecg_events(raw,
                                         999,
                                         ecg_channel,
                                         0.,
                                         ecg_f_lims[0],
                                         ecg_f_lims[1],
                                         qrs_threshold='auto',
                                         return_ecg=False,
                                         **find_kwargs)[0]
            use_reject, use_flat = _restrict_reject_flat(
                _handle_dict(p.ssp_ecg_reject, subj), flat, raw)
            ecg_epochs = Epochs(raw,
                                ecg_events,
                                999,
                                ecg_t_lims[0],
                                ecg_t_lims[1],
                                baseline=None,
                                reject=use_reject,
                                flat=use_flat,
                                preload=True)
            print('  obtained %d epochs from %d events.' %
                  (len(ecg_epochs), len(ecg_events)))
            if len(ecg_epochs) >= 20:
                write_events(ecg_eve, ecg_epochs.events)
                ecg_epochs.save(ecg_epo, **_get_epo_kwargs())
                desc_prefix = 'ECG-%s-%s' % tuple(ecg_t_lims)
                pr = compute_proj_wrap(ecg_epochs,
                                       p.proj_ave,
                                       n_grad=proj_nums[0][0],
                                       n_mag=proj_nums[0][1],
                                       n_eeg=proj_nums[0][2],
                                       desc_prefix=desc_prefix,
                                       **proj_kwargs)
                assert len(pr) == np.sum(proj_nums[0][::p_sl])
                write_proj(ecg_proj, pr)
                projs.extend(pr)
            else:
                plot_drop_log(ecg_epochs.drop_log)
                raw.plot(events=ecg_epochs.events)
                raise RuntimeError('Only %d/%d good ECG epochs found' %
                                   (len(ecg_epochs), len(ecg_events)))
            del raw, ecg_epochs, ecg_events
        else:
            _safe_remove([ecg_proj, ecg_eve, ecg_epo])

        #
        # Next calculate and apply the EOG projectors
        #
        for idx, kind in ((1, 'EOG'), (3, 'HEOG'), (4, 'VEOG')):
            _compute_add_eog(p, subj, raw_orig, projs, proj_nums[idx], kind,
                             pca_dir, flat, proj_kwargs, old_kwargs, p_sl)
        del proj_nums

        # save the projectors
        write_proj(all_proj, projs)

        #
        # Look at raw_orig for trial DQs now, it will be quick
        #
        raw_orig.filter(p.hp_cut,
                        p.lp_cut,
                        n_jobs=p.n_jobs_fir,
                        method='fir',
                        filter_length=p.filter_length,
                        l_trans_bandwidth=p.hp_trans,
                        phase=p.phase,
                        h_trans_bandwidth=p.lp_trans,
                        fir_window=p.fir_window,
                        skip_by_annotation='edge',
                        **fir_kwargs)
        raw_orig.add_proj(projs)
        raw_orig.apply_proj()
        # now let's epoch with 1-sec windows to look for DQs
        events = fixed_len_events(p, raw_orig)
        reject = _handle_dict(p.reject, subj)
        use_reject, use_flat = _restrict_reject_flat(reject, flat, raw_orig)
        epochs = Epochs(raw_orig,
                        events,
                        None,
                        p.tmin,
                        p.tmax,
                        preload=False,
                        baseline=_get_baseline(p),
                        reject=use_reject,
                        flat=use_flat,
                        proj=True)
        try:
            epochs.drop_bad()
        except AttributeError:  # old way
            epochs.drop_bad_epochs()
        drop_logs.append(epochs.drop_log)
        del raw_orig
        del epochs
    if p.plot_drop_logs:
        for subj, drop_log in zip(subjects, drop_logs):
            plot_drop_log(drop_log, p.drop_thresh, subject=subj)
                       eog=True,
                       include=include,
                       exclude='bads')
epochs = mne.Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    preload=True,
                    reject=None,
                    verbose=False,
                    detrend=1)

###############################################################################
# Now we get the rejection dictionary

from autoreject import get_rejection_threshold  # noqa
reject = get_rejection_threshold(epochs)

###############################################################################
# and print it

print('The rejection dictionary is %s' % reject)

###############################################################################
# Finally, the cleaned epochs
epochs.drop_bad(reject=reject)
epochs.average().plot()
def run_epoch(subject_id):
    subject = "sub_%03d" % subject_id
    print("processing subject: %s" % subject)
    in_path = op.join(
        data_path, "EEG_Process")  #make map yourself in cwd called 'Subjects'
    process_path = op.join(
        data_path,
        "EEG_Process")  #make map yourself in cwd called 'EEG_Process'
    raw_list = list()
    events_list = list()

    for run in range(1, 2):
        fname = op.join(in_path, 'sub_%03d_raw.fif' % (subject_id, ))
        raw = mne.io.read_raw_fif(fname, preload=True)
        print("  S %s - R %s" % (subject, run))

        #####ICA#####
        ica = ICA(random_state=97, n_components=15)
        picks = mne.pick_types(raw.info,
                               eeg=True,
                               eog=True,
                               stim=False,
                               exclude='bads')
        ica.fit(raw, picks=picks)
        raw.load_data()

        #make epochs around stimuli events
        fname_events = op.join(process_path,
                               'events_%03d-eve.fif' % (subject_id, ))
        delay = int(round(0.0345 * raw.info['sfreq']))
        events = mne.read_events(fname_events)
        events[:, 0] = events[:, 0] + delay
        events_list.append(events)
        epochs = mne.Epochs(raw,
                            events,
                            events_id,
                            tmin=-0.2,
                            tmax=0.5,
                            proj=True,
                            picks=picks,
                            baseline=(None, 0),
                            preload=False,
                            reject=None)

        #get EOG epochs
        eog_epochs = create_eog_epochs(raw, tmin=-.5, tmax=.5, preload=False)
        n_max_eog = 3  # use max 2 components
        eog_epochs.load_data()
        eog_epochs.apply_baseline((None, None))
        eog_inds, scores_eog = ica.find_bads_eog(eog_epochs)
        print('    Found %d EOG indices' % (len(eog_inds), ))
        ica.exclude.extend(eog_inds[:n_max_eog])
        eog_epochs.average()
        del eog_epochs

        #apply ICA on epochs
        epochs.load_data()
        ica.apply(epochs)
        reject = get_rejection_threshold(epochs, random_state=97)
        epochs.drop_bad(reject=reject)
        print('  Dropped %0.1f%% of epochs' % (epochs.drop_log_stats(), ))
        #epochs.plot(picks=('Oz'), title='epochs, electrode Oz')
        #save epochs
        epochs.save(
            op.join(process_path, "sub_%03d_raw-epo.fif" % (subject_id, )))
Beispiel #24
0
def mne_get_rejection_threshold(raw):
    epochs = mne_epoch(raw)
    epochs.drop_bad()
    return get_rejection_threshold(epochs)
Beispiel #25
0
def run_ica(sub, eog_channel, ecg_channel, reject, flat, autoreject_interpolation,
            autoreject_threshold, save_plots, figures_path, pscripts_path):
    info = sub.load_info()

    ica_dict = ut.dict_filehandler(sub.name, f'ica_components_{sub.p_preset}',
                                   pscripts_path,
                                   onlyread=True)

    raw = sub.load_filtered()
    if raw.info['highpass'] < 1:
        raw.filter(l_freq=1., h_freq=None)
    epochs = sub.load_epochs()
    picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
                           stim=False, exclude=sub.bad_channels)

    if not isdir(join(figures_path, 'ica')):
        makedirs(join(figures_path, 'ica'))

    # Calculate ICA
    ica = mne.preprocessing.ICA(n_components=25, method='fastica', random_state=8)

    if autoreject_interpolation:
        # Avoid calculation of rejection-threshold again on already cleaned epochs, therefore creating new epochs
        simulated_events = mne.make_fixed_length_events(raw, duration=5)
        simulated_epochs = mne.Epochs(raw, simulated_events, baseline=None, picks=picks, tmin=0, tmax=2)
        reject = ar.get_rejection_threshold(simulated_epochs)
        print(f'Autoreject Rejection-Threshold: {reject}')
    elif autoreject_threshold:
        reject = ut.autoreject_handler(sub.name, epochs, sub.p["highpass"], sub.p["lowpass"], sub.pr.pscripts_path,
                                       overwrite_ar=False, only_read=True)
        print(f'Autoreject Rejection-Threshold: {reject}')
    else:
        print(f'Chosen Rejection-Threshold: {reject}')

    ica.fit(raw, picks, reject=reject, flat=flat,
            reject_by_annotation=True)

    if sub.name in ica_dict and ica_dict[sub.name] != [] and ica_dict[sub.name]:
        indices = ica_dict[sub.name]
        ica.exclude += indices
        print(f'{indices} added to ica.exclude from ica_components.py')
        sub.save_ica(ica)

        comp_list = []
        for c in range(ica.n_components):
            comp_list.append(c)
        fig1 = ica.plot_components(picks=comp_list, title=sub.name, show=False)
        fig3 = ica.plot_sources(raw, picks=comp_list[:12], start=150, stop=200, title=sub.name, show=False)
        fig4 = ica.plot_sources(raw, picks=comp_list[12:], start=150, stop=200, title=sub.name, show=False)
        fig5 = ica.plot_overlay(epochs.average(), title=sub.name, show=False)
        if save_plots and save_plots != 'false':

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_comp' + '_' + sub.pr.p_preset + '.jpg')
            fig1.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_src' + '_' + sub.pr.p_preset + '_0.jpg')
            fig3.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_src' + '_' + sub.pr.p_preset + '_1.jpg')
            fig4.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')
            if not exists(join(figures_path, 'ica/evoked_overlay')):
                makedirs(join(figures_path, 'ica/evoked_overlay'))
            save_path = join(figures_path, 'ica/evoked_overlay', sub.name +
                             '_ica_ovl' + '_' + sub.pr.p_preset + '.jpg')
            fig5.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

        else:
            print('Not saving plots; set "save_plots" to "True" to save')

    elif 'EEG 001' in info['ch_names']:
        eeg_picks = mne.pick_types(raw.info, meg=True, eeg=True, eog=True,
                                   stim=False, exclude=sub.bad_channels)

        eog_epochs = mne.preprocessing.create_eog_epochs(raw, picks=eeg_picks,
                                                         reject=reject, flat=flat, ch_name=eog_channel)
        ecg_epochs = mne.preprocessing.create_ecg_epochs(raw, picks=eeg_picks,
                                                         reject=reject, flat=flat, ch_name=ecg_channel)

        if len(eog_epochs) != 0:
            eog_indices, eog_scores = ica.find_bads_eog(eog_epochs, ch_name=eog_channel)
            ica.exclude.extend(eog_indices)
            print('EOG-Components: ', eog_indices)
            if len(eog_indices) != 0:
                # Plot EOG-Plots
                fig3 = ica.plot_scores(eog_scores, title=sub.name + '_eog', show=False)
                fig2 = ica.plot_properties(eog_epochs, eog_indices, psd_args={'fmax': sub.p["lowpass"]},
                                           image_args={'sigma': 1.}, show=False)
                fig7 = ica.plot_overlay(eog_epochs.average(), exclude=eog_indices, title=sub.name + '_eog',
                                        show=False)
                if save_plots and save_plots != 'false':
                    for f in fig2:
                        save_path = join(figures_path, 'ica', sub.name +
                                         '_ica_prop_eog' + '_' + sub.pr.p_preset +
                                         f'_{fig2.index(f)}.jpg')
                        f.savefig(save_path, dpi=300)
                        print('figure: ' + save_path + ' has been saved')

                    save_path = join(figures_path, 'ica', sub.name +
                                     '_ica_scor_eog' + '_' + sub.pr.p_preset + '.jpg')
                    fig3.savefig(save_path, dpi=300)
                    print('figure: ' + save_path + ' has been saved')

                    save_path = join(figures_path, 'ica', sub.name +
                                     '_ica_ovl_eog' + '_' + sub.pr.p_preset + '.jpg')
                    fig7.savefig(save_path, dpi=300)
                    print('figure: ' + save_path + ' has been saved')

        if len(ecg_epochs) != 0:
            ecg_indices, ecg_scores = ica.find_bads_ecg(ecg_epochs, ch_name=ecg_channel)
            ica.exclude.extend(ecg_indices)
            print('ECG-Components: ', ecg_indices)
            print(len(ecg_indices))
            if len(ecg_indices) != 0:
                # Plot ECG-Plots
                fig4 = ica.plot_scores(ecg_scores, title=sub.name + '_ecg', show=False)
                fig9 = ica.plot_properties(ecg_epochs, ecg_indices, psd_args={'fmax': sub.p["lowpass"]},
                                           image_args={'sigma': 1.}, show=False)
                fig8 = ica.plot_overlay(ecg_epochs.average(), exclude=ecg_indices, title=sub.name + '_ecg',
                                        show=False)
                if save_plots and save_plots != 'false':
                    for f in fig9:
                        save_path = join(figures_path, 'ica', sub.name +
                                         '_ica_prop_ecg' + '_' + sub.pr.p_preset +
                                         f'_{fig9.index(f)}.jpg')
                        f.savefig(save_path, dpi=300)
                        print('figure: ' + save_path + ' has been saved')

                    save_path = join(figures_path, 'ica', sub.name +
                                     '_ica_scor_ecg' + '_' + sub.pr.p_preset + '.jpg')
                    fig4.savefig(save_path, dpi=300)
                    print('figure: ' + save_path + ' has been saved')

                    save_path = join(figures_path, 'ica', sub.name +
                                     '_ica_ovl_ecg' + '_' + sub.pr.p_preset + '.jpg')
                    fig8.savefig(save_path, dpi=300)
                    print('figure: ' + save_path + ' has been saved')

        sub.save_ica(ica)

        # Reading and Writing ICA-Components to a .py-file
        exes = ica.exclude
        indices = []
        for i in exes:
            indices.append(int(i))

        ut.dict_filehandler(sub.name, f'ica_components_{sub.pr.p_preset}', pscripts_path,
                            values=indices, overwrite=True)

        # Plot ICA integrated
        comp_list = []
        for c in range(ica.n_components):
            comp_list.append(c)
        fig1 = ica.plot_components(picks=comp_list, title=sub.name, show=False)
        fig5 = ica.plot_sources(raw, picks=comp_list[:12], start=150, stop=200, title=sub.name, show=False)
        fig6 = ica.plot_sources(raw, picks=comp_list[12:], start=150, stop=200, title=sub.name, show=False)
        fig10 = ica.plot_overlay(epochs.average(), title=sub.name, show=False)

        if save_plots and save_plots != 'false':
            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_comp' + '_' + sub.pr.p_preset + '.jpg')
            fig1.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')
            if not exists(join(figures_path, 'ica/evoked_overlay')):
                makedirs(join(figures_path, 'ica/evoked_overlay'))
            save_path = join(figures_path, 'ica/evoked_overlay', sub.name +
                             '_ica_ovl' + '_' + sub.pr.p_preset + '.jpg')
            fig10.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_src' + '_' + sub.pr.p_preset + '_0.jpg')
            fig5.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_src' + '_' + sub.pr.p_preset + '_1.jpg')
            fig6.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

        else:
            print('Not saving plots; set "save_plots" to "True" to save')

    # No EEG was acquired during the measurement,
    # components have to be selected manually in the ica_components.py
    else:
        print('No EEG-Channels to read EOG/EEG from')
        meg_picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
                                   stim=False, exclude=sub.bad_channels)
        ecg_epochs = mne.preprocessing.create_ecg_epochs(raw, picks=meg_picks,
                                                         reject=reject, flat=flat)

        if len(ecg_epochs) != 0:
            ecg_indices, ecg_scores = ica.find_bads_ecg(ecg_epochs)
            print('ECG-Components: ', ecg_indices)
            if len(ecg_indices) != 0:
                fig4 = ica.plot_scores(ecg_scores, title=sub.name + '_ecg', show=False)
                fig5 = ica.plot_properties(ecg_epochs, ecg_indices, psd_args={'fmax': sub.p["lowpass"]},
                                           image_args={'sigma': 1.}, show=False)
                fig6 = ica.plot_overlay(ecg_epochs.average(), exclude=ecg_indices, title=sub.name + '_ecg',
                                        show=False)

                save_path = join(figures_path, 'ica', sub.name +
                                 '_ica_scor_ecg' + '_' + sub.pr.p_preset + '.jpg')
                fig4.savefig(save_path, dpi=300)
                print('figure: ' + save_path + ' has been saved')
                for f in fig5:
                    save_path = join(figures_path, 'ica', sub.name +
                                     '_ica_prop_ecg' + '_' + sub.pr.p_preset
                                     + f'_{fig5.index(f)}.jpg')
                    f.savefig(save_path, dpi=300)
                    print('figure: ' + save_path + ' has been saved')
                save_path = join(figures_path, 'ica', sub.name +
                                 '_ica_ovl_ecg' + '_' + sub.pr.p_preset + '.jpg')
                fig6.savefig(save_path, dpi=300)
                print('figure: ' + save_path + ' has been saved')

        ut.dict_filehandler(sub.name, f'ica_components_{sub.pr.p_preset}', pscripts_path, values=[])

        sub.save_ica(ica)
        comp_list = []
        for c in range(ica.n_components):
            comp_list.append(c)
        fig1 = ica.plot_components(picks=comp_list, title=sub.name, show=False)
        fig2 = ica.plot_sources(raw, picks=comp_list[:12], start=150, stop=200, title=sub.name, show=False)
        fig3 = ica.plot_sources(raw, picks=comp_list[12:], start=150, stop=200, title=sub.name, show=False)

        if save_plots and save_plots != 'false':
            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_comp' + '_' + sub.pr.p_preset + '.jpg')
            fig1.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_src' + '_' + sub.pr.p_preset + '_0.jpg')
            fig2.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

            save_path = join(figures_path, 'ica', sub.name +
                             '_ica_src' + '_' + sub.pr.p_preset + '_1.jpg')
            fig3.savefig(save_path, dpi=300)
            print('figure: ' + save_path + ' has been saved')

        else:
            print('Not saving plots; set "save_plots" to "True" to save')
Beispiel #26
0
     myMap , myValidMap = myFunc.getTimes2LabelsMap(windowLength , winStep , fileLength , numSeizure , seizuresInfo ,\
                                       postSeizureMargin , predHorizon , preIctalMargin , \
                                       interIctalMargin , endingFileMargin)
     #                        n , d =myMap.shape
     #                        myMap = np.hstack((myMap , np.ones((n ,1)) ))
     finalMap = myFunc.getFinalMap(myValidMap, myMap)
     subDict["seizurefree_data"].append({
         "fileName": relativeFilePath,
         "map": finalMap
     })
     # now we get the autoreject threshold
     times = myValidMap[:, 0]
     events = times2events(times)
     myEpochs = mne.Epochs(raw, events, event_id=None, tmin=0, tmax=windowLength , preload=True ,\
                           baseline = None,decim=decim)
     reject = get_rejection_threshold(myEpochs)
     # append rejecttion threshold
     subDict['autoreject_threshold'].append(reject["eeg"])
 else:
     """==================== one seizure or more ====================="""
     # first we get the seizures info
     for kSeizure in np.arange(numSeizure):
         startTime = myFunc.getTimeFromString(
             txtLines[seizureLineInd + 2 * kSeizure + 1])
         endTime = myFunc.getTimeFromString(
             txtLines[seizureLineInd + 2 * kSeizure + 2])
         seizuresInfo.append(np.array([startTime, endTime]))
     ############################################postSeizureValid
     myMap , myvalidMap , postSeizure , postSeizureValid  = myFunc.getTimes2LabelsMap(windowLength ,\
                     winStep , fileLength , numSeizure , seizuresInfo ,postSeizureMargin , \
                     predHorizon ,preIctalMargin ,  interIctalMargin , endingFileMargin)
    n_good_channels = len(
        mne.pick_types(epochs.info, eeg=True, eog=False, exclude='bads'))
    print(f'# of good channels: {n_good_channels}')
    n_thresh_channels = (n_good_channels*preprocess_options['perc_good_chans'])

    # Exclude voltages > +/100 microvolts
    max_voltage = np.abs(epoch_data).max(axis=2)
    ext_val_nchan = (
        np.sum(max_voltage > preprocess_options['ext_val_thresh'], axis=1))
    ext_val_bad = ext_val_nchan >= n_thresh_channels
    print('Epochs with extreme voltage on more than' +
          f'{n_thresh_channels} channels:',
          ext_val_bad.nonzero()[0])

    # Exclude epochs based on Global Rejection Threshold with 8 epochs
    reject = get_rejection_threshold(epochs, ch_types='eeg')
    p2p_vals = np.abs(epoch_data.max(axis=2) - epoch_data.min(axis=2))
    p2p_nchan = np.sum(p2p_vals >= reject['eeg'], axis=1)
    p2p_bad = p2p_nchan > n_thresh_channels
    print('Epochs exceeding global P2P on more than' +
          f'{n_thresh_channels} channels:', p2p_bad.nonzero()[0])

    # Detect eog at stim onsets
    veog_data = epochs.copy().apply_baseline((None, None)).crop(
        tmin=-.1, tmax=.1).pick_channels(['VEOG']).get_data()
    veog_diff = np.abs(veog_data.max(axis=2) - veog_data.min(axis=2))
    blink_inds = np.where(veog_diff.squeeze() >
                          preprocess_options['blink_thresh'])[0]
    print('Epochs with blink at stim onset:', blink_inds)

    # Make color index
Beispiel #28
0
raw.info['bads'] = ['MEG1031', 'MEG1111', 'MEG1941']

sss_params_dir = '/storage/local/camcan/maxfilter'
cal = op.join(sss_params_dir, 'sss_params', 'sss_cal.dat')
ctc = op.join(sss_params_dir, 'sss_params', 'ct_sparse.fif')
raw = mne.preprocessing.maxwell_filter(raw,
                                       calibration=cal,
                                       cross_talk=ctc,
                                       st_duration=10.,
                                       st_correlation=.98,
                                       destination=None,
                                       coord_frame='head')

eog_epochs = mne.preprocessing.create_eog_epochs(raw)
if len(eog_epochs) >= 5:
    reject_eog = get_rejection_threshold(eog_epochs, decim=8)
    del reject_eog['eog']  # we don't want to reject eog based on eog.
else:
    reject_eog = None

ecg_epochs = mne.preprocessing.create_ecg_epochs(raw)
if len(ecg_epochs) >= 5:
    reject_ecg = get_rejection_threshold(ecg_epochs, decim=8)
    # here we want the eog.
else:
    reject_ecg = None

if reject_eog is None:
    reject_eog = {k: v for k, v in reject_ecg.items() if k != 'eog'}

proj_eog, _ = mne.preprocessing.compute_proj_eog(raw,
def run_epochs(subject_id, tsss=False):
    subject = "sub%03d" % subject_id
    print("Processing subject: %s%s" % (subject,
                                        (' (tSSS=%d)' % tsss) if tsss else ''))

    data_path = op.join(meg_dir, subject)

    # map to correct subject for bad channels
    mapping = map_subjects[subject_id]

    raw_list = list()
    events_list = list()
    print("  Loading raw data")
    for run in range(1, 7):
        bads = list()
        bad_name = op.join('bads', mapping, 'run_%02d_raw_tr.fif_bad' % run)
        if os.path.exists(bad_name):
            with open(bad_name) as f:
                for line in f:
                    bads.append(line.strip())

        if tsss:
            run_fname = op.join(data_path,
                                'run_%02d_filt_tsss_%d_raw.fif' % (run, tsss))
        else:
            run_fname = op.join(
                data_path, 'run_%02d_filt_sss_'
                'highpass-%sHz_raw.fif' % (run, l_freq))

        raw = mne.io.read_raw_fif(run_fname, preload=True)

        delay = int(round(0.0345 * raw.info['sfreq']))
        events = mne.read_events(op.join(data_path, 'run_%02d-eve.fif' % run))
        events[:, 0] = events[:, 0] + delay
        events_list.append(events)

        raw.info['bads'] = bads
        raw.interpolate_bads()
        raw_list.append(raw)

    raw, events = mne.concatenate_raws(raw_list, events_list=events_list)
    raw.set_eeg_reference(projection=True)
    del raw_list

    picks = mne.pick_types(raw.info,
                           meg=True,
                           eeg=True,
                           stim=True,
                           eog=True,
                           exclude=())

    # Epoch the data
    print('  Epoching')
    epochs = mne.Epochs(raw,
                        events,
                        events_id,
                        tmin,
                        tmax,
                        proj=True,
                        picks=picks,
                        baseline=baseline,
                        preload=False,
                        decim=5,
                        reject=None,
                        reject_tmax=reject_tmax)
    print('  Interpolating bad channels')

    # ICA
    if tsss:
        ica_name = op.join(meg_dir, subject,
                           'run_concat-tsss_%d-ica.fif' % (tsss, ))
        ica_out_name = ica_name
    else:
        ica_name = op.join(meg_dir, subject, 'run_concat-ica.fif')
        ica_out_name = op.join(meg_dir, subject,
                               'run_concat_highpass-%sHz-ica.fif' % (l_freq, ))
    print('  Using ICA')
    ica = read_ica(ica_name)
    ica.exclude = []

    filter_label = '-tsss_%d' % tsss if tsss else '_highpass-%sHz' % l_freq
    ecg_epochs = create_ecg_epochs(raw, tmin=-.3, tmax=.3, preload=False)
    eog_epochs = create_eog_epochs(raw, tmin=-.5, tmax=.5, preload=False)
    del raw

    n_max_ecg = 3  # use max 3 components
    ecg_epochs.decimate(5)
    ecg_epochs.load_data()
    ecg_epochs.apply_baseline((None, None))
    ecg_inds, scores_ecg = ica.find_bads_ecg(ecg_epochs,
                                             method='ctps',
                                             threshold=0.8)
    print('    Found %d ECG indices' % (len(ecg_inds), ))
    ica.exclude.extend(ecg_inds[:n_max_ecg])
    ecg_epochs.average().save(
        op.join(data_path, '%s%s-ecg-ave.fif' % (subject, filter_label)))
    np.save(
        op.join(data_path, '%s%s-ecg-scores.npy' % (subject, filter_label)),
        scores_ecg)
    del ecg_epochs

    n_max_eog = 3  # use max 2 components
    eog_epochs.decimate(5)
    eog_epochs.load_data()
    eog_epochs.apply_baseline((None, None))
    eog_inds, scores_eog = ica.find_bads_eog(eog_epochs)
    print('    Found %d EOG indices' % (len(eog_inds), ))
    ica.exclude.extend(eog_inds[:n_max_eog])
    eog_epochs.average().save(
        op.join(data_path, '%s%s-eog-ave.fif' % (subject, filter_label)))
    np.save(
        op.join(data_path, '%s%s-eog-scores.npy' % (subject, filter_label)),
        scores_eog)
    del eog_epochs

    ica.save(ica_out_name)
    epochs.load_data()
    ica.apply(epochs)

    print('  Getting rejection thresholds')
    reject = get_rejection_threshold(epochs.copy().crop(None, reject_tmax))
    epochs.drop_bad(reject=reject)
    print('  Dropped %0.1f%% of epochs' % (epochs.drop_log_stats(), ))

    print('  Writing to disk')
    if tsss:
        epochs.save(op.join(data_path, '%s-tsss_%d-epo.fif' % (subject, tsss)))
    else:
        epochs.save(
            op.join(data_path, '%s_highpass-%sHz-epo.fif' % (subject, l_freq)))
Beispiel #30
0
def run_epochs(subject,
               epoch_on_first_element,
               baseline=True,
               l_freq=None,
               h_freq=None,
               suffix='_eeg_1Hz'):

    print("Processing subject: %s" % subject)
    meg_subject_dir = op.join(config.meg_dir, subject)
    run_info_subject_dir = op.join(config.run_info_dir, subject)
    raw_list = list()
    events_list = list()

    print("  Loading raw data")
    runs = config.runs_dict[subject]
    for run in runs:
        extension = run + '_ica_raw'
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        # ---------------------------------------------------------------------------------------------------------------- #
        # RESAMPLING EACH RUN BEFORE CONCAT & EPOCHING
        # Resampling the raw data while keeping events from original raw data, to avoid potential loss of
        # events when downsampling: https://www.nmr.mgh.harvard.edu/mne/dev/auto_examples/preprocessing/plot_resample.html
        # Find events
        events = mne.find_events(raw,
                                 stim_channel=config.stim_channel,
                                 consecutive=True,
                                 min_duration=config.min_event_duration,
                                 shortest_event=config.shortest_event)

        print('  Downsampling raw data')
        raw, events = raw.resample(config.resample_sfreq,
                                   npad='auto',
                                   events=events)
        if len(events) != 46 * 16:
            raise Exception('We expected %i events but we got %i' %
                            (46 * 16, len(events)))
        raw.filter(l_freq=1, h_freq=None)
        raw_list.append(raw)
        # ---------------------------------------------------------------------------------------------------------------- #

    if subject == 'sub08-cc_150418':
        # For this participant, we had some problems when concatenating the raws for run08. The error message said that raw08._cals didn't match the other ones.
        # We saw that it is the 'calibration' for the channel EOG061 that was different with respect to run09._cals.
        raw_list[7]._cals = raw_list[8]._cals
        print(
            'Warning: corrected an issue with subject08 run08 ica_raw data file...'
        )

    print('Concatenating runs')
    raw = mne.concatenate_raws(raw_list)
    if "eeg" in config.ch_types:
        raw.set_eeg_reference(projection=True)
    del raw_list

    meg = False
    if 'meg' in config.ch_types:
        meg = True
    elif 'grad' in config.ch_types:
        meg = 'grad'
    elif 'mag' in config.ch_types:
        meg = 'mag'
    eeg = 'eeg' in config.ch_types
    picks = mne.pick_types(raw.info,
                           meg=meg,
                           eeg=eeg,
                           stim=True,
                           eog=True,
                           exclude=())

    # Construct metadata from csv events file
    metadata = epoching_funcs.convert_csv_info_to_metadata(
        run_info_subject_dir)
    metadata_pandas = pd.DataFrame.from_dict(metadata, orient='index')
    metadata_pandas = pd.DataFrame.transpose(metadata_pandas)

    # ====== Epoching the data
    print('  Epoching')

    # Events
    events = mne.find_events(raw,
                             stim_channel=config.stim_channel,
                             consecutive=True,
                             min_duration=config.min_event_duration,
                             shortest_event=config.shortest_event)

    if epoch_on_first_element:
        # fosca 06012020
        config.tmin = -0.200
        config.tmax = 0.25 * 17
        config.baseline = (config.tmin, 0)
        if baseline is None:
            config.baseline = None
        for k in range(len(events)):
            events[k, 2] = k % 16 + 1
        epochs = mne.Epochs(raw,
                            events, {'sequence_starts': 1},
                            config.tmin,
                            config.tmax,
                            proj=True,
                            picks=picks,
                            baseline=config.baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)
        epochs.metadata = metadata_pandas[metadata_pandas['StimPosition'] ==
                                          1.0]
    else:
        config.tmin = -0.050
        config.tmax = 0.600
        config.baseline = (config.tmin, 0)
        if baseline is None:
            config.baseline = None
        epochs = mne.Epochs(raw,
                            events,
                            None,
                            config.tmin,
                            config.tmax,
                            proj=True,
                            picks=picks,
                            baseline=config.baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)

        # Add metadata to epochs
        epochs.metadata = metadata_pandas

    # Save epochs (before AutoReject)
    print('  Writing epochs to disk')
    if epoch_on_first_element:
        extension = subject + '_1st_element_epo' + suffix
    else:
        extension = subject + '_epo' + suffix
    epochs_fname = op.join(meg_subject_dir,
                           config.base_fname.format(**locals()))

    print("Output: ", epochs_fname)
    epochs.save(epochs_fname, overwrite=True)
    # epochs.save(epochs_fname)

    if config.autoreject:
        epochs.load_data()

        # Running AutoReject "global" (https://autoreject.github.io) -> just get the thresholds
        from autoreject import get_rejection_threshold
        reject = get_rejection_threshold(epochs,
                                         ch_types=['mag', 'grad', 'eeg'])
        epochsARglob = epochs.copy().drop_bad(reject=reject)
        print('  Writing "AR global" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_ARglob_epo' + suffix
        else:
            extension = subject + '_ARglob_epo' + suffix
        epochs_fname = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochsARglob.save(epochs_fname, overwrite=True)
        # Save autoreject thresholds
        pickle.dump(reject,
                    open(epochs_fname[:-4] + '_ARglob_thresholds.obj', 'wb'))

        # Running AutoReject "local" (https://autoreject.github.io)
        ar = AutoReject()
        epochsAR, reject_log = ar.fit_transform(epochs, return_log=True)
        print('  Writing "AR local" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_clean_epo' + suffix
        else:
            extension = subject + '_clean_epo' + suffix
        epochs_fname = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochsAR.save(epochs_fname, overwrite=True)
        # Save autoreject reject_log
        pickle.dump(reject_log,
                    open(epochs_fname[:-4] + '_reject_local_log.obj', 'wb'))
Beispiel #31
0
def run_epochs(subject,
               epoch_on_first_element,
               baseline=True,
               tmin=None,
               tmax=None,
               whattoreturn=None):

    # SEt this param to True if you want to run autoreject locally too when config.autorject = True
    from datetime import datetime
    now = datetime.now().time()

    ARlocal = False

    print("Processing subject: %s" % subject)
    meg_subject_dir = op.join(config.meg_dir, subject)
    run_info_subject_dir = op.join(config.run_info_dir, subject)
    raw_list = list()
    events_list = list()

    if config.noEEG:
        output_dir = op.join(meg_subject_dir, 'noEEG')
        utils.create_folder(output_dir)
    else:
        output_dir = meg_subject_dir

    print("  Loading raw data")
    runs = config.runs_dict[subject]
    for run in runs:
        extension = run + '_ica_raw'
        print(extension)
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        # ---------------------------------------------------------------------------------------------------------------- #
        # RESAMPLING EACH RUN BEFORE CONCAT & EPOCHING
        # Resampling the raw data while keeping events from original raw data, to avoid potential loss of
        # events when downsampling: https://www.nmr.mgh.harvard.edu/mne/dev/auto_examples/preprocessing/plot_resample.html
        # Find events
        events = mne.find_events(raw,
                                 stim_channel=config.stim_channel,
                                 consecutive=True,
                                 min_duration=config.min_event_duration,
                                 shortest_event=config.shortest_event)

        print('  Downsampling raw data')
        raw, events = raw.resample(config.resample_sfreq,
                                   npad='auto',
                                   events=events)

        times_between_events_and_end = (raw.last_samp -
                                        events[:, 0]) / raw.info['sfreq']
        if np.sum(times_between_events_and_end < 0.6) > 0:
            print("=== some events are too close to the end ====")

        if len(events) != 46 * 16:
            raise Exception('We expected %i events but we got %i' %
                            (46 * 16, len(events)))

        raw_list.append(raw)
        # ---------------------------------------------------------------------------------------------------------------- #

    if subject == 'sub08-cc_150418':
        # For this participant, we had some problems when concatenating the raws for run08. The error message said that raw08._cals didn't match the other ones.
        # We saw that it is the 'calibration' for the channel EOG061 that was different with respect to run09._cals.
        raw_list[7]._cals = raw_list[8]._cals
        print(
            'Warning: corrected an issue with subject08 run08 ica_raw data file...'
        )

    print('Concatenating runs')
    raw = mne.concatenate_raws(raw_list)
    # raw.set_annotations(None)
    if "eeg" in config.ch_types:
        raw.set_eeg_reference(projection=True)
    del raw_list

    # Save resampled, concatenated runs (in case we need it)
    # print('Saving concatenated runs')
    # fname = op.join(meg_subject_dir, subject + '_allruns_final_raw.fif')
    # raw.save(fname, overwrite=True)

    if config.noEEG:
        picks = mne.pick_types(raw.info,
                               meg=True,
                               eeg=False,
                               stim=True,
                               eog=True,
                               exclude=())
    else:
        picks = mne.pick_types(raw.info,
                               meg=True,
                               eeg=True,
                               stim=True,
                               eog=True,
                               exclude=())

    # Construct metadata from csv events file
    metadata = convert_csv_info_to_metadata(run_info_subject_dir)
    metadata_pandas = pd.DataFrame.from_dict(metadata, orient='index')
    metadata_pandas = pd.DataFrame.transpose(metadata_pandas)

    # ====== Epoching the data
    print('  Epoching')

    # Events
    events = mne.find_events(raw,
                             stim_channel=config.stim_channel,
                             consecutive=True,
                             min_duration=config.min_event_duration,
                             shortest_event=config.shortest_event)

    if epoch_on_first_element:
        # fosca 06012020
        if tmin is None:
            tmin = -0.200
        if tmax is None:
            tmax = 0.25 * 17
        baseline = (tmin, 0)
        if (baseline is None) or (baseline is False):
            baseline = None
        for k in range(len(events)):
            events[k, 2] = k % 16 + 1
        epochs = mne.Epochs(raw,
                            events, {'sequence_starts': 1},
                            tmin,
                            tmax,
                            proj=True,
                            picks=picks,
                            baseline=baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)
        epochs.metadata = metadata_pandas[metadata_pandas['StimPosition'] ==
                                          1.0]
    else:
        if tmin is None:
            tmin = -0.050
        if tmax is None:
            tmax = 0.600
        if (baseline is None) or (baseline is False):
            baseline = None
        else:
            baseline = (tmin, 0)

        epochs = mne.Epochs(raw,
                            events,
                            None,
                            tmin,
                            tmax,
                            proj=True,
                            picks=picks,
                            baseline=baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)

        # Add metadata to epochs
        epochs.metadata = metadata_pandas

    # Save epochs (before AutoReject)

    if whattoreturn is None:
        print('  Writing epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_epo'
        else:
            extension = subject + '_epo'
        epochs_fname = op.join(output_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochs.save(epochs_fname, overwrite=True)
    elif whattoreturn == '':
        epochs.load_data()
        return epochs
    else:
        print("=== we continue on the autoreject part ===")

    if config.autoreject:
        epochs.load_data()
        # Running AutoReject "global" (https://autoreject.github.io) -> just get the thresholds
        from autoreject import get_rejection_threshold
        reject = get_rejection_threshold(epochs, ch_types=config.ch_types)
        epochsARglob = epochs.copy().drop_bad(reject=reject)
        print('  Writing "AR global" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_ARglob_epo'
        else:
            extension = subject + '_ARglob_epo'
        epochs_fname = op.join(output_dir,
                               config.base_fname.format(**locals()))
        if whattoreturn is None:
            print("Output: ", epochs_fname)
            epochsARglob.save(epochs_fname, overwrite=True)
            pickle.dump(
                reject, open(epochs_fname[:-4] + '_ARglob_thresholds.obj',
                             'wb'))
        elif whattoreturn == 'ARglobal':
            return epochsARglob
        else:
            print("==== continue to ARlocal ====")
        # Save autoreject thresholds

        # Running AutoReject "local" (https://autoreject.github.io)
        if ARlocal:
            ar = AutoReject()
            epochsAR, reject_log = ar.fit_transform(epochs, return_log=True)
            print('  Writing "AR local" cleaned epochs to disk')
            if epoch_on_first_element:
                extension = subject + '_1st_element_clean_epo'
            else:
                extension = subject + '_clean_epo'
            epochs_fname = op.join(output_dir,
                                   config.base_fname.format(**locals()))
            if whattoreturn is None:
                print("Output: ", epochs_fname)
                epochsAR.save(epochs_fname, overwrite=True)
                # Save autoreject reject_log
                pickle.dump(
                    reject_log,
                    open(epochs_fname[:-4] + '_reject_local_log.obj', 'wb'))
            else:
                return epochsAR
from autoreject import validation_curve  # noqa
from autoreject import get_rejection_threshold  # noqa

_, test_scores, param_range = validation_curve(epochs,
                                               param_range=param_range,
                                               cv=5,
                                               return_param_range=True,
                                               n_jobs=1)

test_scores = -test_scores.mean(axis=1)
best_thresh = param_range[np.argmin(test_scores)]

###############################################################################
# We can also get the best threshold more efficiently using Bayesian
# optimization
reject2 = get_rejection_threshold(epochs, random_state=0, cv=5)

###############################################################################
# Now let us plot the RMSE values against the candidate thresholds.

import matplotlib.pyplot as plt  # noqa
from autoreject import set_matplotlib_defaults  # noqa
set_matplotlib_defaults(plt)

human_thresh = 80e-6  # this is a threshold determined visually by a human
unit = r'$\mu$V'
scaling = 1e6

plt.figure(figsize=(8, 5))
plt.tick_params(axis='x', which='both', bottom='off', top='off')
plt.tick_params(axis='y', which='both', left='off', right='off')