Exemple #1
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    fig.canvas.key_press_event('escape')
    # Sadly close_event isn't called on Agg backend and the test always passes.
    assert_array_equal(ica.exclude, [1])
    plt.close('all')

    # dtype can change int->np.int after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)
    fig = ica.plot_sources(raw)
    # also test mouse clicks
    data_ax = fig.axes[0]
    assert len(plt.get_fignums()) == 1
    _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label
    assert len(plt.get_fignums()) == 2
    ica.exclude = [1]
    ica.plot_sources(raw)

    raw.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"):
        ica.plot_sources(inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"):
        ica.plot_sources(inst=epochs)
    epochs.info['bads'] = []
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
    plt.close('all')
Exemple #2
0
def inscapesMEG_PP(fname, DATA_FOLDER, SAVE_FOLDER):
    fpath = DATA_FOLDER + fname
    raw = read_raw_ctf(fpath, preload=True)
    picks = mne.pick_types(raw.info, meg=True, eog=True, exclude='bads')
    raw.plot()
    raw.plot_psd(average=False, picks=picks)

    ## Filtering
    high_cutoff = 200
    low_cutoff = 0.5
    raw.filter(low_cutoff, high_cutoff, fir_design="firwin")
    raw.notch_filter(np.arange(60, high_cutoff + 1, 60),
                     picks=picks,
                     filter_length='auto',
                     phase='zero',
                     fir_design="firwin")
    raw.plot_psd(average=False, picks=picks)

    ## ICA
    ica = ICA(n_components=20, random_state=0).fit(raw, decim=3)
    ica.plot_sources(raw)
    fmax = 40.  ## correlation threshold for ICA components (maybe increase to 40. ?)

    ## FIND ECG COMPONENTS
    ecg_epochs = create_ecg_epochs(raw, ch_name='EEG059')
    ecg_inds, ecg_scores = ica.find_bads_ecg(ecg_epochs, ch_name='EEG059')
    ica.plot_scores(ecg_scores, ecg_inds)
    ica.plot_properties(ecg_epochs,
                        picks=ecg_inds,
                        psd_args={'fmax': fmax},
                        image_args={'sigma': 1.})

    ## FIND EOG COMPONENTS
    eog_epochs = create_eog_epochs(raw, ch_name='EEG057')
    eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name='EEG057')
    ica.plot_scores(eog_scores, eog_inds)
    ica.plot_properties(eog_epochs,
                        picks=eog_inds,
                        psd_args={'fmax': fmax},
                        image_args={'sigma': 1.})

    ## EXCLUDE COMPONENTS
    ica.exclude = ecg_inds
    ica.apply(raw)
    ica.exclude = eog_inds
    ica.apply(raw)
    raw.plot()
    # Plot the clean signal.

    ## SAVE PREPROCESSED FILE
    time.sleep(60)
    raw.save(SAVE_FOLDER + fname + '_preprocessed.fif.gz', overwrite=True)
    time.sleep(30)
    filename = SAVE_FOLDER + fname + '_log.html'
    #!jupyter nbconvert inscapesMEG_preproc.ipynb --output $filename
    clear_output()
Exemple #3
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    import matplotlib.pyplot as plt
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    fig.canvas.key_press_event('escape')
    # Sadly close_event isn't called on Agg backend and the test always passes.
    assert_array_equal(ica.exclude, [1])

    fig = ica.plot_sources(raw, [1])
    # test mouse clicks
    data_ax = fig.axes[0]
    _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label

    raw.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=epochs)
    epochs.info['bads'] = []
    with warnings.catch_warnings(record=True):  # no labeled objects mpl
        ica.plot_sources(epochs.average())
        evoked = epochs.average()
        fig = ica.plot_sources(evoked)
        # Test a click
        ax = fig.get_axes()[0]
        line = ax.lines[0]
        _fake_click(
            fig, ax,
            [line.get_xdata()[0], line.get_ydata()[0]], 'data')
        _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
        # plot with bad channels excluded
        ica.plot_sources(evoked, exclude=[0])
        ica.exclude = [0]
        ica.plot_sources(evoked)  # does the same thing
        ica.labels_ = dict(eog=[0])
        ica.labels_['eog/0/crazy-channel'] = [0]
        ica.plot_sources(evoked)  # now with labels
    assert_raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Exemple #4
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                           ecg=False, eog=False, exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    fig.canvas.key_press_event('escape')
    # Sadly close_event isn't called on Agg backend and the test always passes.
    assert_array_equal(ica.exclude, [1])
    plt.close('all')

    # dtype can change int->np.int after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)
    fig = ica.plot_sources(raw, [1])
    # also test mouse clicks
    data_ax = fig.axes[0]
    _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label

    raw.info['bads'] = ['MEG 0113']
    pytest.raises(RuntimeError, ica.plot_sources, inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    pytest.raises(RuntimeError, ica.plot_sources, inst=epochs)
    epochs.info['bads'] = []
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax,
                [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax,
                [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.plot_sources(evoked, exclude=[0])
    ica.exclude = [0]
    ica.plot_sources(evoked)  # does the same thing
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    pytest.raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Exemple #5
0
def run_ica(method, fit_params=None):
    chn=['AF3','F7','F3','FC5','T7','P7','O1','O2','P8','T8','FC6','F4','F8','AF4']
    ica = ICA(n_components=14, method=method, fit_params=fit_params,
              random_state=75)
    filt_raw = raw.copy()
    filt_raw.load_data().filter(l_freq=2.5, h_freq=None)
    t0 = time()
    ica.fit(filt_raw)
    fit_time = time() - t0
    title = ('ICA decomposition using %s (took %.1fs)' % (method, fit_time))
    ica.plot_sources(raw)
    ica.plot_components(title=title)
    ica.exclude = []
    # find which ICs match the EOG pattern
    
    eog_indices, eog_scores = ica.find_bads_eog(raw,ch_name='AF4')
    ica.exclude = eog_indices
    # barplot of ICA component "EOG match" scores
    ica.plot_scores(eog_scores)
    print(eog_indices)
    inpre=input("está de acuerdo con los indices?: (y/n)")
    if(inpre=='y'):
    	 print('ok, sigamos')
    else:
         eog_indices=[]
         n=int(input("elija el numero de indices a incluir: "))
         for i in range(0,n):
              print("introduzca el indice ")
              eog_indices.append(int(input(": ")))
    ica.plot_properties(raw, picks=eog_indices)

    ica.exclude = eog_indices  # indices chosen based on various plots above
    # plot ICs applied to raw data, with EOG matches highlighted
    ica.plot_sources(raw)
    # plot ICs applied to the averaged EOG epochs, with EOG matches highlighted
    ica.plot_overlay(raw,exclude=eog_indices)

    # ica.apply() changes the Raw object in-place, so let's make a copy first:
    print('copiando el original')
    reconst_raw = raw.copy()
    print('aplicando ICA')
    ica.apply(reconst_raw)

    raw.plot(events=events,event_color=mapcol)
    print('muestra reconstruida')
    reconst_raw.plot(events=events,event_color=mapcol)
    plt.show()
    return reconst_raw
Exemple #6
0
def ica_pipe(sample_raw_bandpass):
    clf = joblib.load("./models/eog_classifier_v2.joblib")

    sample_raw_train = sample_raw_bandpass.copy()
    sample_raw_corrected = sample_raw_bandpass.copy()

    # Fitting ICA
    ica = ICA(method="extended-infomax", random_state=1)
    ica.fit(sample_raw_corrected)

    maps = _get_ica_map(ica).T
    scalings = np.linalg.norm(maps, axis=0)
    maps /= scalings[None, :]
    X = maps.T

    # Predict EOG
    eog_preds = clf.predict(X)
    list_of_eog = np.where(eog_preds == 1)[0]

    # ica.plot_sources(inst=sample_raw_train)
    # ica.plot_components(inst=sample_raw_train)

    ica.exclude = list_of_eog
    ica.apply(sample_raw_corrected)

    return ica, sample_raw_train, sample_raw_corrected
Exemple #7
0
def test_plot_instance_components(browser_backend):
    """Test plotting of components as instances of raw and epochs."""
    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw, picks=picks)
    ica.exclude = [0]
    fig = ica.plot_sources(raw, title='Components')
    keys = ('home', 'home', 'end', 'down', 'up', 'right', 'left', '-', '+',
            '=', 'd', 'd', 'pageup', 'pagedown', 'z', 'z', 's', 's', 'b')
    for key in keys:
        fig._fake_keypress(key)
    x = fig.mne.traces[0].get_xdata()[0]
    y = fig.mne.traces[0].get_ydata()[0]
    fig._fake_click((x, y), xform='data')
    fig._click_ch_name(ch_index=0, button=1)
    fig._fake_keypress('escape')
    browser_backend._close_all()

    epochs = _get_epochs()
    fig = ica.plot_sources(epochs, title='Components')
    for key in keys:
        fig._fake_keypress(key)
    # Test a click
    x = fig.mne.traces[0].get_xdata()[0]
    y = fig.mne.traces[0].get_ydata()[0]
    fig._fake_click((x, y), xform='data')
    fig._click_ch_name(ch_index=0, button=1)
    fig._fake_keypress('escape')
Exemple #8
0
def test_plot_instance_components():
    """Test plotting of components as instances of raw and epochs."""
    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw, picks=picks)
    ica.exclude = [0]
    fig = ica.plot_sources(raw, title='Components')
    keys = ('home', 'home', 'end', 'down', 'up', 'right', 'left', '-', '+',
            '=', 'd', 'd', 'pageup', 'pagedown', 'z', 'z', 's', 's', 'f11',
            'b')
    for key in keys:
        fig.canvas.key_press_event(key)
    ax = fig.mne.ax_main
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event('escape')
    plt.close('all')
    epochs = _get_epochs()
    fig = ica.plot_sources(epochs, title='Components')
    for key in keys:
        fig.canvas.key_press_event(key)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event('escape')
Exemple #9
0
def test_plot_ica_sources():
    """Test plotting of ICA panel
    """
    import matplotlib.pyplot as plt
    raw = io.read_raw_fif(raw_fname,
                          preload=False).crop(0, 1, copy=False).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                           ecg=False, eog=False, exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    fig.canvas.key_press_event('escape')
    # Sadly close_event isn't called on Agg backend and the test always passes.
    assert_array_equal(ica.exclude, [1])

    raw.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=epochs)
    epochs.info['bads'] = []
    with warnings.catch_warnings(record=True):  # no labeled objects mpl
        ica.plot_sources(epochs.average())
        evoked = epochs.average()
        fig = ica.plot_sources(evoked)
        # Test a click
        ax = fig.get_axes()[0]
        line = ax.lines[0]
        _fake_click(fig, ax,
                    [line.get_xdata()[0], line.get_ydata()[0]], 'data')
        _fake_click(fig, ax,
                    [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
        # plot with bad channels excluded
        ica.plot_sources(evoked, exclude=[0])
        ica.exclude = [0]
        ica.plot_sources(evoked)  # does the same thing
        ica.labels_ = dict(eog=[0])
        ica.labels_['eog/0/crazy-channel'] = [0]
        ica.plot_sources(evoked)  # now with labels
    assert_raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Exemple #10
0
def run_ica(method, fit_params=None):
    ica = ICA(n_components=14,
              method=method,
              fit_params=fit_params,
              random_state=95)
    filt_raw = raw.copy()
    filt_raw.load_data().filter(l_freq=1., h_freq=None)
    t0 = time()
    ica.fit(filt_raw)
    fit_time = time() - t0
    title = ('ICA decomposition using %s (took %.1fs)' % (method, fit_time))
    ica.plot_sources(raw)
    #ica.plot_components(title=title)
    ica.exclude = []
    # find which ICs match the EOG pattern

    eog_indices, eog_scores = ica.find_bads_eog(raw, ch_name='F4')
    ica.exclude = eog_indices
    # barplot of ICA component "EOG match" scores
    ica.plot_scores(eog_scores)

    eog_indices, eog_scores = ica.find_bads_eog(raw, ch_name='AF4')
    ica.exclude = eog_indices
    # barplot of ICA component "EOG match" scores
    ica.plot_scores(eog_scores)

    # plot diagnostics
    ica.plot_properties(raw, picks=eog_indices)
    # plot ICs applied to raw data, with EOG matches highlighted
    ica.plot_sources(raw)
    # plot ICs applied to the averaged EOG epochs, with EOG matches highlighted
    #ica.plot_sources(eog_evoked)
    ica.plot_overlay(raw)
    ica.exclude = [1, 2]  # indices chosen based on various plots above

    # ica.apply() changes the Raw object in-place, so let's make a copy first:
    print('copiando el original')
    reconst_raw = raw.copy()
    print('aplicando ICA')
    ica.apply(reconst_raw)

    raw.plot()
    print('muestra reconstruida')
    reconst_raw.plot()
    plt.show()
Exemple #11
0
def run_events(subject_id):
    subject = "sub_%03d" % subject_id
    print("processing subject: %s" % subject)
    in_path = op.join(
        data_path, "EEG_Process")  #make map yourself in cwd called 'Subjects'
    process_path = op.join(
        data_path,
        "EEG_Process")  #make map yourself in cwd called 'EEG_Process'
    raw_list = list()
    events_list = list()

    for run in range(1, 2):
        fname = op.join(in_path, 'sub_%03d_raw.fif' % (subject_id, ))
        raw = mne.io.read_raw_fif(fname, preload=True)
        print("  S %s - R %s" % (subject, run))

        #import events and reorganize
        delay = int(round(0.0345 * raw.info['sfreq']))
        events = mne.read_events(
            op.join(in_path, 'events_%03d-eve.fif' % (subject_id, )))
        events[:, 0] = events[:, 0] + delay
        events_list.append(events)
        raw_list.append(raw)
        raw, events = mne.concatenate_raws(raw_list, events_list=events_list)

        ###some visualizations on the blinks in the raw data file###
        eog_events = mne.preprocessing.find_eog_events(raw)
        onsets = eog_events[:, 0] / raw.info['sfreq'] - 0.25
        durations = [0.5] * len(eog_events)
        descriptions = ['bad blink'] * len(eog_events)
        blink_annot = mne.Annotations(onsets,
                                      durations,
                                      descriptions,
                                      orig_time=raw.info['meas_date'])
        raw.set_annotations(blink_annot)
        eeg_picks = mne.pick_types(raw.info, eeg=True)
        raw.plot(events=eog_events, order=eeg_picks)
        ###CONCLUSION: NOT THE BEST ALGORITHM

        #####ICA#####
        ica = ICA(random_state=97, n_components=15)
        picks = mne.pick_types(raw.info,
                               eeg=True,
                               eog=True,
                               stim=False,
                               exclude='bads')
        ica.fit(raw, picks=picks)
        raw.load_data()
        ica.plot_sources(raw)
        ica.plot_components()
        ica.plot_overlay(raw, exclude=[6], picks='eeg')
        #visualize the difference
        raw2 = raw.copy()
        ica.exclude = [6]
        ica.apply(raw2)
        raw2.plot()
        ica.plot_properties(raw, picks=[6])
Exemple #12
0
def test_plot_ica_sources():
    """Test plotting of ICA panel
    """
    import matplotlib.pyplot as plt
    raw = io.Raw(raw_fname, preload=False)
    raw.crop(0, 1, copy=False)
    raw.load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    raw.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=epochs)
    epochs.info['bads'] = []
    with warnings.catch_warnings(record=True):  # no labeled objects mpl
        ica.plot_sources(epochs.average())
        evoked = epochs.average()
        fig = ica.plot_sources(evoked)
        # Test a click
        ax = fig.get_axes()[0]
        line = ax.lines[0]
        _fake_click(
            fig, ax,
            [line.get_xdata()[0], line.get_ydata()[0]], 'data')
        _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
        # plot with bad channels excluded
        ica.plot_sources(evoked, exclude=[0])
        ica.exclude = [0]
        ica.plot_sources(evoked)  # does the same thing
        ica.labels_ = dict(eog=[0])
        ica.labels_['eog/0/crazy-channel'] = [0]
        ica.plot_sources(evoked)  # now with labels
    assert_raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Exemple #13
0
    def drop_artefacts(self, n_components, raw, a, b):
        """
        Use the ICA to eliminate windows contaminated with blinking
        """
        ica = ICA(n_components=n_components)
        ica.fit(raw)
        picks = pick_types(raw.info,
                           meg=False,
                           eeg=True,
                           stim=False,
                           eog=False,
                           exclude='bads')
        raw.load_data()
        ica.plot_sources(raw)
        ica.plot_components(layout=read_layout('eeg1005'))

        ica.exclude = [a, b]
        reconst_raw = raw.copy()
        ica.apply(reconst_raw)

        return reconst_raw
Exemple #14
0
def remove_eog_ica(raw, n_components, ch_name, threshold):
    """Remove EOG artifacts by ICA.

    Parameters
    ----------
    raw : instance of Raw.
        The raw data.
    n_components : int
        Number of principal components for ICA.
    ch_name : str
        The name of the channel to use for EOG peak detection.
    threshold : int
        The value above which a feature is classified as outlier.

    Returns         
    -------
    raw : instance of Raw.
        The raw data.

    Notes
    -----
    """
    ica = ICA(n_components=n_components, max_iter='auto')
    ica.fit(raw, verbose=0)
    while threshold > 1:
        eog_inds, _ = ica.find_bads_eog(raw,
                                        ch_name=ch_name,
                                        threshold=threshold,
                                        verbose=0)
        if eog_inds:
            break
        threshold -= 0.3

    if not eog_inds:
        raise RuntimeError('Didn\'t find a EOG component.')

    ica.plot_properties(raw, eog_inds)
    ica.exclude = eog_inds
    ica.apply(raw, verbose=0)
    return raw
def preprocess(eeg_rawdata):
    """
    :param eeg_rawdata: numpy array with the shape of (n_channels, n_samples)
    :return: filtered EEG raw data
    """
    assert eeg_rawdata.shape[0] == 62
    eeg_rawdata = np.array(eeg_rawdata)
    temp = 0 - (eeg_rawdata.shape[1] % 5)
    if temp != 0:
        eeg_rawdata = eeg_rawdata[:, :temp]

    ch_names = [
        'FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2',
        'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4',
        'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8',
        'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7',
        'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3',
        'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2'
    ]
    info = mne.create_info(
        # 通道名
        ch_names=ch_names,
        # 通道类型
        ch_types=['eeg' for _ in range(62)],
        # 采样频率
        sfreq=1000)
    raw_data = mne.io.RawArray(eeg_rawdata, info)
    raw_data.load_data().filter(l_freq=1., h_freq=75)
    raw_data.resample(200)
    ica = ICA(n_components=5, random_state=97)
    ica.fit(raw_data)
    eog_indices, eog_scores = ica.find_bads_eog(raw_data, ch_name='FP1')
    a = abs(eog_scores).tolist()
    ica.exclude = [a.index(max(a))]
    ica.apply(raw_data)
    filted_eeg_rawdata = np.array(raw_data.get_data())

    return filted_eeg_rawdata
def run_ica(epochs: Epochs, fit_params: dict = None) -> ICA:
    """
    Runs ICA decomposition on Epochs instance.

    If there are no EOG channels found, it tries to use 'Fp1' and 'Fp2' as EOG
    channels; if they are not found either, it chooses the first two channels
    to identify EOG components with mne.preprocessing.ica.find_bads_eog().
    Parameters
    ----------
    epochs: the instance to be used for ICA decomposition
    fit_params: parameters to be passed to ICA fit (e.g. orthogonal picard, extended infomax)
    Returns
    -------
    ICA instance
    """
    ica = ICA(
        n_components=settings["ica"]["n_components"],
        random_state=42,
        method=settings["ica"]["method"],
        fit_params=fit_params,
    )
    ica_epochs = epochs.copy()
    ica.fit(ica_epochs, decim=settings["ica"]["decim"])

    if "eog" not in epochs.get_channel_types():
        if "Fp1" and "Fp2" in epochs.get_montage().ch_names:
            eog_channels = ["Fp1", "Fp2"]
        else:
            eog_channels = epochs.get_montage().ch_names[:2]
        logger.info("EOG channels are not found. Attempting to use "
                    f'{",".join(eog_channels)} channels as EOG channels.')
        ica_epochs.set_channel_types({ch: "eog" for ch in eog_channels})

    eog_indices, _ = ica.find_bads_eog(ica_epochs)
    ica.exclude = eog_indices

    return ica
Exemple #17
0
def test_plot_ica_sources():
    """Test plotting of ICA panel
    """
    import matplotlib.pyplot as plt
    raw = io.Raw(raw_fname, preload=False)
    raw.crop(0, 1, copy=False)
    raw.preload_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                           ecg=False, eog=False, exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    raw.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    assert_raises(RuntimeError, ica.plot_sources, inst=epochs)
    epochs.info['bads'] = []
    with warnings.catch_warnings(record=True):  # no labeled objects mpl
        ica.plot_sources(epochs.average())
        evoked = epochs.average()
        fig = ica.plot_sources(evoked)
        # Test a click
        ax = fig.get_axes()[0]
        line = ax.lines[0]
        _fake_click(fig, ax,
                    [line.get_xdata()[0], line.get_ydata()[0]], 'data')
        _fake_click(fig, ax,
                    [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
        # plot with bad channels excluded
        ica.plot_sources(evoked, exclude=[0])
        ica.exclude = [0]
        ica.plot_sources(evoked)  # does the same thing
    assert_raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Exemple #18
0
def test_ica_additional(method):
    """Test additional ICA functionality."""
    _skip_check_picard(method)

    tempdir = _TempDir()
    stop2 = 500
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    raw.del_proj()  # avoid warnings
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    # XXX This breaks the tests :(
    # raw.info['bads'] = [raw.ch_names[1]]
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')[1::2]
    epochs = Epochs(raw,
                    events,
                    None,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    preload=True,
                    proj=False)
    epochs.decimate(3, verbose='error')
    assert len(epochs) == 4

    # test if n_components=None works
    ica = ICA(n_components=None,
              max_pca_components=None,
              n_pca_components=None,
              random_state=0,
              method=method,
              max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    # for testing eog functionality
    picks2 = np.concatenate([picks, pick_types(raw.info, False, eog=True)])
    epochs_eog = Epochs(raw,
                        events[:4],
                        event_id,
                        tmin,
                        tmax,
                        picks=picks2,
                        baseline=(None, 0),
                        preload=True)
    del picks2

    test_cov2 = test_cov.copy()
    ica = ICA(noise_cov=test_cov2,
              n_components=3,
              max_pca_components=4,
              n_pca_components=4,
              method=method)
    assert (ica.info is None)
    with pytest.warns(RuntimeWarning, match='normalize_proj'):
        ica.fit(raw, picks[:5])
    assert (isinstance(ica.info, Info))
    assert (ica.n_components_ < 5)

    ica = ICA(n_components=3,
              max_pca_components=4,
              method=method,
              n_pca_components=4,
              random_state=0)
    pytest.raises(RuntimeError, ica.save, '')

    ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)

    # check passing a ch_name to find_bads_ecg
    with pytest.warns(RuntimeWarning, match='longer'):
        _, scores_1 = ica.find_bads_ecg(raw)
        _, scores_2 = ica.find_bads_ecg(raw, raw.ch_names[1])
    assert scores_1[0] != scores_2[0]

    # test corrmap
    ica2 = ica.copy()
    ica3 = ica.copy()
    corrmap([ica, ica2], (0, 0),
            threshold='auto',
            label='blinks',
            plot=True,
            ch_type="mag")
    corrmap([ica, ica2], (0, 0), threshold=2, plot=False, show=False)
    assert (ica.labels_["blinks"] == ica2.labels_["blinks"])
    assert (0 in ica.labels_["blinks"])
    # test retrieval of component maps as arrays
    components = ica.get_components()
    template = components[:, 0]
    EvokedArray(components, ica.info, tmin=0.).plot_topomap([0], time_unit='s')

    corrmap([ica, ica3],
            template,
            threshold='auto',
            label='blinks',
            plot=True,
            ch_type="mag")
    assert (ica2.labels_["blinks"] == ica3.labels_["blinks"])

    plt.close('all')

    # make sure a single threshold in a list works
    corrmap([ica, ica3],
            template,
            threshold=[0.5],
            label='blinks',
            plot=True,
            ch_type="mag")

    ica_different_channels = ICA(n_components=2,
                                 random_state=0).fit(raw, picks=[2, 3, 4, 5])
    pytest.raises(ValueError, corrmap, [ica_different_channels, ica], (0, 0))

    # test warnings on bad filenames
    ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        ica.save(ica_badname)
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        read_ica(ica_badname)

    # test decim
    ica = ICA(n_components=3,
              max_pca_components=4,
              n_pca_components=4,
              method=method,
              max_iter=1)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=picks[:5], decim=3)
    assert raw_._data.shape[1] == n_samples

    # test expl var
    ica = ICA(n_components=1.0,
              max_pca_components=4,
              n_pca_components=4,
              method=method,
              max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=None, decim=3)
    assert (ica.n_components_ == 4)
    ica_var = _ica_explained_variance(ica, raw, normalize=True)
    assert (np.all(ica_var[:-1] >= ica_var[1:]))

    # test ica sorting
    ica.exclude = [0]
    ica.labels_ = dict(blink=[0], think=[1])
    ica_sorted = _sort_components(ica, [3, 2, 1, 0], copy=True)
    assert_equal(ica_sorted.exclude, [3])
    assert_equal(ica_sorted.labels_, dict(blink=[3], think=[2]))

    # epochs extraction from raw fit
    pytest.raises(RuntimeError, ica.get_sources, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov,
                  n_components=2,
                  max_pca_components=4,
                  n_pca_components=4,
                  method=method,
                  max_iter=1)
        with pytest.warns(None):  # ICA does not converge
            ica.fit(raw, picks=picks[:10], start=start, stop=stop2)
        sources = ica.get_sources(epochs).get_data()
        assert (ica.mixing_matrix_.shape == (2, 2))
        assert (ica.unmixing_matrix_.shape == (2, 2))
        assert (ica.pca_components_.shape == (4, 10))
        assert (sources.shape[1] == ica.n_components_)

        for exclude in [[], [0], np.array([1, 2, 3])]:
            ica.exclude = exclude
            ica.labels_ = {'foo': [0]}
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert (list(ica.exclude) == ica_read.exclude)
            assert_equal(ica.labels_, ica_read.labels_)
            ica.apply(raw)
            ica.exclude = []
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [])

            ica.exclude = [0, 1]
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [0, 1])

            ica_raw = ica.get_sources(raw)
            assert (ica.exclude == [
                ica_raw.ch_names.index(e) for e in ica_raw.info['bads']
            ])

        # test filtering
        d1 = ica_raw._data[0].copy()
        ica_raw.filter(4, 20, fir_design='firwin2')
        assert_equal(ica_raw.info['lowpass'], 20.)
        assert_equal(ica_raw.info['highpass'], 4.)
        assert ((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        ica_raw.notch_filter([10], trans_bandwidth=10, fir_design='firwin')
        assert ((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.method = 'fake'
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert (ica.n_pca_components == ica_read.n_pca_components)
        assert_equal(ica.method, ica_read.method)
        assert_equal(ica.labels_, ica_read.labels_)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ pre_whitener_')

        def f(x, y):
            return getattr(x, y).dtype

        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in [
                'mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                'pca_mean_', 'pca_explained_variance_', 'pre_whitener_'
        ]:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert (ica.ch_names == ica_read.ch_names)
        assert (isinstance(ica_read.info, Info))

        sources = ica.get_sources(raw)[:, :][0]
        sources2 = ica_read.get_sources(raw)[:, :][0]
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.apply(raw, exclude=[1])
        _raw2 = ica_read.apply(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(raw,
                                   target='EOG 061',
                                   score_func=func,
                                   start=0,
                                   stop=10)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(raw, start=0, stop=50, score_func=stats.skew)
    # check exception handling
    pytest.raises(ValueError, ica.score_sources, raw, target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # variance, kurtosis params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw,
                             start_find=0,
                             stop_find=50,
                             ecg_ch=ch_name,
                             eog_ch=ch_name,
                             skew_criterion=idx,
                             var_criterion=idx,
                             kurt_criterion=idx)

    # Make sure detect_artifacts marks the right components.
    # For int criterion, the doc says "E.g. range(2) would return the two
    # sources with the highest score". Assert that's what it does.
    # Only test for skew, since it's always the same code.
    ica.exclude = []
    ica.detect_artifacts(raw,
                         start_find=0,
                         stop_find=50,
                         ecg_ch=None,
                         eog_ch=None,
                         skew_criterion=0,
                         var_criterion=None,
                         kurt_criterion=None)
    assert np.abs(scores[ica.exclude]) == np.max(np.abs(scores))

    evoked = epochs.average()
    evoked_data = evoked.data.copy()
    raw_data = raw[:][0].copy()
    epochs_data = epochs.get_data().copy()

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='ctps')
    assert_equal(len(scores), ica.n_components_)
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(epochs, method='ctps')

    assert_equal(len(scores), ica.n_components_)
    pytest.raises(ValueError,
                  ica.find_bads_ecg,
                  epochs.average(),
                  method='ctps')
    pytest.raises(ValueError, ica.find_bads_ecg, raw, method='crazy-coupling')

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert (isinstance(scores, list))
    assert_equal(len(scores[0]), ica.n_components_)

    idx, scores = ica.find_bads_eog(evoked, ch_name='MEG 1441')
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(evoked, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    assert_array_equal(raw_data, raw[:][0])
    assert_array_equal(epochs_data, epochs.get_data())
    assert_array_equal(evoked_data, evoked.data)

    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(epochs_eog,
                                   target='EOG 061',
                                   score_func=func)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(epochs, score_func=stats.skew)

    # check exception handling
    pytest.raises(ValueError, ica.score_sources, epochs, target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.score_sources(raw,
                                   target='MEG 1531',
                                   score_func='pearsonr')

    with pytest.warns(RuntimeWarning, match='longer'):
        ecg_events = ica_find_ecg_events(raw,
                                         sources[np.abs(ecg_scores).argmax()])
    assert (ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.score_sources(raw,
                                   target='EOG 061',
                                   score_func='pearsonr')
    with pytest.warns(RuntimeWarning, match='longer'):
        eog_events = ica_find_eog_events(raw,
                                         sources[np.abs(eog_scores).argmax()])
    assert (eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.get_sources(raw, start=0, stop=100)
    assert (ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_equal(len(ica_raw._filenames), 1)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = read_raw_fif(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.get_sources(epochs)
    assert (ica_epochs.events.shape == epochs.events.shape)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    assert (ica.n_components_ == ica_epochs.get_data().shape[1])
    assert (ica_epochs._raw is None)
    assert (ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = ica._check_n_pca_components(ncomps)
        assert (ncomps_ == expected)

    ica = ICA(method=method)
    with pytest.warns(None):  # sometimes does not converge
        ica.fit(raw, picks=picks[:5])
    with pytest.warns(RuntimeWarning, match='longer'):
        ica.find_bads_ecg(raw)
    ica.find_bads_eog(epochs, ch_name='MEG 0121')
    assert_array_equal(raw_data, raw[:][0])

    raw.drop_channels(['MEG 0122'])
    pytest.raises(RuntimeError, ica.find_bads_eog, raw)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(RuntimeError, ica.find_bads_ecg, raw)
Exemple #19
0
#
#
# Selecting ICA components manually
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Once we're certain which components we want to exclude, we can specify that
# manually by setting the ``ica.exclude`` attribute. Similar to marking bad
# channels, merely setting ``ica.exclude`` doesn't do anything immediately (it
# just adds the excluded ICs to a list that will get used later when it's
# needed). Once the exclusions have been set, ICA methods like
# :meth:`~mne.preprocessing.ICA.plot_overlay` will exclude those component(s)
# even if no ``exclude`` parameter is passed, and the list of excluded
# components will be preserved when using :meth:`mne.preprocessing.ICA.save`
# and :func:`mne.preprocessing.read_ica`.

ica.exclude = [0, 1]  # indices chosen based on various plots above

###############################################################################
# Now that the exclusions have been set, we can reconstruct the sensor signals
# with artifacts removed using the :meth:`~mne.preprocessing.ICA.apply` method
# (remember, we're applying the ICA solution from the *filtered* data to the
# original *unfiltered* signal). Plotting the original raw data alongside the
# reconstructed data shows that the heartbeat and blink artifacts are repaired.

# ica.apply() changes the Raw object in-place, so let's make a copy first:
reconst_raw = raw.copy()
ica.apply(reconst_raw)

raw.plot(order=artifact_picks, n_channels=len(artifact_picks),
         show_scrollbars=False)
reconst_raw.plot(order=artifact_picks, n_channels=len(artifact_picks),
Exemple #20
0
def main():

    #################################################
    ## SETUP

    ## Get list of subject files
    subj_files = listdir(DAT_PATH)
    subj_files = [file for file in subj_files if EXT.lower() in file.lower()]

    ## Set up FOOOF Objects
    # Initialize FOOOF settings & objects objects
    fooof_settings = FOOOFSettings(peak_width_limits=PEAK_WIDTH_LIMITS, max_n_peaks=MAX_N_PEAKS,
                                   min_peak_amplitude=MIN_PEAK_AMP, peak_threshold=PEAK_THRESHOLD,
                                   aperiodic_mode=APERIODIC_MODE)
    fm = FOOOF(*fooof_settings, verbose=False)
    fg = FOOOFGroup(*fooof_settings, verbose=False)

    # Save out a settings file
    fg.save('0-FOOOF_Settings', pjoin(RES_PATH, 'FOOOF'), save_settings=True)

    # Set up the dictionary to store all the FOOOF results
    fg_dict = dict()
    for load_label in LOAD_LABELS:
        fg_dict[load_label] = dict()
        for side_label in SIDE_LABELS:
            fg_dict[load_label][side_label] = dict()
            for seg_label in SEG_LABELS:
                fg_dict[load_label][side_label][seg_label] = []

    ## Initialize group level data stores
    n_subjs, n_conds, n_times = len(subj_files), 3, N_TIMES
    group_fooofed_alpha_freqs = np.zeros(shape=[n_subjs])
    dropped_components = np.ones(shape=[n_subjs, 50]) * 999
    dropped_trials = np.ones(shape=[n_subjs, 1500]) * 999
    canonical_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])
    fooofed_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])

    # Set channel types
    ch_types = {'LHor' : 'eog', 'RHor' : 'eog', 'IVer' : 'eog', 'SVer' : 'eog',
                'LMas' : 'misc', 'RMas' : 'misc', 'Nose' : 'misc', 'EXG8' : 'misc'}

    #################################################
    ## RUN ACROSS ALL SUBJECTS

    # Run analysis across each subject
    for s_ind, subj_file in enumerate(subj_files):

        # Get subject label and print status
        subj_label = subj_file.split('.')[0]
        print('\nCURRENTLY RUNNING SUBJECT: ', subj_label, '\n')

        #################################################
        ## LOAD / ORGANIZE / SET-UP DATA

        # Load subject of data, apply apply fixes for channels, etc
        eeg_dat = mne.io.read_raw_edf(pjoin(DAT_PATH, subj_file),
                                      preload=True, verbose=False)

        # Fix channel name labels
        eeg_dat.info['ch_names'] = [chl[2:] for chl in \
            eeg_dat.ch_names[:-1]] + [eeg_dat.ch_names[-1]]
        for ind, chi in enumerate(eeg_dat.info['chs']):
            eeg_dat.info['chs'][ind]['ch_name'] = eeg_dat.info['ch_names'][ind]

        # Update channel types
        eeg_dat.set_channel_types(ch_types)

        # Set reference - average reference
        eeg_dat = eeg_dat.set_eeg_reference(ref_channels='average',
                                            projection=False, verbose=False)

        # Set channel montage
        chs = mne.channels.read_montage('standard_1020', eeg_dat.ch_names)
        eeg_dat.set_montage(chs)

        # Get event information & check all used event codes
        evs = mne.find_events(eeg_dat, shortest_event=1, verbose=False)

        # Pull out sampling rate
        srate = eeg_dat.info['sfreq']

        #################################################
        ## Pre-Processing: ICA

        # High-pass filter data for running ICA
        eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

        if RUN_ICA:

            print("\nICA: CALCULATING SOLUTION\n")

            # ICA settings
            method = 'fastica'
            n_components = 0.99
            random_state = 47
            reject = {'eeg': 20e-4}

            # Initialize ICA object
            ica = ICA(n_components=n_components, method=method,
                      random_state=random_state)

            # Fit ICA
            ica.fit(eeg_dat, reject=reject)

            # Save out ICA solution
            ica.save(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Otherwise: load previously saved ICA to apply
        else:
            print("\nICA: USING PRECOMPUTED\n")
            ica = read_ica(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Find components to drop, based on correlation with EOG channels
        drop_inds = []
        for chi in EOG_CHS:
            inds, _ = ica.find_bads_eog(eeg_dat, ch_name=chi, threshold=2.5,
                                             l_freq=1, h_freq=10, verbose=False)
            drop_inds.extend(inds)
        drop_inds = list(set(drop_inds))

        # Set which components to drop, and collect record of this
        ica.exclude = drop_inds
        dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

        # Apply ICA to data
        eeg_dat = ica.apply(eeg_dat)

        #################################################
        ## SORT OUT EVENT CODES

        # Extract a list of all the event labels
        all_trials = [it for it2 in EV_DICT.values() for it in it2]

        # Create list of new event codes to be used to label correct trials (300s)
        all_trials_new = [it + 100 for it in all_trials]
        # This is an annoying way to collapse across the doubled event markers from above
        all_trials_new = [it - 1 if not ind%2 == 0 else it for ind, it in enumerate(all_trials_new)]
        # Get labelled dictionary of new event names
        ev_dict2 = {k:v for k, v in zip(EV_DICT.keys(), set(all_trials_new))}

        # Initialize variables to store new event definitions
        evs2 = np.empty(shape=[0, 3], dtype='int64')
        lags = np.array([])

        # Loop through, creating new events for all correct trials
        t_min, t_max = -0.4, 3.0
        for ref_id, targ_id, new_id in zip(all_trials, CORR_CODES * 6, all_trials_new):

            t_evs, t_lags = mne.event.define_target_events(evs, ref_id, targ_id, srate,
                                                           t_min, t_max, new_id)

            if len(t_evs) > 0:
                evs2 = np.vstack([evs2, t_evs])
                lags = np.concatenate([lags, t_lags])

        #################################################
        ## FOOOF

        # Set channel of interest
        ch_ind = eeg_dat.ch_names.index(CHL)

        # Calculate PSDs over ~ first 2 minutes of data, for specified channel
        fmin, fmax = 1, 50
        tmin, tmax = 5, 125
        psds, freqs = mne.time_frequency.psd_welch(eeg_dat, fmin=fmin, fmax=fmax,
                                                   tmin=tmin, tmax=tmax,
                                                   n_fft=int(2*srate), n_overlap=int(srate),
                                                   n_per_seg=int(2*srate),
                                                   verbose=False)

        # Fit FOOOF across all channels
        fg.fit(freqs, psds, FREQ_RANGE, n_jobs=-1)

        # Save out FOOOF results
        fg.save(subj_label + '_fooof', pjoin(RES_PATH, 'FOOOF'), save_results=True)

        # Extract individualized CF from specified channel, add to group collection
        fm = fg.get_fooof(ch_ind, False)
        fooof_freq, _, _ = get_band_peak(fm.peak_params_, [7, 14])
        group_fooofed_alpha_freqs[s_ind] = fooof_freq

        # If not FOOOF alpha extracted, reset to 10
        if np.isnan(fooof_freq):
            fooof_freq = 10

        #################################################
        ## ALPHA FILTERING

        # CANONICAL: Filter data to canonical alpha band: 8-12 Hz
        alpha_dat = eeg_dat.copy()
        alpha_dat.filter(8, 12, fir_design='firwin', verbose=False)
        alpha_dat.apply_hilbert(envelope=True, verbose=False)

        # FOOOF: Filter data to FOOOF derived alpha band
        fooof_dat = eeg_dat.copy()
        fooof_dat.filter(fooof_freq-2, fooof_freq+2, fir_design='firwin')
        fooof_dat.apply_hilbert(envelope=True)

        #################################################
        ## EPOCH TRIALS

        # Set epoch timings
        tmin, tmax = -0.85, 1.1

        # Epoch trials - raw data for trial rejection
        epochs = mne.Epochs(eeg_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                            baseline=None, preload=True, verbose=False)

        # Epoch trials - filtered version
        epochs_alpha = mne.Epochs(alpha_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)
        epochs_fooof = mne.Epochs(fooof_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)

        #################################################
        ## PRE-PROCESSING: AUTO-REJECT
        if RUN_AUTOREJECT:

            print('\nAUTOREJECT: CALCULATING SOLUTION\n')

            # Initialize and run autoreject across epochs
            ar = AutoReject(n_jobs=4, verbose=False)
            ar.fit(epochs)

            # Save out AR solution
            ar.save(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'), overwrite=True)

        # Otherwise: load & apply previously saved AR solution
        else:
            print('\nAUTOREJECT: USING PRECOMPUTED\n')
            ar = read_auto_reject(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'))
            ar.verbose = 'tqdm'

        # Apply autoreject to the original epochs object it was learnt on
        epochs, rej_log = ar.transform(epochs, return_log=True)

        # Apply autoreject to the copies of the data - apply interpolation, then drop same epochs
        _apply_interp(rej_log, epochs_alpha, ar.threshes_, ar.picks_, ar.verbose)
        epochs_alpha.drop(rej_log.bad_epochs)
        _apply_interp(rej_log, epochs_fooof, ar.threshes_, ar.picks_, ar.verbose)
        epochs_fooof.drop(rej_log.bad_epochs)

        # Collect which epochs were dropped
        dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(rej_log.bad_epochs)[0]

        #################################################
        ## SET UP CHANNEL CLUSTERS

        # Set channel clusters - take channels contralateral to stimulus presentation
        #  Note: channels will be used to extract data contralateral to stimulus presentation
        le_chs = ['P3', 'P5', 'P7', 'P9', 'O1', 'PO3', 'PO7']       # Left Side Channels
        le_inds = [epochs.ch_names.index(chn) for chn in le_chs]
        ri_chs = ['P4', 'P6', 'P8', 'P10', 'O2', 'PO4', 'PO8']      # Right Side Channels
        ri_inds = [epochs.ch_names.index(chn) for chn in ri_chs]

        #################################################
        ## TRIAL-RELATED ANALYSIS: CANONICAL vs. FOOOF

        ## Pull out channels of interest for each load level
        #  Channels extracted are those contralateral to stimulus presentation

        # Canonical Data
        lo1_a = np.concatenate([epochs_alpha['LeLo1']._data[:, ri_inds, :],
                                epochs_alpha['RiLo1']._data[:, le_inds, :]], 0)
        lo2_a = np.concatenate([epochs_alpha['LeLo2']._data[:, ri_inds, :],
                                epochs_alpha['RiLo2']._data[:, le_inds, :]], 0)
        lo3_a = np.concatenate([epochs_alpha['LeLo3']._data[:, ri_inds, :],
                                epochs_alpha['RiLo3']._data[:, le_inds, :]], 0)

        # FOOOFed data
        lo1_f = np.concatenate([epochs_fooof['LeLo1']._data[:, ri_inds, :],
                                epochs_fooof['RiLo1']._data[:, le_inds, :]], 0)
        lo2_f = np.concatenate([epochs_fooof['LeLo2']._data[:, ri_inds, :],
                                epochs_fooof['RiLo2']._data[:, le_inds, :]], 0)
        lo3_f = np.concatenate([epochs_fooof['LeLo3']._data[:, ri_inds, :],
                                epochs_fooof['RiLo3']._data[:, le_inds, :]], 0)

        ## Calculate average across trials and channels - add to group data collection

        # Canonical data
        canonical_group_avg_dat[s_ind, 0, :] = np.mean(lo1_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 1, :] = np.mean(lo2_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 2, :] = np.mean(lo3_a, 1).mean(0)

        # FOOOFed data
        fooofed_group_avg_dat[s_ind, 0, :] = np.mean(lo1_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 1, :] = np.mean(lo2_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 2, :] = np.mean(lo3_f, 1).mean(0)

        #################################################
        ## FOOOFING TRIAL AVERAGED DATA

        # Loop loop loads & trials segments
        for seg_label, seg_time in zip(SEG_LABELS, SEG_TIMES):
            tmin, tmax = seg_time[0], seg_time[1]

            # Calculate PSDs across trials, fit FOOOF models to averages
            for le_label, ri_label, load_label in zip(['LeLo1', 'LeLo2', 'LeLo3'],
                                                      ['RiLo1', 'RiLo2', 'RiLo3'],
                                                      LOAD_LABELS):

                ## Calculate trial wise PSDs for left & right side trials
                trial_freqs, le_trial_psds = periodogram(
                    epochs[le_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)
                trial_freqs, ri_trial_psds = periodogram(
                    epochs[ri_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)

                ## FIT ALL CHANNELS VERSION
                if FIT_ALL_CHANNELS:

                    ## Average spectra across trials within a given load & side
                    le_avg_psd_contra = avg_func(le_trial_psds[:, ri_inds, :], 0)
                    le_avg_psd_ipsi = avg_func(le_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_contra = avg_func(ri_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_ipsi = avg_func(ri_trial_psds[:, ri_inds, :], 0)

                    ## Combine spectra across left & right trials for given load
                    ch_psd_contra = np.vstack([le_avg_psd_contra, ri_avg_psd_contra])
                    ch_psd_ipsi = np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi])

                    ## Fit FOOOFGroup to all channels, average & and collect results
                    fg.fit(trial_freqs, ch_psd_contra, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fg.fit(trial_freqs, ch_psd_ipsi, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

                ## COLLAPSE ACROSS CHANNELS VERSION
                else:

                    ## Average spectra across trials and channels within a given load & side
                    le_avg_psd_contra = avg_func(avg_func(le_trial_psds[:, ri_inds, :], 0), 0)
                    le_avg_psd_ipsi = avg_func(avg_func(le_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_contra = avg_func(avg_func(ri_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_ipsi = avg_func(avg_func(ri_trial_psds[:, ri_inds, :], 0), 0)

                    ## Collapse spectra across left & right trials for given load
                    avg_psd_contra = avg_func(np.vstack([le_avg_psd_contra, ri_avg_psd_contra]), 0)
                    avg_psd_ipsi = avg_func(np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]), 0)

                    ## Fit FOOOF, and collect results
                    fm.fit(trial_freqs, avg_psd_contra, FREQ_RANGE)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fm.fit(trial_freqs, avg_psd_ipsi, FREQ_RANGE)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

    #################################################
    ## SAVE OUT RESULTS

    # Save out group data
    np.save(pjoin(RES_PATH, 'Group', 'alpha_freqs_group'), group_fooofed_alpha_freqs)
    np.save(pjoin(RES_PATH, 'Group', 'canonical_group'), canonical_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'fooofed_group'), fooofed_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_trials'), dropped_trials)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_components'), dropped_components)

    # Save out second round of FOOOFing
    for load_label in LOAD_LABELS:
        for side_label in SIDE_LABELS:
            for seg_label in SEG_LABELS:
                fg = combine_fooofs(fg_dict[load_label][side_label][seg_label])
                fg.save('Group_' + load_label + '_' + side_label + '_' + seg_label,
                        pjoin(RES_PATH, 'FOOOF'), save_results=True)
Exemple #21
0
def test_ica_additional(method):
    """Test additional ICA functionality."""
    _skip_check_picard(method)

    tempdir = _TempDir()
    stop2 = 500
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    raw.del_proj()  # avoid warnings
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    # XXX This breaks the tests :(
    # raw.info['bads'] = [raw.ch_names[1]]
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')[1::2]
    epochs = Epochs(raw, events, None, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True, proj=False)
    epochs.decimate(3, verbose='error')
    assert len(epochs) == 4

    # test if n_components=None works
    ica = ICA(n_components=None, max_pca_components=None,
              n_pca_components=None, random_state=0, method=method, max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    # for testing eog functionality
    picks2 = np.concatenate([picks, pick_types(raw.info, False, eog=True)])
    epochs_eog = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks2,
                        baseline=(None, 0), preload=True)
    del picks2

    test_cov2 = test_cov.copy()
    ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
              n_pca_components=4, method=method)
    assert (ica.info is None)
    with pytest.warns(RuntimeWarning, match='normalize_proj'):
        ica.fit(raw, picks[:5])
    assert (isinstance(ica.info, Info))
    assert (ica.n_components_ < 5)

    ica = ICA(n_components=3, max_pca_components=4, method=method,
              n_pca_components=4, random_state=0)
    pytest.raises(RuntimeError, ica.save, '')

    ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)

    # check passing a ch_name to find_bads_ecg
    with pytest.warns(RuntimeWarning, match='longer'):
        _, scores_1 = ica.find_bads_ecg(raw)
        _, scores_2 = ica.find_bads_ecg(raw, raw.ch_names[1])
    assert scores_1[0] != scores_2[0]

    # test corrmap
    ica2 = ica.copy()
    ica3 = ica.copy()
    corrmap([ica, ica2], (0, 0), threshold='auto', label='blinks', plot=True,
            ch_type="mag")
    corrmap([ica, ica2], (0, 0), threshold=2, plot=False, show=False)
    assert (ica.labels_["blinks"] == ica2.labels_["blinks"])
    assert (0 in ica.labels_["blinks"])
    # test retrieval of component maps as arrays
    components = ica.get_components()
    template = components[:, 0]
    EvokedArray(components, ica.info, tmin=0.).plot_topomap([0], time_unit='s')

    corrmap([ica, ica3], template, threshold='auto', label='blinks', plot=True,
            ch_type="mag")
    assert (ica2.labels_["blinks"] == ica3.labels_["blinks"])

    plt.close('all')

    ica_different_channels = ICA(n_components=2, random_state=0).fit(
        raw, picks=[2, 3, 4, 5])
    pytest.raises(ValueError, corrmap, [ica_different_channels, ica], (0, 0))

    # test warnings on bad filenames
    ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        ica.save(ica_badname)
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        read_ica(ica_badname)

    # test decim
    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4, method=method, max_iter=1)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=picks[:5], decim=3)
    assert raw_._data.shape[1] == n_samples

    # test expl var
    ica = ICA(n_components=1.0, max_pca_components=4,
              n_pca_components=4, method=method, max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=None, decim=3)
    assert (ica.n_components_ == 4)
    ica_var = _ica_explained_variance(ica, raw, normalize=True)
    assert (np.all(ica_var[:-1] >= ica_var[1:]))

    # test ica sorting
    ica.exclude = [0]
    ica.labels_ = dict(blink=[0], think=[1])
    ica_sorted = _sort_components(ica, [3, 2, 1, 0], copy=True)
    assert_equal(ica_sorted.exclude, [3])
    assert_equal(ica_sorted.labels_, dict(blink=[3], think=[2]))

    # epochs extraction from raw fit
    pytest.raises(RuntimeError, ica.get_sources, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov, n_components=2, max_pca_components=4,
                  n_pca_components=4, method=method, max_iter=1)
        with pytest.warns(None):  # ICA does not converge
            ica.fit(raw, picks=picks[:10], start=start, stop=stop2)
        sources = ica.get_sources(epochs).get_data()
        assert (ica.mixing_matrix_.shape == (2, 2))
        assert (ica.unmixing_matrix_.shape == (2, 2))
        assert (ica.pca_components_.shape == (4, 10))
        assert (sources.shape[1] == ica.n_components_)

        for exclude in [[], [0], np.array([1, 2, 3])]:
            ica.exclude = exclude
            ica.labels_ = {'foo': [0]}
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert (list(ica.exclude) == ica_read.exclude)
            assert_equal(ica.labels_, ica_read.labels_)
            ica.apply(raw)
            ica.exclude = []
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [])

            ica.exclude = [0, 1]
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [0, 1])

            ica_raw = ica.get_sources(raw)
            assert (ica.exclude == [ica_raw.ch_names.index(e) for e in
                                    ica_raw.info['bads']])

        # test filtering
        d1 = ica_raw._data[0].copy()
        ica_raw.filter(4, 20, fir_design='firwin2')
        assert_equal(ica_raw.info['lowpass'], 20.)
        assert_equal(ica_raw.info['highpass'], 4.)
        assert ((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        ica_raw.notch_filter([10], trans_bandwidth=10, fir_design='firwin')
        assert ((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.method = 'fake'
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert (ica.n_pca_components == ica_read.n_pca_components)
        assert_equal(ica.method, ica_read.method)
        assert_equal(ica.labels_, ica_read.labels_)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ pre_whitener_')

        def f(x, y):
            return getattr(x, y).dtype

        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in ['mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                     'pca_mean_', 'pca_explained_variance_',
                     'pre_whitener_']:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert (ica.ch_names == ica_read.ch_names)
        assert (isinstance(ica_read.info, Info))

        sources = ica.get_sources(raw)[:, :][0]
        sources2 = ica_read.get_sources(raw)[:, :][0]
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.apply(raw, exclude=[1])
        _raw2 = ica_read.apply(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(raw, target='EOG 061', score_func=func,
                                   start=0, stop=10)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(raw, start=0, stop=50, score_func=stats.skew)
    # check exception handling
    pytest.raises(ValueError, ica.score_sources, raw,
                  target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # variance, kurtosis params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw, start_find=0, stop_find=50, ecg_ch=ch_name,
                             eog_ch=ch_name, skew_criterion=idx,
                             var_criterion=idx, kurt_criterion=idx)

    # Make sure detect_artifacts marks the right components.
    # For int criterion, the doc says "E.g. range(2) would return the two
    # sources with the highest score". Assert that's what it does.
    # Only test for skew, since it's always the same code.
    ica.exclude = []
    ica.detect_artifacts(raw, start_find=0, stop_find=50, ecg_ch=None,
                         eog_ch=None, skew_criterion=0,
                         var_criterion=None, kurt_criterion=None)
    assert np.abs(scores[ica.exclude]) == np.max(np.abs(scores))

    evoked = epochs.average()
    evoked_data = evoked.data.copy()
    raw_data = raw[:][0].copy()
    epochs_data = epochs.get_data().copy()

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='ctps')
    assert_equal(len(scores), ica.n_components_)
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(epochs, method='ctps')

    assert_equal(len(scores), ica.n_components_)
    pytest.raises(ValueError, ica.find_bads_ecg, epochs.average(),
                  method='ctps')
    pytest.raises(ValueError, ica.find_bads_ecg, raw,
                  method='crazy-coupling')

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert (isinstance(scores, list))
    assert_equal(len(scores[0]), ica.n_components_)

    idx, scores = ica.find_bads_eog(evoked, ch_name='MEG 1441')
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(evoked, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    assert_array_equal(raw_data, raw[:][0])
    assert_array_equal(epochs_data, epochs.get_data())
    assert_array_equal(evoked_data, evoked.data)

    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(epochs_eog, target='EOG 061',
                                   score_func=func)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(epochs, score_func=stats.skew)

    # check exception handling
    pytest.raises(ValueError, ica.score_sources, epochs,
                  target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.score_sources(raw, target='MEG 1531',
                                   score_func='pearsonr')

    with pytest.warns(RuntimeWarning, match='longer'):
        ecg_events = ica_find_ecg_events(
            raw, sources[np.abs(ecg_scores).argmax()])
    assert (ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.score_sources(raw, target='EOG 061',
                                   score_func='pearsonr')
    with pytest.warns(RuntimeWarning, match='longer'):
        eog_events = ica_find_eog_events(
            raw, sources[np.abs(eog_scores).argmax()])
    assert (eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.get_sources(raw, start=0, stop=100)
    assert (ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_equal(len(ica_raw._filenames), 1)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = read_raw_fif(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.get_sources(epochs)
    assert (ica_epochs.events.shape == epochs.events.shape)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    assert (ica.n_components_ == ica_epochs.get_data().shape[1])
    assert (ica_epochs._raw is None)
    assert (ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = ica._check_n_pca_components(ncomps)
        assert (ncomps_ == expected)

    ica = ICA(method=method)
    with pytest.warns(None):  # sometimes does not converge
        ica.fit(raw, picks=picks[:5])
    with pytest.warns(RuntimeWarning, match='longer'):
        ica.find_bads_ecg(raw)
    ica.find_bads_eog(epochs, ch_name='MEG 0121')
    assert_array_equal(raw_data, raw[:][0])

    raw.drop_channels(['MEG 0122'])
    pytest.raises(RuntimeError, ica.find_bads_eog, raw)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(RuntimeError, ica.find_bads_ecg, raw)
Exemple #22
0
def remove_eyeblinks_and_heartbeat(raw, subject, figdir, events, eventid, rng):
    """
    Find and repair eyeblink and heartbeat artifacts in the data.
    Data should be filtered.
    Importantly, ICA is fitted on artificially epoched data with reject
    criteria estimates via the autoreject package - this is done to reject high-
    amplitude artifacts to influence the ICA solution.
    The ICA fit is then applied to the raw data.
    :param raw: Raw data
    :param subject: str, subject identifier, e.g., '001'
    :param figdir:
    :param rng: random state instance
    """
    # prior to an ICA, it is recommended to high-pass filter the data
    # as low frequency artifacts can alter the ICA solution. We fit the ICA
    # to high-pass filtered (1Hz) data, and apply it to non-highpass-filtered
    # data
    logging.info("Applying a temporary high-pass filtering prior to ICA")
    filt_raw = raw.copy()
    filt_raw.load_data().filter(l_freq=1., h_freq=None)
    # evoked eyeblinks and heartbeats for diagnostic plots

    eog_evoked = create_eog_epochs(filt_raw).average()
    eog_evoked.apply_baseline(baseline=(None, -0.2))
    if subject == '008':
        # subject 008's ECG channel is flat. It will not find any heartbeats by
        # default. We let it estimate heartbeat from magnetometers. For this,
        # we'll drop the ECG channel
        filt_raw.drop_channels('ECG003')
    ecg_evoked = create_ecg_epochs(filt_raw).average()
    ecg_evoked.apply_baseline(baseline=(None, -0.2))
    # make sure that we actually found sensible artifacts here
    eog_fig = eog_evoked.plot_joint()
    for i, fig in enumerate(eog_fig):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"evoked-artifact_eog_sub-{subject}_{i}.png",
        ])
        fig.savefig(fname)
    ecg_fig = ecg_evoked.plot_joint()
    for i, fig in enumerate(ecg_fig):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"evoked-artifact_ecg_sub-{subject}_{i}.png",
        ])
        fig.savefig(fname)
    # Chunk raw data into epochs to fit the ICA
    # No baseline correction as it would interfere with ICA.
    logging.info("Epoching filtered data")
    epochs = mne.Epochs(filt_raw,
                        events,
                        event_id=eventid,
                        tmin=0,
                        tmax=3,
                        picks='meg',
                        baseline=None)

    ## First, estimate rejection criteria for high-amplitude artifacts. This is
    ## done via autoreject
    #logging.info('Estimating bad epochs quick-and-dirty, to improve ICA')
    #ar = AutoReject(random_state=rng)
    # fit on first 200 epochs to save (a bit of) time
    #epochs.load_data()
    #ar.fit(epochs[:200])
    #epochs_ar, reject_log = ar.transform(epochs, return_log=True)

    # run an ICA to capture heartbeat and eyeblink artifacts.
    # set a seed for reproducibility.
    # When left to figure out the component number by itself, it ends up with
    # about 80. I'm setting n_components to 45 to have a chance at checking them
    # by hand.
    # We fit it on a set of epochs excluding the initial bad epochs following
    # https://github.com/autoreject/autoreject/blob/dfbc64f49eddeda53c5868290a6792b5233843c6/examples/plot_autoreject_workflow.py
    logging.info('Fitting the ICA')
    ica = ICA(max_iter='auto', n_components=45, random_state=rng)
    ica.fit(epochs)  #[~reject_log.bad_epochs])
    logging.info("Searching for eyeblink and heartbeat artifacts in the data")
    # get ICA components for the given subject
    if subject == '008':
        eog_indices = [10]
        ecg_indices = [29]
        #eog_indices = ica_comps[subject]['eog']
        #ecg_indices = ica_comps[subject]['ecg']
    # An initially manual component detection did not reproduce after a software
    # update - for now, we have to do the automatic detection for all but sub 8
    else:
        eog_indices, eog_scores = ica.find_bads_eog(filt_raw)
        ecg_indices, ecg_scores = ica.find_bads_ecg(filt_raw)
        logging.info(f"Identified the following EOG components: {eog_indices}")
        logging.info(f"Identified the following ECG components: {ecg_indices}")
    # visualize the components
    components = ica.plot_components()
    for i, fig in enumerate(components):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"ica-components_sub-{subject}_{i}.png",
        ])
        fig.savefig(fname)
    # visualize the time series of components and save it
    plt.rcParams['figure.figsize'] = 30, 20
    comp_sources = ica.plot_sources(epochs)
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-components_sources_sub-{subject}.png",
    ])
    comp_sources.savefig(fname)
    # reset plotting params
    plt.rcParams['figure.figsize'] = plt.rcParamsDefault['figure.figsize']

    # plot EOG components
    overlay_eog = ica.plot_overlay(eog_evoked,
                                   exclude=ica_comps[subject]['eog'])
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-eog-components_over-avg-epochs_sub-{subject}.png",
    ])
    overlay_eog.savefig(fname)
    # plot ECG components
    overlay_ecg = ica.plot_overlay(ecg_evoked,
                                   exclude=ica_comps[subject]['ecg'])
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-ecg-components_over-avg-epochs_sub-{subject}.png",
    ])
    overlay_ecg.savefig(fname)
    # plot EOG component properties
    figs = ica.plot_properties(filt_raw, picks=eog_indices)
    for i, fig in enumerate(figs):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"ica-property{i}_artifact-eog_sub-{subject}.png",
        ])
        fig.savefig(fname)
    # plot ECG component properties
    figs = ica.plot_properties(filt_raw, picks=ecg_indices)
    for i, fig in enumerate(figs):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"ica-property{i}_artifact-ecg_sub-{subject}.png",
        ])
        fig.savefig(fname)

    # Set the indices to be excluded
    ica.exclude = eog_indices
    ica.exclude.extend(ecg_indices)

    # plot ICs applied to the averaged EOG epochs, with EOG matches highlighted
    sources = ica.plot_sources(eog_evoked)
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-sources_artifact-eog_sub-{subject}.png",
    ])
    sources.savefig(fname)

    # plot ICs applied to the averaged ECG epochs, with ECG matches highlighted
    sources = ica.plot_sources(ecg_evoked)
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-sources_artifact-ecg_sub-{subject}.png",
    ])
    sources.savefig(fname)
    # apply the ICA to the raw data
    logging.info('Applying ICA to the raw data.')
    raw.load_data()
    ica.apply(raw)
def preprocess_eeg(id_num, random_seed=None):

    # Set important variables
    bids_path = BIDSPath(id_num, task=task, datatype=datatype, root=bids_root)
    plot_path = os.path.join(plotdir, "sub_{0}".format(id_num))
    if os.path.exists(plot_path):
        shutil.rmtree(plot_path)
    os.mkdir(plot_path)
    if not random_seed:
        random_seed = int(binascii.b2a_hex(os.urandom(4)), 16)
    random.seed(random_seed)
    id_info = {"id": id_num, "random_seed": random_seed}

    ### Load and prepare EEG data #############################################

    header = "### Processing sub-{0} (seed: {1}) ###".format(
        id_num, random_seed)
    print("\n" + "#" * len(header))
    print(header)
    print("#" * len(header) + "\n")

    # Load EEG data
    raw = read_raw_bids(bids_path, verbose=True)

    # Check if recording is complete
    complete = len(raw.annotations) >= 600

    # Add a montage to the data
    montage_kind = "standard_1005"
    montage = mne.channels.make_standard_montage(montage_kind)
    mne.datasets.eegbci.standardize(raw)
    raw.set_montage(montage)

    # Extract some info
    eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
    ch_names = raw.info["ch_names"]
    ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
    sample_rate = raw.info["sfreq"]

    # Make a copy of the data
    raw_copy = raw.copy()
    raw_copy.load_data()

    # Trim duplicated data (only needed for sub-005)
    annot = raw_copy.annotations
    file_starts = [a for a in annot if a['description'] == "file start"]
    if len(file_starts):
        duplicate_start = file_starts[0]['onset']
        raw_copy.crop(tmax=duplicate_start)

    # Make backup of EOG and EMG channels to re-append after PREP
    raw_other = raw_copy.copy()
    raw_other.pick_types(eog=True, emg=True, stim=False)

    # Prepare copy of raw data for PREP
    raw_copy.pick_types(eeg=True)

    # Plot data prior to any processing
    if complete:
        save_psd_plot(id_num, "psd_0_raw", plot_path, raw_copy)
        save_channel_plot(id_num, "ch_0_raw", plot_path, raw_copy)

    ### Clean up events #######################################################

    print("\n\n=== Processing Event Annotations... ===\n")

    event_names = [
        "stim_on", "red_on", "trace_start", "trace_end", "accuracy_submit",
        "vividness_submit"
    ]
    doubled = []
    wrong_label = []
    new_onsets = []
    new_durations = []
    new_descriptions = []

    # Find and flag any duplicate triggers
    annot = raw_copy.annotations
    trigger_count = len(annot)
    for i in range(1, trigger_count - 1):
        a = annot[i]
        on_last = i + 1 == trigger_count
        prev_trigger = annot[i - 1]['description']
        next_onset = annot[i + 1]['onset'] if not on_last else a['onset'] + 100
        # Determine whether duplicates are doubles or mislabeled
        if a['description'] == prev_trigger:
            if (next_onset - a['onset']) < 0.002:
                doubled.append(a)
            else:
                wrong_label.append(a)

    # Rename annotations to have meaningful names & fix duplicates
    for a in raw_copy.annotations:
        if a in doubled or a['description'] not in event_names:
            continue
        if a in wrong_label:
            index = event_names.index(a['description'])
            a['description'] = event_names[index + 1]
        new_onsets.append(a['onset'])
        new_durations.append(a['duration'])
        new_descriptions.append(a['description'])

    # Replace old annotations with new fixed ones
    if len(annot):
        new_annot = mne.Annotations(
            new_onsets,
            new_durations,
            new_descriptions,
            orig_time=raw_copy.annotations[0]['orig_time'])
        raw_copy.set_annotations(new_annot)

    # Check annotations to verify we have equal numbers of each
    orig_counts = Counter(annot.description)
    counts = Counter(raw_copy.annotations.description)
    print("Updated Annotation Counts:")
    for a in event_names:
        out = " - '{0}': {1} -> {2}"
        print(out.format(a, orig_counts[a], counts[a]))

    # Get info
    id_info['annot_doubled'] = len(doubled)
    id_info['annot_wrong'] = len(wrong_label)

    count_vals = [
        n for n in counts.values() if n != counts['vividness_submit']
    ]
    id_info['equal_triggers'] = all(x == count_vals[0] for x in count_vals)
    id_info['stim_on'] = counts['stim_on']
    id_info['red_on'] = counts['red_on']
    id_info['trace_start'] = counts['trace_start']
    id_info['trace_end'] = counts['trace_end']
    id_info['acc_submit'] = counts['accuracy_submit']
    id_info['vivid_submit'] = counts['vividness_submit']

    if not complete:
        remaining_info = {
            'initial_bad': "NA",
            'num_initial_bad': "NA",
            'interpolated': "NA",
            'num_interpolated': "NA",
            'remaining_bad': "NA",
            'num_remaining_bad': "NA"
        }
        id_info.update(remaining_info)
        e = "\n\n### Incomplete recording for sub-{0}, skipping... ###\n\n"
        print(e.format(id_num))
        return id_info

    ### Run components of PREP manually #######################################

    print("\n\n=== Performing CleanLine... ===")

    # Try to remove line noise using CleanLine approach
    linenoise = np.arange(60, sample_rate / 2, 60)
    EEG_raw = raw_copy.get_data() * 1e6
    EEG_new = removeTrend(EEG_raw, sample_rate=raw.info["sfreq"])
    EEG_clean = mne.filter.notch_filter(
        EEG_new,
        Fs=raw.info["sfreq"],
        freqs=linenoise,
        filter_length="10s",
        method="spectrum_fit",
        mt_bandwidth=2,
        p_value=0.01,
    )
    EEG_final = EEG_raw - EEG_new + EEG_clean
    raw_copy._data = EEG_final * 1e-6
    del linenoise, EEG_raw, EEG_new, EEG_clean, EEG_final

    # Plot data following cleanline
    save_psd_plot(id_num, "psd_1_cleanline", plot_path, raw_copy)
    save_channel_plot(id_num, "ch_1_cleanline", plot_path, raw_copy)

    # Perform robust re-referencing
    prep_params = {"ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg}
    reference = Reference(raw_copy,
                          prep_params,
                          ransac=True,
                          random_state=random_seed)
    print("\n\n=== Performing Robust Re-referencing... ===\n")
    reference.perform_reference()

    # If not interpolating bad channels, use pre-interpolation channel data
    if not interpolate_bads:
        reference.raw._data = reference.EEG_before_interpolation * 1e-6
        reference.interpolated_channels = []
        reference.still_noisy_channels = reference.bad_before_interpolation
        reference.raw.info["bads"] = reference.bad_before_interpolation

    # Plot data following robust re-reference
    save_psd_plot(id_num, "psd_2_reref", plot_path, reference.raw)
    save_channel_plot(id_num, "ch_2_reref", plot_path, reference.raw)

    # Re-append removed EMG/EOG/trigger channels
    raw_prepped = reference.raw.add_channels([raw_other])

    # Get info
    initial_bad = reference.noisy_channels_original["bad_all"]
    id_info['initial_bad'] = " ".join(initial_bad)
    id_info['num_initial_bad'] = len(initial_bad)

    interpolated = reference.interpolated_channels
    id_info['interpolated'] = " ".join(interpolated)
    id_info['num_interpolated'] = len(interpolated)

    remaining_bad = reference.still_noisy_channels
    id_info['remaining_bad'] = " ".join(remaining_bad)
    id_info['num_remaining_bad'] = len(remaining_bad)

    # Print re-referencing info
    print("\nRe-Referencing Info:")
    print(" - Bad channels original: {0}".format(initial_bad))
    if interpolate_bads:
        print(" - Bad channels after re-referencing: {0}".format(interpolated))
        print(" - Bad channels after interpolation: {0}".format(remaining_bad))
    else:
        print(
            " - Bad channels after re-referencing: {0}".format(remaining_bad))

    # Check if too many channels were interpolated for the participant
    prop_interpolated = len(
        reference.interpolated_channels) / len(ch_names_eeg)
    e = "### NOTE: Too many interpolated channels for sub-{0} ({1}) ###"
    if max_interpolated < prop_interpolated:
        print("\n")
        print(e.format(id_num, len(reference.interpolated_channels)))
        print("\n")

    ### Filter data and apply ICA to remove blinks ############################

    # Apply highpass & lowpass filters
    print("\n\n=== Applying Highpass & Lowpass Filters... ===")
    raw_prepped.filter(1.0, 50.0, fir_design='firwin')

    # Plot data following frequency filters
    save_psd_plot(id_num, "psd_3_filtered", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_3_filtered", plot_path, raw_prepped)

    # Perform ICA using EOG data on eye blinks
    print("\n\n=== Removing Blinks Using ICA... ===\n")
    ica = ICA(n_components=20, random_state=random_seed, method='picard')
    ica.fit(raw_prepped, decim=5)
    eog_indices, eog_scores = ica.find_bads_eog(raw_prepped)
    ica.exclude = eog_indices

    if not len(ica.exclude):
        err = " - Encountered an ICA error for sub-{0}, skipping for now..."
        print("\n")
        print(err.format(id_num))
        print("\n")
        save_bad_fif(raw_prepped, id_num, ica_err_dir)
        return id_info

    # Plot ICA info & diagnostics before removing from signal
    save_ica_plots(id_num, plot_path, raw_prepped, ica, eog_scores)

    # Remove eye blink independent components based on ICA
    ica.apply(raw_prepped)

    # Plot data following ICA
    save_psd_plot(id_num, "psd_4_ica", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_4_ica", plot_path, raw_prepped)

    ### Compute Current Source Density (CSD) estimates ########################

    if perform_csd:
        print("\n")
        print("=== Computing Current Source Density (CSD) Estimates... ===\n")
        raw_prepped = mne.preprocessing.compute_current_source_density(
            raw_prepped.drop_channels(remaining_bad))

        # Plot data following CSD
        save_psd_plot(id_num, "psd_5_csd", plot_path, raw_prepped)
        save_channel_plot(id_num, "ch_5_csd", plot_path, raw_prepped)

    ### Write preprocessed data to new EDF ####################################

    if max_interpolated < prop_interpolated:
        if not os.path.isdir(noisy_bad_dir):
            os.makedirs(noisy_bad_dir)
        outpath = os.path.join(noisy_bad_dir, outfile_fmt.format(id_num))
    else:
        outpath = os.path.join(outdir, outfile_fmt.format(id_num))
    write_mne_edf(outpath, raw_prepped)

    print("\n\n### sub-{0} complete! ###\n\n".format(id_num))

    return id_info
    def run(self):

        eog = self.info['channel_info']['EOG']
        misc = self.info['channel_info']['Misc']
        stim = self.info['channel_info']['Stim']

        try:
            ext_files = glob.glob(self.info['ext_file_folder'] + '/' +
                                  self.participant + '/*axis0.dat')
        except:
            pass

        tmin = self.t_epoch[0]
        tmax = self.t_epoch[1]

        raw = read_raw_edf(self.file, eog=eog, misc=misc)
        self.raw = cp.deepcopy(raw)
        raw.load_data()

        # marker detection (one marker continous trial)
        if self.info['marker_detection'] == True:
            starts = find_trialstart(raw,
                                     stim_channel=raw.ch_names[stim[0]],
                                     new_samplin_rate=self.sr_new)
            try:
                starts[1] = starts[0] + 30 * 200
            except:
                starts = np.r_[starts, (starts[0] + 30 * 200)]
            events = np.zeros((len(starts), 3))
            events[:, 0] = starts
            events[:, 2] = list(self.info['event_dict'].values())
            events = events.astype(np.int)

        # event detection (one marker regular events)
        if self.info['event_detection'] == True:
            starts = find_trialstart(raw,
                                     stim_channel=raw.ch_names[stim[0]],
                                     new_samplin_rate=self.sr_new)

            events = force_events(ext_files, self.info['event_dict'],
                                  self.sr_new, self.info['trial_length'],
                                  self.info['trials'],
                                  starts[:len(self.info['event_dict'])])

        if self.info['ICA'] == True:
            ica = ICA(method='fastica')

        if self.info['Autoreject'] == True:
            ar = AutoReject()

        ## EEG preprocessing options will applied if parameters are set in object

        #read montage
        try:
            montage = make_standard_montage(self.montage)
            raw.set_montage(montage)
        except:
            pass

        #resampling
        try:
            raw.resample(sfreq=self.sr_new)
        except:
            pass

        #rereferencing
        try:
            raw, _ = mne.set_eeg_reference(raw, ref_channels=['EXG5', 'EXG6'])
        except:
            pass

        #filter
        try:
            low = self.filter_freqs[0]
            high = self.filter_freqs[1]
            raw.filter(low, high, fir_design='firwin')
        except:
            pass

        # occular correction
        try:
            ica.fit(raw)
            ica.exclude = []
            eog_indices, eog_scores = ica.find_bads_eog(raw)
            ica.exclude = eog_indices
            ica.apply(raw)
            self.ica = ica
        except:
            pass

        picks = mne.pick_types(raw.info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude='bads')

        event_id = self.info['event_dict']
        epochs = mne.Epochs(raw,
                            events,
                            event_id,
                            tmin,
                            tmax,
                            proj=True,
                            baseline=None,
                            preload=True,
                            picks=picks)

        #epoch rejection
        try:
            epochs = epochs.drop(indices=self.bads)
        except:
            pass

        try:
            epochs, self.autoreject_log = ar.fit_transform(epochs,
                                                           return_log=True)
        except:
            pass

        bads = np.asarray(
            [l == ['USER'] or l == ['AUTOREJECT'] for l in epochs.drop_log])
        self.bads = np.where(bads == True)
        self.epochs = epochs
        return (self)
# estimate average artifact
ecg_evoked = ecg_epochs.average()
fig = ica.plot_sources(ecg_evoked, exclude=ecg_inds)  # plot ECG sources + selection
fig.savefig(img_folder + '/ica_ecg_evoked_sources.png')
fig = ica.plot_overlay(ecg_evoked, exclude=ecg_inds)  # plot ECG cleaning
fig.savefig(img_folder + '/ica_ecg_evoked_overlay.png')

#eog_evoked = create_eog_epochs(raw, tmin=-.5, tmax=.5, picks=picks).average()
#fig = ica.plot_sources(eog_evoked, exclude=eog_inds)  # plot EOG sources + selection
#fig.savefig(img_folder + '/ica_eog_evoked_sources.png')
#fig = ica.plot_overlay(eog_evoked, exclude=eog_inds)  # plot EOG cleaning
#fig.savefig(img_folder + '/ica_eog_evoked_overlay.png')

tmp=ica.exclude
ica.exclude = []
veog_evoked = create_eog_epochs(raw, ch_name='EOG001', tmin=-.5, tmax=.5, picks=picks).average()
fig = ica.plot_sources(veog_evoked, exclude=veog_inds)  # plot EOG sources + selection
fig.savefig(img_folder + '/ica_veog_evoked_sources_veog_inds.png')
fig = ica.plot_overlay(veog_evoked, exclude=veog_inds)  # plot EOG cleaning
fig.savefig(img_folder + '/ica_veog_evoked_overlay_veog_inds.png')
fig = ica.plot_sources(veog_evoked, exclude=heog_inds)  # plot EOG sources + selection
fig.savefig(img_folder + '/ica_veog_evoked_sources_heog_inds.png')
fig = ica.plot_overlay(veog_evoked, exclude=heog_inds)  # plot EOG cleaning
fig.savefig(img_folder + '/ica_veog_evoked_overlay_heog_inds.png')
fig = ica.plot_sources(veog_evoked, exclude=eog_inds)  # plot EOG sources + selection
fig.savefig(img_folder + '/ica_veog_evoked_sources_eog_inds.png')
fig = ica.plot_overlay(veog_evoked, exclude=eog_inds)  # plot EOG cleaning
fig.savefig(img_folder + '/ica_veog_evoked_overlay_eog_inds.png')

heog_evoked = create_eog_epochs(raw, ch_name='EOG003', tmin=-.5, tmax=.5, picks=picks).average()
        json.dump(json_info, outfile, indent=4)
    del json_info, json_file

    # Estimate ICA
    ica = ICA(method='picard', max_iter=1000, random_state=97,
              fit_params=dict(ortho=True, extended=True),
              verbose=True)
    ica.fit(epochs)

    # Save ICA
    ica_file = deriv_path / f'{sub}_task-{task}_ref-FCz_desc-ica_ica.fif.gz'
    ica.save(ica_file)

    # Find EOG artifacts
    eog_inds, eog_scores = ica.find_bads_eog(raw)
    ica.exclude = eog_inds
    eog_ica_plot = ica.plot_scores(eog_scores, labels=['VEOG', 'HEOG'])
    eog_ica_file = fig_path / f'{sub}_task-{task}_ic_eog_scores.png'
    eog_ica_plot.savefig(eog_ica_file, dpi=600)
    plt.close(eog_ica_plot)

    # Plot all component properties
    ica.plot_components(inst=epochs, reject=None,
                        psd_args=dict(fmax=70))
    ica.exclude.sort()
    ica.save(ica_file)
    print(f'ICs Flagged for Removal: {ica.exclude}')

    # Make a JSON
    json_info = {
        'Description': 'ICA components',
Exemple #27
0
def test_ica_additional():
    """Test additional functionality
    """
    stop2 = 500

    test_cov2 = deepcopy(test_cov)
    ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
              n_pca_components=4)
    ica.decompose_raw(raw, picks[:5])
    assert_true(ica.n_components_ < 5)

    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_raises(RuntimeError, ica.save, '')
    ica.decompose_raw(raw, picks=None, start=start, stop=stop2)

    # epochs extraction from raw fit
    assert_raises(RuntimeError, ica.get_sources_epochs, epochs)

    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'ica_test.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov, n_components=3, max_pca_components=4,
                  n_pca_components=4)
        ica.decompose_raw(raw, picks=picks, start=start, stop=stop2)
        sources = ica.get_sources_epochs(epochs)
        assert_true(sources.shape[1] == ica.n_components_)

        for exclude in [[], [0]]:
            ica.exclude = [0]
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert_true(ica.exclude == ica_read.exclude)
            # test pick merge -- add components
            ica.pick_sources_raw(raw, exclude=[1])
            assert_true(ica.exclude == [0, 1])
            #                 -- only as arg
            ica.exclude = []
            ica.pick_sources_raw(raw, exclude=[0, 1])
            assert_true(ica.exclude == [0, 1])
            #                 -- remove duplicates
            ica.exclude += [1]
            ica.pick_sources_raw(raw, exclude=[0, 1])
            assert_true(ica.exclude == [0, 1])

            ica_raw = ica.sources_as_raw(raw)
            assert_true(ica.exclude == [ica.ch_names.index(e) for e in
                                        ica_raw.info['bads']])

        ica.n_pca_components = 2
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert_true(ica.n_pca_components ==
                    ica_read.n_pca_components)
        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)

        assert_true(ica.ch_names == ica_read.ch_names)

        assert_true(np.allclose(ica.mixing_matrix_, ica_read.mixing_matrix_,
                                rtol=1e-16, atol=1e-32))
        assert_array_equal(ica.pca_components_,
                           ica_read.pca_components_)
        assert_array_equal(ica.pca_mean_, ica_read.pca_mean_)
        assert_array_equal(ica.pca_explained_variance_,
                           ica_read.pca_explained_variance_)
        assert_array_equal(ica._pre_whitener, ica_read._pre_whitener)

        # assert_raises(RuntimeError, ica_read.decompose_raw, raw)
        sources = ica.get_sources_raw(raw)
        sources2 = ica_read.get_sources_raw(raw)
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.pick_sources_raw(raw, exclude=[1])
        _raw2 = ica_read.pick_sources_raw(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # score funcs raw, with catch since "ties preclude exact" warning
    # XXX this should be fixed by a future PR...
    with warnings.catch_warnings(True) as w:
        sfunc_test = [ica.find_sources_raw(raw, target='EOG 061',
                score_func=n, start=0, stop=10)
                for n, f in score_funcs.items()]
    # score funcs raw

    # check lenght of scores
    [assert_true(ica.n_components_ == len(scores)) for scores in sfunc_test]

    # check univariate stats
    scores = ica.find_sources_raw(raw, score_func=stats.skew)
    # check exception handling
    assert_raises(ValueError, ica.find_sources_raw, raw,
                  target=np.arange(1))

    ## score funcs epochs ##

    # check lenght of scores
    # XXX this needs to be fixed, some of the score funcs don't seem to be
    # suited for the testing data.
    with warnings.catch_warnings(True) as w:
        sfunc_test = [ica.find_sources_epochs(epochs_eog, target='EOG 061',
                score_func=n)
                for n, f in score_funcs.items()]

    # check lenght of scores
    [assert_true(ica.n_components_ == len(scores)) for scores in sfunc_test]

    # check univariat stats
    scores = ica.find_sources_epochs(epochs, score_func=stats.skew)

    # check exception handling
    assert_raises(ValueError, ica.find_sources_epochs, epochs,
                  target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.find_sources_raw(raw, target='MEG 1531',
                                      score_func='pearsonr')

    ecg_events = ica_find_ecg_events(raw, sources[np.abs(ecg_scores).argmax()])

    assert_true(ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.find_sources_raw(raw, target='EOG 061',
                                      score_func='pearsonr')
    eog_events = ica_find_eog_events(raw, sources[np.abs(eog_scores).argmax()])

    assert_true(eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.sources_as_raw(raw, start=0, stop=100)
    assert_true(ica_raw.last_samp - ica_raw.first_samp == 100)
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test_ica.fif')
    ica_raw.save(test_ica_fname)
    ica_raw2 = fiff.Raw(test_ica_fname, preload=True)
    assert_array_almost_equal(ica_raw._data, ica_raw2._data)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # regression test for plot method
    assert_raises(ValueError, ica.plot_sources_raw, raw,
                  order=np.arange(50))
    assert_raises(ValueError, ica.plot_sources_epochs, epochs,
                  order=np.arange(50))
def compute_ica(raw, subject, n_components=0.99, picks=None, decim=None,
                reject=None, ecg_tmin=-0.5, ecg_tmax=0.5, eog_tmin=-0.5,
                eog_tmax=0.5, n_max_ecg=3, n_max_eog=1,
                n_max_ecg_epochs=200, show=True, img_scale=1.0,
                random_state=None, report=None, artifact_stats=None):
    """Run ICA in raw data

    Parameters
    ----------,
    raw : instance of Raw
        Raw measurements to be decomposed.
    subject : str
        The name of the subject.
    picks : array-like of int, shape(n_channels, ) | None
        Channels to be included. This selection remains throughout the
        initialized ICA solution. If None only good data channels are used.
        Defaults to None.
    n_components : int | float | None | 'rank'
        The number of components used for ICA decomposition. If int, it must be
        smaller then max_pca_components. If None, all PCA components will be
        used. If float between 0 and 1 components can will be selected by the
        cumulative percentage of explained variance.
        If 'rank', the number of components equals the rank estimate.
        Defaults to 0.99.
    decim : int | None
        Increment for selecting each nth time slice. If None, all samples
        within ``start`` and ``stop`` are used. Defalts to None.
    reject : dict | None
        Rejection parameters based on peak to peak amplitude.
        Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg'.
        If reject is None then no rejection is done. You should
        use such parameters to reject big measurement artifacts
        and not EOG for example. It only applies if `inst` is of type Raw.
        Defaults to {'mag': 5e-12}
    ecg_tmin : float
        Start time before ECG event. Defaults to -0.5.
    ecg_tmax : float
        End time after ECG event. Defaults to 0.5.
    eog_tmin : float
        Start time before rog event. Defaults to -0.5.
    eog_tmax : float
        End time after rog event. Defaults to 0.5.
    n_max_ecg : int | None
        The maximum number of ECG components to exclude. Defaults to 3.
    n_max_eog : int | None
        The maximum number of EOG components to exclude. Defaults to 1.
    n_max_ecg_epochs : int
        The maximum number of ECG epochs to use for phase-consistency
        estimation. Defaults to 200.
    show : bool
        Show figure if True
    scale_img : float
        The scaling factor for the report. Defaults to 1.0.
    random_state : None | int | instance of np.random.RandomState
        np.random.RandomState to initialize the FastICA estimation.
        As the estimation is non-deterministic it can be useful to
        fix the seed to have reproducible results. Defaults to None.
    report : instance of Report | None
        The report object. If None, a new report will be generated.
    artifact_stats : None | dict
        A dict that contains info on amplitude ranges of artifacts and
        numbers of events, etc. by channel type.

    Returns
    -------
    ica : instance of ICA
        The ICA solution.
    report : dict
        A dict with an html report ('html') and artifact statistics ('stats').
    """
    if report is None:
        report = Report(subject=subject, title='ICA preprocessing')
    if n_components == 'rank':
        n_components = raw.estimate_rank(picks=picks)
    ica = ICA(n_components=n_components, max_pca_components=None,
              random_state=random_state, max_iter=256)
    ica.fit(raw, picks=picks, decim=decim, reject=reject)

    comment = []
    for ch in ('mag', 'grad', 'eeg'):
        if ch in ica:
            comment += [ch.upper()]
    if len(comment) > 0:
        comment = '+'.join(comment) + ' '
    else:
        comment = ''

    topo_ch_type = 'mag'
    if 'GRAD' in comment and 'MAG' not in comment:
        topo_ch_type = 'grad'
    elif 'EEG' in comment:
        topo_ch_type = 'eeg'

    ###########################################################################
    # 2) identify bad components by analyzing latent sources.

    title = '%s related to %s artifacts (red) ({})'.format(subject)

    # generate ECG epochs use detection via phase statistics
    reject_ = {'mag': 5e-12, 'grad': 5000e-13, 'eeg': 300e-6}
    if reject is not None:
        reject_.update(reject)
    for ch_type in ['mag', 'grad', 'eeg']:
        if ch_type not in ica:
            reject_.pop(ch_type)

    picks_ = np.array([raw.ch_names.index(k) for k in ica.ch_names])
    if 'eeg' in ica:
        if 'ecg' in raw:
            picks_ = np.append(picks_,
                               pick_types(raw.info, meg=False, ecg=True)[0])
        else:
            logger.info('There is no ECG channel, trying to guess ECG from '
                        'magnetormeters')

    if artifact_stats is None:
        artifact_stats = dict()

    ecg_epochs = create_ecg_epochs(raw, tmin=ecg_tmin, tmax=ecg_tmax,
                                   keep_ecg=True, picks=picks_, reject=reject_)

    n_ecg_epochs_found = len(ecg_epochs.events)
    artifact_stats['ecg_n_events'] = n_ecg_epochs_found
    n_max_ecg_epochs = min(n_max_ecg_epochs, n_ecg_epochs_found)
    artifact_stats['ecg_n_used'] = n_max_ecg_epochs

    sel_ecg_epochs = np.arange(n_ecg_epochs_found)
    rng = np.random.RandomState(42)
    rng.shuffle(sel_ecg_epochs)
    ecg_ave = ecg_epochs.average()

    report.add_figs_to_section(ecg_ave.plot(), 'ECG-full', 'artifacts')
    ecg_epochs = ecg_epochs[sel_ecg_epochs[:n_max_ecg_epochs]]
    ecg_ave = ecg_epochs.average()
    report.add_figs_to_section(ecg_ave.plot(), 'ECG-used', 'artifacts')

    _put_artifact_range(artifact_stats, ecg_ave, kind='ecg')

    ecg_inds, scores = ica.find_bads_ecg(ecg_epochs, method='ctps')
    if len(ecg_inds) > 0:
        ecg_evoked = ecg_epochs.average()
        del ecg_epochs

        fig = ica.plot_scores(scores, exclude=ecg_inds, labels='ecg',
                              title='', show=show)

        report.add_figs_to_section(fig, 'scores ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)

        current_exclude = [e for e in ica.exclude]  # issue #2608 MNE
        fig = ica.plot_sources(raw, ecg_inds, exclude=ecg_inds,
                               title=title % ('components', 'ecg'), show=show)

        report.add_figs_to_section(fig, 'sources ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)
        ica.exclude = current_exclude

        fig = ica.plot_components(ecg_inds, ch_type=topo_ch_type,
                                  title='', colorbar=True, show=show)
        report.add_figs_to_section(fig, title % ('sources', 'ecg'),
                                   section=comment + 'ECG', scale=img_scale)
        ica.exclude = current_exclude

        ecg_inds = ecg_inds[:n_max_ecg]
        ica.exclude += ecg_inds
        fig = ica.plot_sources(ecg_evoked, exclude=ecg_inds, show=show)
        report.add_figs_to_section(fig, 'evoked sources ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)

        fig = ica.plot_overlay(ecg_evoked, exclude=ecg_inds, show=show)
        report.add_figs_to_section(fig,
                                   'rejection overlay ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)

    # detect EOG by correlation
    picks_eog = np.concatenate(
        [picks_, pick_types(raw.info, meg=False, eeg=False, ecg=False,
                            eog=True)])

    eog_epochs = create_eog_epochs(raw, tmin=eog_tmin, tmax=eog_tmax,
                                   picks=picks_eog, reject=reject_)
    artifact_stats['eog_n_events'] = len(eog_epochs.events)
    artifact_stats['eog_n_used'] = artifact_stats['eog_n_events']
    eog_ave = eog_epochs.average()
    report.add_figs_to_section(eog_ave.plot(), 'EOG-used', 'artifacts')
    _put_artifact_range(artifact_stats, eog_ave, kind='eog')

    eog_inds = None
    if len(eog_epochs.events) > 0:
        eog_inds, scores = ica.find_bads_eog(eog_epochs)

    if eog_inds is not None and len(eog_epochs.events) > 0:
        fig = ica.plot_scores(scores, exclude=eog_inds, labels='eog',
                              show=show, title='')
        report.add_figs_to_section(fig, 'scores ({})'.format(subject),
                                   section=comment + 'EOG',
                                   scale=img_scale)

        current_exclude = [e for e in ica.exclude]  # issue #2608 MNE
        fig = ica.plot_sources(raw, eog_inds, exclude=ecg_inds,
                               title=title % ('sources', 'eog'), show=show)
        report.add_figs_to_section(fig, 'sources', section=comment + 'EOG',
                                   scale=img_scale)
        ica.exclude = current_exclude

        fig = ica.plot_components(eog_inds, ch_type=topo_ch_type,
                                  title='', colorbar=True, show=show)
        report.add_figs_to_section(fig, title % ('components', 'eog'),
                                   section=comment + 'EOG', scale=img_scale)
        ica.exclude = current_exclude

        eog_inds = eog_inds[:n_max_eog]
        ica.exclude += eog_inds

        eog_evoked = eog_epochs.average()
        fig = ica.plot_sources(eog_evoked, exclude=eog_inds, show=show)
        report.add_figs_to_section(
            fig, 'evoked sources ({})'.format(subject),
            section=comment + 'EOG', scale=img_scale)

        fig = ica.plot_overlay(eog_evoked, exclude=eog_inds, show=show)
        report.add_figs_to_section(
            fig, 'rejection overlay({})'.format(subject),
            section=comment + 'EOG', scale=img_scale)
    else:
        del eog_epochs

    # check the amplitudes do not change
    if len(ica.exclude) > 0:
        fig = ica.plot_overlay(raw, show=show)  # EOG artifacts remain
        html = _render_components_table(ica)
        report.add_htmls_to_section(
            html, captions='excluded components',
            section='ICA rejection summary (%s)' % ch_type)
        report.add_figs_to_section(
            fig, 'rejection overlay({})'.format(subject),
            section=comment + 'RAW', scale=img_scale)
    return ica, dict(html=report, stats=artifact_stats)
def ICA_pipeline(mne_array,
                 regions,
                 chans_to_plot=20,
                 base_name="",
                 exclude=None,
                 skip_plots=False):
    """This is example code using mne."""
    raw = mne_array

    if not skip_plots:
        # Plot raw signal
        raw.plot(n_channels=chans_to_plot,
                 block=True,
                 duration=25,
                 show=True,
                 clipping="transparent",
                 title="Raw LFP Data from {}".format(base_name),
                 remove_dc=False,
                 scalings=dict(eeg=350e-6))

    # Perform ICA using mne
    from mne.preprocessing import ICA
    filt_raw = raw.copy()
    filt_raw.load_data().filter(l_freq=1., h_freq=None)
    ica = ICA(method='fastica', random_state=97)
    # ica = ICA(method='picard', random_state=97)
    ica.fit(filt_raw)

    # ica.exclude = [4, 6, 12]
    raw.load_data()
    if exclude is None:
        # Plot raw ICAs
        print('Select channels to exclude using this plot...')
        ica.plot_sources(raw,
                         block=False,
                         stop=25,
                         title='ICA from {}'.format(base_name))

        print('Click topo to get more ICA properties')
        ica.plot_components(inst=raw)

        # Overlay ICA cleaned signal over raw. Seperate plot for each region.
        # TODO Add scroll bar or include window selection option.
        # cont = input("Plot region overlay? (y|n) \n")
        # if cont.strip().lower() == "y":
        #     reg_grps = []
        #     for reg in set(regions):
        #         temp_grp = []
        #         for ch in raw.info.ch_names:
        #             if reg in ch:
        #                 temp_grp.append(ch)
        #         reg_grps.append(temp_grp)
        #     for grps in reg_grps:
        #         ica.plot_overlay(raw, stop=int(30 * 250), title='{}'.format(
        #             grps[0][:3]), picks=grps)
    else:
        # ICAs to exclude
        ica.exclude = exclude
        if not skip_plots:
            ica.plot_sources(raw,
                             block=False,
                             stop=25,
                             title='ICA from {}'.format(base_name))
            ica.plot_components(inst=raw)
    # Apply ICA exclusion
    reconst_raw = raw.copy()
    exclude_raw = raw.copy()
    print("ICAs excluded: ", ica.exclude)
    ica.apply(reconst_raw)

    if not skip_plots:
        # change exclude to all except chosen ICs
        all_ICs = list(range(ica.n_components_))
        for i in ica.exclude:
            all_ICs.remove(i)
        ica.exclude = all_ICs
        ica.apply(exclude_raw)

        # Plot excluded ICAs
        exclude_raw.plot(block=True,
                         show=True,
                         clipping="transparent",
                         duration=25,
                         title="Excluded ICs from {}".format(base_name),
                         remove_dc=False,
                         scalings=dict(eeg=350e-6))

        # Plot reconstructed signals w/o excluded ICAs
        reconst_raw.plot(
            block=True,
            show=True,
            clipping="transparent",
            duration=25,
            title="Reconstructed LFP Data from {}".format(base_name),
            remove_dc=False,
            scalings=dict(eeg=350e-6))
    return reconst_raw
Exemple #30
0
def test_ica_additional():
    """Test additional ICA functionality
    """
    stop2 = 500
    raw = io.Raw(raw_fname, preload=True).crop(0, stop, False).crop(1.5)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)
    # for testing eog functionality
    picks2 = pick_types(raw.info, meg=True, stim=False, ecg=False,
                        eog=True, exclude='bads')
    epochs_eog = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks2,
                        baseline=(None, 0), preload=True)

    test_cov2 = deepcopy(test_cov)
    ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_true(ica.info is None)
    ica.decompose_raw(raw, picks[:5])
    assert_true(isinstance(ica.info, Info))
    assert_true(ica.n_components_ < 5)

    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_raises(RuntimeError, ica.save, '')
    ica.decompose_raw(raw, picks=None, start=start, stop=stop2)

    # test warnings on bad filenames
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
        ica.save(ica_badname)
        read_ica(ica_badname)
    assert_true(len(w) == 2)

    # test decim
    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    ica.decompose_raw(raw, picks=None, decim=3)
    assert_true(raw_._data.shape[1], n_samples)

    # test expl var
    ica = ICA(n_components=1.0, max_pca_components=4,
              n_pca_components=4)
    ica.decompose_raw(raw, picks=None, decim=3)
    assert_true(ica.n_components_ == 4)

    # epochs extraction from raw fit
    assert_raises(RuntimeError, ica.get_sources_epochs, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov, n_components=2, max_pca_components=4,
                  n_pca_components=4)
        with warnings.catch_warnings(record=True):  # ICA does not converge
            ica.decompose_raw(raw, picks=picks, start=start, stop=stop2)
        sources = ica.get_sources_epochs(epochs)
        assert_true(ica.mixing_matrix_.shape == (2, 2))
        assert_true(ica.unmixing_matrix_.shape == (2, 2))
        assert_true(ica.pca_components_.shape == (4, len(picks)))
        assert_true(sources.shape[1] == ica.n_components_)

        for exclude in [[], [0]]:
            ica.exclude = [0]
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert_true(ica.exclude == ica_read.exclude)
            # test pick merge -- add components
            ica.pick_sources_raw(raw, exclude=[1])
            assert_true(ica.exclude == [0, 1])
            #                 -- only as arg
            ica.exclude = []
            ica.pick_sources_raw(raw, exclude=[0, 1])
            assert_true(ica.exclude == [0, 1])
            #                 -- remove duplicates
            ica.exclude += [1]
            ica.pick_sources_raw(raw, exclude=[0, 1])
            assert_true(ica.exclude == [0, 1])

            # test basic include
            ica.exclude = []
            ica.pick_sources_raw(raw, include=[1])

            ica_raw = ica.sources_as_raw(raw)
            assert_true(ica.exclude == [ica_raw.ch_names.index(e) for e in
                                        ica_raw.info['bads']])

        # test filtering
        d1 = ica_raw._data[0].copy()
        with warnings.catch_warnings(record=True):  # dB warning
            ica_raw.filter(4, 20)
        assert_true((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        with warnings.catch_warnings(record=True):  # dB warning
            ica_raw.notch_filter([10])
        assert_true((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert_true(ica.n_pca_components == ica_read.n_pca_components)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ _pre_whitener')
        f = lambda x, y: getattr(x, y).dtype
        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in ['mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                     'pca_mean_', 'pca_explained_variance_',
                     '_pre_whitener']:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert_true(ica.ch_names == ica_read.ch_names)
        assert_true(isinstance(ica_read.info, Info))

        assert_raises(RuntimeError, ica_read.decompose_raw, raw)
        sources = ica.get_sources_raw(raw)
        sources2 = ica_read.get_sources_raw(raw)
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.pick_sources_raw(raw, exclude=[1])
        _raw2 = ica_read.pick_sources_raw(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check scrore funcs
    for name, func in score_funcs.items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.find_sources_raw(raw, target='EOG 061', score_func=func,
                                      start=0, stop=10)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.find_sources_raw(raw, score_func=stats.skew)
    # check exception handling
    assert_raises(ValueError, ica.find_sources_raw, raw,
                  target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # varicance, kurtosis idx params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw, start_find=0, stop_find=50, ecg_ch=ch_name,
                             eog_ch=ch_name, skew_criterion=idx,
                             var_criterion=idx, kurt_criterion=idx)
    ## score funcs epochs ##

    # check score funcs
    for name, func in score_funcs.items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.find_sources_epochs(epochs_eog, target='EOG 061',
                                         score_func=func)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.find_sources_epochs(epochs, score_func=stats.skew)

    # check exception handling
    assert_raises(ValueError, ica.find_sources_epochs, epochs,
                  target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.find_sources_raw(raw, target='MEG 1531',
                                      score_func='pearsonr')

    with warnings.catch_warnings(record=True):  # filter attenuation warning
        ecg_events = ica_find_ecg_events(raw,
                                         sources[np.abs(ecg_scores).argmax()])

    assert_true(ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.find_sources_raw(raw, target='EOG 061',
                                      score_func='pearsonr')
    with warnings.catch_warnings(record=True):  # filter attenuation warning
        eog_events = ica_find_eog_events(raw,
                                         sources[np.abs(eog_scores).argmax()])

    assert_true(eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.sources_as_raw(raw, start=0, stop=100)
    assert_true(ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_true(len(ica_raw._filenames) == 0)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = io.Raw(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.sources_as_epochs(epochs)
    assert_true(ica_epochs.events.shape == epochs.events.shape)
    sources_epochs = ica.get_sources_epochs(epochs)
    assert_array_equal(ica_epochs.get_data(), sources_epochs)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    assert_true(ica.n_components_ == ica_epochs.get_data().shape[1])
    assert_true(ica_epochs.raw is None)
    assert_true(ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = _check_n_pca_components(ica, ncomps)
        assert_true(ncomps_ == expected)
Exemple #31
0
def test_plot_ica_sources(raw_orig, mpl_backend):
    """Test plotting of ICA panel."""
    raw = raw_orig.copy().crop(0, 1)
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    assert mpl_backend._get_n_figs() == 1
    # change which component is in ICA.exclude (click data trace to remove
    # current one; click name to add other one)
    fig._redraw()
    # ToDo: This will be different methods in pyqtgraph
    x = fig.mne.traces[1].get_xdata()[5]
    y = fig.mne.traces[1].get_ydata()[5]
    fig._fake_click((x, y), xform='data')  # exclude = []
    _click_ch_name(fig, ch_index=0, button=1)  # exclude = [0]
    fig._fake_keypress(fig.mne.close_key)
    fig._close_event()
    assert mpl_backend._get_n_figs() == 0
    assert_array_equal(ica.exclude, [0])
    # test when picks does not include ica.exclude.
    fig = ica.plot_sources(raw, picks=[1])
    assert len(plt.get_fignums()) == 1
    mpl_backend._close_all()

    # dtype can change int->np.int64 after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)

    # test clicks on y-label (need >2 secs for plot_properties() to work)
    long_raw = raw_orig.crop(0, 5)
    fig = ica.plot_sources(long_raw)
    assert len(plt.get_fignums()) == 1
    fig._redraw()
    _click_ch_name(fig, ch_index=0, button=3)
    assert len(fig.mne.child_figs) == 1
    assert len(plt.get_fignums()) == 2
    # close child fig directly (workaround for mpl issue #18609)
    fig._fake_keypress('escape', fig=fig.mne.child_figs[0])
    assert len(plt.get_fignums()) == 1
    fig._fake_keypress(fig.mne.close_key)
    assert len(plt.get_fignums()) == 0
    del long_raw

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.mne.ax_main.collections) == 1
    assert len(fig.mne.ax_hscroll.collections) == 1
    raw.set_annotations(orig_annot)

    # test error handling
    raw_ = raw.copy().load_data()
    raw_.drop_channels('MEG 0113')
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"), \
         pytest.warns(RuntimeWarning, match='could not be picked'):
        ica.plot_sources(inst=raw_)
    epochs_ = epochs.copy().load_data()
    epochs_.drop_channels('MEG 0113')
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"), \
         pytest.warns(RuntimeWarning, match='could not be picked'):
        ica.plot_sources(inst=epochs_)
    del raw_
    del epochs_

    # test w/ epochs and evokeds
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')

    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)

    # pretend find_bads_eog() yielded some results
    ica.labels_ = {'eog': [0], 'eog/0/crazy-channel': [0]}
    ica.plot_sources(evoked)  # now with labels

    # pass an invalid inst
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
Exemple #32
0
def run_preprocessing_eog(raw_eeg,
                          tmin=0,
                          tmax=60,
                          remove_artefact=True,
                          show=True):
    '''Function to remove occular artefact by removing the ICA component(s) that are closest to a reference channel (here Fp2)

  Parameters
  ----------
  raw_eeg : Raw file 
  tmin : default 0
  tmax : default 60
  remove_artefact : wether to remove the ICA component from raw (default True)
  show : show scores (default : true)

  Return
  ------
  Return cleaned raw or original raw if remove_artefact = True or no matching components were found.
  
  '''
    raw_inter = raw_eeg.copy().crop(tmin=tmin, tmax=tmax)
    filtered_raw = raw_inter.filter(l_freq=1., h_freq=40.)

    ica = ICA(n_components=raw_eeg.info.get('nchan') - 2, random_state=9)

    print(raw_eeg.info)

    ica.fit(filtered_raw, verbose=False)
    ica.exclude = []
    # find which ICs match the EOG pattern
    eog_indices, eog_scores = ica.find_bads_eog(raw_inter,
                                                ch_name='Fp2',
                                                threshold=1.4,
                                                verbose=False)
    ica.exclude = eog_indices

    if show == True:
        ica.plot_sources(filtered_raw)
        ica.plot_components()

        # barplot of ICA component "EOG match" scores
        ica.plot_scores(eog_scores)

        #ica.plot_overlay(raw_inter)
        #ica.plot_properties(raw_inter)

    if eog_indices != []:
        print('matching ICA component found')
        # plot diagnostics
        #ica.plot_properties(filtered_raw, picks=eog_indices)

        # plot ICs applied to raw data, with EOG matches highlighted
        if show == True:
            ica.plot_sources(raw_inter)

        if remove_artefact == True:
            reconst_raw = filtered_raw.copy()
            ica.apply(reconst_raw)
            return reconst_raw
        else:
            return filtered_raw

    else:
        print('No matching ICA component')
        return filtered_raw
Exemple #33
0
def test_ica_additional():
    """Test additional functionality
    """
    stop2 = 500

    test_cov2 = deepcopy(test_cov)
    ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_true(ica.info is None)
    ica.decompose_raw(raw, picks[:5])
    assert_true(isinstance(ica.info, Info))
    assert_true(ica.n_components_ < 5)

    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_raises(RuntimeError, ica.save, '')
    ica.decompose_raw(raw, picks=None, start=start, stop=stop2)

    # epochs extraction from raw fit
    assert_raises(RuntimeError, ica.get_sources_epochs, epochs)

    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'ica_test.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov, n_components=3, max_pca_components=4,
                  n_pca_components=4)
        ica.decompose_raw(raw, picks=picks, start=start, stop=stop2)
        sources = ica.get_sources_epochs(epochs)
        assert_true(sources.shape[1] == ica.n_components_)

        for exclude in [[], [0]]:
            ica.exclude = [0]
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert_true(ica.exclude == ica_read.exclude)
            # test pick merge -- add components
            ica.pick_sources_raw(raw, exclude=[1])
            assert_true(ica.exclude == [0, 1])
            #                 -- only as arg
            ica.exclude = []
            ica.pick_sources_raw(raw, exclude=[0, 1])
            assert_true(ica.exclude == [0, 1])
            #                 -- remove duplicates
            ica.exclude += [1]
            ica.pick_sources_raw(raw, exclude=[0, 1])
            assert_true(ica.exclude == [0, 1])

            ica_raw = ica.sources_as_raw(raw)
            assert_true(ica.exclude == [ica_raw.ch_names.index(e) for e in
                                        ica_raw.info['bads']])

        ica.n_pca_components = 2
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert_true(ica.n_pca_components ==
                    ica_read.n_pca_components)
        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)

        assert_true(ica.ch_names == ica_read.ch_names)
        assert_true(isinstance(ica_read.info, Info))  # XXX improve later
        assert_true(np.allclose(ica.mixing_matrix_, ica_read.mixing_matrix_,
                                rtol=1e-16, atol=1e-32))
        assert_array_equal(ica.pca_components_,
                           ica_read.pca_components_)
        assert_array_equal(ica.pca_mean_, ica_read.pca_mean_)
        assert_array_equal(ica.pca_explained_variance_,
                           ica_read.pca_explained_variance_)
        assert_array_equal(ica._pre_whitener, ica_read._pre_whitener)

        # assert_raises(RuntimeError, ica_read.decompose_raw, raw)
        sources = ica.get_sources_raw(raw)
        sources2 = ica_read.get_sources_raw(raw)
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.pick_sources_raw(raw, exclude=[1])
        _raw2 = ica_read.pick_sources_raw(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check scrore funcs
    for name, func in score_funcs.items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.find_sources_raw(raw, target='EOG 061', score_func=func,
                                      start=0, stop=10)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.find_sources_raw(raw, score_func=stats.skew)
    # check exception handling
    assert_raises(ValueError, ica.find_sources_raw, raw,
                  target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # varicance, kurtosis idx params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw, start_find=0, stop_find=50, ecg_ch=ch_name,
                             eog_ch=ch_name, skew_criterion=idx,
                             var_criterion=idx, kurt_criterion=idx)
    ## score funcs epochs ##

    # check score funcs
    for name, func in score_funcs.items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.find_sources_epochs(epochs_eog, target='EOG 061',
                                         score_func=func)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.find_sources_epochs(epochs, score_func=stats.skew)

    # check exception handling
    assert_raises(ValueError, ica.find_sources_epochs, epochs,
                  target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.find_sources_raw(raw, target='MEG 1531',
                                      score_func='pearsonr')

    ecg_events = ica_find_ecg_events(raw, sources[np.abs(ecg_scores).argmax()])

    assert_true(ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.find_sources_raw(raw, target='EOG 061',
                                      score_func='pearsonr')
    eog_events = ica_find_eog_events(raw, sources[np.abs(eog_scores).argmax()])

    assert_true(eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.sources_as_raw(raw, start=0, stop=100)
    assert_true(ica_raw.last_samp - ica_raw.first_samp == 100)
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test_ica.fif')
    ica_raw.save(test_ica_fname)
    ica_raw2 = fiff.Raw(test_ica_fname, preload=True)
    assert_array_almost_equal(ica_raw._data, ica_raw2._data)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.sources_as_epochs(epochs)
    assert_true(ica_epochs.events.shape == epochs.events.shape)
    sources_epochs = ica.get_sources_epochs(epochs)
    assert_array_equal(ica_epochs.get_data(), sources_epochs)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    assert_true(ica.n_components_ == ica_epochs.get_data().shape[1])
    assert_true(ica_epochs.raw is None)
    assert_true(ica_epochs.preload == True)

    # regression test for plot method
    assert_raises(ValueError, ica.plot_sources_raw, raw,
                  order=np.arange(50))
    assert_raises(ValueError, ica.plot_sources_epochs, epochs,
                  order=np.arange(50))
Exemple #34
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    assert len(plt.get_fignums()) == 1
    # change which component is in ICA.exclude (click data trace to remove
    # current one; click name to add other one)
    fig.canvas.draw()
    x = fig.mne.traces[1].get_xdata()[5]
    y = fig.mne.traces[1].get_ydata()[5]
    _fake_click(fig, fig.mne.ax_main, (x, y), xform='data')  # exclude = []
    _click_ch_name(fig, ch_index=0, button=1)  # exclude = [0]
    fig.canvas.key_press_event(fig.mne.close_key)
    _close_event(fig)
    assert len(plt.get_fignums()) == 0
    assert_array_equal(ica.exclude, [0])
    # test when picks does not include ica.exclude.
    fig = ica.plot_sources(raw, picks=[1])
    assert len(plt.get_fignums()) == 1
    plt.close('all')

    # dtype can change int->np.int64 after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)

    # test clicks on y-label (need >2 secs for plot_properties() to work)
    long_raw = read_raw_fif(raw_fname).crop(0, 5).load_data()
    fig = ica.plot_sources(long_raw)
    assert len(plt.get_fignums()) == 1
    fig.canvas.draw()
    _fake_click(fig, fig.mne.ax_main, (-0.1, 0), xform='data', button=3)
    assert len(fig.mne.child_figs) == 1
    assert len(plt.get_fignums()) == 2
    # close child fig directly (workaround for mpl issue #18609)
    fig.mne.child_figs[0].canvas.key_press_event('escape')
    assert len(plt.get_fignums()) == 1
    fig.canvas.key_press_event(fig.mne.close_key)
    assert len(plt.get_fignums()) == 0
    del long_raw

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.mne.ax_main.collections) == 1
    assert len(fig.mne.ax_hscroll.collections) == 1
    raw.set_annotations(orig_annot)

    # test error handling
    raw.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"):
        ica.plot_sources(inst=raw)
    epochs.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"):
        ica.plot_sources(inst=epochs)
    epochs.info['bads'] = []

    # test w/ epochs and evokeds
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
def preprocess_ICA_fif_to_ts(fif_file, ECG_ch_name, EoG_ch_name, l_freq, h_freq, down_sfreq, variance, is_sensor_space, data_type):
    import os
    import numpy as np

    import mne
    from mne.io import Raw
    from mne.preprocessing import ICA, read_ica
    from mne.preprocessing import create_ecg_epochs, create_eog_epochs
    from mne.report import Report

    from nipype.utils.filemanip import split_filename as split_f

    report = Report()

    subj_path, basename, ext = split_f(fif_file)
    (data_path, sbj_name) = os.path.split(subj_path)
    print data_path

    # Read raw
    # If None the compensation in the data is not modified.
    # If set to n, e.g. 3, apply gradient compensation of grade n as for
    # CTF systems (compensation=3)
    raw = Raw(fif_file, preload=True)

    # select sensors
    select_sensors = mne.pick_types(raw.info, meg=True, ref_meg=False,
                                    exclude='bads')
    picks_meeg = mne.pick_types(raw.info, meg=True, eeg=True, exclude='bads')

    # save electrode locations
    sens_loc = [raw.info['chs'][i]['loc'][:3] for i in select_sensors]
    sens_loc = np.array(sens_loc)

    channel_coords_file = os.path.abspath("correct_channel_coords.txt")
    print '*** ' + channel_coords_file + '***'
    np.savetxt(channel_coords_file, sens_loc, fmt='%s')

    # save electrode names
    sens_names = np.array([raw.ch_names[pos] for pos in select_sensors],dtype = "str")

    # AP 21032016 
#    channel_names_file = os.path.join(data_path, "correct_channel_names.txt") 
    channel_names_file = os.path.abspath("correct_channel_names.txt")
    np.savetxt(channel_names_file,sens_names , fmt = '%s')
 
    ### filtering + downsampling
    raw.filter(l_freq=l_freq, h_freq=h_freq, picks=picks_meeg,
               method='iir', n_jobs=8)
#    raw.filter(l_freq = l_freq, h_freq = h_freq, picks = picks_meeg,
#               method='iir')
#    raw.resample(sfreq=down_sfreq, npad=0)

    ### 1) Fit ICA model using the FastICA algorithm
    # Other available choices are `infomax` or `extended-infomax`
    # We pass a float value between 0 and 1 to select n_components based on the
    # percentage of variance explained by the PCA components.
    ICA_title = 'Sources related to %s artifacts (red)'
    is_show = False # visualization
    reject = dict(mag=4e-12, grad=4000e-13)

    # check if we have an ICA, if yes, we load it
    ica_filename = os.path.join(subj_path,basename + "-ica.fif")  
    if os.path.exists(ica_filename) is False:
        ica = ICA(n_components=variance, method='fastica', max_iter=500) # , max_iter=500
        ica.fit(raw, picks=select_sensors, reject=reject) # decim = 3, 

        has_ICA = False
    else:
        has_ICA = True
        print ica_filename + '   exists!!!'
        ica = read_ica(ica_filename)
        ica.exclude = []

    # 2) identify bad components by analyzing latent sources.
    # generate ECG epochs use detection via phase statistics

    # if we just have exclude channels we jump these steps
#    if len(ica.exclude)==0:
    n_max_ecg = 3
    n_max_eog = 2

    # check if ECG_ch_name is in the raw channels
    if ECG_ch_name in raw.info['ch_names']:
        ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5,
                                       picks=select_sensors,
                                       ch_name=ECG_ch_name)
    # if not  a synthetic ECG channel is created from cross channel average
    else:
        ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5,
                                       picks=select_sensors)

    # ICA for ECG artifact
    # threshold=0.25 come default
    ecg_inds, scores = ica.find_bads_ecg(ecg_epochs, method='ctps')
    print scores
    print '\n len ecg_inds *** ' + str(len(ecg_inds)) + '***\n'
    if len(ecg_inds) > 0:
        ecg_evoked = ecg_epochs.average()

        fig1 = ica.plot_scores(scores, exclude=ecg_inds,
                               title=ICA_title % 'ecg', show=is_show)

        show_picks = np.abs(scores).argsort()[::-1][:5] # Pick the five largest scores and plot them

        # Plot estimated latent sources given the unmixing matrix.
        #ica.plot_sources(raw, show_picks, exclude=ecg_inds, title=ICA_title % 'ecg', show=is_show)
        t_start = 0
        t_stop = 30 # take the fist 30s
        fig2 = ica.plot_sources(raw, show_picks, exclude=ecg_inds, title=ICA_title % 'ecg' + ' in 30s' 
                                            ,start = t_start, stop  = t_stop, show=is_show)

        # topoplot of unmixing matrix columns
        fig3 = ica.plot_components(show_picks, title=ICA_title % 'ecg', colorbar=True, show=is_show)

        ecg_inds = ecg_inds[:n_max_ecg]
        ica.exclude += ecg_inds
    
        fig4 = ica.plot_sources(ecg_evoked, exclude=ecg_inds, show=is_show)  # plot ECG sources + selection
        fig5 = ica.plot_overlay(ecg_evoked, exclude=ecg_inds, show=is_show)  # plot ECG cleaning
    
        fig = [fig1, fig2, fig3, fig4, fig5]
        report.add_figs_to_section(fig, captions=['Scores of ICs related to ECG',
                                                  'Time Series plots of ICs (ECG)',
                                                  'TopoMap of ICs (ECG)', 
                                                  'Time-locked ECG sources', 
                                                  'ECG overlay'], section = 'ICA - ECG')    
    
    # check if EoG_ch_name is in the raw channels
    # if EoG_ch_name is empty if data_type is fif, ICA routine automatically looks for EEG61, EEG62
    # otherwise if data_type is ds we jump this step
    if not EoG_ch_name and data_type=='ds':
        eog_inds = []
    else:
        if EoG_ch_name in raw.info['ch_names']:        
            ### ICA for eye blink artifact - detect EOG by correlation
            eog_inds, scores = ica.find_bads_eog(raw, ch_name = EoG_ch_name)
        else:
            eog_inds, scores = ica.find_bads_eog(raw)

    if len(eog_inds) > 0:  
        
        fig6 = ica.plot_scores(scores, exclude=eog_inds, title=ICA_title % 'eog', show=is_show)
        report.add_figs_to_section(fig6, captions=['Scores of ICs related to EOG'], 
                           section = 'ICA - EOG')
                           
        # check how many EoG ch we have
        rs = np.shape(scores)
        if len(rs)>1:
            rr = rs[0]
            show_picks = [np.abs(scores[i][:]).argsort()[::-1][:5] for i in range(rr)]
            for i in range(rr):
                fig7 = ica.plot_sources(raw, show_picks[i][:], exclude=eog_inds, 
                                    start = raw.times[0], stop  = raw.times[-1], 
                                    title=ICA_title % 'eog',show=is_show)       
                                    
                fig8 = ica.plot_components(show_picks[i][:], title=ICA_title % 'eog', colorbar=True, show=is_show) # ICA nel tempo

                fig = [fig7, fig8]
                report.add_figs_to_section(fig, captions=['Scores of ICs related to EOG', 
                                                 'Time Series plots of ICs (EOG)'],
                                            section = 'ICA - EOG')    
        else:
            show_picks = np.abs(scores).argsort()[::-1][:5]
            fig7 = ica.plot_sources(raw, show_picks, exclude=eog_inds, title=ICA_title % 'eog', show=is_show)                                    
            fig8 = ica.plot_components(show_picks, title=ICA_title % 'eog', colorbar=True, show=is_show) 
            fig = [fig7, fig8]            
            report.add_figs_to_section(fig, captions=['Time Series plots of ICs (EOG)',
                                                      'TopoMap of ICs (EOG)',],
                                            section = 'ICA - EOG') 
        
        eog_inds = eog_inds[:n_max_eog]
        ica.exclude += eog_inds
        
        if EoG_ch_name in raw.info['ch_names']:
            eog_evoked = create_eog_epochs(raw, tmin=-.5, tmax=.5, picks=select_sensors, 
                                   ch_name=EoG_ch_name).average()
        else:
            eog_evoked = create_eog_epochs(raw, tmin=-.5, tmax=.5, picks=select_sensors).average()               
       
        fig9 = ica.plot_sources(eog_evoked, exclude=eog_inds, show=is_show)  # plot EOG sources + selection
        fig10 = ica.plot_overlay(eog_evoked, exclude=eog_inds, show=is_show)  # plot EOG cleaning

        fig = [fig9, fig10]
        report.add_figs_to_section(fig, captions=['Time-locked EOG sources',
                                                  'EOG overlay'], section = 'ICA - EOG')

    fig11 = ica.plot_overlay(raw, show=is_show)
    report.add_figs_to_section(fig11, captions=['Signal'], section = 'Signal quality') 
   
    ### plot all topographies and time seris of the ICA components
    n_ica_components = ica.mixing_matrix_.shape[1]
    
    n_topo = 10;
    n_fig  = n_ica_components/n_topo;
    n_plot = n_ica_components%n_topo;

    print '*************** n_fig = ' + str(n_fig) + ' n_plot = ' + str(n_plot) + '********************'
    fig = []
    t_start = 0
    t_stop = None # 60 if we want to take the fist 60s
    for n in range(0,n_fig):
        fig_tmp = ica.plot_components(range(n_topo*n,n_topo*(n+1)),title='ICA components', show=is_show)    
        fig.append(fig_tmp)
        fig_tmp = ica.plot_sources(raw, range(n_topo*n,n_topo*(n+1)), 
                                    start = t_start, stop  = t_stop, 
                                    title='ICA components')     
        fig.append(fig_tmp)
    
#    if n_plot > 0:
#        fig_tmp = ica.plot_components(range(n_fig*n_topo,n_ica_components), title='ICA components', show=is_show)    
#        fig.append(fig_tmp)
#        fig_tmp = ica.plot_sources(raw, range(n_fig*n_topo,n_ica_components), 
#                                        start = t_start, stop  = t_stop,
#                                        title='ICA components')     
#        fig.append(fig_tmp)   
#   
#    for n in range(0, len(fig)):
#        report.add_figs_to_section(fig[n], captions=['TOPO'], section = 'ICA Topo Maps')   
     
    if n_plot > 5:
        n_fig_l  = n_plot//5
                
        print '*************************** ' + str(n_fig_l) + ' *********************************'
        for n in range(0,n_fig_l):
            print range(n_fig*n_topo+5*n, n_fig*n_topo+5*(n+1))
            fig_tmp = ica.plot_components(range(n_fig*n_topo+5*n, n_fig*n_topo+5*(n+1)),title='ICA components')    
            fig.append(fig_tmp)
            fig_tmp = ica.plot_sources(raw, range(n_fig*n_topo+5*n, n_fig*n_topo+5*(n+1)), 
                                    start = t_start, stop  = t_stop, 
                                    title='ICA components')     
            fig.append(fig_tmp)
        
        print range(n_fig*n_topo+5*(n+1),n_ica_components)
        fig_tmp = ica.plot_components(range(n_fig*n_topo+5*(n+1),n_ica_components), title='ICA components')    
        fig.append(fig_tmp)
        fig_tmp = ica.plot_sources(raw, range(n_fig*n_topo+5*(n+1),n_ica_components), 
                                        start = t_start, stop  = t_stop, 
                                        title='ICA components')     
        fig.append(fig_tmp)   
        
    for n in range(0, len(fig)):
        report.add_figs_to_section(fig[n], captions=['TOPO'], section = 'ICA Topo Maps')       
    
    report_filename = os.path.join(subj_path,basename + "-report.html")
    print '******* ' + report_filename
    report.save(report_filename, open_browser=False, overwrite=True)
        
        
    # 3) apply ICA to raw data and save solution and report
    # check the amplitudes do not change
#    raw_ica_file = os.path.abspath(basename[:i_raw] + 'ica-raw.fif')
    raw_ica_file = os.path.join(subj_path, basename + '-preproc-raw.fif')
    raw_ica = ica.apply(raw)

    raw_ica.resample(sfreq=down_sfreq, npad=0)

    raw_ica.save(raw_ica_file, overwrite=True)

    # save ICA solution
    print ica_filename
    if has_ICA is False:
        ica.save(ica_filename)

    # 4) save data
    data_noIca, times = raw[select_sensors, :]
    data, times = raw_ica[select_sensors, :]

    print data.shape
    print raw.info['sfreq']

    ts_file = os.path.abspath(basename + "_ica.npy")
    np.save(ts_file, data)
    print '***** TS FILE ' + ts_file + '*****'

    if is_sensor_space:
        return ts_file, channel_coords_file, channel_names_file, raw.info['sfreq']
    else:
#        return raw_ica, channel_coords_file, channel_names_file, raw.info['sfreq']
        return raw_ica_file, channel_coords_file, channel_names_file, raw.info['sfreq']
Exemple #36
0
def run_preprocessing_eog(raw_eeg, tmin=0, remove_artefact=True, show=False):
    '''
  Remove EOG artefacts

  Parameters
  ----------
  raw_eeg : mne raw file
  tmin : minimum time to consider (default = 0)
  remove_artefact : whether to remove artefacts once found  (default = True)
  show : Whether to show plots (components, signal before and after ... )

  Return 
  ------
  Clean raw file if remove artefact is true, original raw file otherwise

  '''

    raw_inter = raw_eeg.copy().crop(tmin=tmin)
    filtered_raw = raw_inter.filter(l_freq=3., h_freq=40.)
    print(filtered_raw.annotations)
    ica = ICA(n_components=raw_eeg.info.get('nchan'), random_state=9)

    ica.fit(filtered_raw, verbose=False)
    ica.exclude = []
    # find which ICs match the EOG pattern
    eog_indices, eog_scores = ica.find_bads_eog(raw_inter,
                                                threshold=1.6,
                                                ch_name='Fpz',
                                                verbose=False)
    eog_indices.append(3)
    #eog_indices.append(4)
    ica.exclude = eog_indices

    if show == True:
        ica.plot_sources(filtered_raw)
        ica.plot_components()

        # barplot of ICA component "EOG match" scores
        #ica.plot_scores(eog_scores)

        ica.plot_overlay(raw_inter)
        #ica.plot_properties(raw_inter)

    if eog_indices != []:
        print('matching ICA component found')
        # plot diagnostics
        #ica.plot_properties(filtered_raw, picks=eog_indices)

        # plot ICs applied to raw data, with EOG matches highlighted
        if show == True:
            ica.plot_sources(raw_inter)

        if remove_artefact == True:
            reconst_raw = filtered_raw.copy()
            ica.apply(reconst_raw)
            return reconst_raw.filter(3, 40)
        else:
            return filtered_raw

    else:
        print('No matching ICA component')
        return filtered_raw
Exemple #37
0
    raw = raw.filter(filt_l, filt_h)

    print("Finding events...")
    events = find_events(raw)  # the output of this is a 3 x n_trial np array

    print("Epoching data...")
    epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, decim=5, baseline=None)

    print("Creating ICA object...")
    # apply ICA to the conjoint data
    picks = pick_types(raw.info, meg=True, exclude='bads')
    ica = ICA(n_components=0.9, method='fastica', max_iter=300)

    print("Fitting epochs...")
    # get ica components
    ica.exclude = []
    ica.fit(epochs, picks=picks)

    print("Saving ICA solution...")
    ica.save(ica_fname)  # save solution

#__________________________________________________________
subject = 'A0307'

meg_dir = '/Users/ea84/Dropbox/shepard_sourceloc/%s/' % (subject)
raw_fname = meg_dir + subject + '_shepard-raw.fif'
ica_fname = meg_dir + subject + '_shepard_ica1-ica.fif'
ica_raw_fname = meg_dir + subject + '_ica_shepard-raw.fif'  # applied ica to raw

# params
filt_l = 1  # same as aquisition
Exemple #38
0
ica.plot_sources(raw)

ica.plot_components()

# notify when done
os.system('say "... I am ready for you Neil."')

#%% Select components to exclude and zero them out

rm_select = [0, 3, 8]

# blinks
ica.plot_overlay(raw, exclude=rm_select, picks='eeg')

# indices chosen based on various plots above
ica.exclude = rm_select

# apply ICA
# ica.apply() changes the Raw object in-place:
ica.apply(raw)

# plot raw data excluding ocular artifacts for quality check
raw.plot()

#%% Fit autoreject

events = mne.make_fixed_length_events(raw, duration=tstep)
epochs = mne.Epochs(raw,
                    events,
                    tmin=0.0,
                    tmax=tstep,
Exemple #39
0
def test_ica_additional():
    """Test additional ICA functionality"""
    tempdir = _TempDir()
    stop2 = 500
    raw = Raw(raw_fname).crop(1.5, stop, False)
    raw.load_data()
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)
    # test if n_components=None works
    with warnings.catch_warnings(record=True):
        ica = ICA(n_components=None,
                  max_pca_components=None,
                  n_pca_components=None, random_state=0)
        ica.fit(epochs, picks=picks, decim=3)
    # for testing eog functionality
    picks2 = pick_types(raw.info, meg=True, stim=False, ecg=False,
                        eog=True, exclude='bads')
    epochs_eog = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks2,
                        baseline=(None, 0), preload=True)

    test_cov2 = test_cov.copy()
    ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_true(ica.info is None)
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks[:5])
    assert_true(isinstance(ica.info, Info))
    assert_true(ica.n_components_ < 5)

    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_raises(RuntimeError, ica.save, '')
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)

    # test corrmap
    ica2 = ica.copy()
    corrmap([ica, ica2], (0, 0), threshold='auto', label='blinks', plot=True,
            ch_type="mag")
    corrmap([ica, ica2], (0, 0), threshold=2, plot=False, show=False)
    assert_true(ica.labels_["blinks"] == ica2.labels_["blinks"])
    assert_true(0 in ica.labels_["blinks"])
    plt.close('all')

    # test warnings on bad filenames
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
        ica.save(ica_badname)
        read_ica(ica_badname)
    assert_naming(w, 'test_ica.py', 2)

    # test decim
    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=None, decim=3)
    assert_true(raw_._data.shape[1], n_samples)

    # test expl var
    ica = ICA(n_components=1.0, max_pca_components=4,
              n_pca_components=4)
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=None, decim=3)
    assert_true(ica.n_components_ == 4)

    # epochs extraction from raw fit
    assert_raises(RuntimeError, ica.get_sources, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov, n_components=2, max_pca_components=4,
                  n_pca_components=4)
        with warnings.catch_warnings(record=True):  # ICA does not converge
            ica.fit(raw, picks=picks, start=start, stop=stop2)
        sources = ica.get_sources(epochs).get_data()
        assert_true(ica.mixing_matrix_.shape == (2, 2))
        assert_true(ica.unmixing_matrix_.shape == (2, 2))
        assert_true(ica.pca_components_.shape == (4, len(picks)))
        assert_true(sources.shape[1] == ica.n_components_)

        for exclude in [[], [0]]:
            ica.exclude = [0]
            ica.labels_ = {'foo': [0]}
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert_true(ica.exclude == ica_read.exclude)
            assert_equal(ica.labels_, ica_read.labels_)
            ica.exclude = []
            ica.apply(raw, exclude=[1])
            assert_true(ica.exclude == [])

            ica.exclude = [0, 1]
            ica.apply(raw, exclude=[1])
            assert_true(ica.exclude == [0, 1])

            ica_raw = ica.get_sources(raw)
            assert_true(ica.exclude == [ica_raw.ch_names.index(e) for e in
                                        ica_raw.info['bads']])

        # test filtering
        d1 = ica_raw._data[0].copy()
        with warnings.catch_warnings(record=True):  # dB warning
            ica_raw.filter(4, 20)
        assert_true((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        with warnings.catch_warnings(record=True):  # dB warning
            ica_raw.notch_filter([10])
        assert_true((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert_true(ica.n_pca_components == ica_read.n_pca_components)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ _pre_whitener')

        def f(x, y):
            return getattr(x, y).dtype

        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in ['mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                     'pca_mean_', 'pca_explained_variance_',
                     '_pre_whitener']:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert_true(ica.ch_names == ica_read.ch_names)
        assert_true(isinstance(ica_read.info, Info))

        sources = ica.get_sources(raw)[:, :][0]
        sources2 = ica_read.get_sources(raw)[:, :][0]
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.apply(raw, exclude=[1])
        _raw2 = ica_read.apply(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check scrore funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(raw, target='EOG 061', score_func=func,
                                   start=0, stop=10)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(raw, score_func=stats.skew)
    # check exception handling
    assert_raises(ValueError, ica.score_sources, raw,
                  target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # varicance, kurtosis idx params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw, start_find=0, stop_find=50, ecg_ch=ch_name,
                             eog_ch=ch_name, skew_criterion=idx,
                             var_criterion=idx, kurt_criterion=idx)
    with warnings.catch_warnings(record=True):
        idx, scores = ica.find_bads_ecg(raw, method='ctps')
        assert_equal(len(scores), ica.n_components_)
        idx, scores = ica.find_bads_ecg(raw, method='correlation')
        assert_equal(len(scores), ica.n_components_)
        idx, scores = ica.find_bads_ecg(epochs, method='ctps')
        assert_equal(len(scores), ica.n_components_)
        assert_raises(ValueError, ica.find_bads_ecg, epochs.average(),
                      method='ctps')
        assert_raises(ValueError, ica.find_bads_ecg, raw,
                      method='crazy-coupling')

        idx, scores = ica.find_bads_eog(raw)
        assert_equal(len(scores), ica.n_components_)
        raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
        idx, scores = ica.find_bads_eog(raw)
        assert_true(isinstance(scores, list))
        assert_equal(len(scores[0]), ica.n_components_)

    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(epochs_eog, target='EOG 061',
                                   score_func=func)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(epochs, score_func=stats.skew)

    # check exception handling
    assert_raises(ValueError, ica.score_sources, epochs,
                  target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.score_sources(raw, target='MEG 1531',
                                   score_func='pearsonr')

    with warnings.catch_warnings(record=True):  # filter attenuation warning
        ecg_events = ica_find_ecg_events(raw,
                                         sources[np.abs(ecg_scores).argmax()])

    assert_true(ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.score_sources(raw, target='EOG 061',
                                   score_func='pearsonr')
    with warnings.catch_warnings(record=True):  # filter attenuation warning
        eog_events = ica_find_eog_events(raw,
                                         sources[np.abs(eog_scores).argmax()])

    assert_true(eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.get_sources(raw, start=0, stop=100)
    assert_true(ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_true(len(ica_raw._filenames) == 0)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = Raw(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.get_sources(epochs)
    assert_true(ica_epochs.events.shape == epochs.events.shape)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    assert_true(ica.n_components_ == ica_epochs.get_data().shape[1])
    assert_true(ica_epochs._raw is None)
    assert_true(ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = ica._check_n_pca_components(ncomps)
        assert_true(ncomps_ == expected)
Exemple #40
0
print('Fit %d components (explaining at least %0.1f%% of the variance)'
      % (ica.n_components_, 100 * n_components))

# Find onsets of heart beats and blinks. Create epochs around them
ecg_epochs = create_ecg_epochs(raw, tmin=-.3, tmax=.3, preload=False)
eog_epochs = create_eog_epochs(raw, tmin=-.5, tmax=.5, preload=False)

# Find ICA components that correlate with heart beats.
ecg_epochs.decimate(5)
ecg_epochs.load_data()
ecg_epochs.apply_baseline((None, None))
ecg_inds, ecg_scores = ica.find_bads_ecg(ecg_epochs, method='ctps')
ecg_scores = np.abs(ecg_scores)
rank = np.argsort(ecg_scores)[::-1]
rank = [r for r in rank if ecg_scores[r] > 0.05]
ica.exclude = rank[:n_ecg_components]
print('    Found %d ECG indices' % (len(ecg_inds),))

# Find ICA components that correlate with eye blinks
eog_epochs.decimate(5)
eog_epochs.load_data()
eog_epochs.apply_baseline((None, None))
eog_inds, eog_scores = ica.find_bads_eog(eog_epochs)
eog_scores = np.max(np.abs(eog_scores), axis=0)
# Remove all components with a correlation > 0.1 to the EOG channels and that
# have not already been flagged as ECG components
rank = np.argsort(eog_scores)[::-1]
rank = [r for r in rank if eog_scores[r] > 0.1 and r not in ecg_inds]
ica.exclude += rank[:n_eog_components]
print('    Found %d EOG indices' % (len(eog_inds),))