Beispiel #1
0
def test_plot_ica_overlay():
    """Test plotting of ICA cleaning."""
    import matplotlib.pyplot as plt
    raw = _get_raw(preload=True)
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    # can't use info.normalize_proj here because of how and when ICA and Epochs
    # objects do picking of Raw data
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw, picks=picks)
    # don't test raw, needs preload ...
    with pytest.warns(RuntimeWarning, match='projection'):
        ecg_epochs = create_ecg_epochs(raw, picks=picks)
    ica.plot_overlay(ecg_epochs.average())
    with pytest.warns(RuntimeWarning, match='projection'):
        eog_epochs = create_eog_epochs(raw, picks=picks)
    ica.plot_overlay(eog_epochs.average())
    pytest.raises(TypeError, ica.plot_overlay, raw[:2, :3][0])
    ica.plot_overlay(raw)
    plt.close('all')

    # smoke test for CTF
    raw = read_raw_fif(raw_ctf_fname)
    raw.apply_gradient_compensation(3)
    picks = pick_types(raw.info, meg=True, ref_meg=False)
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=picks)
    with pytest.warns(RuntimeWarning, match='longer than'):
        ecg_epochs = create_ecg_epochs(raw)
    ica.plot_overlay(ecg_epochs.average())
    plt.close('all')
Beispiel #2
0
def test_plot_instance_components():
    """Test plotting of components as instances of raw and epochs."""
    import matplotlib.pyplot as plt
    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw, picks=picks)
    fig = ica.plot_sources(raw, exclude=[0], title='Components')
    for key in ['down', 'up', 'right', 'left', 'o', '-', '+', '=', 'pageup',
                'pagedown', 'home', 'end', 'f11', 'b']:
        fig.canvas.key_press_event(key)
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]],
                'data')
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event('escape')
    plt.close('all')
    epochs = _get_epochs()
    fig = ica.plot_sources(epochs, exclude=[0], title='Components')
    for key in ['down', 'up', 'right', 'left', 'o', '-', '+', '=', 'pageup',
                'pagedown', 'home', 'end', 'f11', 'b']:
        fig.canvas.key_press_event(key)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event('escape')
    plt.close('all')
def apply_ica(fname_filtered, n_components=0.99, decim=None):

    ''' Applies ICA to a list of (filtered) raw files. '''

    import mne
    from mne.preprocessing import ICA
    import os


    if isinstance(fname_filtered, list):
        fnfilt = fname_filtered
    else:
        if isinstance(fname_filtered, str):
            fnfilt = list([fname_filtered]) 
        else:
            fnfilt = list(fname_filtered)

    # loop across all filenames
    for fname in fnfilt:                    
        name  = os.path.split(fname)[1]
        print ">>>> perform ICA signal decomposition on :  "+name
        # load filtered data
        raw = mne.io.Raw(fname,preload=True)
        picks = mne.pick_types(raw.info, meg=True, exclude='bads')
        # ICA decomposition
        ica = ICA(n_components=n_components, max_pca_components=None)

        ica.fit(raw, picks=picks, decim=decim, reject={'mag': 5e-12})

        # save ICA object 
        fnica_out = fname.strip('-raw.fif') + '-ica.fif'
        # fnica_out = fname[0:len(fname)-4]+'-ica.fif'
        ica.save(fnica_out)
def run_ica(method):
    ica = ICA(n_components=20, method=method, random_state=0)
    t0 = time()
    ica.fit(raw, picks=picks, reject=reject)
    fit_time = time() - t0
    title = ('ICA decomposition using %s (took %.1fs)' % (method, fit_time))
    ica.plot_components(title=title)
Beispiel #5
0
def test_ica_ctf():
    """Test run ICA computation on ctf data with/without compensation."""
    method = 'fastica'
    raw = read_raw_ctf(ctf_fname, preload=True)
    events = make_fixed_length_events(raw, 99999)
    for comp in [0, 1]:
        raw.apply_gradient_compensation(comp)
        epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True)
        evoked = epochs.average()

        # test fit
        for inst in [raw, epochs]:
            ica = ICA(n_components=2, random_state=0, max_iter=2,
                      method=method)
            with pytest.warns(UserWarning, match='did not converge'):
                ica.fit(inst)

        # test apply and get_sources
        for inst in [raw, epochs, evoked]:
            ica.apply(inst)
            ica.get_sources(inst)

    # test mixed compensation case
    raw.apply_gradient_compensation(0)
    ica = ICA(n_components=2, random_state=0, max_iter=2, method=method)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw)
    raw.apply_gradient_compensation(1)
    epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True)
    evoked = epochs.average()
    for inst in [raw, epochs, evoked]:
        with pytest.raises(RuntimeError, match='Compensation grade of ICA'):
            ica.apply(inst)
        with pytest.raises(RuntimeError, match='Compensation grade of ICA'):
            ica.get_sources(inst)
Beispiel #6
0
def test_n_components_and_max_pca_components_none(method):
    """Test n_components and max_pca_components=None."""
    _skip_check_picard(method)
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    events = read_events(event_name)
    picks = pick_types(raw.info, eeg=True, meg=False)
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)

    max_pca_components = None
    n_components = None
    random_state = 12345

    tempdir = _TempDir()
    output_fname = op.join(tempdir, 'test_ica-ica.fif')
    ica = ICA(max_pca_components=max_pca_components, method=method,
              n_components=n_components, random_state=random_state)
    with pytest.warns(None):  # convergence
        ica.fit(epochs)
    ica.save(output_fname)

    ica = read_ica(output_fname)

    # ICA.fit() replaced max_pca_components, which was previously None,
    # with the appropriate integer value.
    assert_equal(ica.max_pca_components, epochs.info['nchan'])
    assert ica.n_components is None
Beispiel #7
0
def test_n_components_none():
    """Test n_components=None."""
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    events = read_events(event_name)
    picks = pick_types(raw.info, eeg=True, meg=False)
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)

    max_pca_components = 10
    n_components = None
    random_state = 12345

    tempdir = _TempDir()
    output_fname = op.join(tempdir, 'test_ica-ica.fif')

    ica = ICA(max_pca_components=max_pca_components,
              n_components=n_components, random_state=random_state)
    with warnings.catch_warnings(record=True):  # convergence
        ica.fit(epochs)
    ica.save(output_fname)

    ica = read_ica(output_fname)

    # ICA.fit() replaced max_pca_components, which was previously None,
    # with the appropriate integer value.
    assert_equal(ica.max_pca_components, 10)
    assert_is_none(ica.n_components)
Beispiel #8
0
def test_eog_channel(method):
    """Test that EOG channel is included when performing ICA."""
    _skip_check_picard(method)
    raw = read_raw_fif(raw_fname, preload=True)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=True, ecg=False,
                       eog=True, exclude='bads')
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)
    n_components = 0.9
    ica = ICA(n_components=n_components, method=method)
    # Test case for MEG and EOG data. Should have EOG channel
    for inst in [raw, epochs]:
        picks1a = pick_types(inst.info, meg=True, stim=False, ecg=False,
                             eog=False, exclude='bads')[:4]
        picks1b = pick_types(inst.info, meg=False, stim=False, ecg=False,
                             eog=True, exclude='bads')
        picks1 = np.append(picks1a, picks1b)
        ica.fit(inst, picks=picks1)
        assert (any('EOG' in ch for ch in ica.ch_names))
    # Test case for MEG data. Should have no EOG channel
    for inst in [raw, epochs]:
        picks1 = pick_types(inst.info, meg=True, stim=False, ecg=False,
                            eog=False, exclude='bads')[:5]
        ica.fit(inst, picks=picks1)
        assert not any('EOG' in ch for ch in ica.ch_names)
Beispiel #9
0
def test_ica_eeg():
    """Test ICA on EEG."""
    method = 'fastica'
    raw_fif = read_raw_fif(fif_fname, preload=True)
    with pytest.warns(RuntimeWarning, match='events'):
        raw_eeglab = read_raw_eeglab(input_fname=eeglab_fname,
                                     montage=eeglab_montage, preload=True)
    for raw in [raw_fif, raw_eeglab]:
        events = make_fixed_length_events(raw, 99999, start=0, stop=0.3,
                                          duration=0.1)
        picks_meg = pick_types(raw.info, meg=True, eeg=False)[:2]
        picks_eeg = pick_types(raw.info, meg=False, eeg=True)[:2]
        picks_all = []
        picks_all.extend(picks_meg)
        picks_all.extend(picks_eeg)
        epochs = Epochs(raw, events, None, -0.1, 0.1, preload=True)
        evoked = epochs.average()

        for picks in [picks_meg, picks_eeg, picks_all]:
            if len(picks) == 0:
                continue
            # test fit
            for inst in [raw, epochs]:
                ica = ICA(n_components=2, random_state=0, max_iter=2,
                          method=method)
                with pytest.warns(None):
                    ica.fit(inst, picks=picks)

            # test apply and get_sources
            for inst in [raw, epochs, evoked]:
                ica.apply(inst)
                ica.get_sources(inst)

    with pytest.warns(RuntimeWarning, match='MISC channel'):
        raw = read_raw_ctf(ctf_fname2,  preload=True)
    events = make_fixed_length_events(raw, 99999, start=0, stop=0.2,
                                      duration=0.1)
    picks_meg = pick_types(raw.info, meg=True, eeg=False)[:2]
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)[:2]
    picks_all = picks_meg + picks_eeg
    for comp in [0, 1]:
        raw.apply_gradient_compensation(comp)
        epochs = Epochs(raw, events, None, -0.1, 0.1, preload=True)
        evoked = epochs.average()

        for picks in [picks_meg, picks_eeg, picks_all]:
            if len(picks) == 0:
                continue
            # test fit
            for inst in [raw, epochs]:
                ica = ICA(n_components=2, random_state=0, max_iter=2,
                          method=method)
                with pytest.warns(None):
                    ica.fit(inst)

            # test apply and get_sources
            for inst in [raw, epochs, evoked]:
                ica.apply(inst)
                ica.get_sources(inst)
Beispiel #10
0
def test_plot_ica_properties():
    """Test plotting of ICA properties."""
    import matplotlib.pyplot as plt

    res = 8
    raw = _get_raw(preload=True)
    raw.add_proj([], remove_existing=True)
    events = _get_events()
    picks = _get_picks(raw)[:6]
    pick_names = [raw.ch_names[k] for k in picks]
    raw.pick_channels(pick_names)

    with warnings.catch_warnings(record=True):  # bad proj
        epochs = Epochs(raw, events[:10], event_id, tmin, tmax,
                        baseline=(None, 0), preload=True)

    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=2, n_pca_components=2)
    with warnings.catch_warnings(record=True):  # bad proj
        ica.fit(raw)

    # test _create_properties_layout
    fig, ax = _create_properties_layout()
    assert_equal(len(ax), 5)

    topoargs = dict(topomap_args={'res': res, 'contours': 0, "sensors": False})
    ica.plot_properties(raw, picks=0, **topoargs)
    ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
    ica.plot_properties(epochs, picks=1, image_args={'sigma': 1.5},
                        topomap_args={'res': 10, 'colorbar': True},
                        psd_args={'fmax': 65.}, plot_std=False,
                        figsize=[4.5, 4.5])
    plt.close('all')

    assert_raises(ValueError, ica.plot_properties, epochs, dB=list('abc'))
    assert_raises(ValueError, ica.plot_properties, epochs, plot_std=[])
    assert_raises(ValueError, ica.plot_properties, ica)
    assert_raises(ValueError, ica.plot_properties, [0.2])
    assert_raises(ValueError, plot_ica_properties, epochs, epochs)
    assert_raises(ValueError, ica.plot_properties, epochs,
                  psd_args='not dict')

    fig, ax = plt.subplots(2, 3)
    ax = ax.ravel()[:-1]
    ica.plot_properties(epochs, picks=1, axes=ax, **topoargs)
    fig = ica.plot_properties(raw, picks=[0, 1], **topoargs)
    assert_equal(len(fig), 2)
    assert_raises(ValueError, plot_ica_properties, epochs, ica, picks=[0, 1],
                  axes=ax)
    assert_raises(ValueError, ica.plot_properties, epochs, axes='not axes')
    plt.close('all')

    # Test merging grads.
    raw = _get_raw(preload=True)
    picks = pick_types(raw.info, meg='grad')[:10]
    ica = ICA(n_components=2)
    ica.fit(raw, picks=picks)
    ica.plot_properties(raw)
    plt.close('all')
Beispiel #11
0
 def decompose(self, X, y=None):
     raw_inst = RawArray(X.T, create_info(self.channel_names, self.fs, 'eeg', None))
     ica = ICA(method='extended-infomax')
     ica.fit(raw_inst)
     filters = np.dot(ica.unmixing_matrix_, ica.pca_components_[:ica.n_components_]).T
     topographies = np.linalg.inv(filters).T
     scores = self.get_scores(X, filters)
     return scores, filters, topographies
Beispiel #12
0
def test_ica_full_data_recovery(method):
    """Test recovery of full data when no source is rejected."""
    # Most basic recovery
    _skip_check_picard(method)
    raw = read_raw_fif(raw_fname).crop(0.5, stop).load_data()
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')[:10]
    with pytest.warns(RuntimeWarning, match='projection'):
        epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
                        baseline=(None, 0), preload=True)
    evoked = epochs.average()
    n_channels = 5
    data = raw._data[:n_channels].copy()
    data_epochs = epochs.get_data()
    data_evoked = evoked.data
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    methods = [method]
    for method in methods:
        stuff = [(2, n_channels, True), (2, n_channels // 2, False)]
        for n_components, n_pca_components, ok in stuff:
            ica = ICA(n_components=n_components, random_state=0,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components,
                      method=method, max_iter=1)
            with pytest.warns(UserWarning, match=None):  # sometimes warns
                ica.fit(raw, picks=list(range(n_channels)))
            raw2 = ica.apply(raw.copy(), exclude=[])
            if ok:
                assert_allclose(data[:n_channels], raw2._data[:n_channels],
                                rtol=1e-10, atol=1e-15)
            else:
                diff = np.abs(data[:n_channels] - raw2._data[:n_channels])
                assert (np.max(diff) > 1e-14)

            ica = ICA(n_components=n_components, method=method,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components, random_state=0)
            with pytest.warns(None):  # sometimes warns
                ica.fit(epochs, picks=list(range(n_channels)))
            epochs2 = ica.apply(epochs.copy(), exclude=[])
            data2 = epochs2.get_data()[:, :n_channels]
            if ok:
                assert_allclose(data_epochs[:, :n_channels], data2,
                                rtol=1e-10, atol=1e-15)
            else:
                diff = np.abs(data_epochs[:, :n_channels] - data2)
                assert (np.max(diff) > 1e-14)

            evoked2 = ica.apply(evoked.copy(), exclude=[])
            data2 = evoked2.data[:n_channels]
            if ok:
                assert_allclose(data_evoked[:n_channels], data2,
                                rtol=1e-10, atol=1e-15)
            else:
                diff = np.abs(evoked.data[:n_channels] - data2)
                assert (np.max(diff) > 1e-14)
    pytest.raises(ValueError, ICA, method='pizza-decomposision')
Beispiel #13
0
def test_ica_full_data_recovery():
    """Test recovery of full data when no source is rejected"""
    # Most basic recovery
    raw = Raw(raw_fname).crop(0.5, stop, False)
    raw.load_data()
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')[:10]
    with warnings.catch_warnings(record=True):  # bad proj
        epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
                        baseline=(None, 0), preload=True)
    evoked = epochs.average()
    n_channels = 5
    data = raw._data[:n_channels].copy()
    data_epochs = epochs.get_data()
    data_evoked = evoked.data
    for method in ['fastica']:
        stuff = [(2, n_channels, True), (2, n_channels // 2, False)]
        for n_components, n_pca_components, ok in stuff:
            ica = ICA(n_components=n_components,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components,
                      method=method, max_iter=1)
            with warnings.catch_warnings(record=True):
                ica.fit(raw, picks=list(range(n_channels)))
            raw2 = ica.apply(raw, exclude=[], copy=True)
            if ok:
                assert_allclose(data[:n_channels], raw2._data[:n_channels],
                                rtol=1e-10, atol=1e-15)
            else:
                diff = np.abs(data[:n_channels] - raw2._data[:n_channels])
                assert_true(np.max(diff) > 1e-14)

            ica = ICA(n_components=n_components,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components)
            with warnings.catch_warnings(record=True):
                ica.fit(epochs, picks=list(range(n_channels)))
            epochs2 = ica.apply(epochs, exclude=[], copy=True)
            data2 = epochs2.get_data()[:, :n_channels]
            if ok:
                assert_allclose(data_epochs[:, :n_channels], data2,
                                rtol=1e-10, atol=1e-15)
            else:
                diff = np.abs(data_epochs[:, :n_channels] - data2)
                assert_true(np.max(diff) > 1e-14)

            evoked2 = ica.apply(evoked, exclude=[], copy=True)
            data2 = evoked2.data[:n_channels]
            if ok:
                assert_allclose(data_evoked[:n_channels], data2,
                                rtol=1e-10, atol=1e-15)
            else:
                diff = np.abs(evoked.data[:n_channels] - data2)
                assert_true(np.max(diff) > 1e-14)
    assert_raises(ValueError, ICA, method='pizza-decomposision')
Beispiel #14
0
def test_plot_ica_scores():
    """Test plotting of ICA scores
    """
    raw = _get_raw()
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude="bads")
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1])
    assert_raises(ValueError, ica.plot_scores, [0.2])
    plt.close("all")
Beispiel #15
0
def test_plot_ica_scores():
    """Test plotting of ICA scores
    """
    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=picks)
    ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1])
    assert_raises(ValueError, ica.plot_scores, [0.2])
    plt.close('all')
Beispiel #16
0
def test_plot_ica_components():
    """Test plotting of ICA solutions."""
    import matplotlib.pyplot as plt
    res = 8
    fast_test = {"res": res, "contours": 0, "sensors": False}
    raw = _get_raw()
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    ica_picks = _get_picks(raw)
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=ica_picks)
    warnings.simplefilter('always', UserWarning)
    with warnings.catch_warnings(record=True):
        for components in [0, [0], [0, 1], [0, 1] * 2, None]:
            ica.plot_components(components, image_interp='bilinear',
                                colorbar=True, **fast_test)
        plt.close('all')

        # test interactive mode (passing 'inst' arg)
        ica.plot_components([0, 1], image_interp='bilinear', inst=raw, res=16)
        fig = plt.gcf()

        # test title click
        # ----------------
        lbl = fig.axes[1].get_label()
        ica_idx = int(lbl[-3:])
        titles = [ax.title for ax in fig.axes]
        title_pos_midpoint = (titles[1].get_window_extent().extents
                              .reshape((2, 2)).mean(axis=0))
        # first click adds to exclude
        _fake_click(fig, fig.axes[1], title_pos_midpoint, xform='pix')
        assert ica_idx in ica.exclude
        # clicking again removes from exclude
        _fake_click(fig, fig.axes[1], title_pos_midpoint, xform='pix')
        assert ica_idx not in ica.exclude

        # test topo click
        # ---------------
        _fake_click(fig, fig.axes[1], (0., 0.), xform='data')

        c_fig = plt.gcf()
        labels = [ax.get_label() for ax in c_fig.axes]

        for l in ['topomap', 'image', 'erp', 'spectrum', 'variance']:
            assert_true(l in labels)

        topomap_ax = c_fig.axes[labels.index('topomap')]
        title = topomap_ax.get_title()
        assert_true(lbl == title)

    ica.info = None
    assert_raises(ValueError, ica.plot_components, 1)
    assert_raises(RuntimeError, ica.plot_components, 1, ch_type='mag')
    plt.close('all')
Beispiel #17
0
def apply_ica(fname_filtered, n_components=0.99, decim=None,
              reject={'mag': 5e-12}, ica_method='fastica',
              flow=None, fhigh=None, verbose=True):

    ''' Applies ICA to a list of (filtered) raw files. '''

    from mne.preprocessing import ICA

    fnfilt = get_files_from_list(fname_filtered)

    # loop across all filenames
    for fname in fnfilt:
        name = os.path.split(fname)[1]
        print ">>>> perform ICA signal decomposition on :  " + name
        # load filtered data
        raw = mne.io.Raw(fname, preload=True)
        picks = mne.pick_types(raw.info, meg=True, ref_meg=False, exclude='bads')

        # check if data to estimate the optimal
        # de-mixing matrix should be filtered
        if flow or fhigh:
            from jumeg.filter import jumeg_filter

            # define filter type
            if not flow:
                filter_type = 'lp'
                filter_info = "     --> filter parameter    : filter type=low pass %dHz" % flow
            elif not fhigh:
                filter_type = 'hp'
                filter_info = "     --> filter parameter    : filter type=high pass %dHz" % flow
            else:
                filter_type = 'bp'
                filter_info = "     --> filter parameter: filter type=band pass %d-%dHz" % (flow, fhigh)

            if verbose:
                print ">>>> NOTE: Optimal cleaning parameter are estimated from filtered data!"
                print filter_info

            fi_mne_notch = jumeg_filter(fcut1=flow, fcut2=fhigh, filter_type=filter_type,
                                        remove_dcoffset=False,
                                        sampling_frequency=raw.info['sfreq'])
            fi_mne_notch.apply_filter(raw._data, picks=picks)

        # ICA decomposition
        ica = ICA(method=ica_method, n_components=n_components,
                  max_pca_components=None)

        ica.fit(raw, picks=picks, decim=decim, reject=reject)

        # save ICA object
        fnica_out = fname[:fname.rfind(ext_raw)] + ext_ica
        # fnica_out = fname[0:len(fname)-4]+'-ica.fif'
        ica.save(fnica_out)
Beispiel #18
0
def test_plot_ica_properties():
    """Test plotting of ICA properties."""
    import matplotlib.pyplot as plt

    raw = _get_raw(preload=True)
    raw.add_proj([], remove_existing=True)
    events = _get_events()
    picks = _get_picks(raw)[:6]
    pick_names = [raw.ch_names[k] for k in picks]
    raw.pick_channels(pick_names)

    with warnings.catch_warnings(record=True):  # bad proj
        epochs = Epochs(raw, events[:10], event_id, tmin, tmax, baseline=(None, 0), preload=True)

    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=2, n_pca_components=2)
    with warnings.catch_warnings(record=True):  # bad proj
        ica.fit(raw)

    # test _create_properties_layout
    fig, ax = _create_properties_layout()
    assert_equal(len(ax), 5)

    topoargs = dict(topomap_args={"res": 10})
    ica.plot_properties(raw, picks=0, **topoargs)
    ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
    ica.plot_properties(
        epochs,
        picks=1,
        image_args={"sigma": 1.5},
        topomap_args={"res": 10, "colorbar": True},
        psd_args={"fmax": 65.0},
        plot_std=False,
        figsize=[4.5, 4.5],
    )
    plt.close("all")

    assert_raises(ValueError, ica.plot_properties, epochs, dB=list("abc"))
    assert_raises(ValueError, ica.plot_properties, epochs, plot_std=[])
    assert_raises(ValueError, ica.plot_properties, ica)
    assert_raises(ValueError, ica.plot_properties, [0.2])
    assert_raises(ValueError, plot_ica_properties, epochs, epochs)
    assert_raises(ValueError, ica.plot_properties, epochs, psd_args="not dict")

    fig, ax = plt.subplots(2, 3)
    ax = ax.ravel()[:-1]
    ica.plot_properties(epochs, picks=1, axes=ax)
    fig = ica.plot_properties(raw, picks=[0, 1], **topoargs)
    assert_equal(len(fig), 2)
    assert_raises(ValueError, plot_ica_properties, epochs, ica, picks=[0, 1], axes=ax)
    assert_raises(ValueError, ica.plot_properties, epochs, axes="not axes")
    plt.close("all")
Beispiel #19
0
def test_plot_instance_components():
    """Test plotting of components as instances of raw and epochs."""
    import matplotlib.pyplot as plt

    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=3, n_pca_components=3)
    with warnings.catch_warnings(record=True):  # bad proj
        ica.fit(raw, picks=picks)
    fig = ica.plot_sources(raw, exclude=[0], title="Components")
    fig.canvas.key_press_event("down")
    fig.canvas.key_press_event("up")
    fig.canvas.key_press_event("right")
    fig.canvas.key_press_event("left")
    fig.canvas.key_press_event("o")
    fig.canvas.key_press_event("-")
    fig.canvas.key_press_event("+")
    fig.canvas.key_press_event("=")
    fig.canvas.key_press_event("pageup")
    fig.canvas.key_press_event("pagedown")
    fig.canvas.key_press_event("home")
    fig.canvas.key_press_event("end")
    fig.canvas.key_press_event("f11")
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], "data")
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event("escape")
    plt.close("all")
    epochs = _get_epochs()
    fig = ica.plot_sources(epochs, exclude=[0], title="Components")
    fig.canvas.key_press_event("down")
    fig.canvas.key_press_event("up")
    fig.canvas.key_press_event("right")
    fig.canvas.key_press_event("left")
    fig.canvas.key_press_event("o")
    fig.canvas.key_press_event("-")
    fig.canvas.key_press_event("+")
    fig.canvas.key_press_event("=")
    fig.canvas.key_press_event("pageup")
    fig.canvas.key_press_event("pagedown")
    fig.canvas.key_press_event("home")
    fig.canvas.key_press_event("end")
    fig.canvas.key_press_event("f11")
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], "data")
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event("escape")
    plt.close("all")
Beispiel #20
0
def test_plot_instance_components():
    """Test plotting of components as instances of raw and epochs."""
    import matplotlib.pyplot as plt
    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    with warnings.catch_warnings(record=True):  # bad proj
        ica.fit(raw, picks=picks)
    fig = ica.plot_sources(raw, exclude=[0], title='Components')
    fig.canvas.key_press_event('down')
    fig.canvas.key_press_event('up')
    fig.canvas.key_press_event('right')
    fig.canvas.key_press_event('left')
    fig.canvas.key_press_event('o')
    fig.canvas.key_press_event('-')
    fig.canvas.key_press_event('+')
    fig.canvas.key_press_event('=')
    fig.canvas.key_press_event('pageup')
    fig.canvas.key_press_event('pagedown')
    fig.canvas.key_press_event('home')
    fig.canvas.key_press_event('end')
    fig.canvas.key_press_event('f11')
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event('escape')
    plt.close('all')
    epochs = _get_epochs()
    fig = ica.plot_sources(epochs, exclude=[0], title='Components')
    fig.canvas.key_press_event('down')
    fig.canvas.key_press_event('up')
    fig.canvas.key_press_event('right')
    fig.canvas.key_press_event('left')
    fig.canvas.key_press_event('o')
    fig.canvas.key_press_event('-')
    fig.canvas.key_press_event('+')
    fig.canvas.key_press_event('=')
    fig.canvas.key_press_event('pageup')
    fig.canvas.key_press_event('pagedown')
    fig.canvas.key_press_event('home')
    fig.canvas.key_press_event('end')
    fig.canvas.key_press_event('f11')
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [-0.1, 0.9])  # click on y-label
    fig.canvas.key_press_event('escape')
    plt.close('all')
Beispiel #21
0
def test_ica_reject_buffer():
    """Test ICA data raw buffer rejection"""
    raw = io.Raw(raw_fname, preload=True).crop(0, stop, False).crop(1.5)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads")
    ica = ICA(n_components=3, max_pca_components=4, n_pca_components=4)
    raw._data[2, 1000:1005] = 5e-12
    drop_log = op.join(op.dirname(tempdir), "ica_drop.log")
    set_log_file(drop_log, overwrite=True)
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks[:5], reject=dict(mag=2.5e-12), decim=2, tstep=0.01, verbose=True)
    assert_true(raw._data[:5, ::2].shape[1] - 4 == ica.n_samples_)
    with open(drop_log) as fid:
        log = [l for l in fid if "detected" in l]
    assert_equal(len(log), 1)
Beispiel #22
0
def test_plot_ica_components():
    """Test plotting of ICA solutions
    """
    raw = _get_raw()
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=3, n_pca_components=3)
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude="bads")
    ica.fit(raw, picks=ica_picks)
    warnings.simplefilter("always", UserWarning)
    with warnings.catch_warnings(record=True):
        for components in [0, [0], [0, 1], [0, 1] * 7, None]:
            ica.plot_components(components)
    ica.info = None
    assert_raises(RuntimeError, ica.plot_components, 1)
    plt.close("all")
Beispiel #23
0
def test_ica_reject_buffer():
    """Test ICA data raw buffer rejection."""
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    ica = ICA(n_components=3, max_pca_components=4, n_pca_components=4)
    raw._data[2, 1000:1005] = 5e-12
    with catch_logging() as drop_log:
        with warnings.catch_warnings(record=True):
            ica.fit(raw, picks[:5], reject=dict(mag=2.5e-12), decim=2,
                    tstep=0.01, verbose=True)
        assert_true(raw._data[:5, ::2].shape[1] - 4 == ica.n_samples_)
    log = [l for l in drop_log.getvalue().split('\n') if 'detected' in l]
    assert_equal(len(log), 1)
Beispiel #24
0
def test_ica_labels():
    """Test ICA labels."""
    # The CTF data are uniquely well suited to testing the ICA.find_bads_
    # methods
    raw = read_raw_ctf(ctf_fname, preload=True)
    # derive reference ICA components and append them to raw
    icarf = ICA(n_components=2, random_state=0, max_iter=2, allow_ref_meg=True)
    with pytest.warns(UserWarning, match='did not converge'):
        icarf.fit(raw.copy().pick_types(meg=False, ref_meg=True))
    icacomps = icarf.get_sources(raw)
    # rename components so they are auto-detected by find_bads_ref
    icacomps.rename_channels({c: 'REF_' + c for c in icacomps.ch_names})
    # and add them to raw
    raw.add_channels([icacomps])
    # set the appropriate EEG channels to EOG and ECG
    raw.set_channel_types({'EEG057': 'eog', 'EEG058': 'eog', 'EEG059': 'ecg'})
    ica = ICA(n_components=4, random_state=0, max_iter=2, method='fastica')
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw)

    ica.find_bads_eog(raw, l_freq=None, h_freq=None)
    picks = list(pick_types(raw.info, meg=False, eog=True))
    for idx, ch in enumerate(picks):
        assert '{}/{}/{}'.format('eog', idx, raw.ch_names[ch]) in ica.labels_
    assert 'eog' in ica.labels_
    for key in ('ecg', 'ref_meg', 'ecg/ECG-MAG'):
        assert key not in ica.labels_

    ica.find_bads_ecg(raw, l_freq=None, h_freq=None, method='correlation')
    picks = list(pick_types(raw.info, meg=False, ecg=True))
    for idx, ch in enumerate(picks):
        assert '{}/{}/{}'.format('ecg', idx, raw.ch_names[ch]) in ica.labels_
    for key in ('ecg', 'eog'):
        assert key in ica.labels_
    for key in ('ref_meg', 'ecg/ECG-MAG'):
        assert key not in ica.labels_

    ica.find_bads_ref(raw, l_freq=None, h_freq=None)
    picks = pick_channels_regexp(raw.ch_names, 'REF_ICA*')
    for idx, ch in enumerate(picks):
        assert '{}/{}/{}'.format('ref_meg', idx,
                                 raw.ch_names[ch]) in ica.labels_
    for key in ('ecg', 'eog', 'ref_meg'):
        assert key in ica.labels_
    assert 'ecg/ECG-MAG' not in ica.labels_

    ica.find_bads_ecg(raw, l_freq=None, h_freq=None)
    for key in ('ecg', 'eog', 'ref_meg', 'ecg/ECG-MAG'):
        assert key in ica.labels_
Beispiel #25
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                           ecg=False, eog=False, exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    fig.canvas.key_press_event('escape')
    # Sadly close_event isn't called on Agg backend and the test always passes.
    assert_array_equal(ica.exclude, [1])
    plt.close('all')

    # dtype can change int->np.int after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)
    fig = ica.plot_sources(raw, [1])
    # also test mouse clicks
    data_ax = fig.axes[0]
    _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label

    raw.info['bads'] = ['MEG 0113']
    pytest.raises(RuntimeError, ica.plot_sources, inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    pytest.raises(RuntimeError, ica.plot_sources, inst=epochs)
    epochs.info['bads'] = []
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax,
                [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax,
                [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.plot_sources(evoked, exclude=[0])
    ica.exclude = [0]
    ica.plot_sources(evoked)  # does the same thing
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    pytest.raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Beispiel #26
0
def test_plot_ica_components():
    """Test plotting of ICA solutions
    """
    raw = _get_raw()
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    ica_picks = _get_picks(raw)
    ica.fit(raw, picks=ica_picks)
    warnings.simplefilter('always', UserWarning)
    with warnings.catch_warnings(record=True):
        for components in [0, [0], [0, 1], [0, 1] * 2, None]:
            ica.plot_components(components, image_interp='bilinear', res=16)
    ica.info = None
    assert_raises(RuntimeError, ica.plot_components, 1)
    plt.close('all')
Beispiel #27
0
def test_plot_ica_overlay():
    """Test plotting of ICA cleaning
    """
    raw = _get_raw()
    picks = _get_picks(raw)
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
              max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=picks)
    # don't test raw, needs preload ...
    ecg_epochs = create_ecg_epochs(raw, picks=picks)
    ica.plot_overlay(ecg_epochs.average())
    eog_epochs = create_eog_epochs(raw, picks=picks)
    ica.plot_overlay(eog_epochs.average())
    assert_raises(ValueError, ica.plot_overlay, raw[:2, :3][0])
    plt.close('all')
Beispiel #28
0
def test_plot_ica_overlay():
    """Test plotting of ICA cleaning
    """
    raw = _get_raw()
    picks = _get_picks(raw)
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude="bads")
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    # don't test raw, needs preload ...
    ecg_epochs = create_ecg_epochs(raw, picks=picks)
    ica.plot_overlay(ecg_epochs.average())
    eog_epochs = create_eog_epochs(raw, picks=picks)
    ica.plot_overlay(eog_epochs.average())
    assert_raises(ValueError, ica.plot_overlay, raw[:2, :3][0])
    plt.close("all")
Beispiel #29
0
def test_plot_ica_sources():
    """Test plotting of ICA panel
    """
    raw = io.Raw(raw_fname, preload=True)
    picks = _get_picks(raw)
    epochs = _get_epochs()
    picks = np.round(np.linspace(0, len(picks) + 1, n_chan)).astype(int)
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude="bads")
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.plot_sources(raw)
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    assert_raises(ValueError, ica.plot_sources, "meeow")
    plt.close("all")
Beispiel #30
0
def test_plot_ica_sources():
    """Test plotting of ICA panel
    """
    raw = io.Raw(raw_fname, preload=True)
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                           ecg=False, eog=False, exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.plot_sources(raw)
    ica.plot_sources(epochs)
    with warnings.catch_warnings(record=True):  # no labeled objects mpl
        ica.plot_sources(epochs.average())
    assert_raises(ValueError, ica.plot_sources, 'meeow')
    plt.close('all')
Beispiel #31
0
def run_ica(subject, session=None):
    """Run ICA."""

    deriv_path = config.get_subject_deriv_path(subject=subject,
                                               session=session,
                                               kind=config.get_kind())
    bids_basename = BIDSPath(subject=subject,
                             session=session,
                             task=config.get_task(),
                             acquisition=config.acq,
                             recording=config.rec,
                             space=config.space,
                             prefix=deriv_path)

    raw_list = list()
    msg = 'Loading filtered raw data'
    logger.info(
        gen_log_message(message=msg, step=4, subject=subject, session=session))

    for run in config.get_runs():
        raw_fname_in = bids_basename.copy().update(run=run,
                                                   processing='filt',
                                                   kind=config.get_kind(),
                                                   extension='.fif')
        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)
        raw_list.append(raw)

    msg = 'Concatenating runs'
    logger.info(
        gen_log_message(message=msg, step=4, subject=subject, session=session))
    raw = mne.concatenate_raws(raw_list)

    events, event_id = mne.events_from_annotations(raw)

    if config.get_kind() == 'eeg':
        raw.set_eeg_reference(projection=True)
    del raw_list

    # don't reject based on EOG to keep blink artifacts
    # in the ICA computation.
    reject_ica = config.get_reject()
    if reject_ica and 'eog' in reject_ica:
        reject_ica = dict(reject_ica)
        del reject_ica['eog']

    # produce high-pass filtered version of the data for ICA
    raw_ica = raw.copy().filter(l_freq=1., h_freq=None)
    epochs_for_ica = mne.Epochs(raw_ica,
                                events,
                                event_id,
                                config.tmin,
                                config.tmax,
                                proj=True,
                                baseline=config.baseline,
                                preload=True,
                                decim=config.decim,
                                reject=reject_ica)

    # get number of components for ICA
    # compute_rank requires 0.18
    # n_components_meg = (mne.compute_rank(epochs_for_ica.copy()
    #                        .pick_types(meg=True)))['meg']

    n_components_meg = 0.999

    n_components = {'meg': n_components_meg, 'eeg': 0.999}

    kind = config.get_kind()
    msg = f'Running ICA for {kind}'
    logger.info(
        gen_log_message(message=msg, step=4, subject=subject, session=session))

    if config.ica_algorithm == 'picard':
        fit_params = dict(fastica_it=5)
    elif config.ica_algorithm == 'extended_infomax':
        fit_params = dict(extended=True)
    elif config.ica_algorithm == 'fastica':
        fit_params = None

    ica = ICA(method=config.ica_algorithm,
              random_state=config.random_state,
              n_components=n_components[kind],
              fit_params=fit_params,
              max_iter=config.ica_max_iterations)

    ica.fit(epochs_for_ica, decim=config.ica_decim)

    msg = (f'Fit {ica.n_components_} components (explaining at least '
           f'{100*n_components[kind]:.1f}% of the variance)')
    logger.info(
        gen_log_message(message=msg, step=4, subject=subject, session=session))

    # Save ICA
    ica_fname = bids_basename.copy().update(run=None,
                                            kind=f'{kind}-ica',
                                            extension='.fif')
    ica.save(ica_fname)

    if config.interactive:
        # plot ICA components to html report
        report_fname = bids_basename.copy().update(run=None,
                                                   kind=f'{kind}-ica',
                                                   extension='.html')
        report = Report(report_fname, verbose=False)

        for idx in range(0, ica.n_components_):
            figure = ica.plot_properties(epochs_for_ica,
                                         picks=idx,
                                         psd_args={'fmax': 60},
                                         show=False)

            report.add_figs_to_section(figure,
                                       section=subject,
                                       captions=(kind.upper() +
                                                 ' - ICA Components'))

        report.save(report_fname, overwrite=True, open_browser=False)
Beispiel #32
0
        print(subj, ix, "filtered")
        raw_out_path = op.join(meg_subj_path,
                               "raw-{}-raw.fif".format(str(ix).zfill(3)))
        events_out_path = op.join(meg_subj_path,
                                  "{}-eve.fif".format(str(ix).zfill(3)))

        ica_out_path = op.join(meg_subj_path,
                               "{}-ica.fif".format(str(ix).zfill(3)))

        n_components = 50
        method = "fastica"
        reject = dict(mag=4e-12)

        ica = ICA(n_components=n_components, method=method)

        ica.fit(raw, picks=picks_meg, reject=reject, verbose=verb)
        print(subj, ix, "ICA_fit")
        raw.save(raw_out_path, overwrite=True)
        mne.write_events(events_out_path, events)
        ica.save(ica_out_path)
        print(subj, ix, "saved")

if pipeline_params["apply_ICA"]:
    ica_json = files.get_files(meg_subj_path, "", "ica-rej.json")[2][0]

    raw_files = files.get_files(meg_subj_path, "raw", "-raw.fif", wp=False)[2]

    comp_ICA_json_path = op.join(meg_subj_path,
                                 "{}-ica-rej.json".format(str(subj).zfill(3)))

    ica_files = files.get_files(meg_subj_path, "", "-ica.fif", wp=False)[2]
Beispiel #33
0
# The PSD of these data show the noise as clear peaks.
raw.plot_psd(fmax=30)

# %%
# Run the "together" algorithm.
raw_tog = raw.copy()
ica_kwargs = dict(
    method='picard',
    fit_params=dict(tol=1e-4),  # use a high tol here for speed
)
all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True)
ica_tog = ICA(n_components=60,
              max_iter='auto',
              allow_ref_meg=True,
              **ica_kwargs)
ica_tog.fit(raw_tog, picks=all_picks)
# low threshold (2.0) here because of cropped data, entire recording can use
# a higher threshold (2.5)
bad_comps, scores = ica_tog.find_bads_ref(raw_tog, threshold=2.0)

# Plot scores with bad components marked.
ica_tog.plot_scores(scores, bad_comps)

# Examine the properties of removed components. It's clear from the time
# courses and topographies that these components represent external,
# intermittent noise.
ica_tog.plot_properties(raw_tog, picks=bad_comps)

# Remove the components.
raw_tog = ica_tog.apply(raw_tog, exclude=bad_comps)
Beispiel #34
0
for condition in conditions:
    epochs = mne.read_epochs(epochs_folder + "%s_%s_ar-epo.fif" %
                             (subject, condition))

    # ICA Part
    ica = ICA(n_components=0.99, method='fastica', max_iter=256)

    picks = mne.pick_types(epochs.info,
                           meg=True,
                           eeg=False,
                           eog=True,
                           emg=False,
                           stim=False,
                           exclude=[])

    ica.fit(epochs, picks=picks, decim=decim, reject=None)

    # maximum number of components to reject
    n_max_eog = 1

    ##########################################################################
    # 2) identify bad components by analyzing latent sources.

    # DETECT EOG BY CORRELATION
    title = "ICA: %s for %s"

    # EOG
    eog_inds, scores = ica.find_bads_eog(epochs)

    eog_inds = eog_inds[:n_max_eog]
    ica.exclude += eog_inds
Beispiel #35
0
def remove_eyeblinks_and_heartbeat(raw, subject, figdir, events, eventid, rng):
    """
    Find and repair eyeblink and heartbeat artifacts in the data.
    Data should be filtered.
    Importantly, ICA is fitted on artificially epoched data with reject
    criteria estimates via the autoreject package - this is done to reject high-
    amplitude artifacts to influence the ICA solution.
    The ICA fit is then applied to the raw data.
    :param raw: Raw data
    :param subject: str, subject identifier, e.g., '001'
    :param figdir:
    :param rng: random state instance
    """
    # prior to an ICA, it is recommended to high-pass filter the data
    # as low frequency artifacts can alter the ICA solution. We fit the ICA
    # to high-pass filtered (1Hz) data, and apply it to non-highpass-filtered
    # data
    logging.info("Applying a temporary high-pass filtering prior to ICA")
    filt_raw = raw.copy()
    filt_raw.load_data().filter(l_freq=1., h_freq=None)
    # evoked eyeblinks and heartbeats for diagnostic plots

    eog_evoked = create_eog_epochs(filt_raw).average()
    eog_evoked.apply_baseline(baseline=(None, -0.2))
    if subject == '008':
        # subject 008's ECG channel is flat. It will not find any heartbeats by
        # default. We let it estimate heartbeat from magnetometers. For this,
        # we'll drop the ECG channel
        filt_raw.drop_channels('ECG003')
    ecg_evoked = create_ecg_epochs(filt_raw).average()
    ecg_evoked.apply_baseline(baseline=(None, -0.2))
    # make sure that we actually found sensible artifacts here
    eog_fig = eog_evoked.plot_joint()
    for i, fig in enumerate(eog_fig):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"evoked-artifact_eog_sub-{subject}_{i}.png",
        ])
        fig.savefig(fname)
    ecg_fig = ecg_evoked.plot_joint()
    for i, fig in enumerate(ecg_fig):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"evoked-artifact_ecg_sub-{subject}_{i}.png",
        ])
        fig.savefig(fname)
    # Chunk raw data into epochs to fit the ICA
    # No baseline correction as it would interfere with ICA.
    logging.info("Epoching filtered data")
    epochs = mne.Epochs(filt_raw,
                        events,
                        event_id=eventid,
                        tmin=0,
                        tmax=3,
                        picks='meg',
                        baseline=None)

    ## First, estimate rejection criteria for high-amplitude artifacts. This is
    ## done via autoreject
    #logging.info('Estimating bad epochs quick-and-dirty, to improve ICA')
    #ar = AutoReject(random_state=rng)
    # fit on first 200 epochs to save (a bit of) time
    #epochs.load_data()
    #ar.fit(epochs[:200])
    #epochs_ar, reject_log = ar.transform(epochs, return_log=True)

    # run an ICA to capture heartbeat and eyeblink artifacts.
    # set a seed for reproducibility.
    # When left to figure out the component number by itself, it ends up with
    # about 80. I'm setting n_components to 45 to have a chance at checking them
    # by hand.
    # We fit it on a set of epochs excluding the initial bad epochs following
    # https://github.com/autoreject/autoreject/blob/dfbc64f49eddeda53c5868290a6792b5233843c6/examples/plot_autoreject_workflow.py
    logging.info('Fitting the ICA')
    ica = ICA(max_iter='auto', n_components=45, random_state=rng)
    ica.fit(epochs)  #[~reject_log.bad_epochs])
    logging.info("Searching for eyeblink and heartbeat artifacts in the data")
    # get ICA components for the given subject
    if subject == '008':
        eog_indices = [10]
        ecg_indices = [29]
        #eog_indices = ica_comps[subject]['eog']
        #ecg_indices = ica_comps[subject]['ecg']
    # An initially manual component detection did not reproduce after a software
    # update - for now, we have to do the automatic detection for all but sub 8
    else:
        eog_indices, eog_scores = ica.find_bads_eog(filt_raw)
        ecg_indices, ecg_scores = ica.find_bads_ecg(filt_raw)
        logging.info(f"Identified the following EOG components: {eog_indices}")
        logging.info(f"Identified the following ECG components: {ecg_indices}")
    # visualize the components
    components = ica.plot_components()
    for i, fig in enumerate(components):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"ica-components_sub-{subject}_{i}.png",
        ])
        fig.savefig(fname)
    # visualize the time series of components and save it
    plt.rcParams['figure.figsize'] = 30, 20
    comp_sources = ica.plot_sources(epochs)
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-components_sources_sub-{subject}.png",
    ])
    comp_sources.savefig(fname)
    # reset plotting params
    plt.rcParams['figure.figsize'] = plt.rcParamsDefault['figure.figsize']

    # plot EOG components
    overlay_eog = ica.plot_overlay(eog_evoked,
                                   exclude=ica_comps[subject]['eog'])
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-eog-components_over-avg-epochs_sub-{subject}.png",
    ])
    overlay_eog.savefig(fname)
    # plot ECG components
    overlay_ecg = ica.plot_overlay(ecg_evoked,
                                   exclude=ica_comps[subject]['ecg'])
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-ecg-components_over-avg-epochs_sub-{subject}.png",
    ])
    overlay_ecg.savefig(fname)
    # plot EOG component properties
    figs = ica.plot_properties(filt_raw, picks=eog_indices)
    for i, fig in enumerate(figs):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"ica-property{i}_artifact-eog_sub-{subject}.png",
        ])
        fig.savefig(fname)
    # plot ECG component properties
    figs = ica.plot_properties(filt_raw, picks=ecg_indices)
    for i, fig in enumerate(figs):
        fname = _construct_path([
            Path(figdir),
            f"sub-{subject}",
            "meg",
            f"ica-property{i}_artifact-ecg_sub-{subject}.png",
        ])
        fig.savefig(fname)

    # Set the indices to be excluded
    ica.exclude = eog_indices
    ica.exclude.extend(ecg_indices)

    # plot ICs applied to the averaged EOG epochs, with EOG matches highlighted
    sources = ica.plot_sources(eog_evoked)
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-sources_artifact-eog_sub-{subject}.png",
    ])
    sources.savefig(fname)

    # plot ICs applied to the averaged ECG epochs, with ECG matches highlighted
    sources = ica.plot_sources(ecg_evoked)
    fname = _construct_path([
        Path(figdir),
        f"sub-{subject}",
        "meg",
        f"ica-sources_artifact-ecg_sub-{subject}.png",
    ])
    sources.savefig(fname)
    # apply the ICA to the raw data
    logging.info('Applying ICA to the raw data.')
    raw.load_data()
    ica.apply(raw)
    eegbci.standardize(raw) #Cambio i nomi dei canali
    montage = make_standard_montage('standard_1005') #Caricare il montaggio
    raw.set_montage(montage) #Setto il montaggio
    raw.filter(1.0, 79.0, fir_design='firwin', skip_by_annotation='edge') #Filtro
    raw_notch = raw.notch_filter(freqs=60) #Faccio un filtro passa banda
    raw.plot_psd(area_mode=None, show=False, average=False, fmin =1.0, fmax=80.0, dB=False, n_fft=160)
    # todo: qui salvare il plot psd
    # Ica
    ica = ICA(n_components=64, random_state=42, method="fastica", max_iter=1000)

    ind = []
    for index, value in enumerate((raw.annotations).description):
        if value == "BAD boundary" or value == "EDGE boundary":
            ind.append(index)
    (raw.annotations).delete(ind)

    ica.fit(raw)
    raws.append(raw)
    icas.append(ica)

#%%
icas[0].plot_properties(raws[0], picks = [2,3], dB = False)

eog_inds, eog_scores = icas[0].find_bads_eog(raws[0],ch_name='Fp1')

icas[0].plot_components(picks = eog_inds)

exc_0 = []


corr_map = corrmap(icas, template=(0, eog_inds[0]))
Beispiel #37
0
def test_ica_full_data_recovery(method):
    """Test recovery of full data when no source is rejected."""
    # Most basic recovery
    _skip_check_picard(method)
    raw = read_raw_fif(raw_fname).crop(0.5, stop).load_data()
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')[:10]
    with pytest.warns(RuntimeWarning, match='projection'):
        epochs = Epochs(raw,
                        events[:4],
                        event_id,
                        tmin,
                        tmax,
                        picks=picks,
                        baseline=(None, 0),
                        preload=True)
    evoked = epochs.average()
    n_channels = 5
    data = raw._data[:n_channels].copy()
    data_epochs = epochs.get_data()
    data_evoked = evoked.data
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    methods = [method]
    for method in methods:
        stuff = [(2, n_channels, True), (2, n_channels // 2, False)]
        for n_components, n_pca_components, ok in stuff:
            ica = ICA(n_components=n_components,
                      random_state=0,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components,
                      method=method,
                      max_iter=1)
            with pytest.warns(UserWarning, match=None):  # sometimes warns
                ica.fit(raw, picks=list(range(n_channels)))
            raw2 = ica.apply(raw.copy(), exclude=[])
            if ok:
                assert_allclose(data[:n_channels],
                                raw2._data[:n_channels],
                                rtol=1e-10,
                                atol=1e-15)
            else:
                diff = np.abs(data[:n_channels] - raw2._data[:n_channels])
                assert (np.max(diff) > 1e-14)

            ica = ICA(n_components=n_components,
                      method=method,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components,
                      random_state=0)
            with pytest.warns(None):  # sometimes warns
                ica.fit(epochs, picks=list(range(n_channels)))
            epochs2 = ica.apply(epochs.copy(), exclude=[])
            data2 = epochs2.get_data()[:, :n_channels]
            if ok:
                assert_allclose(data_epochs[:, :n_channels],
                                data2,
                                rtol=1e-10,
                                atol=1e-15)
            else:
                diff = np.abs(data_epochs[:, :n_channels] - data2)
                assert (np.max(diff) > 1e-14)

            evoked2 = ica.apply(evoked.copy(), exclude=[])
            data2 = evoked2.data[:n_channels]
            if ok:
                assert_allclose(data_evoked[:n_channels],
                                data2,
                                rtol=1e-10,
                                atol=1e-15)
            else:
                diff = np.abs(evoked.data[:n_channels] - data2)
                assert (np.max(diff) > 1e-14)
    pytest.raises(ValueError, ICA, method='pizza-decomposision')
Beispiel #38
0
def test_ica_additional(method):
    """Test additional ICA functionality."""
    _skip_check_picard(method)

    tempdir = _TempDir()
    stop2 = 500
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    raw.del_proj()  # avoid warnings
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    # XXX This breaks the tests :(
    # raw.info['bads'] = [raw.ch_names[1]]
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')[1::2]
    epochs = Epochs(raw,
                    events,
                    None,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    preload=True,
                    proj=False)
    epochs.decimate(3, verbose='error')
    assert len(epochs) == 4

    # test if n_components=None works
    ica = ICA(n_components=None,
              max_pca_components=None,
              n_pca_components=None,
              random_state=0,
              method=method,
              max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    # for testing eog functionality
    picks2 = np.concatenate([picks, pick_types(raw.info, False, eog=True)])
    epochs_eog = Epochs(raw,
                        events[:4],
                        event_id,
                        tmin,
                        tmax,
                        picks=picks2,
                        baseline=(None, 0),
                        preload=True)
    del picks2

    test_cov2 = test_cov.copy()
    ica = ICA(noise_cov=test_cov2,
              n_components=3,
              max_pca_components=4,
              n_pca_components=4,
              method=method)
    assert (ica.info is None)
    with pytest.warns(RuntimeWarning, match='normalize_proj'):
        ica.fit(raw, picks[:5])
    assert (isinstance(ica.info, Info))
    assert (ica.n_components_ < 5)

    ica = ICA(n_components=3,
              max_pca_components=4,
              method=method,
              n_pca_components=4,
              random_state=0)
    pytest.raises(RuntimeError, ica.save, '')

    ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)

    # check passing a ch_name to find_bads_ecg
    with pytest.warns(RuntimeWarning, match='longer'):
        _, scores_1 = ica.find_bads_ecg(raw)
        _, scores_2 = ica.find_bads_ecg(raw, raw.ch_names[1])
    assert scores_1[0] != scores_2[0]

    # test corrmap
    ica2 = ica.copy()
    ica3 = ica.copy()
    corrmap([ica, ica2], (0, 0),
            threshold='auto',
            label='blinks',
            plot=True,
            ch_type="mag")
    corrmap([ica, ica2], (0, 0), threshold=2, plot=False, show=False)
    assert (ica.labels_["blinks"] == ica2.labels_["blinks"])
    assert (0 in ica.labels_["blinks"])
    # test retrieval of component maps as arrays
    components = ica.get_components()
    template = components[:, 0]
    EvokedArray(components, ica.info, tmin=0.).plot_topomap([0], time_unit='s')

    corrmap([ica, ica3],
            template,
            threshold='auto',
            label='blinks',
            plot=True,
            ch_type="mag")
    assert (ica2.labels_["blinks"] == ica3.labels_["blinks"])

    plt.close('all')

    # make sure a single threshold in a list works
    corrmap([ica, ica3],
            template,
            threshold=[0.5],
            label='blinks',
            plot=True,
            ch_type="mag")

    ica_different_channels = ICA(n_components=2,
                                 random_state=0).fit(raw, picks=[2, 3, 4, 5])
    pytest.raises(ValueError, corrmap, [ica_different_channels, ica], (0, 0))

    # test warnings on bad filenames
    ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        ica.save(ica_badname)
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        read_ica(ica_badname)

    # test decim
    ica = ICA(n_components=3,
              max_pca_components=4,
              n_pca_components=4,
              method=method,
              max_iter=1)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=picks[:5], decim=3)
    assert raw_._data.shape[1] == n_samples

    # test expl var
    ica = ICA(n_components=1.0,
              max_pca_components=4,
              n_pca_components=4,
              method=method,
              max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=None, decim=3)
    assert (ica.n_components_ == 4)
    ica_var = _ica_explained_variance(ica, raw, normalize=True)
    assert (np.all(ica_var[:-1] >= ica_var[1:]))

    # test ica sorting
    ica.exclude = [0]
    ica.labels_ = dict(blink=[0], think=[1])
    ica_sorted = _sort_components(ica, [3, 2, 1, 0], copy=True)
    assert_equal(ica_sorted.exclude, [3])
    assert_equal(ica_sorted.labels_, dict(blink=[3], think=[2]))

    # epochs extraction from raw fit
    pytest.raises(RuntimeError, ica.get_sources, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov,
                  n_components=2,
                  max_pca_components=4,
                  n_pca_components=4,
                  method=method,
                  max_iter=1)
        with pytest.warns(None):  # ICA does not converge
            ica.fit(raw, picks=picks[:10], start=start, stop=stop2)
        sources = ica.get_sources(epochs).get_data()
        assert (ica.mixing_matrix_.shape == (2, 2))
        assert (ica.unmixing_matrix_.shape == (2, 2))
        assert (ica.pca_components_.shape == (4, 10))
        assert (sources.shape[1] == ica.n_components_)

        for exclude in [[], [0], np.array([1, 2, 3])]:
            ica.exclude = exclude
            ica.labels_ = {'foo': [0]}
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert (list(ica.exclude) == ica_read.exclude)
            assert_equal(ica.labels_, ica_read.labels_)
            ica.apply(raw)
            ica.exclude = []
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [])

            ica.exclude = [0, 1]
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [0, 1])

            ica_raw = ica.get_sources(raw)
            assert (ica.exclude == [
                ica_raw.ch_names.index(e) for e in ica_raw.info['bads']
            ])

        # test filtering
        d1 = ica_raw._data[0].copy()
        ica_raw.filter(4, 20, fir_design='firwin2')
        assert_equal(ica_raw.info['lowpass'], 20.)
        assert_equal(ica_raw.info['highpass'], 4.)
        assert ((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        ica_raw.notch_filter([10], trans_bandwidth=10, fir_design='firwin')
        assert ((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.method = 'fake'
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert (ica.n_pca_components == ica_read.n_pca_components)
        assert_equal(ica.method, ica_read.method)
        assert_equal(ica.labels_, ica_read.labels_)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ pre_whitener_')

        def f(x, y):
            return getattr(x, y).dtype

        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in [
                'mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                'pca_mean_', 'pca_explained_variance_', 'pre_whitener_'
        ]:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert (ica.ch_names == ica_read.ch_names)
        assert (isinstance(ica_read.info, Info))

        sources = ica.get_sources(raw)[:, :][0]
        sources2 = ica_read.get_sources(raw)[:, :][0]
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.apply(raw, exclude=[1])
        _raw2 = ica_read.apply(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(raw,
                                   target='EOG 061',
                                   score_func=func,
                                   start=0,
                                   stop=10)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(raw, start=0, stop=50, score_func=stats.skew)
    # check exception handling
    pytest.raises(ValueError, ica.score_sources, raw, target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # variance, kurtosis params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw,
                             start_find=0,
                             stop_find=50,
                             ecg_ch=ch_name,
                             eog_ch=ch_name,
                             skew_criterion=idx,
                             var_criterion=idx,
                             kurt_criterion=idx)

    # Make sure detect_artifacts marks the right components.
    # For int criterion, the doc says "E.g. range(2) would return the two
    # sources with the highest score". Assert that's what it does.
    # Only test for skew, since it's always the same code.
    ica.exclude = []
    ica.detect_artifacts(raw,
                         start_find=0,
                         stop_find=50,
                         ecg_ch=None,
                         eog_ch=None,
                         skew_criterion=0,
                         var_criterion=None,
                         kurt_criterion=None)
    assert np.abs(scores[ica.exclude]) == np.max(np.abs(scores))

    evoked = epochs.average()
    evoked_data = evoked.data.copy()
    raw_data = raw[:][0].copy()
    epochs_data = epochs.get_data().copy()

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='ctps')
    assert_equal(len(scores), ica.n_components_)
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(epochs, method='ctps')

    assert_equal(len(scores), ica.n_components_)
    pytest.raises(ValueError,
                  ica.find_bads_ecg,
                  epochs.average(),
                  method='ctps')
    pytest.raises(ValueError, ica.find_bads_ecg, raw, method='crazy-coupling')

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert (isinstance(scores, list))
    assert_equal(len(scores[0]), ica.n_components_)

    idx, scores = ica.find_bads_eog(evoked, ch_name='MEG 1441')
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(evoked, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    assert_array_equal(raw_data, raw[:][0])
    assert_array_equal(epochs_data, epochs.get_data())
    assert_array_equal(evoked_data, evoked.data)

    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(epochs_eog,
                                   target='EOG 061',
                                   score_func=func)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(epochs, score_func=stats.skew)

    # check exception handling
    pytest.raises(ValueError, ica.score_sources, epochs, target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.score_sources(raw,
                                   target='MEG 1531',
                                   score_func='pearsonr')

    with pytest.warns(RuntimeWarning, match='longer'):
        ecg_events = ica_find_ecg_events(raw,
                                         sources[np.abs(ecg_scores).argmax()])
    assert (ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.score_sources(raw,
                                   target='EOG 061',
                                   score_func='pearsonr')
    with pytest.warns(RuntimeWarning, match='longer'):
        eog_events = ica_find_eog_events(raw,
                                         sources[np.abs(eog_scores).argmax()])
    assert (eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.get_sources(raw, start=0, stop=100)
    assert (ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_equal(len(ica_raw._filenames), 1)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = read_raw_fif(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.get_sources(epochs)
    assert (ica_epochs.events.shape == epochs.events.shape)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    assert (ica.n_components_ == ica_epochs.get_data().shape[1])
    assert (ica_epochs._raw is None)
    assert (ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = ica._check_n_pca_components(ncomps)
        assert (ncomps_ == expected)

    ica = ICA(method=method)
    with pytest.warns(None):  # sometimes does not converge
        ica.fit(raw, picks=picks[:5])
    with pytest.warns(RuntimeWarning, match='longer'):
        ica.find_bads_ecg(raw)
    ica.find_bads_eog(epochs, ch_name='MEG 0121')
    assert_array_equal(raw_data, raw[:][0])

    raw.drop_channels(['MEG 0122'])
    pytest.raises(RuntimeError, ica.find_bads_eog, raw)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(RuntimeError, ica.find_bads_ecg, raw)
Beispiel #39
0
def test_ica_core(method):
    """Test ICA on raw and epochs."""
    _skip_check_picard(method)
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()

    # XXX. The None cases helped revealing bugs but are time consuming.
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')
    epochs = Epochs(raw,
                    events[:4],
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    preload=True)
    noise_cov = [None, test_cov]
    # removed None cases to speed up...
    n_components = [2, 1.0]  # for future dbg add cases
    max_pca_components = [3]
    picks_ = [picks]
    methods = [method]
    iter_ica_params = product(noise_cov, n_components, max_pca_components,
                              picks_, methods)

    # # test init catchers
    pytest.raises(ValueError, ICA, n_components=3, max_pca_components=2)
    pytest.raises(ValueError, ICA, n_components=2.3, max_pca_components=2)

    # test essential core functionality
    for n_cov, n_comp, max_n, pcks, method in iter_ica_params:
        # Test ICA raw
        ica = ICA(noise_cov=n_cov,
                  n_components=n_comp,
                  max_pca_components=max_n,
                  n_pca_components=max_n,
                  random_state=0,
                  method=method,
                  max_iter=1)
        pytest.raises(ValueError, ica.__contains__, 'mag')

        print(ica)  # to test repr

        # test fit checker
        pytest.raises(RuntimeError, ica.get_sources, raw)
        pytest.raises(RuntimeError, ica.get_sources, epochs)

        # Test error upon empty epochs fitting
        with pytest.raises(RuntimeError, match='none were found'):
            ica.fit(epochs[0:0])

        # test decomposition
        with pytest.warns(UserWarning, match='did not converge'):
            ica.fit(raw, picks=pcks, start=start, stop=stop)
        repr(ica)  # to test repr
        assert ('mag' in ica)  # should now work without error

        # test re-fit
        unmixing1 = ica.unmixing_matrix_
        with pytest.warns(UserWarning, match='did not converge'):
            ica.fit(raw, picks=pcks, start=start, stop=stop)
        assert_array_almost_equal(unmixing1, ica.unmixing_matrix_)

        raw_sources = ica.get_sources(raw)
        # test for #3804
        assert_equal(raw_sources._filenames, [None])
        print(raw_sources)

        # test for gh-6271 (scaling of ICA traces)
        fig = raw_sources.plot()
        assert len(fig.axes[0].lines) in (4, 5, 6)
        for line in fig.axes[0].lines:
            y = line.get_ydata()
            if len(y) > 2:  # actual data, not markers
                assert np.ptp(y) < 15
        plt.close('all')

        sources = raw_sources[:, :][0]
        assert (sources.shape[0] == ica.n_components_)

        # test preload filter
        raw3 = raw.copy()
        raw3.preload = False
        pytest.raises(RuntimeError, ica.apply, raw3, include=[1, 2])

        #######################################################################
        # test epochs decomposition
        ica = ICA(noise_cov=n_cov,
                  n_components=n_comp,
                  max_pca_components=max_n,
                  n_pca_components=max_n,
                  random_state=0,
                  method=method)
        with pytest.warns(None):  # sometimes warns
            ica.fit(epochs, picks=picks)
        data = epochs.get_data()[:, 0, :]
        n_samples = np.prod(data.shape)
        assert_equal(ica.n_samples_, n_samples)
        print(ica)  # to test repr

        sources = ica.get_sources(epochs).get_data()
        assert (sources.shape[1] == ica.n_components_)

        pytest.raises(ValueError,
                      ica.score_sources,
                      epochs,
                      target=np.arange(1))

        # test preload filter
        epochs3 = epochs.copy()
        epochs3.preload = False
        pytest.raises(RuntimeError, ica.apply, epochs3, include=[1, 2])

    # test for bug with whitener updating
    _pre_whitener = ica.pre_whitener_.copy()
    epochs._data[:, 0, 10:15] *= 1e12
    ica.apply(epochs.copy())
    assert_array_equal(_pre_whitener, ica.pre_whitener_)

    # test expl. var threshold leading to empty sel
    ica.n_components = 0.1
    pytest.raises(RuntimeError, ica.fit, epochs)

    offender = 1, 2, 3,
    pytest.raises(ValueError, ica.get_sources, offender)
    pytest.raises(TypeError, ica.fit, offender)
    pytest.raises(TypeError, ica.apply, offender)
Beispiel #40
0
def test_ica_eeg():
    """Test ICA on EEG."""
    method = 'fastica'
    raw_fif = read_raw_fif(fif_fname, preload=True)
    raw_eeglab = read_raw_eeglab(input_fname=eeglab_fname, preload=True)
    for raw in [raw_fif, raw_eeglab]:
        events = make_fixed_length_events(raw,
                                          99999,
                                          start=0,
                                          stop=0.3,
                                          duration=0.1)
        picks_meg = pick_types(raw.info, meg=True, eeg=False)[:2]
        picks_eeg = pick_types(raw.info, meg=False, eeg=True)[:2]
        picks_all = []
        picks_all.extend(picks_meg)
        picks_all.extend(picks_eeg)
        epochs = Epochs(raw, events, None, -0.1, 0.1, preload=True)
        evoked = epochs.average()

        for picks in [picks_meg, picks_eeg, picks_all]:
            if len(picks) == 0:
                continue
            # test fit
            for inst in [raw, epochs]:
                ica = ICA(n_components=2,
                          random_state=0,
                          max_iter=2,
                          method=method)
                with pytest.warns(None):
                    ica.fit(inst, picks=picks)

            # test apply and get_sources
            for inst in [raw, epochs, evoked]:
                ica.apply(inst)
                ica.get_sources(inst)

    with pytest.warns(RuntimeWarning, match='MISC channel'):
        raw = read_raw_ctf(ctf_fname2, preload=True)
    events = make_fixed_length_events(raw,
                                      99999,
                                      start=0,
                                      stop=0.2,
                                      duration=0.1)
    picks_meg = pick_types(raw.info, meg=True, eeg=False)[:2]
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)[:2]
    picks_all = picks_meg + picks_eeg
    for comp in [0, 1]:
        raw.apply_gradient_compensation(comp)
        epochs = Epochs(raw, events, None, -0.1, 0.1, preload=True)
        evoked = epochs.average()

        for picks in [picks_meg, picks_eeg, picks_all]:
            if len(picks) == 0:
                continue
            # test fit
            for inst in [raw, epochs]:
                ica = ICA(n_components=2,
                          random_state=0,
                          max_iter=2,
                          method=method)
                with pytest.warns(None):
                    ica.fit(inst)

            # test apply and get_sources
            for inst in [raw, epochs, evoked]:
                ica.apply(inst)
                ica.get_sources(inst)
Beispiel #41
0
# we will also set state of the random number generator - ICA is a
# non-deterministic algorithm, but we want to have the same decomposition
# and the same order of components each time this tutorial is run
random_state = 23

###############################################################################
# Define the ICA object instance
ica = ICA(n_components=n_components, method=method, random_state=random_state)
print(ica)

###############################################################################
# we avoid fitting ICA on crazy environmental artifacts that would
# dominate the variance and decomposition
reject = dict(mag=5e-12, grad=4000e-13)
ica.fit(raw, picks='meg', decim=decim, reject=reject)
print(ica)

###############################################################################
# Plot ICA components
ica.plot_components()  # can you spot some potential bad guys?


###############################################################################
# Component properties
# --------------------
#
# Let's take a closer look at properties of first three independent components.

# first, component 0:
ica.plot_properties(raw, picks=0)
Beispiel #42
0
def test_ica_core():
    """Test ICA on raw and epochs"""
    raw = Raw(raw_fname).crop(1.5, stop, False)
    raw.load_data()
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    # XXX. The None cases helped revealing bugs but are time consuming.
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)
    noise_cov = [None, test_cov]
    # removed None cases to speed up...
    n_components = [2, 1.0]  # for future dbg add cases
    max_pca_components = [3]
    picks_ = [picks]
    methods = ['fastica']
    iter_ica_params = product(noise_cov, n_components, max_pca_components,
                              picks_, methods)

    # # test init catchers
    assert_raises(ValueError, ICA, n_components=3, max_pca_components=2)
    assert_raises(ValueError, ICA, n_components=2.3, max_pca_components=2)

    # test essential core functionality
    for n_cov, n_comp, max_n, pcks, method in iter_ica_params:
        # Test ICA raw
        ica = ICA(noise_cov=n_cov, n_components=n_comp,
                  max_pca_components=max_n, n_pca_components=max_n,
                  random_state=0, method=method, max_iter=1)
        assert_raises(ValueError, ica.__contains__, 'mag')

        print(ica)  # to test repr

        # test fit checker
        assert_raises(RuntimeError, ica.get_sources, raw)
        assert_raises(RuntimeError, ica.get_sources, epochs)

        # test decomposition
        with warnings.catch_warnings(record=True):
            ica.fit(raw, picks=pcks, start=start, stop=stop)
            repr(ica)  # to test repr
        assert_true('mag' in ica)  # should now work without error

        # test re-fit
        unmixing1 = ica.unmixing_matrix_
        with warnings.catch_warnings(record=True):
            ica.fit(raw, picks=pcks, start=start, stop=stop)
        assert_array_almost_equal(unmixing1, ica.unmixing_matrix_)

        sources = ica.get_sources(raw)[:, :][0]
        assert_true(sources.shape[0] == ica.n_components_)

        # test preload filter
        raw3 = raw.copy()
        raw3.preload = False
        assert_raises(ValueError, ica.apply, raw3,
                      include=[1, 2])

        #######################################################################
        # test epochs decomposition
        ica = ICA(noise_cov=n_cov, n_components=n_comp,
                  max_pca_components=max_n, n_pca_components=max_n,
                  random_state=0)
        with warnings.catch_warnings(record=True):
            ica.fit(epochs, picks=picks)
        data = epochs.get_data()[:, 0, :]
        n_samples = np.prod(data.shape)
        assert_equal(ica.n_samples_, n_samples)
        print(ica)  # to test repr

        sources = ica.get_sources(epochs).get_data()
        assert_true(sources.shape[1] == ica.n_components_)

        assert_raises(ValueError, ica.score_sources, epochs,
                      target=np.arange(1))

        # test preload filter
        epochs3 = epochs.copy()
        epochs3.preload = False
        assert_raises(ValueError, ica.apply, epochs3,
                      include=[1, 2])

    # test for bug with whitener updating
    _pre_whitener = ica._pre_whitener.copy()
    epochs._data[:, 0, 10:15] *= 1e12
    ica.apply(epochs, copy=True)
    assert_array_equal(_pre_whitener, ica._pre_whitener)

    # test expl. var threshold leading to empty sel
    ica.n_components = 0.1
    assert_raises(RuntimeError, ica.fit, epochs)

    offender = 1, 2, 3,
    assert_raises(ValueError, ica.get_sources, offender)
    assert_raises(ValueError, ica.fit, offender)
    assert_raises(ValueError, ica.apply, offender)
# Other available choices are `infomax` or `extended-infomax`
# We pass a float value between 0 and 1 to select n_components based on the
# percentage of variance explained by the PCA components.

ica = ICA(n_components=0.95, method='fastica', random_state=0, max_iter=100)

picks = mne.pick_types(raw.info,
                       meg=True,
                       eeg=False,
                       eog=False,
                       stim=False,
                       exclude='bads')

ica.fit(raw,
        picks=picks,
        decim=3,
        reject=dict(mag=4e-12, grad=4000e-13),
        verbose='warning')  # low iterations -> does not fully converge

# maximum number of components to reject
n_max_ecg, n_max_eog = 3, 1  # here we don't expect horizontal EOG components

###############################################################################
# 2) identify bad components by analyzing latent sources.

title = 'Sources related to %s artifacts (red)'

# generate ECG epochs use detection via phase statistics

ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5, picks=picks)
Beispiel #44
0
# a huge number of components to do a good job of isolating our artifacts
# (though it is usually preferable to include more components for a more
# accurate solution). As a first guess, we'll run ICA with ``n_components=15``
# (use only the first 15 PCA components to compute the ICA decomposition) — a
# very small number given that our data has over 300 channels, but with the
# advantage that it will run quickly and we will able to tell easily whether it
# worked or not (because we already know what the EOG / ECG artifacts should
# look like).
#
# ICA fitting is not deterministic (e.g., the components may get a sign
# flip on different runs, or may not always be returned in the same order), so
# we'll also specify a `random seed`_ so that we get identical results each
# time this tutorial is built by our web servers.

ica = ICA(n_components=15, random_state=97)
ica.fit(filt_raw)

###############################################################################
# Some optional parameters that we could have passed to the
# :meth:`~mne.preprocessing.ICA.fit` method include ``decim`` (to use only
# every Nth sample in computing the ICs, which can yield a considerable
# speed-up) and ``reject`` (for providing a rejection dictionary for maximum
# acceptable peak-to-peak amplitudes for each channel type, just like we used
# when creating epoched data in the :ref:`tut-overview` tutorial).
#
# Now we can examine the ICs to see what they captured.
# :meth:`~mne.preprocessing.ICA.plot_sources` will show the time series of the
# ICs. Note that in our call to :meth:`~mne.preprocessing.ICA.plot_sources` we
# can use the original, unfiltered :class:`~mne.io.Raw` object:

raw.load_data()
Beispiel #45
0
def preprocess_eeg(id_num, random_seed=None):

    # Set important variables
    bids_path = BIDSPath(id_num, task=task, datatype=datatype, root=bids_root)
    plot_path = os.path.join(plotdir, "sub_{0}".format(id_num))
    if os.path.exists(plot_path):
        shutil.rmtree(plot_path)
    os.mkdir(plot_path)
    if not random_seed:
        random_seed = int(binascii.b2a_hex(os.urandom(4)), 16)
    random.seed(random_seed)
    id_info = {"id": id_num, "random_seed": random_seed}

    ### Load and prepare EEG data #############################################

    header = "### Processing sub-{0} (seed: {1}) ###".format(
        id_num, random_seed)
    print("\n" + "#" * len(header))
    print(header)
    print("#" * len(header) + "\n")

    # Load EEG data
    raw = read_raw_bids(bids_path, verbose=True)

    # Check if recording is complete
    complete = len(raw.annotations) >= 600

    # Add a montage to the data
    montage_kind = "standard_1005"
    montage = mne.channels.make_standard_montage(montage_kind)
    mne.datasets.eegbci.standardize(raw)
    raw.set_montage(montage)

    # Extract some info
    eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
    ch_names = raw.info["ch_names"]
    ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
    sample_rate = raw.info["sfreq"]

    # Make a copy of the data
    raw_copy = raw.copy()
    raw_copy.load_data()

    # Trim duplicated data (only needed for sub-005)
    annot = raw_copy.annotations
    file_starts = [a for a in annot if a['description'] == "file start"]
    if len(file_starts):
        duplicate_start = file_starts[0]['onset']
        raw_copy.crop(tmax=duplicate_start)

    # Make backup of EOG and EMG channels to re-append after PREP
    raw_other = raw_copy.copy()
    raw_other.pick_types(eog=True, emg=True, stim=False)

    # Prepare copy of raw data for PREP
    raw_copy.pick_types(eeg=True)

    # Plot data prior to any processing
    if complete:
        save_psd_plot(id_num, "psd_0_raw", plot_path, raw_copy)
        save_channel_plot(id_num, "ch_0_raw", plot_path, raw_copy)

    ### Clean up events #######################################################

    print("\n\n=== Processing Event Annotations... ===\n")

    event_names = [
        "stim_on", "red_on", "trace_start", "trace_end", "accuracy_submit",
        "vividness_submit"
    ]
    doubled = []
    wrong_label = []
    new_onsets = []
    new_durations = []
    new_descriptions = []

    # Find and flag any duplicate triggers
    annot = raw_copy.annotations
    trigger_count = len(annot)
    for i in range(1, trigger_count - 1):
        a = annot[i]
        on_last = i + 1 == trigger_count
        prev_trigger = annot[i - 1]['description']
        next_onset = annot[i + 1]['onset'] if not on_last else a['onset'] + 100
        # Determine whether duplicates are doubles or mislabeled
        if a['description'] == prev_trigger:
            if (next_onset - a['onset']) < 0.002:
                doubled.append(a)
            else:
                wrong_label.append(a)

    # Rename annotations to have meaningful names & fix duplicates
    for a in raw_copy.annotations:
        if a in doubled or a['description'] not in event_names:
            continue
        if a in wrong_label:
            index = event_names.index(a['description'])
            a['description'] = event_names[index + 1]
        new_onsets.append(a['onset'])
        new_durations.append(a['duration'])
        new_descriptions.append(a['description'])

    # Replace old annotations with new fixed ones
    if len(annot):
        new_annot = mne.Annotations(
            new_onsets,
            new_durations,
            new_descriptions,
            orig_time=raw_copy.annotations[0]['orig_time'])
        raw_copy.set_annotations(new_annot)

    # Check annotations to verify we have equal numbers of each
    orig_counts = Counter(annot.description)
    counts = Counter(raw_copy.annotations.description)
    print("Updated Annotation Counts:")
    for a in event_names:
        out = " - '{0}': {1} -> {2}"
        print(out.format(a, orig_counts[a], counts[a]))

    # Get info
    id_info['annot_doubled'] = len(doubled)
    id_info['annot_wrong'] = len(wrong_label)

    count_vals = [
        n for n in counts.values() if n != counts['vividness_submit']
    ]
    id_info['equal_triggers'] = all(x == count_vals[0] for x in count_vals)
    id_info['stim_on'] = counts['stim_on']
    id_info['red_on'] = counts['red_on']
    id_info['trace_start'] = counts['trace_start']
    id_info['trace_end'] = counts['trace_end']
    id_info['acc_submit'] = counts['accuracy_submit']
    id_info['vivid_submit'] = counts['vividness_submit']

    if not complete:
        remaining_info = {
            'initial_bad': "NA",
            'num_initial_bad': "NA",
            'interpolated': "NA",
            'num_interpolated': "NA",
            'remaining_bad': "NA",
            'num_remaining_bad': "NA"
        }
        id_info.update(remaining_info)
        e = "\n\n### Incomplete recording for sub-{0}, skipping... ###\n\n"
        print(e.format(id_num))
        return id_info

    ### Run components of PREP manually #######################################

    print("\n\n=== Performing CleanLine... ===")

    # Try to remove line noise using CleanLine approach
    linenoise = np.arange(60, sample_rate / 2, 60)
    EEG_raw = raw_copy.get_data() * 1e6
    EEG_new = removeTrend(EEG_raw, sample_rate=raw.info["sfreq"])
    EEG_clean = mne.filter.notch_filter(
        EEG_new,
        Fs=raw.info["sfreq"],
        freqs=linenoise,
        filter_length="10s",
        method="spectrum_fit",
        mt_bandwidth=2,
        p_value=0.01,
    )
    EEG_final = EEG_raw - EEG_new + EEG_clean
    raw_copy._data = EEG_final * 1e-6
    del linenoise, EEG_raw, EEG_new, EEG_clean, EEG_final

    # Plot data following cleanline
    save_psd_plot(id_num, "psd_1_cleanline", plot_path, raw_copy)
    save_channel_plot(id_num, "ch_1_cleanline", plot_path, raw_copy)

    # Perform robust re-referencing
    prep_params = {"ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg}
    reference = Reference(raw_copy,
                          prep_params,
                          ransac=True,
                          random_state=random_seed)
    print("\n\n=== Performing Robust Re-referencing... ===\n")
    reference.perform_reference()

    # If not interpolating bad channels, use pre-interpolation channel data
    if not interpolate_bads:
        reference.raw._data = reference.EEG_before_interpolation * 1e-6
        reference.interpolated_channels = []
        reference.still_noisy_channels = reference.bad_before_interpolation
        reference.raw.info["bads"] = reference.bad_before_interpolation

    # Plot data following robust re-reference
    save_psd_plot(id_num, "psd_2_reref", plot_path, reference.raw)
    save_channel_plot(id_num, "ch_2_reref", plot_path, reference.raw)

    # Re-append removed EMG/EOG/trigger channels
    raw_prepped = reference.raw.add_channels([raw_other])

    # Get info
    initial_bad = reference.noisy_channels_original["bad_all"]
    id_info['initial_bad'] = " ".join(initial_bad)
    id_info['num_initial_bad'] = len(initial_bad)

    interpolated = reference.interpolated_channels
    id_info['interpolated'] = " ".join(interpolated)
    id_info['num_interpolated'] = len(interpolated)

    remaining_bad = reference.still_noisy_channels
    id_info['remaining_bad'] = " ".join(remaining_bad)
    id_info['num_remaining_bad'] = len(remaining_bad)

    # Print re-referencing info
    print("\nRe-Referencing Info:")
    print(" - Bad channels original: {0}".format(initial_bad))
    if interpolate_bads:
        print(" - Bad channels after re-referencing: {0}".format(interpolated))
        print(" - Bad channels after interpolation: {0}".format(remaining_bad))
    else:
        print(
            " - Bad channels after re-referencing: {0}".format(remaining_bad))

    # Check if too many channels were interpolated for the participant
    prop_interpolated = len(
        reference.interpolated_channels) / len(ch_names_eeg)
    e = "### NOTE: Too many interpolated channels for sub-{0} ({1}) ###"
    if max_interpolated < prop_interpolated:
        print("\n")
        print(e.format(id_num, len(reference.interpolated_channels)))
        print("\n")

    ### Filter data and apply ICA to remove blinks ############################

    # Apply highpass & lowpass filters
    print("\n\n=== Applying Highpass & Lowpass Filters... ===")
    raw_prepped.filter(1.0, 50.0, fir_design='firwin')

    # Plot data following frequency filters
    save_psd_plot(id_num, "psd_3_filtered", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_3_filtered", plot_path, raw_prepped)

    # Perform ICA using EOG data on eye blinks
    print("\n\n=== Removing Blinks Using ICA... ===\n")
    ica = ICA(n_components=20, random_state=random_seed, method='picard')
    ica.fit(raw_prepped, decim=5)
    eog_indices, eog_scores = ica.find_bads_eog(raw_prepped)
    ica.exclude = eog_indices

    if not len(ica.exclude):
        err = " - Encountered an ICA error for sub-{0}, skipping for now..."
        print("\n")
        print(err.format(id_num))
        print("\n")
        save_bad_fif(raw_prepped, id_num, ica_err_dir)
        return id_info

    # Plot ICA info & diagnostics before removing from signal
    save_ica_plots(id_num, plot_path, raw_prepped, ica, eog_scores)

    # Remove eye blink independent components based on ICA
    ica.apply(raw_prepped)

    # Plot data following ICA
    save_psd_plot(id_num, "psd_4_ica", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_4_ica", plot_path, raw_prepped)

    ### Compute Current Source Density (CSD) estimates ########################

    if perform_csd:
        print("\n")
        print("=== Computing Current Source Density (CSD) Estimates... ===\n")
        raw_prepped = mne.preprocessing.compute_current_source_density(
            raw_prepped.drop_channels(remaining_bad))

        # Plot data following CSD
        save_psd_plot(id_num, "psd_5_csd", plot_path, raw_prepped)
        save_channel_plot(id_num, "ch_5_csd", plot_path, raw_prepped)

    ### Write preprocessed data to new EDF ####################################

    if max_interpolated < prop_interpolated:
        if not os.path.isdir(noisy_bad_dir):
            os.makedirs(noisy_bad_dir)
        outpath = os.path.join(noisy_bad_dir, outfile_fmt.format(id_num))
    else:
        outpath = os.path.join(outdir, outfile_fmt.format(id_num))
    write_mne_edf(outpath, raw_prepped)

    print("\n\n### sub-{0} complete! ###\n\n".format(id_num))

    return id_info
Beispiel #46
0
def test_plot_ica_properties():
    """Test plotting of ICA properties."""
    raw = _get_raw(preload=True).crop(0, 5)
    raw.add_proj([], remove_existing=True)
    with raw.info._unlock():
        raw.info['highpass'] = 1.0  # fake high-pass filtering
    events = make_fixed_length_events(raw)
    picks = _get_picks(raw)[:6]
    pick_names = [raw.ch_names[k] for k in picks]
    raw.pick_channels(pick_names)
    reject = dict(grad=4000e-13, mag=4e-12)

    epochs = Epochs(raw,
                    events[:3],
                    event_id,
                    tmin,
                    tmax,
                    baseline=(None, 0),
                    preload=True)

    ica = ICA(noise_cov=read_cov(cov_fname),
              n_components=2,
              max_iter=1,
              random_state=0)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw)

    # test _create_properties_layout
    fig, ax = _create_properties_layout()
    assert_equal(len(ax), 5)
    with pytest.raises(ValueError, match='specify both fig and figsize'):
        _create_properties_layout(figsize=(2, 2), fig=fig)

    topoargs = dict(topomap_args={'res': 4, 'contours': 0, "sensors": False})
    with catch_logging() as log:
        ica.plot_properties(raw, picks=0, verbose='debug', **topoargs)
    log = log.getvalue()
    assert raw.ch_names[0] == 'MEG 0113'
    assert 'Interpolation mode local to mean' in log, log
    ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
    ica.plot_properties(epochs,
                        picks=1,
                        image_args={'sigma': 1.5},
                        topomap_args={
                            'res': 4,
                            'colorbar': True
                        },
                        psd_args={'fmax': 65.},
                        plot_std=False,
                        figsize=[4.5, 4.5],
                        reject=reject)
    plt.close('all')

    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(epochs, dB=list('abc'))
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(ica)
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties([0.2])
    with pytest.raises(TypeError, match='must be an instance'):
        plot_ica_properties(epochs, epochs)
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(epochs, psd_args='not dict')
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(epochs, plot_std=[])

    fig, ax = plt.subplots(2, 3)
    ax = ax.ravel()[:-1]
    ica.plot_properties(epochs, picks=1, axes=ax, **topoargs)
    pytest.raises(TypeError,
                  plot_ica_properties,
                  epochs,
                  ica,
                  picks=[0, 1],
                  axes=ax)
    pytest.raises(ValueError, ica.plot_properties, epochs, axes='not axes')
    plt.close('all')

    # Test merging grads.
    pick_names = raw.ch_names[:15:2] + raw.ch_names[1:15:2]
    raw = _get_raw(preload=True).pick_channels(pick_names).crop(0, 5)
    raw.info.normalize_proj()
    ica = ICA(random_state=0, max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw)
    ica.plot_properties(raw)
    plt.close('all')

    # Test handling of zeros
    ica = ICA(random_state=0, max_iter=1)
    epochs.pick_channels(pick_names)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    epochs._data[0] = 0
    # Usually UserWarning: Infinite value .* for epo
    with _record_warnings():
        ica.plot_properties(epochs, **topoargs)
    plt.close('all')

    # Test Raw with annotations
    annot = Annotations(onset=[1], duration=[1], description=['BAD'])
    raw_annot = _get_raw(preload=True).set_annotations(annot).crop(0, 8)
    raw_annot.pick(np.arange(10))
    raw_annot.del_proj()

    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw_annot)
    # drop bad data segments
    fig = ica.plot_properties(raw_annot, picks=[0, 1], **topoargs)
    assert_equal(len(fig), 2)
    # don't drop
    ica.plot_properties(raw_annot, reject_by_annotation=False, **topoargs)
Beispiel #47
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    assert len(plt.get_fignums()) == 1
    # change which component is in ICA.exclude (click data trace to remove
    # current one; click name to add other one)
    fig.canvas.draw()
    x = fig.mne.traces[1].get_xdata()[5]
    y = fig.mne.traces[1].get_ydata()[5]
    _fake_click(fig, fig.mne.ax_main, (x, y), xform='data')  # exclude = []
    _click_ch_name(fig, ch_index=0, button=1)  # exclude = [0]
    fig.canvas.key_press_event(fig.mne.close_key)
    _close_event(fig)
    assert len(plt.get_fignums()) == 0
    assert_array_equal(ica.exclude, [0])
    # test when picks does not include ica.exclude.
    fig = ica.plot_sources(raw, picks=[1])
    assert len(plt.get_fignums()) == 1
    plt.close('all')

    # dtype can change int->np.int64 after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)

    # test clicks on y-label (need >2 secs for plot_properties() to work)
    long_raw = read_raw_fif(raw_fname).crop(0, 5).load_data()
    fig = ica.plot_sources(long_raw)
    assert len(plt.get_fignums()) == 1
    fig.canvas.draw()
    _fake_click(fig, fig.mne.ax_main, (-0.1, 0), xform='data', button=3)
    assert len(fig.mne.child_figs) == 1
    assert len(plt.get_fignums()) == 2
    # close child fig directly (workaround for mpl issue #18609)
    fig.mne.child_figs[0].canvas.key_press_event('escape')
    assert len(plt.get_fignums()) == 1
    fig.canvas.key_press_event(fig.mne.close_key)
    assert len(plt.get_fignums()) == 0
    del long_raw

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.mne.ax_main.collections) == 1
    assert len(fig.mne.ax_hscroll.collections) == 1
    raw.set_annotations(orig_annot)

    # test error handling
    raw.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"):
        ica.plot_sources(inst=raw)
    epochs.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"):
        ica.plot_sources(inst=epochs)
    epochs.info['bads'] = []

    # test w/ epochs and evokeds
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
Beispiel #48
0
def test_plot_ica_sources(raw_orig, mpl_backend):
    """Test plotting of ICA panel."""
    raw = raw_orig.copy().crop(0, 1)
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    assert mpl_backend._get_n_figs() == 1
    # change which component is in ICA.exclude (click data trace to remove
    # current one; click name to add other one)
    fig._redraw()
    # ToDo: This will be different methods in pyqtgraph
    x = fig.mne.traces[1].get_xdata()[5]
    y = fig.mne.traces[1].get_ydata()[5]
    fig._fake_click((x, y), xform='data')  # exclude = []
    _click_ch_name(fig, ch_index=0, button=1)  # exclude = [0]
    fig._fake_keypress(fig.mne.close_key)
    fig._close_event()
    assert mpl_backend._get_n_figs() == 0
    assert_array_equal(ica.exclude, [0])
    # test when picks does not include ica.exclude.
    fig = ica.plot_sources(raw, picks=[1])
    assert len(plt.get_fignums()) == 1
    mpl_backend._close_all()

    # dtype can change int->np.int64 after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)

    # test clicks on y-label (need >2 secs for plot_properties() to work)
    long_raw = raw_orig.crop(0, 5)
    fig = ica.plot_sources(long_raw)
    assert len(plt.get_fignums()) == 1
    fig._redraw()
    _click_ch_name(fig, ch_index=0, button=3)
    assert len(fig.mne.child_figs) == 1
    assert len(plt.get_fignums()) == 2
    # close child fig directly (workaround for mpl issue #18609)
    fig._fake_keypress('escape', fig=fig.mne.child_figs[0])
    assert len(plt.get_fignums()) == 1
    fig._fake_keypress(fig.mne.close_key)
    assert len(plt.get_fignums()) == 0
    del long_raw

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.mne.ax_main.collections) == 1
    assert len(fig.mne.ax_hscroll.collections) == 1
    raw.set_annotations(orig_annot)

    # test error handling
    raw_ = raw.copy().load_data()
    raw_.drop_channels('MEG 0113')
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"), \
         pytest.warns(RuntimeWarning, match='could not be picked'):
        ica.plot_sources(inst=raw_)
    epochs_ = epochs.copy().load_data()
    epochs_.drop_channels('MEG 0113')
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"), \
         pytest.warns(RuntimeWarning, match='could not be picked'):
        ica.plot_sources(inst=epochs_)
    del raw_
    del epochs_

    # test w/ epochs and evokeds
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')

    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)

    # pretend find_bads_eog() yielded some results
    ica.labels_ = {'eog': [0], 'eog/0/crazy-channel': [0]}
    ica.plot_sources(evoked)  # now with labels

    # pass an invalid inst
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
Beispiel #49
0
def test_plot_ica_properties():
    """Test plotting of ICA properties."""
    res = 8
    raw = _get_raw(preload=True)
    raw.add_proj([], remove_existing=True)
    events = _get_events()
    picks = _get_picks(raw)[:6]
    pick_names = [raw.ch_names[k] for k in picks]
    raw.pick_channels(pick_names)
    reject = dict(grad=4000e-13, mag=4e-12)

    epochs = Epochs(raw,
                    events[:10],
                    event_id,
                    tmin,
                    tmax,
                    baseline=(None, 0),
                    preload=True)

    ica = ICA(noise_cov=read_cov(cov_fname),
              n_components=2,
              max_pca_components=2,
              n_pca_components=2)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw)

    # test _create_properties_layout
    fig, ax = _create_properties_layout()
    assert_equal(len(ax), 5)

    topoargs = dict(topomap_args={'res': res, 'contours': 0, "sensors": False})
    ica.plot_properties(raw, picks=0, **topoargs)
    ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
    ica.plot_properties(epochs,
                        picks=1,
                        image_args={'sigma': 1.5},
                        topomap_args={
                            'res': 10,
                            'colorbar': True
                        },
                        psd_args={'fmax': 65.},
                        plot_std=False,
                        figsize=[4.5, 4.5],
                        reject=reject)
    plt.close('all')

    pytest.raises(TypeError, ica.plot_properties, epochs, dB=list('abc'))
    pytest.raises(TypeError, ica.plot_properties, ica)
    pytest.raises(TypeError, ica.plot_properties, [0.2])
    pytest.raises(TypeError, plot_ica_properties, epochs, epochs)
    pytest.raises(TypeError, ica.plot_properties, epochs, psd_args='not dict')
    pytest.raises(ValueError, ica.plot_properties, epochs, plot_std=[])

    fig, ax = plt.subplots(2, 3)
    ax = ax.ravel()[:-1]
    ica.plot_properties(epochs, picks=1, axes=ax, **topoargs)
    fig = ica.plot_properties(raw, picks=[0, 1], **topoargs)
    assert_equal(len(fig), 2)
    pytest.raises(TypeError,
                  plot_ica_properties,
                  epochs,
                  ica,
                  picks=[0, 1],
                  axes=ax)
    pytest.raises(ValueError, ica.plot_properties, epochs, axes='not axes')
    plt.close('all')

    # Test merging grads.
    raw = _get_raw(preload=True)
    picks = pick_types(raw.info, meg='grad')[:10]
    ica = ICA(n_components=2)
    ica.fit(raw, picks=picks)
    ica.plot_properties(raw)
    plt.close('all')
Beispiel #50
0
def test_plot_ica_components():
    """Test plotting of ICA solutions."""
    res = 8
    fast_test = {"res": res, "contours": 0, "sensors": False}
    raw = _get_raw()
    ica = ICA(noise_cov=read_cov(cov_fname), n_components=2)
    ica_picks = _get_picks(raw)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw, picks=ica_picks)

    for components in [0, [0], [0, 1], [0, 1] * 2, None]:
        ica.plot_components(components,
                            image_interp='bilinear',
                            colorbar=True,
                            **fast_test)
    plt.close('all')

    # test interactive mode (passing 'inst' arg)
    with catch_logging() as log:
        ica.plot_components([0, 1],
                            image_interp='bilinear',
                            inst=raw,
                            res=16,
                            verbose='debug',
                            ch_type='grad')
    log = log.getvalue()
    assert 'grad data' in log
    assert 'Interpolation mode local to mean' in log
    fig = plt.gcf()

    # test title click
    # ----------------
    lbl = fig.axes[1].get_label()
    ica_idx = int(lbl[-3:])
    titles = [ax.title for ax in fig.axes]
    title_pos_midpoint = (titles[1].get_window_extent().extents.reshape(
        (2, 2)).mean(axis=0))
    # first click adds to exclude
    _fake_click(fig, fig.axes[1], title_pos_midpoint, xform='pix')
    assert ica_idx in ica.exclude
    # clicking again removes from exclude
    _fake_click(fig, fig.axes[1], title_pos_midpoint, xform='pix')
    assert ica_idx not in ica.exclude

    # test topo click
    # ---------------
    _fake_click(fig, fig.axes[1], (0., 0.), xform='data')

    c_fig = plt.gcf()
    labels = [ax.get_label() for ax in c_fig.axes]

    for label in ['topomap', 'image', 'erp', 'spectrum', 'variance']:
        assert label in labels

    topomap_ax = c_fig.axes[labels.index('topomap')]
    title = topomap_ax.get_title()
    assert (lbl == title)

    ica.info = None
    with pytest.raises(RuntimeError, match='fit the ICA'):
        ica.plot_components(1, ch_type='mag')
Beispiel #51
0
    print("Finding events...")
    events = find_events(raw)  # the output of this is a 3 x n_trial np array

    print("Epoching data...")
    epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, decim=5, baseline=None)

    print("Creating ICA object...")
    # apply ICA to the conjoint data
    picks = pick_types(raw.info, meg=True, exclude='bads')
    ica = ICA(n_components=0.9, method='fastica', max_iter=300)

    print("Fitting epochs...")
    # get ica components
    ica.exclude = []
    ica.fit(epochs, picks=picks)

    print("Saving ICA solution...")
    ica.save(ica_fname)  # save solution

#__________________________________________________________
subject = 'A0307'

meg_dir = '/Users/ea84/Dropbox/shepard_sourceloc/%s/' % (subject)
raw_fname = meg_dir + subject + '_shepard-raw.fif'
ica_fname = meg_dir + subject + '_shepard_ica1-ica.fif'
ica_raw_fname = meg_dir + subject + '_ica_shepard-raw.fif'  # applied ica to raw

# params
filt_l = 1  # same as aquisition
filt_h = 60
Beispiel #52
0
                           exclude='bads')
n_components = 25  # if float, select n_components by explained variance of PCA
method = 'fastica'  # for comparison with EEGLAB try "extended-infomax" here
decim = 3  # we need sufficient statistics, not all time points -> saves time
random_state = 23

ica = ICA(n_components=n_components,
          method=method,
          random_state=random_state,
          max_iter=200)
print(ica)

reject = dict(
    grad=4000e-13,  # T / m (gradiometers)
    mag=4e-12)  # T (magnetometers)
ica.fit(raw_sss, picks=picks_meg, decim=decim, reject=reject)
print(ica)
raw_sss.plot_psd(fmax=100)

eog_average = create_eog_epochs(raw_sss,
                                reject=dict(mag=5e-12, grad=4000e-13),
                                picks=picks_meg).average()

n_max_eog = 1  # here we bet on finding the vertical EOG components
eog_epochs = create_eog_epochs(raw_sss, reject=reject)  # get single EOG trials
eog_inds, scores = ica.find_bads_eog(eog_epochs)  # find via correlation
print(ica)
raw_sss.plot_psd(fmax=100)

#ica.plot_scores(scores, exclude=eog_inds)
#ica.plot_sources(eog_average, exclude=eog_inds)  # look at source time course
def apply_ica_data(fname,raw=None,do_run=False,verbose=False,save=True,fif_extention=".fif",fif_postfix="-ica",**kwargs):
    """
     apply mne ica

      return
        fnica_out  : fif filename of mne ica-obj
        raw        : fif-raw obj
        ICAobj     : mne-ica-object


             Attributes
        ----------
        current_fit : str
            Flag informing about which data type (raw or epochs) was used for
            the fit.
        ch_names : list-like
            Channel names resulting from initial picking.
            The number of components used for ICA decomposition.
        n_components_` : int
            If fit, the actual number of components used for ICA decomposition.
        n_pca_components : int
            See above.
        max_pca_components : int
            The number of components used for PCA dimensionality reduction.
        verbose : bool, str, int, or None
            See above.
        pca_components_` : ndarray
            If fit, the PCA components
        pca_mean_` : ndarray
            If fit, the mean vector used to center the data before doing the PCA.
        pca_explained_variance_` : ndarray
            If fit, the variance explained by each PCA component
        mixing_matrix_` : ndarray
            If fit, the mixing matrix to restore observed data, else None.
        unmixing_matrix_` : ndarray
            If fit, the matrix to unmix observed data, else None.
        exclude : list
            List of sources indices to exclude, i.e. artifact components identified
            throughout the ICA solution. Indices added to this list, will be
            dispatched to the .pick_sources methods. Source indices passed to
            the .pick_sources method via the 'exclude' argument are added to the
            .exclude attribute. When saving the ICA also the indices are restored.
            Hence, artifact components once identified don't have to be added
            again. To dump this 'artifact memory' say: ica.exclude = []
        info : None | instance of mne.io.meas_info.Info
            The measurement info copied from the object fitted.
        n_samples_` : int
            the number of samples used on fit.

    """
    ICAobj = None

    if do_run :
       if raw is None:
          if fname is None:
             print"ERROR no file foumd!!\n"
             return
          raw = mne.io.Raw(fname,preload=True)
          print"\n"

       from mne.preprocessing import ICA
       picks = jumeg_base.pick_meg_nobads(raw)

      #--- init MNE ICA obj

       kwargs['global_parameter']['verbose'] = verbose
       ICAobj = ICA( **kwargs['global_parameter'] )

      #--- run  mne ica
       kwargs['fit_parameter']['verbose'] = verbose
       ICAobj.fit(raw, picks=picks,**kwargs['fit_parameter'] )

       fnica_out = fname[:fname.rfind('-raw.fif')] + fif_postfix + fif_extention
      # fnica_out = fname[0:len(fname)-4]+'-ica.fif'

      #--- save ICA object
       if save :
          ICAobj.save(fnica_out)

    print "===> Done JuMEG MNE ICA : " + fnica_out
    print "\n"


    return (fnica_out,raw,ICAobj)
Beispiel #54
0
def compute_ica(fif_file, ecg_ch_name, eog_ch_name, n_components, reject):
    """Compute ica solution"""

    import os

    import mne
    from mne.io import read_raw_fif
    from mne.preprocessing import ICA
    from mne.preprocessing import create_ecg_epochs, create_eog_epochs

    from nipype.utils.filemanip import split_filename as split_f

    subj_path, basename, ext = split_f(fif_file)

    raw = read_raw_fif(fif_file, preload=True)

    # select sensors
    select_sensors = mne.pick_types(raw.info, meg=True,
                                    ref_meg=False, exclude='bads')

    # 1) Fit ICA model using the FastICA algorithm
    # Other available choices are `infomax` or `extended-infomax`
    # We pass a float value between 0 and 1 to select n_components based on the
    # percentage of variance explained by the PCA components.

    # reject = dict(mag=1e-1, grad=1e-9)
    flat = dict(mag=1e-13, grad=1e-13)

    ica = ICA(n_components=n_components, method='fastica', max_iter=500)

    ica.fit(raw, picks=select_sensors, reject=reject, flat=flat)
    # -------------------- Save ica timeseries ---------------------------- #
    ica_ts_file = os.path.abspath(basename + "_ica-tseries.fif")
    ica_src = ica.get_sources(raw)
    ica_src.save(ica_ts_file)
    ica_src = None
    # --------------------------------------------------------------------- #

    # 2) identify bad components by analyzing latent sources.
    # generate ECG epochs use detection via phase statistics

    # if we just have exclude channels we jump these steps
    n_max_ecg = 3
    n_max_eog = 2

    # check if ecg_ch_name is in the raw channels
    if ecg_ch_name in raw.info['ch_names']:
        raw.set_channel_types({ecg_ch_name: 'ecg'})
    else:
        ecg_ch_name = None

    ecg_epochs = create_ecg_epochs(raw, tmin=-0.5, tmax=0.5,
                                   picks=select_sensors,
                                   ch_name=ecg_ch_name)

    ecg_inds, ecg_scores = ica.find_bads_ecg(ecg_epochs, method='ctps')

    ecg_evoked = ecg_epochs.average()
    ecg_epochs = None

    ecg_inds = ecg_inds[:n_max_ecg]
    ica.exclude += ecg_inds

    eog_ch_name = eog_ch_name.replace(' ', '')
    if set(eog_ch_name.split(',')).issubset(set(raw.info['ch_names'])):
        print('*** EOG CHANNELS FOUND ***')
        eog_inds, eog_scores = ica.find_bads_eog(raw, ch_name=eog_ch_name)
        eog_inds = eog_inds[:n_max_eog]
        ica.exclude += eog_inds
        eog_evoked = create_eog_epochs(raw, tmin=-0.5, tmax=0.5,
                                       picks=select_sensors,
                                       ch_name=eog_ch_name).average()
    else:
        print('*** NO EOG CHANNELS FOUND!!! ***')
        eog_inds = eog_scores = eog_evoked = None

    report_file = generate_report(raw=raw, ica=ica, subj_name=fif_file,
                                  basename=basename,
                                  ecg_evoked=ecg_evoked, ecg_scores=ecg_scores,
                                  ecg_inds=ecg_inds, ecg_ch_name=ecg_ch_name,
                                  eog_evoked=eog_evoked, eog_scores=eog_scores,
                                  eog_inds=eog_inds, eog_ch_name=eog_ch_name)

    report_file = os.path.abspath(report_file)

    ica_sol_file = os.path.abspath(basename + '_ica_solution.fif')

    ica.save(ica_sol_file)
    raw_ica = ica.apply(raw)
    raw_ica_file = os.path.abspath(basename + '_ica' + ext)
    raw_ica.save(raw_ica_file)

    return raw_ica_file, ica_sol_file, ica_ts_file, report_file
Beispiel #55
0
# Plot the clusters

# smica.plot_extended(sort=False)

bad_sources = [17, 18, 19]

X_filtered = smica.filter(
    raw._data[picks], bad_sources=bad_sources, method="wiener"
)
raw_filtered = raw.copy()
raw_filtered._data[picks] = X_filtered

raw.filter(1, 70)
ica = ICA_mne(n_components=20, method="fastica", random_state=0)
ica.fit(raw, picks=picks)

sources = ica.get_sources(raw).get_data()
ica_mne = transfer_to_ica(
    raw, picks, freqs, ica.get_sources(raw).get_data(), ica.get_components()
)

# ica_mne.plot_extended(sources, sort=False)
bads_infomax = [0, 1, 2]
X_ifmx = ica_mne.filter(
    raw._data[picks], bad_sources=bads_infomax, method="pinv"
)
raw_ifmx = raw.copy()
raw_ifmx._data[picks] = X_ifmx
# We identify that clusters 6, 7, 8, 9 correspond to noise
    # --- 2) ICA DECOMPOSITION --------------------------------
    # ICA parameters
    n_components = 25
    method = 'picard'
    fit_params = dict(extended=True, ortho=False)
    # decim = None
    reject = dict(eeg=300e-6)

    # Pick electrodes to use
    picks = pick_types(raw.info, meg=False, eeg=True, eog=False, stim=False)

    # ICA parameters
    ica = ICA(n_components=n_components, method=method, fit_params=fit_params)

    # Fit ICA
    ica.fit(raw.copy().filter(1.0, 40.0), picks=picks, reject=reject)

    # -- 3) save solution -------------------------------------
    # create directory for save
    if not op.exists(op.join(output_path, 'sub-%s' % subj)):
        mkdir(op.join(output_path, 'sub-%s' % subj))

    # save file
    ica.save(op.join(output_path, 'sub-%s' % subj, 'sub-%s-ica.fif' % subj))

    # --- 3) PLOT RESULTING COMPONENTS ------------------------
    # Plot components
    ica_fig = ica.plot_components(picks=range(0, 25), show=True)
    ica_fig.savefig(op.join(output_path, 'sub-%s' % subj, '%s_ica.pdf' % subj))
# Next, you have to specify the ICA algorithm you wish to use. Without going any deeper into differences between algorithms,
# I recommend you choose the 'extended-infomax' algorithm implemented in MNE for (in my experience) very robust results.
method = 'extended-infomax'

# Finally, you can pick a decimation rate for your ICA. This basically entails that you can save time by reducing
# computational effort and precision of the results by only selecting each nth time slice of data. 'decim' represents the
# utilized increment. Please be careful, because higher increments save time but decrease the derived statistics' accuracy.
decim = 3 

# Additionally, you can specify data rejection parameters with the 'reject' argument to avoid the distortion 
# of ica components by large artifacts (i.e., upper body movement, slight head shaking). For EEG, reject has to be
# something along the lines of dict(eeg=100e-6) depending on your threshold of amplitude distortion.
reject = None

ica = ICA(n_components=n_components, method=method)
ica.fit(raw, picks=picks, decim=decim, reject=reject)

# You can now plot ICA component topographies to get an idea of the decomposition.
ica.plot_components()
# This function will give you a list of components, but not their distinct properties. To plot frequency power across
# the sepctrum and other characteristics of a component, use the command below with picks as a list fo component numbers.
ica.plot_properties(raw, picks=[])

# In order to remove components, you have to specify component numbers when back-projecting the decomposition
# to a continuous raw signal. Then, the specified components' signal contribution will be excluded from the data.
ica.apply(raw, exclude=[])


# For the second part, the type of loop you build to read data and save it, depends on your creativity and how your naming
# convention for data sets looks like. Here are two examples that work fine for me.
def run_ica(subject, tsss=config.mf_st_duration):
    print("Processing subject: %s" % subject)

    meg_subject_dir = op.join(config.meg_dir, subject)

    raw_list = list()
    events_list = list()
    print("  Loading raw data")

    for run in config.runs:
        extension = run + '_sss_raw'
        raw_fname_in = op.join(meg_subject_dir,
                               config.base_fname.format(**locals()))
        eve_fname = op.splitext(raw_fname_in)[0] + '-eve.fif'
        print("Input: ", raw_fname_in, eve_fname)

        raw = mne.io.read_raw_fif(raw_fname_in, preload=True)

        events = mne.read_events(eve_fname)
        events_list.append(events)

        # XXX mark bads from any run – is it a problem for ICA
        # if we just exclude the bads shared by all runs ?
        if run:
            bads = set(chain(*config.bads[subject].values()))
        else:
            bads = config.bads[subject]

        raw.info['bads'] = bads
        print("added bads: ", raw.info['bads'])

        raw_list.append(raw)

    print('  Concatenating runs')
    raw, events = mne.concatenate_raws(raw_list, events_list=events_list)
    raw.set_eeg_reference(projection=True)
    del raw_list

    # produce high-pass filtered version of the data for ICA
    epochs_for_ica = mne.Epochs(raw.copy().filter(l_freq=1., h_freq=None),
                                events,
                                config.event_id,
                                config.tmin,
                                config.tmax,
                                proj=True,
                                baseline=config.baseline,
                                preload=True,
                                decim=config.decim,
                                reject=config.reject)

    # run ICA on MEG and EEG
    picks_meg = mne.pick_types(epochs_for_ica.info,
                               meg=True,
                               eeg=False,
                               eog=False,
                               stim=False,
                               exclude='bads')
    picks_eeg = mne.pick_types(epochs_for_ica.info,
                               meg=False,
                               eeg=True,
                               eog=False,
                               stim=False,
                               exclude='bads')
    all_picks = {'meg': picks_meg, 'eeg': picks_eeg}

    # get number of components for ICA
    # compute_rank requires 0.18
    # n_components_meg = (mne.compute_rank(epochs_for_ica.copy()
    #                        .pick_types(meg=True)))['meg']

    n_components_meg = 0.999

    n_components = {'meg': n_components_meg, 'eeg': 0.999}

    for ch_type in ['meg', 'eeg']:
        print('Running ICA for ' + ch_type)

        ica = ICA(method='fastica',
                  random_state=config.random_state,
                  n_components=n_components[ch_type])

        picks = all_picks[ch_type]

        ica.fit(epochs_for_ica, picks=picks, decim=decim)

        print(
            '  Fit %d components (explaining at least %0.1f%% of the variance)'
            % (ica.n_components_, 100 * n_components[ch_type]))

        ica_name = op.join(
            meg_subject_dir,
            '{0}_{1}_{2}-ica.fif'.format(subject, config.study_name, ch_type))
        ica.save(ica_name)

        if config.plot:
            # plot ICA components to html report
            from mne.report import Report
            report_name = op.join(
                meg_subject_dir,
                '{0}_{1}_{2}-ica.html'.format(subject, config.study_name,
                                              ch_type))
            report = Report(report_name, verbose=False)

            for figure in ica.plot_properties(epochs_for_ica,
                                              picks=list(
                                                  range(0, ica.n_components_)),
                                              psd_args={'fmax': 60},
                                              show=False):

                report.add_figs_to_section(figure,
                                           section=subject,
                                           captions=(ch_type.upper() +
                                                     ' - ICA Components'))

                # XXX how to close each figure within the loop to avoid
                # runtime error: > 20 figures opened

            report.save(report_name, overwrite=True, open_browser=False)
Beispiel #59
0
def test_ica_additional():
    """Test additional ICA functionality"""
    tempdir = _TempDir()
    stop2 = 500
    raw = Raw(raw_fname).crop(1.5, stop, False)
    raw.load_data()
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=True)
    # test if n_components=None works
    with warnings.catch_warnings(record=True):
        ica = ICA(n_components=None,
                  max_pca_components=None,
                  n_pca_components=None, random_state=0)
        ica.fit(epochs, picks=picks, decim=3)
    # for testing eog functionality
    picks2 = pick_types(raw.info, meg=True, stim=False, ecg=False,
                        eog=True, exclude='bads')
    epochs_eog = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks2,
                        baseline=(None, 0), preload=True)

    test_cov2 = test_cov.copy()
    ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_true(ica.info is None)
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks[:5])
    assert_true(isinstance(ica.info, Info))
    assert_true(ica.n_components_ < 5)

    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    assert_raises(RuntimeError, ica.save, '')
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)

    # test corrmap
    ica2 = ica.copy()
    corrmap([ica, ica2], (0, 0), threshold='auto', label='blinks', plot=True,
            ch_type="mag")
    corrmap([ica, ica2], (0, 0), threshold=2, plot=False, show=False)
    assert_true(ica.labels_["blinks"] == ica2.labels_["blinks"])
    assert_true(0 in ica.labels_["blinks"])
    plt.close('all')

    # test warnings on bad filenames
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
        ica.save(ica_badname)
        read_ica(ica_badname)
    assert_naming(w, 'test_ica.py', 2)

    # test decim
    ica = ICA(n_components=3, max_pca_components=4,
              n_pca_components=4)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=None, decim=3)
    assert_true(raw_._data.shape[1], n_samples)

    # test expl var
    ica = ICA(n_components=1.0, max_pca_components=4,
              n_pca_components=4)
    with warnings.catch_warnings(record=True):
        ica.fit(raw, picks=None, decim=3)
    assert_true(ica.n_components_ == 4)

    # epochs extraction from raw fit
    assert_raises(RuntimeError, ica.get_sources, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov, n_components=2, max_pca_components=4,
                  n_pca_components=4)
        with warnings.catch_warnings(record=True):  # ICA does not converge
            ica.fit(raw, picks=picks, start=start, stop=stop2)
        sources = ica.get_sources(epochs).get_data()
        assert_true(ica.mixing_matrix_.shape == (2, 2))
        assert_true(ica.unmixing_matrix_.shape == (2, 2))
        assert_true(ica.pca_components_.shape == (4, len(picks)))
        assert_true(sources.shape[1] == ica.n_components_)

        for exclude in [[], [0]]:
            ica.exclude = [0]
            ica.labels_ = {'foo': [0]}
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert_true(ica.exclude == ica_read.exclude)
            assert_equal(ica.labels_, ica_read.labels_)
            ica.exclude = []
            ica.apply(raw, exclude=[1])
            assert_true(ica.exclude == [])

            ica.exclude = [0, 1]
            ica.apply(raw, exclude=[1])
            assert_true(ica.exclude == [0, 1])

            ica_raw = ica.get_sources(raw)
            assert_true(ica.exclude == [ica_raw.ch_names.index(e) for e in
                                        ica_raw.info['bads']])

        # test filtering
        d1 = ica_raw._data[0].copy()
        with warnings.catch_warnings(record=True):  # dB warning
            ica_raw.filter(4, 20)
        assert_true((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        with warnings.catch_warnings(record=True):  # dB warning
            ica_raw.notch_filter([10])
        assert_true((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert_true(ica.n_pca_components == ica_read.n_pca_components)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ _pre_whitener')

        def f(x, y):
            return getattr(x, y).dtype

        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in ['mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                     'pca_mean_', 'pca_explained_variance_',
                     '_pre_whitener']:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert_true(ica.ch_names == ica_read.ch_names)
        assert_true(isinstance(ica_read.info, Info))

        sources = ica.get_sources(raw)[:, :][0]
        sources2 = ica_read.get_sources(raw)[:, :][0]
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.apply(raw, exclude=[1])
        _raw2 = ica_read.apply(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check scrore funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(raw, target='EOG 061', score_func=func,
                                   start=0, stop=10)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(raw, score_func=stats.skew)
    # check exception handling
    assert_raises(ValueError, ica.score_sources, raw,
                  target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # varicance, kurtosis idx params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw, start_find=0, stop_find=50, ecg_ch=ch_name,
                             eog_ch=ch_name, skew_criterion=idx,
                             var_criterion=idx, kurt_criterion=idx)
    with warnings.catch_warnings(record=True):
        idx, scores = ica.find_bads_ecg(raw, method='ctps')
        assert_equal(len(scores), ica.n_components_)
        idx, scores = ica.find_bads_ecg(raw, method='correlation')
        assert_equal(len(scores), ica.n_components_)
        idx, scores = ica.find_bads_ecg(epochs, method='ctps')
        assert_equal(len(scores), ica.n_components_)
        assert_raises(ValueError, ica.find_bads_ecg, epochs.average(),
                      method='ctps')
        assert_raises(ValueError, ica.find_bads_ecg, raw,
                      method='crazy-coupling')

        idx, scores = ica.find_bads_eog(raw)
        assert_equal(len(scores), ica.n_components_)
        raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
        idx, scores = ica.find_bads_eog(raw)
        assert_true(isinstance(scores, list))
        assert_equal(len(scores[0]), ica.n_components_)

    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(epochs_eog, target='EOG 061',
                                   score_func=func)
        assert_true(ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(epochs, score_func=stats.skew)

    # check exception handling
    assert_raises(ValueError, ica.score_sources, epochs,
                  target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.score_sources(raw, target='MEG 1531',
                                   score_func='pearsonr')

    with warnings.catch_warnings(record=True):  # filter attenuation warning
        ecg_events = ica_find_ecg_events(raw,
                                         sources[np.abs(ecg_scores).argmax()])

    assert_true(ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.score_sources(raw, target='EOG 061',
                                   score_func='pearsonr')
    with warnings.catch_warnings(record=True):  # filter attenuation warning
        eog_events = ica_find_eog_events(raw,
                                         sources[np.abs(eog_scores).argmax()])

    assert_true(eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.get_sources(raw, start=0, stop=100)
    assert_true(ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_true(len(ica_raw._filenames) == 0)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = Raw(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.get_sources(epochs)
    assert_true(ica_epochs.events.shape == epochs.events.shape)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert_true(ica.n_components_ == len(ica_chans))
    assert_true(ica.n_components_ == ica_epochs.get_data().shape[1])
    assert_true(ica_epochs._raw is None)
    assert_true(ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = ica._check_n_pca_components(ncomps)
        assert_true(ncomps_ == expected)
Beispiel #60
0
def main():

    #################################################
    ## SETUP

    ## Get list of subject files
    subj_files = listdir(DAT_PATH)
    subj_files = [file for file in subj_files if EXT.lower() in file.lower()]

    ## Set up FOOOF Objects
    # Initialize FOOOF settings & objects objects
    fooof_settings = FOOOFSettings(peak_width_limits=PEAK_WIDTH_LIMITS, max_n_peaks=MAX_N_PEAKS,
                                   min_peak_amplitude=MIN_PEAK_AMP, peak_threshold=PEAK_THRESHOLD,
                                   aperiodic_mode=APERIODIC_MODE)
    fm = FOOOF(*fooof_settings, verbose=False)
    fg = FOOOFGroup(*fooof_settings, verbose=False)

    # Save out a settings file
    fg.save('0-FOOOF_Settings', pjoin(RES_PATH, 'FOOOF'), save_settings=True)

    # Set up the dictionary to store all the FOOOF results
    fg_dict = dict()
    for load_label in LOAD_LABELS:
        fg_dict[load_label] = dict()
        for side_label in SIDE_LABELS:
            fg_dict[load_label][side_label] = dict()
            for seg_label in SEG_LABELS:
                fg_dict[load_label][side_label][seg_label] = []

    ## Initialize group level data stores
    n_subjs, n_conds, n_times = len(subj_files), 3, N_TIMES
    group_fooofed_alpha_freqs = np.zeros(shape=[n_subjs])
    dropped_components = np.ones(shape=[n_subjs, 50]) * 999
    dropped_trials = np.ones(shape=[n_subjs, 1500]) * 999
    canonical_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])
    fooofed_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])

    # Set channel types
    ch_types = {'LHor' : 'eog', 'RHor' : 'eog', 'IVer' : 'eog', 'SVer' : 'eog',
                'LMas' : 'misc', 'RMas' : 'misc', 'Nose' : 'misc', 'EXG8' : 'misc'}

    #################################################
    ## RUN ACROSS ALL SUBJECTS

    # Run analysis across each subject
    for s_ind, subj_file in enumerate(subj_files):

        # Get subject label and print status
        subj_label = subj_file.split('.')[0]
        print('\nCURRENTLY RUNNING SUBJECT: ', subj_label, '\n')

        #################################################
        ## LOAD / ORGANIZE / SET-UP DATA

        # Load subject of data, apply apply fixes for channels, etc
        eeg_dat = mne.io.read_raw_edf(pjoin(DAT_PATH, subj_file),
                                      preload=True, verbose=False)

        # Fix channel name labels
        eeg_dat.info['ch_names'] = [chl[2:] for chl in \
            eeg_dat.ch_names[:-1]] + [eeg_dat.ch_names[-1]]
        for ind, chi in enumerate(eeg_dat.info['chs']):
            eeg_dat.info['chs'][ind]['ch_name'] = eeg_dat.info['ch_names'][ind]

        # Update channel types
        eeg_dat.set_channel_types(ch_types)

        # Set reference - average reference
        eeg_dat = eeg_dat.set_eeg_reference(ref_channels='average',
                                            projection=False, verbose=False)

        # Set channel montage
        chs = mne.channels.read_montage('standard_1020', eeg_dat.ch_names)
        eeg_dat.set_montage(chs)

        # Get event information & check all used event codes
        evs = mne.find_events(eeg_dat, shortest_event=1, verbose=False)

        # Pull out sampling rate
        srate = eeg_dat.info['sfreq']

        #################################################
        ## Pre-Processing: ICA

        # High-pass filter data for running ICA
        eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

        if RUN_ICA:

            print("\nICA: CALCULATING SOLUTION\n")

            # ICA settings
            method = 'fastica'
            n_components = 0.99
            random_state = 47
            reject = {'eeg': 20e-4}

            # Initialize ICA object
            ica = ICA(n_components=n_components, method=method,
                      random_state=random_state)

            # Fit ICA
            ica.fit(eeg_dat, reject=reject)

            # Save out ICA solution
            ica.save(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Otherwise: load previously saved ICA to apply
        else:
            print("\nICA: USING PRECOMPUTED\n")
            ica = read_ica(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Find components to drop, based on correlation with EOG channels
        drop_inds = []
        for chi in EOG_CHS:
            inds, _ = ica.find_bads_eog(eeg_dat, ch_name=chi, threshold=2.5,
                                             l_freq=1, h_freq=10, verbose=False)
            drop_inds.extend(inds)
        drop_inds = list(set(drop_inds))

        # Set which components to drop, and collect record of this
        ica.exclude = drop_inds
        dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

        # Apply ICA to data
        eeg_dat = ica.apply(eeg_dat)

        #################################################
        ## SORT OUT EVENT CODES

        # Extract a list of all the event labels
        all_trials = [it for it2 in EV_DICT.values() for it in it2]

        # Create list of new event codes to be used to label correct trials (300s)
        all_trials_new = [it + 100 for it in all_trials]
        # This is an annoying way to collapse across the doubled event markers from above
        all_trials_new = [it - 1 if not ind%2 == 0 else it for ind, it in enumerate(all_trials_new)]
        # Get labelled dictionary of new event names
        ev_dict2 = {k:v for k, v in zip(EV_DICT.keys(), set(all_trials_new))}

        # Initialize variables to store new event definitions
        evs2 = np.empty(shape=[0, 3], dtype='int64')
        lags = np.array([])

        # Loop through, creating new events for all correct trials
        t_min, t_max = -0.4, 3.0
        for ref_id, targ_id, new_id in zip(all_trials, CORR_CODES * 6, all_trials_new):

            t_evs, t_lags = mne.event.define_target_events(evs, ref_id, targ_id, srate,
                                                           t_min, t_max, new_id)

            if len(t_evs) > 0:
                evs2 = np.vstack([evs2, t_evs])
                lags = np.concatenate([lags, t_lags])

        #################################################
        ## FOOOF

        # Set channel of interest
        ch_ind = eeg_dat.ch_names.index(CHL)

        # Calculate PSDs over ~ first 2 minutes of data, for specified channel
        fmin, fmax = 1, 50
        tmin, tmax = 5, 125
        psds, freqs = mne.time_frequency.psd_welch(eeg_dat, fmin=fmin, fmax=fmax,
                                                   tmin=tmin, tmax=tmax,
                                                   n_fft=int(2*srate), n_overlap=int(srate),
                                                   n_per_seg=int(2*srate),
                                                   verbose=False)

        # Fit FOOOF across all channels
        fg.fit(freqs, psds, FREQ_RANGE, n_jobs=-1)

        # Save out FOOOF results
        fg.save(subj_label + '_fooof', pjoin(RES_PATH, 'FOOOF'), save_results=True)

        # Extract individualized CF from specified channel, add to group collection
        fm = fg.get_fooof(ch_ind, False)
        fooof_freq, _, _ = get_band_peak(fm.peak_params_, [7, 14])
        group_fooofed_alpha_freqs[s_ind] = fooof_freq

        # If not FOOOF alpha extracted, reset to 10
        if np.isnan(fooof_freq):
            fooof_freq = 10

        #################################################
        ## ALPHA FILTERING

        # CANONICAL: Filter data to canonical alpha band: 8-12 Hz
        alpha_dat = eeg_dat.copy()
        alpha_dat.filter(8, 12, fir_design='firwin', verbose=False)
        alpha_dat.apply_hilbert(envelope=True, verbose=False)

        # FOOOF: Filter data to FOOOF derived alpha band
        fooof_dat = eeg_dat.copy()
        fooof_dat.filter(fooof_freq-2, fooof_freq+2, fir_design='firwin')
        fooof_dat.apply_hilbert(envelope=True)

        #################################################
        ## EPOCH TRIALS

        # Set epoch timings
        tmin, tmax = -0.85, 1.1

        # Epoch trials - raw data for trial rejection
        epochs = mne.Epochs(eeg_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                            baseline=None, preload=True, verbose=False)

        # Epoch trials - filtered version
        epochs_alpha = mne.Epochs(alpha_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)
        epochs_fooof = mne.Epochs(fooof_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)

        #################################################
        ## PRE-PROCESSING: AUTO-REJECT
        if RUN_AUTOREJECT:

            print('\nAUTOREJECT: CALCULATING SOLUTION\n')

            # Initialize and run autoreject across epochs
            ar = AutoReject(n_jobs=4, verbose=False)
            ar.fit(epochs)

            # Save out AR solution
            ar.save(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'), overwrite=True)

        # Otherwise: load & apply previously saved AR solution
        else:
            print('\nAUTOREJECT: USING PRECOMPUTED\n')
            ar = read_auto_reject(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'))
            ar.verbose = 'tqdm'

        # Apply autoreject to the original epochs object it was learnt on
        epochs, rej_log = ar.transform(epochs, return_log=True)

        # Apply autoreject to the copies of the data - apply interpolation, then drop same epochs
        _apply_interp(rej_log, epochs_alpha, ar.threshes_, ar.picks_, ar.verbose)
        epochs_alpha.drop(rej_log.bad_epochs)
        _apply_interp(rej_log, epochs_fooof, ar.threshes_, ar.picks_, ar.verbose)
        epochs_fooof.drop(rej_log.bad_epochs)

        # Collect which epochs were dropped
        dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(rej_log.bad_epochs)[0]

        #################################################
        ## SET UP CHANNEL CLUSTERS

        # Set channel clusters - take channels contralateral to stimulus presentation
        #  Note: channels will be used to extract data contralateral to stimulus presentation
        le_chs = ['P3', 'P5', 'P7', 'P9', 'O1', 'PO3', 'PO7']       # Left Side Channels
        le_inds = [epochs.ch_names.index(chn) for chn in le_chs]
        ri_chs = ['P4', 'P6', 'P8', 'P10', 'O2', 'PO4', 'PO8']      # Right Side Channels
        ri_inds = [epochs.ch_names.index(chn) for chn in ri_chs]

        #################################################
        ## TRIAL-RELATED ANALYSIS: CANONICAL vs. FOOOF

        ## Pull out channels of interest for each load level
        #  Channels extracted are those contralateral to stimulus presentation

        # Canonical Data
        lo1_a = np.concatenate([epochs_alpha['LeLo1']._data[:, ri_inds, :],
                                epochs_alpha['RiLo1']._data[:, le_inds, :]], 0)
        lo2_a = np.concatenate([epochs_alpha['LeLo2']._data[:, ri_inds, :],
                                epochs_alpha['RiLo2']._data[:, le_inds, :]], 0)
        lo3_a = np.concatenate([epochs_alpha['LeLo3']._data[:, ri_inds, :],
                                epochs_alpha['RiLo3']._data[:, le_inds, :]], 0)

        # FOOOFed data
        lo1_f = np.concatenate([epochs_fooof['LeLo1']._data[:, ri_inds, :],
                                epochs_fooof['RiLo1']._data[:, le_inds, :]], 0)
        lo2_f = np.concatenate([epochs_fooof['LeLo2']._data[:, ri_inds, :],
                                epochs_fooof['RiLo2']._data[:, le_inds, :]], 0)
        lo3_f = np.concatenate([epochs_fooof['LeLo3']._data[:, ri_inds, :],
                                epochs_fooof['RiLo3']._data[:, le_inds, :]], 0)

        ## Calculate average across trials and channels - add to group data collection

        # Canonical data
        canonical_group_avg_dat[s_ind, 0, :] = np.mean(lo1_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 1, :] = np.mean(lo2_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 2, :] = np.mean(lo3_a, 1).mean(0)

        # FOOOFed data
        fooofed_group_avg_dat[s_ind, 0, :] = np.mean(lo1_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 1, :] = np.mean(lo2_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 2, :] = np.mean(lo3_f, 1).mean(0)

        #################################################
        ## FOOOFING TRIAL AVERAGED DATA

        # Loop loop loads & trials segments
        for seg_label, seg_time in zip(SEG_LABELS, SEG_TIMES):
            tmin, tmax = seg_time[0], seg_time[1]

            # Calculate PSDs across trials, fit FOOOF models to averages
            for le_label, ri_label, load_label in zip(['LeLo1', 'LeLo2', 'LeLo3'],
                                                      ['RiLo1', 'RiLo2', 'RiLo3'],
                                                      LOAD_LABELS):

                ## Calculate trial wise PSDs for left & right side trials
                trial_freqs, le_trial_psds = periodogram(
                    epochs[le_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)
                trial_freqs, ri_trial_psds = periodogram(
                    epochs[ri_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)

                ## FIT ALL CHANNELS VERSION
                if FIT_ALL_CHANNELS:

                    ## Average spectra across trials within a given load & side
                    le_avg_psd_contra = avg_func(le_trial_psds[:, ri_inds, :], 0)
                    le_avg_psd_ipsi = avg_func(le_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_contra = avg_func(ri_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_ipsi = avg_func(ri_trial_psds[:, ri_inds, :], 0)

                    ## Combine spectra across left & right trials for given load
                    ch_psd_contra = np.vstack([le_avg_psd_contra, ri_avg_psd_contra])
                    ch_psd_ipsi = np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi])

                    ## Fit FOOOFGroup to all channels, average & and collect results
                    fg.fit(trial_freqs, ch_psd_contra, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fg.fit(trial_freqs, ch_psd_ipsi, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

                ## COLLAPSE ACROSS CHANNELS VERSION
                else:

                    ## Average spectra across trials and channels within a given load & side
                    le_avg_psd_contra = avg_func(avg_func(le_trial_psds[:, ri_inds, :], 0), 0)
                    le_avg_psd_ipsi = avg_func(avg_func(le_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_contra = avg_func(avg_func(ri_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_ipsi = avg_func(avg_func(ri_trial_psds[:, ri_inds, :], 0), 0)

                    ## Collapse spectra across left & right trials for given load
                    avg_psd_contra = avg_func(np.vstack([le_avg_psd_contra, ri_avg_psd_contra]), 0)
                    avg_psd_ipsi = avg_func(np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]), 0)

                    ## Fit FOOOF, and collect results
                    fm.fit(trial_freqs, avg_psd_contra, FREQ_RANGE)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fm.fit(trial_freqs, avg_psd_ipsi, FREQ_RANGE)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

    #################################################
    ## SAVE OUT RESULTS

    # Save out group data
    np.save(pjoin(RES_PATH, 'Group', 'alpha_freqs_group'), group_fooofed_alpha_freqs)
    np.save(pjoin(RES_PATH, 'Group', 'canonical_group'), canonical_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'fooofed_group'), fooofed_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_trials'), dropped_trials)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_components'), dropped_components)

    # Save out second round of FOOOFing
    for load_label in LOAD_LABELS:
        for side_label in SIDE_LABELS:
            for seg_label in SEG_LABELS:
                fg = combine_fooofs(fg_dict[load_label][side_label][seg_label])
                fg.save('Group_' + load_label + '_' + side_label + '_' + seg_label,
                        pjoin(RES_PATH, 'FOOOF'), save_results=True)