示例#1
0
def test_compute_whitener_rank():
    """Test risky rank options."""
    info = read_info(ave_fname)
    info = pick_info(info, pick_types(info, meg=True))
    with info._unlock():
        info['projs'] = []
    # need a square version because the diag one takes shortcuts in
    # compute_whitener (users shouldn't even need this function so it's
    # private)
    cov = make_ad_hoc_cov(info)._as_square()
    assert len(cov['names']) == 306
    _, _, rank = compute_whitener(cov, info, rank=None, return_rank=True)
    assert rank == 306
    assert compute_rank(cov, info=info, verbose=True) == dict(meg=rank)
    cov['data'][-1] *= 1e-14  # trivially rank-deficient
    _, _, rank = compute_whitener(cov, info, rank=None, return_rank=True)
    assert rank == 305
    assert compute_rank(cov, info=info, verbose=True) == dict(meg=rank)
    # this should emit a warning
    with pytest.warns(RuntimeWarning, match='exceeds the estimated'):
        _, _, rank = compute_whitener(cov,
                                      info,
                                      rank=dict(meg=306),
                                      return_rank=True)
    assert rank == 306
示例#2
0
def _compute_rank(p, subj, run_indices):
    """Compute rank of the data."""
    epochs_fnames, _ = get_epochs_evokeds_fnames(p, subj, p.analyses)
    _, fif_file = epochs_fnames
    epochs = read_epochs(fif_file)  # .crop(p.bmin, p.bmax)  maybe someday...?
    meg, eeg = 'meg' in epochs, 'eeg' in epochs
    rank = dict()
    epochs.apply_proj()
    if p.cov_rank_method == 'estimate_rank':
        # old way
        if meg:
            eps = epochs.copy().pick_types(meg=meg, eeg=False)
            eps = eps.get_data().transpose([1, 0, 2])
            eps = eps.reshape(len(eps), -1)
            if 'grad' in epochs and 'mag' in epochs:  # Neuromag
                key = 'meg'
            else:
                key = 'grad' if 'grad' in epochs else 'mag'
            rank[key] = estimate_rank(eps, tol=p.cov_rank_tol)
        if eeg:
            eps = epochs.copy().pick_types(meg=False, eeg=eeg)
            eps = eps.get_data().transpose([1, 0, 2])
            eps = eps.reshape(len(eps), -1)
            rank['eeg'] = estimate_rank(eps, tol=p.cov_rank_tol)
    else:
        assert p.cov_rank_method == 'compute_rank'
        # new way
        rank = compute_rank(epochs, tol=p.cov_rank_tol, tol_kind='relative')
    for k, v in rank.items():
        print(' : %s rank %2d' % (k.upper(), v), end='')
    return rank
示例#3
0
def test_lcmv_maxfiltered():
    """Test LCMV on maxfiltered data."""
    raw = mne.io.read_raw_fif(fname_raw).fix_mag_coil_types()
    raw_sss = mne.preprocessing.maxwell_filter(raw)
    events = mne.find_events(raw_sss)
    del raw
    raw_sss.pick_types(meg='mag')
    assert len(raw_sss.ch_names) == 102
    epochs = mne.Epochs(raw_sss, events)
    data_cov = mne.compute_covariance(epochs, tmin=0)
    fwd = mne.read_forward_solution(fname_fwd)
    rank = compute_rank(data_cov, info=epochs.info)
    assert rank == {'mag': 71}
    for use_rank in ('info', rank, 'full', None):
        make_lcmv(epochs.info, fwd, data_cov, rank=use_rank)
示例#4
0
def run_maxwell_filter(subject, session=None):
    if config.proc and 'sss' in config.proc and config.use_maxwell_filter:
        raise ValueError(f'You cannot set use_maxwell_filter to True '
                         f'if data have already processed with Maxwell-filter.'
                         f' Got proc={config.proc}.')

    bids_path_in = BIDSPath(subject=subject,
                            session=session,
                            task=config.get_task(),
                            acquisition=config.acq,
                            processing=config.proc,
                            recording=config.rec,
                            space=config.space,
                            suffix=config.get_datatype(),
                            datatype=config.get_datatype(),
                            root=config.bids_root)
    bids_path_out = bids_path_in.copy().update(suffix='raw',
                                               root=config.deriv_root,
                                               check=False)

    # Load dev_head_t and digitization points from MaxFilter reference run.
    # Re-use in all runs and for processing empty-room recording.
    if config.use_maxwell_filter:
        reference_run = config.get_mf_reference_run()
        msg = f'Loading reference run: {reference_run}.'
        logger.info(
            gen_log_message(message=msg,
                            step=1,
                            subject=subject,
                            session=session))
        bids_path_in.update(run=reference_run)
        info = mne.io.read_info(bids_path_in.fpath)
        dev_head_t = info['dev_head_t']
        dig = info['dig']
        del reference_run, info

    for run_idx, run in enumerate(config.get_runs()):
        bids_path_in.update(run=run)
        bids_path_out.update(run=run)
        raw = load_data(bids_path_in)

        # Fix stimulation artifact
        if config.fix_stim_artifact:
            events, _ = mne.events_from_annotations(raw)
            raw = mne.preprocessing.fix_stim_artifact(
                raw,
                events=events,
                event_id=None,
                tmin=config.stim_artifact_tmin,
                tmax=config.stim_artifact_tmax,
                mode='linear')

        # Auto-detect bad channels.
        if config.find_flat_channels_meg or config.find_noisy_channels_meg:
            find_bad_channels(raw=raw,
                              subject=subject,
                              session=session,
                              task=config.get_task(),
                              run=run)

        # Maxwell-filter experimental data.
        if config.use_maxwell_filter:
            msg = 'Applying Maxwell filter to experimental data.'
            logger.info(
                gen_log_message(message=msg,
                                step=1,
                                subject=subject,
                                session=session))

            # Warn if no bad channels are set before Maxwell filter
            if not raw.info['bads']:
                msg = '\nFound no bad channels. \n '
                logger.warning(
                    gen_log_message(message=msg,
                                    subject=subject,
                                    step=1,
                                    session=session))

            if config.mf_st_duration:
                msg = '    st_duration=%d' % (config.mf_st_duration)
                logger.info(
                    gen_log_message(message=msg,
                                    step=1,
                                    subject=subject,
                                    session=session))

            # Keyword arguments shared between Maxwell filtering of the
            # experimental and the empty-room data.
            common_mf_kws = dict(calibration=get_mf_cal_fname(
                subject, session),
                                 cross_talk=get_mf_ctc_fname(subject, session),
                                 st_duration=config.mf_st_duration,
                                 origin=config.mf_head_origin,
                                 coord_frame='head',
                                 destination=dev_head_t)

            raw_sss = mne.preprocessing.maxwell_filter(raw, **common_mf_kws)
            raw_out = raw_sss
            raw_fname_out = (bids_path_out.copy().update(processing='sss',
                                                         extension='.fif'))
        elif config.ch_types == ['eeg']:
            msg = 'Not applying Maxwell filter to EEG data.'
            logger.info(
                gen_log_message(message=msg,
                                step=1,
                                subject=subject,
                                session=session))
            raw_out = raw
            raw_fname_out = bids_path_out.copy().update(extension='.fif')
        else:
            msg = ('Not applying Maxwell filter.\nIf you wish to apply it, '
                   'set use_maxwell_filter=True in your configuration.')
            logger.info(
                gen_log_message(message=msg,
                                step=1,
                                subject=subject,
                                session=session))
            raw_out = raw
            raw_fname_out = bids_path_out.copy().update(extension='.fif')

        # Save only the channel types we wish to analyze (including the
        # channels marked as "bad").
        # We do not run `raw_out.pick()` here because it uses too much memory.
        chs_to_include = config.get_channels_to_analyze(raw_out.info)
        raw_out.save(raw_fname_out,
                     picks=chs_to_include,
                     overwrite=True,
                     split_naming='bids')
        del raw_out
        if config.interactive:
            # Load the data we have just written, because it contains only
            # the relevant channels.
            raw = mne.io.read_raw_fif(raw_fname_out, allow_maxshield=True)
            raw.plot(n_channels=50, butterfly=True)

        # Empty-room processing.
        #
        # We pick the empty-room recording closest in time to the first run
        # of the experimental session.
        if run_idx == 0 and config.process_er:
            msg = 'Processing empty-room recording …'
            logger.info(
                gen_log_message(step=1,
                                subject=subject,
                                session=session,
                                message=msg))

            bids_path_er_in = bids_path_in.find_empty_room()
            raw_er = load_data(bids_path_er_in)
            raw_er.info['bads'] = [
                ch for ch in raw.info['bads'] if ch.startswith('MEG')
            ]

            # Maxwell-filter empty-room data.
            if config.use_maxwell_filter:
                msg = 'Applying Maxwell filter to empty-room recording'
                logger.info(
                    gen_log_message(message=msg,
                                    step=1,
                                    subject=subject,
                                    session=session))

                # We want to ensure we use the same coordinate frame origin in
                # empty-room and experimental data processing. To do this, we
                # inject the sensor locations and the head <> device transform
                # into the empty-room recording's info, and leave all other
                # parameters the same as for the experimental data. This is not
                # very clean, as we normally should not alter info manually,
                # except for info['bads']. Will need improvement upstream in
                # MNE-Python.
                raw_er.info['dig'] = dig
                raw_er.info['dev_head_t'] = dev_head_t
                raw_er_sss = mne.preprocessing.maxwell_filter(
                    raw_er, **common_mf_kws)

                # Perform a sanity check: empty-room rank should match the
                # experimental data rank after Maxwell filtering.
                rank_exp = mne.compute_rank(raw, rank='info')['meg']
                rank_er = mne.compute_rank(raw_er, rank='info')['meg']
                if not np.isclose(rank_exp, rank_er):
                    msg = (f'Experimental data rank {rank_exp:.1f} does not '
                           f'match empty-room data rank {rank_er:.1f} after '
                           f'Maxwell filtering. This indicates that the data '
                           f'were processed  differenlty.')
                    raise RuntimeError(msg)

                raw_er_out = raw_er_sss
                raw_er_fname_out = bids_path_out.copy().update(
                    processing='sss')
            else:
                raw_er_out = raw_er
                raw_er_fname_out = bids_path_out.copy()

            raw_er_fname_out = raw_er_fname_out.update(task='noise',
                                                       extension='.fif',
                                                       run=None)

            # Save only the channel types we wish to analyze
            # (same as for experimental data above).
            raw_er_out.save(raw_er_fname_out,
                            picks=chs_to_include,
                            overwrite=True,
                            split_naming='bids')
            del raw_er_out
def run_maxwell_filter(subject, session=None):
    deriv_path = config.get_subject_deriv_path(subject=subject,
                                               session=session,
                                               kind=config.get_kind())
    os.makedirs(deriv_path, exist_ok=True)

    for run_idx, run in enumerate(config.get_runs()):
        bids_basename = make_bids_basename(subject=subject,
                                           session=session,
                                           task=config.get_task(),
                                           acquisition=config.acq,
                                           run=run,
                                           processing=config.proc,
                                           recording=config.rec,
                                           space=config.space)

        raw = load_data(bids_basename)
        if run_idx == 0:
            dev_head_t = raw.info['dev_head_t']  # Re-use in all runs.

        # Auto-detect bad channels.
        if config.find_flat_channels_meg or config.find_noisy_channels_meg:
            find_bad_channels(raw=raw,
                              subject=subject,
                              session=session,
                              task=config.get_task(),
                              run=run)

        # Maxwell-filter experimental data.
        if config.use_maxwell_filter:
            msg = 'Applying Maxwell filter to experimental data.'
            logger.info(
                gen_log_message(message=msg,
                                step=1,
                                subject=subject,
                                session=session))

            # Warn if no bad channels are set before Maxwell filter
            if not raw.info['bads']:
                msg = '\nFound no bad channels. \n '
                logger.warn(
                    gen_log_message(message=msg,
                                    subject=subject,
                                    step=1,
                                    session=session))

            if config.mf_st_duration:
                msg = '    st_duration=%d' % (config.mf_st_duration)
                logger.info(
                    gen_log_message(message=msg,
                                    step=1,
                                    subject=subject,
                                    session=session))

            # Keyword arguments shared between Maxwell filtering of the
            # experimental and the empty-room data.
            common_mf_kws = dict(calibration=config.mf_cal_fname,
                                 cross_talk=config.mf_ctc_fname,
                                 st_duration=config.mf_st_duration,
                                 origin=config.mf_head_origin,
                                 coord_frame='head',
                                 destination=dev_head_t)

            raw_sss = mne.preprocessing.maxwell_filter(raw, **common_mf_kws)
            raw_out = raw_sss
            raw_fname_out = op.join(deriv_path, f'{bids_basename}_sss_raw.fif')
        else:
            msg = ('Not applying Maxwell filter.\nIf you wish to apply it, '
                   'set use_maxwell_filter=True in your configuration.')
            logger.info(
                gen_log_message(message=msg,
                                step=1,
                                subject=subject,
                                session=session))
            raw_out = raw
            raw_fname_out = op.join(deriv_path,
                                    f'{bids_basename}_nosss_raw.fif')
        raw_out.save(raw_fname_out, overwrite=True)
        if config.interactive:
            raw_out.plot(n_channels=50, butterfly=True)

        # Empty-room processing.
        #
        # We pick the empty-room recording closest in time to the first run
        # of the experimental session.
        if run_idx == 0 and config.noise_cov == 'emptyroom':
            msg = 'Processing empty-room recording …'
            logger.info(
                gen_log_message(step=1,
                                subject=subject,
                                session=session,
                                message=msg))

            bids_basename_er_in = get_matched_empty_room(
                bids_basename=bids_basename, bids_root=config.bids_root)
            raw_er = load_data(bids_basename_er_in)
            raw_er.info['bads'] = [
                ch for ch in raw.info['bads'] if ch.startswith('MEG')
            ]

            # Maxwell-filter empty-room data.
            if config.use_maxwell_filter:
                msg = 'Applying Maxwell filter to empty-room recording'
                logger.info(
                    gen_log_message(message=msg,
                                    step=1,
                                    subject=subject,
                                    session=session))

                # We want to ensure we use the same coordinate frame origin in
                # empty-room and experimental data processing. To do this, we
                # inject the sensor locations and the head <> device transform
                # into the empty-room recording's info, and leave all other
                # parameters the same as for the experimental data. This is not
                # very clean, as we normally should not alter info manually,
                # except for info['bads']. Will need improvement upstream in
                # MNE-Python.
                raw_er.info['dig'] = raw.info['dig']
                raw_er.info['dev_head_t'] = dev_head_t
                raw_er_sss = mne.preprocessing.maxwell_filter(
                    raw_er, **common_mf_kws)

                # Perform a sanity check: empty-room rank should match the
                # experimental data rank after Maxwell filtering.
                rank_exp = mne.compute_rank(raw, rank='info')['meg']
                rank_er = mne.compute_rank(raw_er, rank='info')['meg']
                if not np.isclose(rank_exp, rank_er):
                    msg = (f'Experimental data rank {rank_exp:.1f} does not '
                           f'match empty-room data rank {rank_er:.1f} after '
                           f'Maxwell filtering. This indicates that the data '
                           f'were processed  differenlty.')
                    raise RuntimeError(msg)

                raw_er_out = raw_er_sss
                raw_er_fname_out = op.join(
                    deriv_path, f'{bids_basename}_emptyroom_sss_raw.fif')
            else:
                raw_er_out = raw_er
                raw_er_fname_out = op.join(
                    deriv_path, f'{bids_basename}_emptyroom_nosss_raw.fif')

            raw_er_out.save(raw_er_fname_out, overwrite=True)
                    'sub-{}_task-{}-fwd.fif'.format(subject, task))
subjects_dir = op.join(data_path, 'derivatives', 'freesurfer', 'subjects')

fwd = mne.read_forward_solution(fname_fwd)

# %%
# Compute covariances
# -------------------
# ERS activity starts at 0.5 seconds after stimulus onset. Because these
# data have been processed by MaxFilter directly (rather than MNE-Python's
# version), we have to be careful to compute the rank with a more conservative
# threshold in order to get the correct data rank (64). Once this is used in
# combination with an advanced covariance estimator like "shrunk", the rank
# will be correctly preserved.

rank = mne.compute_rank(epochs, tol=1e-6, tol_kind='relative')
active_win = (0.5, 1.5)
baseline_win = (-1, 0)
baseline_cov = compute_covariance(epochs,
                                  tmin=baseline_win[0],
                                  tmax=baseline_win[1],
                                  method='shrunk',
                                  rank=rank,
                                  verbose=True)
active_cov = compute_covariance(epochs,
                                tmin=active_win[0],
                                tmax=active_win[1],
                                method='shrunk',
                                rank=rank,
                                verbose=True)
def run_maxwell_filter(*, cfg, subject, session=None):
    if cfg.proc and 'sss' in cfg.proc and cfg.use_maxwell_filter:
        raise ValueError(f'You cannot set use_maxwell_filter to True '
                         f'if data have already processed with Maxwell-filter.'
                         f' Got proc={config.proc}.')

    bids_path_out = BIDSPath(subject=subject,
                             session=session,
                             task=cfg.task,
                             acquisition=cfg.acq,
                             processing='sss',
                             recording=cfg.rec,
                             space=cfg.space,
                             suffix='raw',
                             extension='.fif',
                             datatype=cfg.datatype,
                             root=cfg.deriv_root,
                             check=False)

    # Load dev_head_t and digitization points from MaxFilter reference run.
    # Re-use in all runs and for processing empty-room recording.
    msg = f'Loading reference run: {cfg.mf_reference_run}.'
    logger.info(**gen_log_kwargs(message=msg, subject=subject,
                                 session=session))

    reference_run_info = get_reference_run_info(
        subject=subject, session=session, run=cfg.mf_reference_run
    )
    dev_head_t = reference_run_info['dev_head_t']
    dig = reference_run_info['dig']
    del reference_run_info

    for run in cfg.runs:
        bids_path_out.update(run=run)

        raw = import_experimental_data(
            cfg=cfg,
            subject=subject,
            session=session,
            run=run,
            save=False
        )

        # Maxwell-filter experimental data.
        msg = 'Applying Maxwell filter to experimental data.'
        logger.info(**gen_log_kwargs(message=msg, subject=subject,
                                     session=session, run=run))

        # Warn if no bad channels are set before Maxwell filter
        # Create a copy, we'll need this later for setting the bads of the
        # empty-room recording
        bads = raw.info['bads'].copy()
        if not bads:
            msg = 'Found no bad channels.'
            logger.warning(**gen_log_kwargs(message=msg, subject=subject,
                                            session=session, run=run))

        if cfg.mf_st_duration:
            msg = '    st_duration=%d' % (cfg.mf_st_duration)
            logger.info(**gen_log_kwargs(message=msg,
                                         subject=subject, session=session,
                                         run=run))

        # Keyword arguments shared between Maxwell filtering of the
        # experimental and the empty-room data.
        common_mf_kws = dict(
            calibration=cfg.mf_cal_fname,
            cross_talk=cfg.mf_ctc_fname,
            st_duration=cfg.mf_st_duration,
            origin=cfg.mf_head_origin,
            coord_frame='head',
            destination=dev_head_t
        )

        raw_sss = mne.preprocessing.maxwell_filter(raw, **common_mf_kws)

        # Save only the channel types we wish to analyze (including the
        # channels marked as "bad").
        # We do not run `raw_sss.pick()` here because it uses too much memory.
        picks = config.get_channels_to_analyze(raw.info)
        raw_sss.save(bids_path_out, picks=picks, split_naming='bids',
                     overwrite=True)
        del raw_sss

        if cfg.interactive:
            # Load the data we have just written, because it contains only
            # the relevant channels.
            raw = mne.io.read_raw_fif(bids_path_out, allow_maxshield=True)
            raw.plot(n_channels=50, butterfly=True)

        # Empty-room processing.
        # Only process empty-room data once – we ensure this by simply checking
        # if the current run is the reference run, and only then initiate
        # empty-room processing. No sophisticated logic behind this – it's just
        # convenient to code it this way.
        if cfg.process_er and run == cfg.mf_reference_run:
            msg = 'Processing empty-room recording …'
            logger.info(**gen_log_kwargs(subject=subject,
                                         session=session, message=msg))

            raw_er = import_er_data(
                cfg=cfg,
                subject=subject,
                session=session,
                bads=bads,
                save=False
            )

            # Maxwell-filter empty-room data.
            msg = 'Applying Maxwell filter to empty-room recording'
            logger.info(**gen_log_kwargs(message=msg,
                                         subject=subject, session=session))

            # We want to ensure we use the same coordinate frame origin in
            # empty-room and experimental data processing. To do this, we
            # inject the sensor locations and the head <> device transform
            # into the empty-room recording's info, and leave all other
            # parameters the same as for the experimental data. This is not
            # very clean, as we normally should not alter info manually,
            # except for info['bads']. Will need improvement upstream in
            # MNE-Python.
            raw_er.info['dig'] = dig
            raw_er.info['dev_head_t'] = dev_head_t
            raw_er_sss = mne.preprocessing.maxwell_filter(raw_er,
                                                          **common_mf_kws)

            # Perform a sanity check: empty-room rank should match the
            # experimental data rank after Maxwell filtering.
            raw_sss = mne.io.read_raw_fif(bids_path_out)
            rank_exp = mne.compute_rank(raw_sss, rank='info')['meg']
            rank_er = mne.compute_rank(raw_er_sss, rank='info')['meg']
            if not np.isclose(rank_exp, rank_er):
                msg = (f'Experimental data rank {rank_exp:.1f} does not '
                       f'match empty-room data rank {rank_er:.1f} after '
                       f'Maxwell filtering. This indicates that the data '
                       f'were processed  differently.')
                raise RuntimeError(msg)

            raw_er_fname_out = bids_path_out.copy().update(
                task='noise',
                run=None,
                processing='sss'
            )

            # Save only the channel types we wish to analyze
            # (same as for experimental data above).
            raw_er_sss.save(raw_er_fname_out, picks=picks,
                            overwrite=True, split_naming='bids')
            del raw_er_sss
示例#8
0
import os.path as op

import numpy as np
import matplotlib.pyplot as plt
import mne

import config_drago as cfg

info = np.load(op.join(cfg.path_data, 'info_allch.npy')).item()
picks = mne.pick_types(info, meg='mag')

fname = op.join(cfg.path_outputs, 'covs_allch_oas.h5')
data = mne.externals.h5io.read_hdf5(fname)  # (sub, fb, ch, ch)

subjects = [d['subject'] for d in data if 'subject' in d]
covs = [d['covs'][:, picks][:, :, picks] for d in data if 'subject' in d]
covs = np.array(covs)  # (sub,fb,chan,chan)

ranks = []
for sub in range(len(subjects)):
    cov = mne.Covariance(covs[sub][4],
                         np.array(info['ch_names'])[picks], [], [], 1)
    ranks.append(mne.compute_rank(cov, info=info)['mag'])
plt.figure()
plt.hist(ranks)
def run_ica_correction(sub_id,
                       raw,
                       method='picard',
                       reject=None,
                       decim=3,
                       random_state=42,
                       show_figs=False,
                       results_dir=None):
    """
    Fit and apply ICA to correct heartbeats and blinks

    Parameters
    ----------

    sub_id: str
            Subject ID

    raw: mne.Raw object
         Data to apply ICA to

    method: str, default 'picard'
            Method for the ICA algorithm. See the documentation on the MNE webpage for other
            options

    reject: None or dict, default None
            Reject values previous to the ICA fit and application

    decim: int, default 3
           ICA parameter. Check out the documentation for more info

    random_state: int, default 42
                  Seed for ICA random state

    Returns
    -------

    sub_id: str
            Subject ID, unchanged

    raw: mne.Raw object
         Our old friend, but now without blinks and hearbeats (hopefully)

    ica: mne.ICA object
         The ICA object with the configuration used to eliminate artifacts
    """
    # picks_meg = mne.pick_types(raw.info, meg=True, eog=True, stim=True,
    #                            exclude='bads')

    # The rank will be used as a reference for the ICA components
    rank = mne.compute_rank(raw)

    ica = ICA(n_components=rank['meg'],
              method=method,
              random_state=random_state)
    ica.fit(raw, decim=decim,
            reject=reject)  # {'mag': 5e-12, 'grad': 4000e-13}

    # Create eog/ecg epochs using our dedicated channels in the data, and then find data
    # segments containing artifacts
    eog_epochs = create_eog_epochs(raw, ch_name='EOG061', reject=None)
    eog_inds, _ = ica.find_bads_eog(eog_epochs)

    ecg_epochs = create_ecg_epochs(raw, reject=None)
    ecg_inds, _ = ica.find_bads_ecg(ecg_epochs)

    # Compute an average of our eog/ecg epochs
    eog_avg = eog_epochs.average()
    ecg_avg = ecg_epochs.average()

    if show_figs or results_dir is not None:
        fig_eog = ica.plot_overlay(eog_avg, exclude=eog_inds, show=show_figs)
        fig_ecg = ica.plot_overlay(ecg_avg, exclude=ecg_inds, show=show_figs)
    if results_dir is not None:
        sub_dir = os.path.join(results_dir,
                               sub_id)  # Create a dir with the ID name
        if not os.path.exists(sub_dir):
            os.mkdir(sub_dir)

        fig_eog.savefig(os.path.join(sub_dir, f'{sub_id}_eog_correction.png'))
        fig_ecg.savefig(os.path.join(sub_dir, f'{sub_id}_ecg_correction.png'))

    # Tell our friend ICA what to do, and apply to data
    ica.exclude.extend(eog_inds)
    ica.exclude.extend(ecg_inds)
    ica.apply(raw)

    return sub_id, raw, ica