示例#1
0
def segment_files(bids_filepath):
    raw = read_raw_fif(bids_filepath, preload=True)
    picks = mne.pick_types(raw.info,
                           meg=True,
                           ref_meg=False,
                           eeg=False,
                           eog=False,
                           stim=False)
    ### Set some constants for epoching
    baseline = None  #(None, 0.0)
    reject = {'mag': 4e-12}
    tmin, tmax = 0, 0.8
    events = mne.find_events(raw, min_duration=2 / raw.info['sfreq'])
    event_id = {'Freq': 21, 'Rare': 31, 'Resp': 99}
    epochs = mne.Epochs(raw,
                        events=events,
                        event_id=event_id,
                        tmin=tmin,
                        tmax=tmax,
                        baseline=baseline,
                        reject=None,
                        picks=picks,
                        preload=True)
    ar = AutoReject()
    epochs_clean = ar.fit_transform(epochs)
    return epochs_clean
示例#2
0
文件: neuro.py 项目: hyruuk/saflow
def segment_files(bids_filepath, tmin=0, tmax=0.8):
    raw = read_raw_fif(bids_filepath, preload=True)
    picks = mne.pick_types(raw.info,
                           meg=True,
                           ref_meg=True,
                           eeg=False,
                           eog=False,
                           stim=False)
    ### Set some constants for epoching
    baseline = None  #(None, -0.05)
    #reject = {'mag': 4e-12}
    try:
        events = mne.find_events(raw,
                                 min_duration=1 / raw.info['sfreq'],
                                 verbose=False)
    except ValueError:
        events = mne.find_events(raw,
                                 min_duration=2 / raw.info['sfreq'],
                                 verbose=False)
    event_id = {'Freq': 21, 'Rare': 31}
    epochs = mne.Epochs(raw,
                        events=events,
                        event_id=event_id,
                        tmin=tmin,
                        tmax=tmax,
                        baseline=baseline,
                        reject=None,
                        picks=picks,
                        preload=True)
    ar = AutoReject(n_jobs=6)
    epochs_clean, autoreject_log = ar.fit_transform(epochs, return_log=True)
    return epochs_clean, autoreject_log
示例#3
0
def autoreject_marmouset(subject):

    root_path = '/neurospin/unicog/protocols/ABSeq_marmousets/'
    neural_data_path = root_path + 'neural_data/'

    subject = 'Nr'
    epoch_name = '/epoch_items'
    tmin = -0.099

    # ======== rebuild the epoch object and run autoreject ========
    epoch_data = np.load(neural_data_path + subject + epoch_name + '_data.npy')
    info = np.load(neural_data_path + subject + epoch_name + '_info.npy',
                   allow_pickle=True).item()
    metadata = np.load(neural_data_path + subject + epoch_name +
                       '_metadata.pkl',
                       allow_pickle=True)
    epochs = mne.EpochsArray(epoch_data, info=info, tmin=tmin)
    epochs.metadata = metadata
    epochs.load_data()

    # ======== ======== ======== ======== ======== ======== ========
    ar = AutoReject()
    epochs, reject_log = ar.fit_transform(epochs, return_log=True)
    epochs_clean_fname = neural_data_path + subject + epoch_name + '_clean.fif'
    print("Output: ", epochs_clean_fname)
    epochs.save(epochs_clean_fname, overwrite=True)
    # Save autoreject reject_log
    pickle.dump(reject_log,
                open(epochs_clean_fname[:-4] + '_reject_log.obj', 'wb'))
    np.save(neural_data_path + subject + epoch_name + '_data_clean.npy',
            epochs.get_data())
    epochs.metadata.to_pickle(neural_data_path + subject + epoch_name +
                              '_metadata_clean.pkl')
    np.save(neural_data_path + subject + epoch_name + '_info_clean.npy',
            epochs.info)
def autoreject_repair_epochs(epochs, reject_plot=False):
    """Rejects the bad epochs with AutoReject algorithm

    Parameters
    ----------
    epochs : mne epoch object
        Epoched, filtered eeg data.

    Returns
    ----------
    epochs : mne epoch object
        Epoched data after rejection of bad epochs.

    """
    # Cleaning with autoreject
    picks = mne.pick_types(epochs.info, eeg=True)  # Pick EEG channels
    ar = AutoReject(n_interpolate=[1, 2, 3],
                    n_jobs=6,
                    picks=picks,
                    thresh_func='bayesian_optimization',
                    cv=3,
                    random_state=42,
                    verbose=False)

    cleaned_epochs, reject_log = ar.fit_transform(epochs, return_log=True)

    if reject_plot:
        reject_log.plot_epochs(epochs, scalings=dict(eeg=40e-6))

    return cleaned_epochs
示例#5
0
def run_autoreject(subject, epoch_on_first_element):
    N_JOBS_ar = 1  # "The number of thresholds to compute in parallel."

    print(
        '#########################################################################################'
    )
    print(
        '########################## Processing subject: %s ##########################'
        % subject)
    print(
        '#########################################################################################'
    )

    if epoch_on_first_element:
        print("  Loading 'full sequences' epochs")
        epochs = epoching_funcs.load_epochs_full_sequence(subject,
                                                          cleaned=False)
    else:
        print("  Loading 'items' epochs")
        epochs = epoching_funcs.load_epochs_items(subject, cleaned=False)

    # Running AutoReject (https://autoreject.github.io)
    epochs.load_data()
    ar = AutoReject(n_jobs=N_JOBS_ar)
    epochs, reject_log = ar.fit_transform(epochs, return_log=True)

    # Save epochs (after AutoReject)
    print('  Writing cleaned epochs to disk')
    meg_subject_dir = op.join(config.meg_dir, subject)
    if epoch_on_first_element:
        extension = subject + '_1st_element_clean_epo'
    else:
        extension = subject + '_clean_epo'
    epochs_fname = op.join(meg_subject_dir,
                           config.base_fname.format(**locals()))
    print("Output: ", epochs_fname)
    epochs.save(epochs_fname)  # , overwrite=True)

    # Save autoreject reject_log
    pickle.dump(reject_log, open(epochs_fname[:-4] + '_reject_log.obj', 'wb'))
示例#6
0
def runautoreject(epochs,
                  fiffile,
                  senstype,
                  bads=[],
                  n_interpolates=np.array([1, 4, 32]),
                  consensus_percs=np.linspace(0, 1, 11)):

    check_random_state(42)

    raw = mne.io.read_raw_fif(fiffile, preload=True)
    raw.info['bads'] = list()
    raw.pick_types(meg=True)
    raw.info['projs'] = list()
    epochs.info = raw.info  #required since no channel infos

    del raw

    picks = mne.pick_types(epochs.info,
                           meg=senstype,
                           eeg=False,
                           stim=False,
                           eog=False,
                           include=[],
                           exclude=bads)

    epochs.verbose = False
    epochs.baseline = (None, 0)
    epochs.preload = True
    epochs.detrend = 0

    ar = AutoReject(n_interpolates,
                    consensus_percs,
                    picks=picks,
                    thresh_method='bayesian_optimization',
                    random_state=42,
                    verbose=False)

    epochs, reject_log = ar.fit_transform(epochs, return_log=True)
    return reject_log
def run_autoreject(epochs, show_figs=False, results_dir=None):
    """Run autoreject.
    """
    from autoreject import AutoReject

    ar = AutoReject()
    epochs = ar.fit_transform(epochs)

    if show_figs or results_dir is not None:
        pass
        # ar_log = ar.get_reject_log(epochs_clean)
        # fig_log = ar_log.plot()
        # ar_log.plot_epochs()
        # Similar to bad_segments, but with entries 0, 1, and 2.
        #     0 : good data segment
        #     1 : bad data segment not interpolated
        #     2 : bad data segment interpolated
    if results_dir is not None:
        pass
        # fig_log.savefig(os.path.join(results_dir, '4a_bad_epochs.png'))

    return epochs
示例#8
0
def reject_epochs(epochs, autoreject_parameters):
    ar = AutoReject(**autoreject_parameters, verbose="tqdm")
    # for event in epochs.event_id.keys():
    #    epochs[event] = ar.fit_transform(epochs[event])
    epochs = ar.fit_transform(epochs)
    fig, ax = plt.subplots(2)
    # plotipyt histogram of rejection thresholds
    ax[0].set_title("Rejection Thresholds")
    ax[0].hist(1e6 * np.array(list(ar.threshes_.values())),
               30,
               color='g',
               alpha=0.4)
    ax[0].set(xlabel='Threshold (μV)', ylabel='Number of sensors')
    # plot cross validation error:
    loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.
    im = ax[1].matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))
    ax[1].set_xticks(range(len(ar.consensus)))
    ax[1].set_xticklabels(['%.1f' % c for c in ar.consensus])
    ax[1].set_yticks(range(len(ar.n_interpolate)))
    ax[1].set_yticklabels(ar.n_interpolate)
    # Draw rectangle at location of best parameters
    idx, jdx = np.unravel_index(loss.argmin(), loss.shape)
    rect = patches.Rectangle((idx - 0.5, jdx - 0.5),
                             1,
                             1,
                             linewidth=2,
                             edgecolor='r',
                             facecolor='none')
    ax[1].add_patch(rect)
    ax[1].xaxis.set_ticks_position('bottom')
    ax[1].set(xlabel=r'Consensus percentage $\kappa$',
              ylabel=r'Max sensors interpolated $\rho$',
              title='Mean cross validation error (x 1e6)')
    fig.colorbar(im)
    fig.tight_layout()
    fig.savefig(_out_folder / Path("reject_epochs.pdf"), dpi=800)
    plt.close()
    return epochs
示例#9
0
def test_fnirs():
    """Test that autoreject runs on fNIRS data."""
    raw = mne.io.read_raw_nirx(
        os.path.join(mne.datasets.fnirs_motor.data_path(), 'Participant-1'))
    raw.crop(tmax=1200)
    raw = mne.preprocessing.nirs.optical_density(raw)
    raw = mne.preprocessing.nirs.beer_lambert_law(raw)
    events, _ = mne.events_from_annotations(raw,
                                            event_id={
                                                '1.0': 1,
                                                '2.0': 2,
                                                '3.0': 3
                                            })
    event_dict = {'Control': 1, 'Tapping/Left': 2, 'Tapping/Right': 3}
    epochs = mne.Epochs(raw,
                        events,
                        event_id=event_dict,
                        tmin=-5,
                        tmax=15,
                        proj=True,
                        baseline=(None, 0),
                        preload=True,
                        detrend=None,
                        verbose=True)
    # Test autoreject
    ar = AutoReject()
    assert len(epochs) == 37
    epochs_clean = ar.fit_transform(epochs)
    assert len(epochs_clean) < len(epochs)
    # Test threshold extraction
    reject = get_rejection_threshold(epochs)
    print(reject)
    assert "hbo" in reject.keys()
    assert "hbr" in reject.keys()
    assert reject["hbo"] < 0.001  # This is a very high value as sanity check
    assert reject["hbr"] < 0.001
    assert reject["hbr"] > 0.0
示例#10
0
def run_autoreject(subject):
    """Interpolate bad epochs/sensors using Autoreject.

    Parameters
    ----------
    *subject: string
        The participant reference

    Save the resulting *-epo.fif file in the '4_autoreject' directory.
    Save .png of ERP difference and heatmap plots.

    References
    ----------
    [1] Mainak Jas, Denis Engemann, Federico Raimondo, Yousra Bekhti, and
    Alexandre Gramfort, “Automated rejection and repair of bad trials in
    MEG/EEG.” In 6th International Workshop on Pattern Recognition in
    Neuroimaging (PRNI), 2016.

    [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and
    Alexandre Gramfort. 2017. “Autoreject: Automated artifact rejection for
    MEG and EEG data”. NeuroImage, 159, 417-429.

    """
    # Import data
    input_path = root + '/4_ICA/' + subject + '-epo.fif'
    epochs = mne.read_epochs(input_path)

    # Autoreject
    ar = AutoReject(random_state=42,
                    n_jobs=4)

    ar.fit_transform(epochs)
    epochs_clean = ar.transform(epochs)

    # Plot difference
    evoked = epochs.average()
    evoked_clean = epochs_clean.average()

    fig, axes = plt.subplots(2, 1, figsize=(6, 6))

    for ax in axes:
        ax.tick_params(axis='x', which='both', bottom='off', top='off')
        ax.tick_params(axis='y', which='both', left='off', right='off')

    evoked.plot(exclude=[], axes=axes[0], ylim=[-30, 30], show=False)
    axes[0].set_title('Before autoreject')
    evoked_clean.plot(exclude=[], axes=axes[1], ylim=[-30, 30])
    axes[1].set_title('After autoreject')
    plt.tight_layout()
    plt.savefig(root + '/5_autoreject/' +
                subject + '-autoreject.png')
    plt.close()

    # Plot heatmap
    ar.get_reject_log(epochs).plot()
    plt.savefig(root + '/5_autoreject/' + subject + '-heatmap.png')
    plt.close()

    # Save epoch data
    out_epoch = root + '/5_autoreject/' + subject + '-epo.fif'
    epochs_clean.save(out_epoch)
# %%
# Note that :class:`autoreject.AutoReject` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

# %%
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.AutoReject` on a
# representative subsample of the data.

ar = AutoReject(picks=picks, random_state=42, n_jobs=1, verbose=True)

epochs_ar, reject_log = ar.fit_transform(this_epoch, return_log=True)

# %%
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.

plt.matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))
plt.xticks(range(len(ar.consensus)), ['%.1f' % c for c in ar.consensus])
plt.yticks(range(len(ar.n_interpolate)), ar.n_interpolate)
    def run(self):

        eog = self.info['channel_info']['EOG']
        misc = self.info['channel_info']['Misc']
        stim = self.info['channel_info']['Stim']

        try:
            ext_files = glob.glob(self.info['ext_file_folder'] + '/' +
                                  self.participant + '/*axis0.dat')
        except:
            pass

        tmin = self.t_epoch[0]
        tmax = self.t_epoch[1]

        raw = read_raw_edf(self.file, eog=eog, misc=misc)
        self.raw = cp.deepcopy(raw)
        raw.load_data()

        # marker detection (one marker continous trial)
        if self.info['marker_detection'] == True:
            starts = find_trialstart(raw,
                                     stim_channel=raw.ch_names[stim[0]],
                                     new_samplin_rate=self.sr_new)
            try:
                starts[1] = starts[0] + 30 * 200
            except:
                starts = np.r_[starts, (starts[0] + 30 * 200)]
            events = np.zeros((len(starts), 3))
            events[:, 0] = starts
            events[:, 2] = list(self.info['event_dict'].values())
            events = events.astype(np.int)

        # event detection (one marker regular events)
        if self.info['event_detection'] == True:
            starts = find_trialstart(raw,
                                     stim_channel=raw.ch_names[stim[0]],
                                     new_samplin_rate=self.sr_new)

            events = force_events(ext_files, self.info['event_dict'],
                                  self.sr_new, self.info['trial_length'],
                                  self.info['trials'],
                                  starts[:len(self.info['event_dict'])])

        if self.info['ICA'] == True:
            ica = ICA(method='fastica')

        if self.info['Autoreject'] == True:
            ar = AutoReject()

        ## EEG preprocessing options will applied if parameters are set in object

        #read montage
        try:
            montage = make_standard_montage(self.montage)
            raw.set_montage(montage)
        except:
            pass

        #resampling
        try:
            raw.resample(sfreq=self.sr_new)
        except:
            pass

        #rereferencing
        try:
            raw, _ = mne.set_eeg_reference(raw, ref_channels=['EXG5', 'EXG6'])
        except:
            pass

        #filter
        try:
            low = self.filter_freqs[0]
            high = self.filter_freqs[1]
            raw.filter(low, high, fir_design='firwin')
        except:
            pass

        # occular correction
        try:
            ica.fit(raw)
            ica.exclude = []
            eog_indices, eog_scores = ica.find_bads_eog(raw)
            ica.exclude = eog_indices
            ica.apply(raw)
            self.ica = ica
        except:
            pass

        picks = mne.pick_types(raw.info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude='bads')

        event_id = self.info['event_dict']
        epochs = mne.Epochs(raw,
                            events,
                            event_id,
                            tmin,
                            tmax,
                            proj=True,
                            baseline=None,
                            preload=True,
                            picks=picks)

        #epoch rejection
        try:
            epochs = epochs.drop(indices=self.bads)
        except:
            pass

        try:
            epochs, self.autoreject_log = ar.fit_transform(epochs,
                                                           return_log=True)
        except:
            pass

        bads = np.asarray(
            [l == ['USER'] or l == ['AUTOREJECT'] for l in epochs.drop_log])
        self.bads = np.where(bads == True)
        self.epochs = epochs
        return (self)
示例#13
0
def epoch_and_clean_trials(subject,
                           diagdir,
                           bidsdir,
                           datadir,
                           derivdir,
                           epochlength=3,
                           eventid={'visualfix/fixCross': 10}):
    """
    Chunk the data into epochs starting at the eventid specified per trial,
    lasting 7 seconds (which should include all trial elements).
    Do automatic artifact detection, rejection and fixing for eyeblinks,
    heartbeat, and high- and low-amplitude artifacts.
    :param subject: str, subject identifier. takes the form '001'
    :param diagdir: str, path to a directory where diagnostic plots can be saved
    :param bidsdir: str, path to a directory with BIDS data. Needed to load
    event logs from the experiment
    :param datadir: str, path to a directory with SSS-processed data
    :param derivdir: str, path to a directory where cleaned epochs can be saved
    :param epochlength: int, length of epoch
    :param eventid: dict, the event to start an Epoch from
    """
    # construct name of the first split
    raw_fname = Path(datadir) / f'sub-{subject}/meg' / \
                f'sub-{subject}_task-memento_proc-sss_meg.fif'
    logging.info(f"Reading in SSS-processed data from subject sub-{subject}. "
                 f"Attempting the following path: {raw_fname}")
    raw = mne.io.read_raw_fif(raw_fname)
    events, event_dict = get_events(raw)
    # filter the data to remove high-frequency noise. Minimal high-pass filter
    # based on
    # https://www.sciencedirect.com/science/article/pii/S0165027021000157
    # ensure the data is loaded prior to filtering
    raw.load_data()
    if subject == '017':
        logging.info('Setting additional bad channels for subject 17')
        raw.info['bads'] = ['MEG0313', 'MEG0513', 'MEG0523']
        raw.interpolate_bads()
    # high-pass doesn't make sense, raw data has 0.1Hz high-pass filter already!
    _filter_data(raw, h_freq=100)
    # ICA to detect and repair artifacts
    logging.info('Removing eyeblink and heartbeat artifacts')
    rng = np.random.RandomState(28)
    remove_eyeblinks_and_heartbeat(
        raw=raw,
        subject=subject,
        figdir=diagdir,
        events=events,
        eventid=eventid,
        rng=rng,
    )
    # get the actual epochs: chunk the trial into epochs starting from the
    # event ID. Do not baseline correct the data.
    logging.info(f'Creating epochs of length {epochlength}')
    if eventid == {'press/left': 1, 'press/right': 4}:
        # when centered on the response, move back in time
        epochs = mne.Epochs(raw,
                            events,
                            event_id=eventid,
                            tmin=-epochlength,
                            tmax=0,
                            picks='meg',
                            baseline=None)
    else:
        epochs = mne.Epochs(raw,
                            events,
                            event_id=eventid,
                            tmin=0,
                            tmax=epochlength,
                            picks='meg',
                            baseline=None)
    # ADD SUBJECT SPECIFIC TRIAL NUMBER TO THE EPOCH! ONLY THIS WAY WE CAN
    # LATER RECOVER WHICH TRIAL PARAMETERS WE'RE LOOKING AT BASED ON THE LOGS AS
    # THE EPOCH REJECTION WILL REMOVE TRIALS
    logging.info("Retrieving trial metadata.")
    from pymento_meg.proc.epoch import get_trial_features
    metadata = get_trial_features(bids_path=bidsdir,
                                  subject=subject,
                                  column='trial_no')
    # transform to integers
    metadata = metadata.astype(int)
    # this does not work if we start at fixation cross for subject 5, because 1
    # fixation cross trigger is missing from the data, and it becomes impossible
    # to associate the trial metadata to the correct trials in the data
    epochs.metadata = metadata
    epochs.load_data()
    ## downsample the data to 200Hz
    #logging.info('Resampling epoched data down to 200 Hz')
    #epochs.resample(sfreq=200, verbose=True)
    # use autoreject to repair bad epochs
    ar = AutoReject(
        random_state=rng,
        n_interpolate=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
    epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True)
    # save the cleaned, epoched data to disk.
    outpath = _construct_path([
        Path(derivdir),
        f"sub-{subject}",
        "meg",
        f"sub-{subject}_task-memento_cleaned_epo.fif",
    ])
    logging.info(f"Saving cleaned, epoched data to {outpath}")
    epochs_clean.save(outpath, overwrite=True)
    # visualize the bad sensors for each trial
    fig = ar.get_reject_log(epochs).plot()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"epoch-rejectlog_sub-{subject}.png",
    ])
    fig.savefig(fname)
    # plot the average of all cleaned epochs
    fig = epochs_clean.average().plot()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"clean-epoch_average_sub-{subject}.png",
    ])
    fig.savefig(fname)
    # plot psd of cleaned epochs
    psd = epochs_clean.plot_psd()
    fname = _construct_path([
        Path(diagdir),
        f"sub-{subject}",
        "meg",
        f"psd_cleaned-epochs-{subject}.png",
    ])
    psd.savefig(fname)
示例#14
0
def run_epochs(subject,
               epoch_on_first_element,
               baseline=True,
               tmin=None,
               tmax=None,
               whattoreturn=None):

    # SEt this param to True if you want to run autoreject locally too when config.autorject = True
    from datetime import datetime
    now = datetime.now().time()

    ARlocal = False

    print("Processing subject: %s" % subject)
    meg_subject_dir = op.join(config.meg_dir, subject)
    run_info_subject_dir = op.join(config.run_info_dir, subject)
    raw_list = list()
    events_list = list()

    if config.noEEG:
        output_dir = op.join(meg_subject_dir, 'noEEG')
        utils.create_folder(output_dir)
    else:
        output_dir = meg_subject_dir

    print("  Loading raw data")
    runs = config.runs_dict[subject]
    for run in runs:
        extension = run + '_ica_raw'
        print(extension)
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        # ---------------------------------------------------------------------------------------------------------------- #
        # RESAMPLING EACH RUN BEFORE CONCAT & EPOCHING
        # Resampling the raw data while keeping events from original raw data, to avoid potential loss of
        # events when downsampling: https://www.nmr.mgh.harvard.edu/mne/dev/auto_examples/preprocessing/plot_resample.html
        # Find events
        events = mne.find_events(raw,
                                 stim_channel=config.stim_channel,
                                 consecutive=True,
                                 min_duration=config.min_event_duration,
                                 shortest_event=config.shortest_event)

        print('  Downsampling raw data')
        raw, events = raw.resample(config.resample_sfreq,
                                   npad='auto',
                                   events=events)

        times_between_events_and_end = (raw.last_samp -
                                        events[:, 0]) / raw.info['sfreq']
        if np.sum(times_between_events_and_end < 0.6) > 0:
            print("=== some events are too close to the end ====")

        if len(events) != 46 * 16:
            raise Exception('We expected %i events but we got %i' %
                            (46 * 16, len(events)))

        raw_list.append(raw)
        # ---------------------------------------------------------------------------------------------------------------- #

    if subject == 'sub08-cc_150418':
        # For this participant, we had some problems when concatenating the raws for run08. The error message said that raw08._cals didn't match the other ones.
        # We saw that it is the 'calibration' for the channel EOG061 that was different with respect to run09._cals.
        raw_list[7]._cals = raw_list[8]._cals
        print(
            'Warning: corrected an issue with subject08 run08 ica_raw data file...'
        )

    print('Concatenating runs')
    raw = mne.concatenate_raws(raw_list)
    # raw.set_annotations(None)
    if "eeg" in config.ch_types:
        raw.set_eeg_reference(projection=True)
    del raw_list

    # Save resampled, concatenated runs (in case we need it)
    # print('Saving concatenated runs')
    # fname = op.join(meg_subject_dir, subject + '_allruns_final_raw.fif')
    # raw.save(fname, overwrite=True)

    if config.noEEG:
        picks = mne.pick_types(raw.info,
                               meg=True,
                               eeg=False,
                               stim=True,
                               eog=True,
                               exclude=())
    else:
        picks = mne.pick_types(raw.info,
                               meg=True,
                               eeg=True,
                               stim=True,
                               eog=True,
                               exclude=())

    # Construct metadata from csv events file
    metadata = convert_csv_info_to_metadata(run_info_subject_dir)
    metadata_pandas = pd.DataFrame.from_dict(metadata, orient='index')
    metadata_pandas = pd.DataFrame.transpose(metadata_pandas)

    # ====== Epoching the data
    print('  Epoching')

    # Events
    events = mne.find_events(raw,
                             stim_channel=config.stim_channel,
                             consecutive=True,
                             min_duration=config.min_event_duration,
                             shortest_event=config.shortest_event)

    if epoch_on_first_element:
        # fosca 06012020
        if tmin is None:
            tmin = -0.200
        if tmax is None:
            tmax = 0.25 * 17
        baseline = (tmin, 0)
        if (baseline is None) or (baseline is False):
            baseline = None
        for k in range(len(events)):
            events[k, 2] = k % 16 + 1
        epochs = mne.Epochs(raw,
                            events, {'sequence_starts': 1},
                            tmin,
                            tmax,
                            proj=True,
                            picks=picks,
                            baseline=baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)
        epochs.metadata = metadata_pandas[metadata_pandas['StimPosition'] ==
                                          1.0]
    else:
        if tmin is None:
            tmin = -0.050
        if tmax is None:
            tmax = 0.600
        if (baseline is None) or (baseline is False):
            baseline = None
        else:
            baseline = (tmin, 0)

        epochs = mne.Epochs(raw,
                            events,
                            None,
                            tmin,
                            tmax,
                            proj=True,
                            picks=picks,
                            baseline=baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)

        # Add metadata to epochs
        epochs.metadata = metadata_pandas

    # Save epochs (before AutoReject)

    if whattoreturn is None:
        print('  Writing epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_epo'
        else:
            extension = subject + '_epo'
        epochs_fname = op.join(output_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochs.save(epochs_fname, overwrite=True)
    elif whattoreturn == '':
        epochs.load_data()
        return epochs
    else:
        print("=== we continue on the autoreject part ===")

    if config.autoreject:
        epochs.load_data()
        # Running AutoReject "global" (https://autoreject.github.io) -> just get the thresholds
        from autoreject import get_rejection_threshold
        reject = get_rejection_threshold(epochs, ch_types=config.ch_types)
        epochsARglob = epochs.copy().drop_bad(reject=reject)
        print('  Writing "AR global" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_ARglob_epo'
        else:
            extension = subject + '_ARglob_epo'
        epochs_fname = op.join(output_dir,
                               config.base_fname.format(**locals()))
        if whattoreturn is None:
            print("Output: ", epochs_fname)
            epochsARglob.save(epochs_fname, overwrite=True)
            pickle.dump(
                reject, open(epochs_fname[:-4] + '_ARglob_thresholds.obj',
                             'wb'))
        elif whattoreturn == 'ARglobal':
            return epochsARglob
        else:
            print("==== continue to ARlocal ====")
        # Save autoreject thresholds

        # Running AutoReject "local" (https://autoreject.github.io)
        if ARlocal:
            ar = AutoReject()
            epochsAR, reject_log = ar.fit_transform(epochs, return_log=True)
            print('  Writing "AR local" cleaned epochs to disk')
            if epoch_on_first_element:
                extension = subject + '_1st_element_clean_epo'
            else:
                extension = subject + '_clean_epo'
            epochs_fname = op.join(output_dir,
                                   config.base_fname.format(**locals()))
            if whattoreturn is None:
                print("Output: ", epochs_fname)
                epochsAR.save(epochs_fname, overwrite=True)
                # Save autoreject reject_log
                pickle.dump(
                    reject_log,
                    open(epochs_fname[:-4] + '_reject_local_log.obj', 'wb'))
            else:
                return epochsAR
示例#15
0
 a_lst[kk] = list()
 # epoch raw into 5 sec trials
 events = mne.make_fixed_length_events(raw, duration=5.0)
 epochs = mne.Epochs(
     raw,
     events=events,
     tmin=0,
     tmax=5.0,
     baseline=None,
     reject=None,
     preload=True,
 )
 if not op.isfile(eps_fname):
     # k-fold CV thresholded artifact rejection
     ar = AutoReject()
     epochs = ar.fit_transform(epochs)
     print("      \nSaving ...%s" % op.relpath(eps_fname, defaults.megdata))
     epochs.save(eps_fname, overwrite=True)
 epochs = read_epochs(eps_fname)
 print(
     "%d, %d (Epochs, drops)"
     % (len(events), len(events) - len(epochs.selection))
 )
 # epochs.plot_psd()
 roi_nms = np.setdiff1d(np.arange(len(events)), epochs.selection)
 # raw = raw.copy().filter(lf, hf, fir_window='blackman',
 #                       method='iir', n_jobs=config.N_JOBS)
 iir_params = dict(order=4, ftype="butter", output="sos")
 epochs_ = epochs.copy().filter(
     hp, lp, method="iir", iir_params=iir_params, n_jobs=config.N_JOBS
 )
示例#16
0
def run_epochs(subject,
               epoch_on_first_element,
               baseline=True,
               l_freq=None,
               h_freq=None,
               suffix='_eeg_1Hz'):

    print("Processing subject: %s" % subject)
    meg_subject_dir = op.join(config.meg_dir, subject)
    run_info_subject_dir = op.join(config.run_info_dir, subject)
    raw_list = list()
    events_list = list()

    print("  Loading raw data")
    runs = config.runs_dict[subject]
    for run in runs:
        extension = run + '_ica_raw'
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        # ---------------------------------------------------------------------------------------------------------------- #
        # RESAMPLING EACH RUN BEFORE CONCAT & EPOCHING
        # Resampling the raw data while keeping events from original raw data, to avoid potential loss of
        # events when downsampling: https://www.nmr.mgh.harvard.edu/mne/dev/auto_examples/preprocessing/plot_resample.html
        # Find events
        events = mne.find_events(raw,
                                 stim_channel=config.stim_channel,
                                 consecutive=True,
                                 min_duration=config.min_event_duration,
                                 shortest_event=config.shortest_event)

        print('  Downsampling raw data')
        raw, events = raw.resample(config.resample_sfreq,
                                   npad='auto',
                                   events=events)
        if len(events) != 46 * 16:
            raise Exception('We expected %i events but we got %i' %
                            (46 * 16, len(events)))
        raw.filter(l_freq=1, h_freq=None)
        raw_list.append(raw)
        # ---------------------------------------------------------------------------------------------------------------- #

    if subject == 'sub08-cc_150418':
        # For this participant, we had some problems when concatenating the raws for run08. The error message said that raw08._cals didn't match the other ones.
        # We saw that it is the 'calibration' for the channel EOG061 that was different with respect to run09._cals.
        raw_list[7]._cals = raw_list[8]._cals
        print(
            'Warning: corrected an issue with subject08 run08 ica_raw data file...'
        )

    print('Concatenating runs')
    raw = mne.concatenate_raws(raw_list)
    if "eeg" in config.ch_types:
        raw.set_eeg_reference(projection=True)
    del raw_list

    meg = False
    if 'meg' in config.ch_types:
        meg = True
    elif 'grad' in config.ch_types:
        meg = 'grad'
    elif 'mag' in config.ch_types:
        meg = 'mag'
    eeg = 'eeg' in config.ch_types
    picks = mne.pick_types(raw.info,
                           meg=meg,
                           eeg=eeg,
                           stim=True,
                           eog=True,
                           exclude=())

    # Construct metadata from csv events file
    metadata = epoching_funcs.convert_csv_info_to_metadata(
        run_info_subject_dir)
    metadata_pandas = pd.DataFrame.from_dict(metadata, orient='index')
    metadata_pandas = pd.DataFrame.transpose(metadata_pandas)

    # ====== Epoching the data
    print('  Epoching')

    # Events
    events = mne.find_events(raw,
                             stim_channel=config.stim_channel,
                             consecutive=True,
                             min_duration=config.min_event_duration,
                             shortest_event=config.shortest_event)

    if epoch_on_first_element:
        # fosca 06012020
        config.tmin = -0.200
        config.tmax = 0.25 * 17
        config.baseline = (config.tmin, 0)
        if baseline is None:
            config.baseline = None
        for k in range(len(events)):
            events[k, 2] = k % 16 + 1
        epochs = mne.Epochs(raw,
                            events, {'sequence_starts': 1},
                            config.tmin,
                            config.tmax,
                            proj=True,
                            picks=picks,
                            baseline=config.baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)
        epochs.metadata = metadata_pandas[metadata_pandas['StimPosition'] ==
                                          1.0]
    else:
        config.tmin = -0.050
        config.tmax = 0.600
        config.baseline = (config.tmin, 0)
        if baseline is None:
            config.baseline = None
        epochs = mne.Epochs(raw,
                            events,
                            None,
                            config.tmin,
                            config.tmax,
                            proj=True,
                            picks=picks,
                            baseline=config.baseline,
                            preload=False,
                            decim=config.decim,
                            reject=None)

        # Add metadata to epochs
        epochs.metadata = metadata_pandas

    # Save epochs (before AutoReject)
    print('  Writing epochs to disk')
    if epoch_on_first_element:
        extension = subject + '_1st_element_epo' + suffix
    else:
        extension = subject + '_epo' + suffix
    epochs_fname = op.join(meg_subject_dir,
                           config.base_fname.format(**locals()))

    print("Output: ", epochs_fname)
    epochs.save(epochs_fname, overwrite=True)
    # epochs.save(epochs_fname)

    if config.autoreject:
        epochs.load_data()

        # Running AutoReject "global" (https://autoreject.github.io) -> just get the thresholds
        from autoreject import get_rejection_threshold
        reject = get_rejection_threshold(epochs,
                                         ch_types=['mag', 'grad', 'eeg'])
        epochsARglob = epochs.copy().drop_bad(reject=reject)
        print('  Writing "AR global" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_ARglob_epo' + suffix
        else:
            extension = subject + '_ARglob_epo' + suffix
        epochs_fname = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochsARglob.save(epochs_fname, overwrite=True)
        # Save autoreject thresholds
        pickle.dump(reject,
                    open(epochs_fname[:-4] + '_ARglob_thresholds.obj', 'wb'))

        # Running AutoReject "local" (https://autoreject.github.io)
        ar = AutoReject()
        epochsAR, reject_log = ar.fit_transform(epochs, return_log=True)
        print('  Writing "AR local" cleaned epochs to disk')
        if epoch_on_first_element:
            extension = subject + '_1st_element_clean_epo' + suffix
        else:
            extension = subject + '_clean_epo' + suffix
        epochs_fname = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        print("Output: ", epochs_fname)
        epochsAR.save(epochs_fname, overwrite=True)
        # Save autoreject reject_log
        pickle.dump(reject_log,
                    open(epochs_fname[:-4] + '_reject_local_log.obj', 'wb'))
示例#17
0
def main():

    # Initialize to keep track of how many events each subject has
    n_kept_events = []

    # Set the MNE print out level
    set_log_level('ERROR')

    # Get list of available subjects
    subnums = [name for name in os.listdir(DATA_PATH) if name[0] is not '.']

    # Loop across all subjects
    for idx, subnum in enumerate(subnums):

        # Add status updates
        print('\nRunning Subject # ', subnum)

        ## DATA LOADING

        print('\tData Wrangling')

        # Get the files that match the desired task, load and concatenate data
        subj_path = os.path.join(DATA_PATH, subnum, 'EEG', 'raw', 'raw_format')
        subj_files = list(fms.loc[(fms["SUBNUM"] == subnum)
                                  & (fms["TASK"] == TASK)]["FILE"].values)

        # Stop if subject doesn't have enough data, defined in terms of number of blocks ('runs')
        if len(subj_files) < N_BLOCKS:
            print('Subject does not have enough data')
            continue

        raws = [
            mne.io.read_raw_egi(os.path.join(subj_path, raw_file),
                                preload=True) for raw_file in subj_files
        ]

        # Set montage, drop misc channels, and filter
        montage = mne.channels.read_montage('GSN-HydroCel-129',
                                            ch_names=raws[0].ch_names)
        for raw in raws:
            raw.set_montage(montage)
            raw.set_channel_types(EOG_MAPPINGS)
            raw.drop_channels(raw.ch_names[128:-1])
            raw.filter(l_freq=L_FREQ, h_freq=H_FREQ, fir_design='firwin')

        # Concatenate raw objects into single new raw object
        raw = mne.concatenate_raws(raws)

        ## EVENT MANAGEMENT

        print('\tEvent Management')

        try:
            events = mne.find_events(raw, verbose=False)
        except:
            print('Subject has weird shortest_event error...skipping')
            continue

        # Create correct-response events
        new_events = []

        for trgt, rspn, corr in zip(TRGT_EVCS, RSPN_EVCS, CORR_EVCS):

            tmp, _ = define_target_events(events,
                                          rspn,
                                          trgt,
                                          raw.info['sfreq'],
                                          tmin=-2.,
                                          tmax=0.,
                                          new_id=corr,
                                          fill_na=None)
            new_events.append(tmp)

        # Collapse new events into an array
        new_events = np.concatenate(
            (np.array(new_events[0]), np.array(new_events[1])), axis=0)
        new_event_ids = dict(left=21, right=22)

        # Check if subject has enough epochs to continue
        if not np.any(new_events):
            print('Subject has no correct trials...')
            continue
        elif new_events.shape[0] < N_EPOCHS:
            print('Subject has too few trials for analysis: %d' %
                  new_events.shape[0])
            continue

        # Create epochs object
        epochs = mne.Epochs(raw,
                            new_events,
                            new_event_ids,
                            tmin=TMIN,
                            tmax=TMAX,
                            picks=None,
                            baseline=BASELINE,
                            reject=None,
                            preload=True)

        ## BAD CHANNELS & RE-REFERENCING

        print('\tBad Channels & Re-Referencing')

        # Mark bad channels via scrappy kurtosis z-score method
        bad_channels = faster_bad_channels(epochs, thres=5)
        raw.info['bads'] = bad_channels
        epochs.info['bads'] = bad_channels

        # Re-reference to average reference
        raw.set_eeg_reference('average', projection=False)
        epochs.set_eeg_reference('average', projection=False)

        ## ICA

        if RUN_ICA:

            print('\tRunning ICA')

            # High-pass filter for the purpose of ICA de-noising
            raw_hpf = raw.copy()
            raw_hpf.filter(l_freq=1.,
                           h_freq=None,
                           fir_design='firwin',
                           verbose=False)
            epochs_hpf = mne.Epochs(raw_hpf,
                                    new_events,
                                    new_event_ids,
                                    tmin=TMIN,
                                    tmax=TMAX,
                                    picks=None,
                                    baseline=BASELINE,
                                    reject=None,
                                    preload=True)

            # Get the EEG picks (eeg electrode indices), ignoring bad channels as marked by Faster
            eeg_ica_picks = mne.pick_types(epochs_hpf.info,
                                           meg=False,
                                           eeg=True)
            ica = ICA(random_state=1)
            ica.fit(epochs_hpf, picks=eeg_ica_picks)

            # Define bad components by correlating with channels near eyes
            eog_chs = [ch for ch in EOG_CHS if ch not in raw.info['bads']]

            bad_ica_comps = []
            for ch in eog_chs:
                inds, scores = ica.find_bads_eog(raw_hpf,
                                                 ch_name=ch,
                                                 threshold=4,
                                                 l_freq=1,
                                                 h_freq=8)
                bad_ica_comps.extend(inds)

            ica.exclude = list(set(bad_ica_comps))

            # Plot and save bad components
            if len(bad_ica_comps) > 0:
                fig = ica.plot_components(picks=np.array(bad_ica_comps),
                                          show=False)
                fig_name = os.path.join(FIG_PATH,
                                        subnum + '_ica_scalp_maps.png')
                fig.savefig(fig_name, dpi=150)

            # Save out ICA decomposition
            ica_filename = os.path.join(ICA_PATH, subnum + '-ica.fif')
            ica.save(ica_filename)

            # Apply ICA to both epoched and raw data
            epochs = ica.apply(epochs)
            raw = ica.apply(raw)

            raw_filename = os.path.join(PROC_PATH, subnum + '-raw.fif')
            raw.save(raw_filename, overwrite=True)

        ## Autoreject

        if RUN_AUTOREJECT:

            print('\tRunning AutoReject')

            # Use AutoReject to reject bad epochs and interpolate bad channels
            ar = AutoReject(n_jobs=4, random_state=1, verbose=False, cv=3)
            epochs, rej_log = ar.fit_transform(epochs, return_log=True)
            epochs.info['bads'] = [
            ]  # no need for bad channels after AutoReject

            # Save out autoreject log, as a pickled object
            pickle.dump(rej_log,
                        open(os.path.join(ICA_PATH, subnum + "-ar.p"), "wb"))

        ## SAVE OUT DATA

        # Enforce consistencty in the number of events per condition
        epochs.equalize_event_counts(new_event_ids, method='mintime')

        # Don't save the subject's data if they don't have enough epochs
        if len(epochs) < N_EPOCHS:
            print(
                'Subject has too few trials for analysis after preprocessing')
            continue

        # Save out pre-processed data
        epochs_filename = os.path.join(PROC_PATH,
                                       subnum + '_preprocessed-epo.fif')
        epochs.save(epochs_filename)

        print('\tData Saved - {:2d} kept events'.format(len(epochs)))

        n_kept_events.append((subnum, len(epochs)))

        print('\nGreat Success.\n')

    # Save out the log file of number of good events per subject
    with open('event_log.csv', 'w') as csvfile:
        writer = csv.writer(csvfile)
        for ev_info in n_kept_events:
            writer.writerow(list(ev_info))
###############################################################################
# Note that :class:`autoreject.AutoReject` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
# channel type and internally combines. This then readily supports cleaning
# unseen epochs from the different channel types used during fit.
# Here we only use a subset of channels to save time.

###############################################################################
# Also note that once the parameters are learned, any data can be repaired
# that contains channels that were used during fit. This also means that time
# may be saved by fitting :class:`autoreject.AutoReject` on a
# representative subsample of the data.

ar = AutoReject(picks=picks, random_state=42, n_jobs=1, verbose='tqdm')

epochs_ar, reject_log = ar.fit_transform(this_epoch, return_log=True)

###############################################################################
# We can visualize the cross validation curve over two variables

import numpy as np  # noqa
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa

set_matplotlib_defaults(plt, style='seaborn-white')
loss = ar.loss_['eeg'].mean(axis=-1)  # losses are stored by channel type.

plt.matshow(loss.T * 1e6, cmap=plt.get_cmap('viridis'))
plt.xticks(range(len(ar.consensus)), ar.consensus)
plt.yticks(range(len(ar.n_interpolate)), ar.n_interpolate)
示例#19
0
    sfreq=raw.info['sfreq'],
    first_samp=raw.first_samp,
    event_id=event_id,
    on_missing='ignore',
)
fig.subplots_adjust(right=0.7)  # make room for legend

for (e, i) in event_id.items():
    a = (events[:, -1] == i).sum()
    print(f"event {e} is present {a} times")

# %% use autoreject local to clean the data from remaining artifacts
if AUTOREJECT:
    ar = AutoReject()
    epochs.load_data()
    epochs_clean = ar.fit_transform(epochs)
else:
    epochs_clean = epochs

# %%
# Is this related with the bonferroni correction ?
reject = get_rejection_threshold(epochs)
print(reject)

# %%
evoked = epochs_clean['audiovis/1200Hz'].average()
evoked.plot()

# %%
epochs_clean.save(epochs_file, overwrite=True)
evoked.save(evoked_file)
示例#20
0
def run_preproc(datadir='/data'):

    print('data directory: {}'.format(datadir))
    conf_file_path = join(datadir, 'eegprep.conf')
    config = Configuration()
    config.setDefaults(defaults)
    if os.path.isfile(conf_file_path):
        with open(conf_file_path) as fh:
            conf_string = fh.read()
        config.updateFromString(conf_string)
    print('configuration:')
    print(config)

    bidsdir = join(datadir, 'BIDS')
    eegprepdir = join(bidsdir, 'derivatives', 'eegprep')

    
    subjectdirs = sorted(glob.glob(join(bidsdir, 'sub-*')))
    for subjectdir in subjectdirs:
        assert os.path.isdir(subjectdir)
        
        sub = basename(subjectdir)[4:]

        # prepare derivatives directory
        derivdir = join(eegprepdir, 'sub-' + sub)
        os.makedirs(derivdir, exist_ok=True)
        reportsdir = join(eegprepdir, 'reports', 'sub-' + sub)
        os.makedirs(reportsdir, exist_ok=True)


        subject_epochs = {}
        rawtypes = {'.set': mne.io.read_raw_eeglab, '.bdf': mne.io.read_raw_edf}
        for fname in sorted(glob.glob(join(subjectdir, 'eeg', '*'))):
            _, ext = splitext(fname)
            if ext not in rawtypes.keys():
                continue
            sub, ses, task, run = filename2tuple(basename(fname))

            print('\nProcessing raw file: ' + basename(fname))

            # read data
            raw = rawtypes[ext](fname, preload=True, verbose=False)
            events = mne.find_events(raw)  #raw, consecutive=False, min_duration=0.005)

            # Set channel types and select reference channels
            channelFile = fname.replace('eeg' + ext, 'channels.tsv')
            channels = pandas.read_csv(channelFile, index_col='name', sep='\t')
            bids2mne = {
                'MISC': 'misc',
                'EEG': 'eeg',
                'VEOG': 'eog',
                'TRIG': 'stim',
                'REF': 'eeg',
            }
            channels['mne'] = channels.type.replace(bids2mne)
            
            # the below fails if the specified channels are not in the data
            raw.set_channel_types(channels.mne.to_dict())

            # set bad channels
            raw.info['bads'] = channels[channels.status=='bad'].index.tolist()

            # pick channels to use for epoching
            epoching_picks = mne.pick_types(raw.info, eeg=True, eog=False, stim=False, exclude='bads')


            # Filtering
            #raw.filter(l_freq=0.05, h_freq=40, fir_design='firwin')

            montage = mne.channels.read_montage(guess_montage(raw.ch_names))
            print(montage)
            raw.set_montage(montage)

            # plot raw data
            nchans = len(raw.ch_names)
            pick_channels = numpy.arange(0, nchans, numpy.floor(nchans/20)).astype(int)
            start = numpy.round(raw.times.max()/2)
            fig = raw.plot(start=start, order=pick_channels)
            fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_raw.png'.format(sub, ses, task, run)
            fig.savefig(join(reportsdir, fname_plot))

            # Set reference
            refChannels = channels[channels.type=='REF'].index.tolist()
            raw.set_eeg_reference(ref_channels=refChannels)

            ##  epoching
            epochs_params = dict(
                events=events,
                tmin=-0.1,
                tmax=0.8,
                reject=None,  # dict(eeg=250e-6, eog=150e-6)
                picks=epoching_picks,
                detrend=0,
            )
            file_epochs = mne.Epochs(raw, preload=True, **epochs_params)
            file_epochs.drop_channels(refChannels)

            # autoreject (under development)
            ar = AutoReject(n_jobs=4)
            clean_epochs = ar.fit_transform(file_epochs)

            rejectlog = ar.get_reject_log(clean_epochs)
            fname_log = 'sub-{}_ses-{}_task-{}_run-{}_reject-log.npz'.format(sub, ses, task, run)
            save_rejectlog(join(reportsdir, fname_log), rejectlog)
            fig = plot_rejectlog(rejectlog)
            fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_bad-epochs.png'.format(sub, ses, task, run)
            fig.savefig(join(reportsdir, fname_plot))


            # store for now
            subject_epochs[(ses, task, run)] = clean_epochs

            # create evoked plots
            conds = clean_epochs.event_id.keys()
            selected_conds = random.sample(conds, min(len(conds), 6))
            picks = mne.pick_types(clean_epochs.info, eeg=True)
            for cond in selected_conds:
                evoked = clean_epochs[cond].average()
                fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_evoked-{}.png'.format(sub, ses, task, run, cond)
                fig = evoked.plot_joint(picks=picks)
                fig.savefig(join(reportsdir, fname_plot))



        sessSeg = 0
        sessions = sorted(list(set([k[sessSeg] for k in subject_epochs.keys()])))
        for session in sessions:
            taskSeg = 1
            tasks = list(set([k[taskSeg] for k in subject_epochs.keys() if k[sessSeg]==session]))
            for task in tasks:
                print('\nGathering epochs for session {} task {}'.format(session, task))
                epochs_selection = [v for (k, v) in subject_epochs.items() if k[:2]==(session, task)]

                task_epochs = mne.epochs.concatenate_epochs(epochs_selection)
                
                # downsample if configured to do so
                # important to do this after concatenation because 
                # downsampling may cause rejection for 'TOOSHORT'
                if config['downsample'] < task_epochs.info['sfreq']:
                    task_epochs = task_epochs.copy().resample(config['downsample'], npad='auto')

                ext = config['out_file_format']
                fname = join(derivdir, 'sub-{}_ses-{}_task-{}_epo.{}'.format(sub, session, task, ext))
                variables = {
                    'epochs': task_epochs.get_data(),
                    'events': task_epochs.events,
                    'timepoints': task_epochs.times
                }
                if ext == 'fif':
                    task_epochs.save(fname)
                elif ext == 'mat':
                    scipy.io.savemat(fname, mdict=variables)
                elif ext == 'npy':
                    numpy.savez(fname, **variables)
ecg_event_id = 999
eog_event_id = 998

eog_events = mne.preprocessing.find_eog_events(raw)
ecg_events = mne.preprocessing.find_ecg_events(raw)
ecg_events = np.asarray(ecg_events[0])

epochs_all = mne.Epochs(raw,
                        np.concatenate((events, ecg_events, eog_events),
                                       axis=0),
                        reject=None,
                        preload=True,
                        event_repeated='drop')

ar = AutoReject(random_state=97, n_jobs=1)
epochs_ar, reject_log = ar.fit_transform(epochs_all, return_log=True)

n_blinks = len(eog_events)
#onset = eog_events[:, 0] / raw.info['sfreq'] - 0.25
#duration = np.repeat(0.5, n_blinks)
#description = ['blink'] * n_blinks
#annotations = mne.Annotations(onset, duration, description)
#raw.set_annotations(annotations)
#epochs_blink = mne.Epochs(raw, eog_events, eog_event_id, reject=None, preload=True)
epochs_blink = mne.Epochs(raw,
                          eog_events,
                          eog_event_id,
                          reject=None,
                          preload=True,
                          tmin=-0.5,
                          tmax=0.5)
示例#22
0
# subject = config.subjects_list[11]
subject = 'sub08-cc_150418'
meg_subject_dir = op.join(config.meg_dir, subject)
epochs = epoching_funcs.load_epochs_items(subject, cleaned=False)

# run autoreject "global" -> just get the thresholds
reject = get_rejection_threshold(epochs, ch_types=['mag', 'grad', 'eeg'])
epochs1 = epochs.copy().drop_bad(reject=reject)
fname = op.join(meg_subject_dir, 'epochs_globalAR-epo.fif')
print("Saving: ", fname)
epochs1.save(fname, overwrite=True)

# run autoreject "local"
ar = AutoReject()
epochs2, reject_log = ar.fit_transform(epochs, return_log=True)
fname = op.join(meg_subject_dir, 'epochs_localAR-epo.fif')
print("Saving: ", fname)
epochs2.save(fname, overwrite=True)
# Save autoreject reject_log
pickle.dump(reject_log, open(fname[:-4] + '_reject_log.obj', 'wb'))

######################
fname = op.join(meg_subject_dir, 'epochs_globalAR-epo.fif')
epochs1 = mne.read_epochs(fname, preload=True)
epochs1
epochs1['ViolationOrNot == 1'].copy().average().plot_joint()

fname = op.join(meg_subject_dir, 'epochs_localAR-epo.fif')
epochs2 = mne.read_epochs(fname, preload=True)
epochs2['ViolationOrNot == 1'].copy().average().plot_joint()
def main():

    # Initialize fg
    # TODO: add any settings we want to ue
    fg = FOOOFGroup(peak_width_limits=[1, 6],
                    min_peak_amplitude=0.075,
                    max_n_peaks=6,
                    peak_threshold=1,
                    verbose=False)

    # Save out a settings file
    fg.save(file_name=GROUP + '_fooof_group_settings',
            file_path=SAVE_PATH,
            save_settings=True)

    # START LOOP
    for sub in SUBJ_DAT_NUM:

        print('Current Subject' + str(sub))

        # load subject data
        subj_dat_fname = str(sub) + "_resampled.set"
        full_path = os.path.join(BASE_PATH, subj_dat_fname)
        path_check = Path(full_path)

        if path_check.is_file():

            eeg_dat = mne.io.read_raw_eeglab(full_path,
                                             event_id_func=None,
                                             preload=True)
            evs = mne.io.eeglab.read_events_eeglab(full_path, EV_DICT)

            new_evs = np.empty(shape=(0, 3))

            for ev_label in BLOCK_EVS:
                ev_code = EV_DICT[ev_label]
                temp = evs[evs[:, 2] == ev_code]
                new_evs = np.vstack([new_evs, temp])

            eeg_dat.add_events(new_evs)

            # set EEG average reference
            eeg_dat.set_eeg_reference()

            ## PRE-PROCESSING: ICA
            if RUN_ICA:

                # ICA settings
                method = 'fastica'
                n_components = 0.99
                random_state = 47
                reject = {'eeg': 20e-4}

                # Initialize ICA object
                ica = ICA(n_components=n_components,
                          method=method,
                          random_state=random_state)

                # High-pass filter data for running ICA
                eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

                # Fit ICA
                ica.fit(eeg_dat, reject=reject)

                # Find components to drop, based on correlation with EOG channels
                drop_inds = []
                for chi in EOG_CHS:
                    inds, scores = ica.find_bads_eog(eeg_dat,
                                                     ch_name=chi,
                                                     threshold=2.5,
                                                     l_freq=1,
                                                     h_freq=10,
                                                     verbose=False)
                    drop_inds.extend(inds)
                drop_inds = list(set(drop_inds))

                # Set which components to drop, and collect record of this
                ica.exclude = drop_inds
                #dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

                # Save out ICA solution
                ica.save(pjoin(ICA_PATH, str(sub) + '-ica.fif'))

                # Apply ICA to data
                eeg_dat = ica.apply(eeg_dat)

            ## EPOCH BLOCKS
            events = mne.find_events(eeg_dat)

            #epochs = mne.Epochs(eeg_dat, events=events, tmin=5, tmax=125, baseline=None, preload=True)
            rest_epochs = mne.Epochs(eeg_dat,
                                     events=events,
                                     event_id=REST_EVENT_ID,
                                     tmin=5,
                                     tmax=125,
                                     baseline=None,
                                     preload=True)
            trial_epochs = mne.Epochs(eeg_dat,
                                      events=events,
                                      event_id=TRIAL_EVENT_ID,
                                      tmin=5,
                                      tmax=125,
                                      baseline=None,
                                      preload=True)

            ## PRE-PROCESSING: AUTO-REJECT
            if RUN_AUTOREJECT:

                # Initialize and run autoreject across epochs
                ar = AutoReject(n_jobs=4, verbose=False)
                epochs, rej_log = ar.fit_transform(epochs, True)

                # Drop same trials from filtered data
                rest_epochs.drop(rej_log.bad_epochs)
                trial_epochs.drop(rej_log.bad_epochs)

                # Collect list of dropped trials
                dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(
                    rej_log.bad_epochs)[0]

            # Set montage
            chs = mne.channels.read_montage('standard_1020',
                                            rest_epochs.ch_names[:-1])
            rest_epochs.set_montage(chs)
            trial_epochs.set_montage(chs)

            # Calculate PSDs
            rest_psds, rest_freqs = mne.time_frequency.psd_welch(rest_epochs,
                                                                 fmin=1.,
                                                                 fmax=50.,
                                                                 n_fft=2000,
                                                                 n_overlap=250,
                                                                 n_per_seg=500)
            trial_psds, trial_freqs = mne.time_frequency.psd_welch(
                trial_epochs,
                fmin=1.,
                fmax=50.,
                n_fft=2000,
                n_overlap=250,
                n_per_seg=500)

            # Setting frequency range
            freq_range = [3, 30]

            ## FOOOF the Data

            # Rest Data
            for ind, entry in enumerate(rest_psds):
                rest_fooof_psds = rest_psds[ind, :, :]
                fg.fit(rest_freqs, rest_fooof_psds, freq_range)
                fg.save(file_name=str(sub) + 'fooof_group_results' + str(ind),
                        file_path=REST_SAVE_PATH,
                        save_results=True)

            # Trial Data
            for ind, entry in enumerate(trial_psds):
                trial_fooof_psds = trial_psds[ind, :, :]
                fg.fit(trial_freqs, trial_fooof_psds, freq_range)
                fg.save(file_name=str(sub) + 'fooof_group_results' + str(ind),
                        file_path=TRIAL_SAVE_PATH,
                        save_results=True)

            print('Subject Saved')

        else:

            print('Current Subject' + str(sub) + ' does not exist')
            print(path_check)

    print('Pre-processing Complete')