예제 #1
0
def test_XdawnTransformer():
    """Test _XdawnTransformer."""
    # Get data
    raw, events, picks = _get_data()
    raw.del_proj()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False)
    X = epochs._data
    y = epochs.events[:, -1]
    # Fit
    xdt = _XdawnTransformer()
    xdt.fit(X, y)
    pytest.raises(ValueError, xdt.fit, X, y[1:])
    pytest.raises(ValueError, xdt.fit, 'foo')

    # Provide covariance object
    signal_cov = compute_raw_covariance(raw, picks=picks)
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    xdt.fit(X, y)
    # Provide ndarray
    signal_cov = np.eye(len(picks))
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    xdt.fit(X, y)
    # Provide ndarray of bad shape
    signal_cov = np.eye(len(picks) - 1)
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    pytest.raises(ValueError, xdt.fit, X, y)
    # Provide another type
    signal_cov = 42
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    pytest.raises(ValueError, xdt.fit, X, y)

    # Fit with y as None
    xdt = _XdawnTransformer()
    xdt.fit(X)

    # Compare xdawn and _XdawnTransformer
    xd = Xdawn(correct_overlap=False)
    xd.fit(epochs)

    xdt = _XdawnTransformer()
    xdt.fit(X, y)
    assert_array_almost_equal(xd.filters_['cond2'][:2, :],
                              xdt.filters_.reshape(2, 2, 8)[0])

    # Transform testing
    xdt.transform(X[1:, ...])  # different number of epochs
    xdt.transform(X[:, :, 1:])  # different number of time
    pytest.raises(ValueError, xdt.transform, X[:, 1:, :])
    Xt = xdt.transform(X)
    pytest.raises(ValueError, xdt.transform, 42)

    # Inverse transform testing
    Xinv = xdt.inverse_transform(Xt)
    assert Xinv.shape == X.shape
    xdt.inverse_transform(Xt[1:, ...])
    xdt.inverse_transform(Xt[:, :, 1:])
    # should raise an error if not correct number of components
    pytest.raises(ValueError, xdt.inverse_transform, Xt[:, 1:, :])
    pytest.raises(ValueError, xdt.inverse_transform, 42)
예제 #2
0
def test_XdawnTransformer():
    """Test _XdawnTransformer."""
    # Get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False,
                    add_eeg_ref=False)
    X = epochs._data
    y = epochs.events[:, -1]
    # Fit
    xdt = _XdawnTransformer()
    xdt.fit(X, y)
    assert_raises(ValueError, xdt.fit, X, y[1:])
    assert_raises(ValueError, xdt.fit, 'foo')

    # Provide covariance object
    signal_cov = compute_raw_covariance(raw, picks=picks)
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    xdt.fit(X, y)
    # Provide ndarray
    signal_cov = np.eye(len(picks))
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    xdt.fit(X, y)
    # Provide ndarray of bad shape
    signal_cov = np.eye(len(picks) - 1)
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    assert_raises(ValueError, xdt.fit, X, y)
    # Provide another type
    signal_cov = 42
    xdt = _XdawnTransformer(signal_cov=signal_cov)
    assert_raises(ValueError, xdt.fit, X, y)

    # Fit with y as None
    xdt = _XdawnTransformer()
    xdt.fit(X)

    # Compare xdawn and _XdawnTransformer
    xd = Xdawn(correct_overlap=False)
    xd.fit(epochs)

    xdt = _XdawnTransformer()
    xdt.fit(X, y)
    assert_array_almost_equal(xd.filters_['cond2'][:, :2],
                              xdt.filters_.reshape(2, 2, 8)[0].T)

    # Transform testing
    xdt.transform(X[1:, ...])  # different number of epochs
    xdt.transform(X[:, :, 1:])  # different number of time
    assert_raises(ValueError, xdt.transform, X[:, 1:, :])
    Xt = xdt.transform(X)
    assert_raises(ValueError, xdt.transform, 42)

    # Inverse transform testing
    Xinv = xdt.inverse_transform(Xt)
    assert_equal(Xinv.shape, X.shape)
    xdt.inverse_transform(Xt[1:, ...])
    xdt.inverse_transform(Xt[:, :, 1:])
    # should raise an error if not correct number of components
    assert_raises(ValueError, xdt.inverse_transform, Xt[:, 1:, :])
    assert_raises(ValueError, xdt.inverse_transform, 42)
예제 #3
0
def test_xdawn_picks():
    """Test picking with Xdawn."""
    data = np.random.RandomState(0).randn(10, 2, 10)
    info = create_info(2, 1000., ('eeg', 'misc'))
    epochs = EpochsArray(data, info)
    xd = Xdawn(correct_overlap=False)
    xd.fit(epochs)
    epochs_out = xd.apply(epochs)['1']
    assert epochs_out.info['ch_names'] == epochs.ch_names
    assert not (epochs_out.get_data()[:, 0] != data[:, 0]).any()
    assert_array_equal(epochs_out.get_data()[:, 1], data[:, 1])
예제 #4
0
def test_xdawn_picks():
    """Test picking with Xdawn."""
    data = np.random.RandomState(0).randn(10, 2, 10)
    info = create_info(2, 1000., ('eeg', 'misc'))
    epochs = EpochsArray(data, info)
    xd = Xdawn(correct_overlap=False)
    xd.fit(epochs)
    epochs_out = xd.apply(epochs)['1']
    assert epochs_out.info['ch_names'] == epochs.ch_names
    assert not (epochs_out.get_data()[:, 0] != data[:, 0]).any()
    assert_array_equal(epochs_out.get_data()[:, 1], data[:, 1])
예제 #5
0
def test_xdawn_regularization():
    """Test Xdawn with regularization."""
    # Get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False)

    # Test with overlapping events.
    # modify events to simulate one overlap
    events = epochs.events
    sel = np.where(events[:, 2] == 2)[0][:2]
    modified_event = events[sel[0]]
    modified_event[0] += 1
    epochs.events[sel[1]] = modified_event
    # Fit and check that overlap was found and applied
    xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas')
    xd.fit(epochs)
    assert_equal(xd.correct_overlap_, True)
    evoked = epochs['cond2'].average()
    assert_true(np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data)))

    # With covariance regularization
    for reg in [.1, 0.1, 'ledoit_wolf', 'oas']:
        xd = Xdawn(n_components=2, correct_overlap=False,
                   signal_cov=np.eye(len(picks)), reg=reg)
        xd.fit(epochs)
    # With bad shrinkage
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=np.eye(len(picks)), reg=2)
    assert_raises(ValueError, xd.fit, epochs)
예제 #6
0
def test_xdawn_regularization():
    """Test Xdawn with regularization."""
    # Get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=None,
                    verbose=False,
                    add_eeg_ref=False)

    # Test with overlapping events.
    # modify events to simulate one overlap
    events = epochs.events
    sel = np.where(events[:, 2] == 2)[0][:2]
    modified_event = events[sel[0]]
    modified_event[0] += 1
    epochs.events[sel[1]] = modified_event
    # Fit and check that overlap was found and applied
    xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas')
    xd.fit(epochs)
    assert_equal(xd.correct_overlap_, True)
    evoked = epochs['cond2'].average()
    assert_true(np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data)))

    # With covariance regularization
    for reg in [.1, 0.1, 'ledoit_wolf', 'oas']:
        xd = Xdawn(n_components=2,
                   correct_overlap=False,
                   signal_cov=np.eye(len(picks)),
                   reg=reg)
        xd.fit(epochs)
    # With bad shrinkage
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=np.eye(len(picks)),
               reg=2)
    assert_raises(ValueError, xd.fit, epochs)
예제 #7
0
def test_xdawn_decoding_performance():
    """Test decoding performance and extracted pattern on synthetic data."""
    from sklearn.model_selection import KFold
    from sklearn.pipeline import make_pipeline
    from sklearn.linear_model import LogisticRegression
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.metrics import accuracy_score

    n_xdawn_comps = 3
    expected_accuracy = 0.98

    epochs, mixing_mat = _simulate_erplike_mixed_data(n_epochs=100)
    y = epochs.events[:, 2]

    # results of Xdawn and _XdawnTransformer should match
    xdawn_pipe = make_pipeline(
        Xdawn(n_components=n_xdawn_comps),
        Vectorizer(),
        MinMaxScaler(),
        LogisticRegression(solver='liblinear'))
    xdawn_trans_pipe = make_pipeline(
        _XdawnTransformer(n_components=n_xdawn_comps),
        Vectorizer(),
        MinMaxScaler(),
        LogisticRegression(solver='liblinear'))

    cv = KFold(n_splits=3, shuffle=False)
    for pipe, X in (
            (xdawn_pipe, epochs),
            (xdawn_trans_pipe, epochs.get_data())):
        predictions = np.empty_like(y, dtype=float)
        for train, test in cv.split(X, y):
            pipe.fit(X[train], y[train])
            predictions[test] = pipe.predict(X[test])

        cv_accuracy_xdawn = accuracy_score(y, predictions)
        assert_allclose(cv_accuracy_xdawn, expected_accuracy, atol=0.01)

        # for both event types, the first component should "match" the mixing
        fitted_xdawn = pipe.steps[0][1]
        if isinstance(fitted_xdawn, Xdawn):
            relev_patterns = np.concatenate(
                [comps[[0]] for comps in fitted_xdawn.patterns_.values()])
        else:
            relev_patterns = fitted_xdawn.patterns_[::n_xdawn_comps]

        for i in range(len(relev_patterns)):
            r, _ = stats.pearsonr(relev_patterns[i, :], mixing_mat[0, :])
            assert np.abs(r) > 0.99
예제 #8
0
def test_xdawn_apply_transform():
    """Test Xdawn apply and transform."""
    # Get data
    raw, events, picks = _get_data()
    raw.pick_types(eeg=True, meg=False)
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    proj=False,
                    preload=True,
                    baseline=None,
                    verbose=False)
    n_components = 2
    # Fit Xdawn
    xd = Xdawn(n_components=n_components, correct_overlap=False)
    xd.fit(epochs)

    # Apply on different types of instances
    for inst in [raw, epochs.average(), epochs]:
        denoise = xd.apply(inst)
    # Apply on other thing should raise an error
    pytest.raises(ValueError, xd.apply, 42)

    # Transform on Epochs
    xd.transform(epochs)
    # Transform on Evoked
    xd.transform(epochs.average())
    # Transform on ndarray
    xd.transform(epochs._data)
    xd.transform(epochs._data[0])
    # Transform on something else
    pytest.raises(ValueError, xd.transform, 42)

    # Check numerical results with shuffled epochs
    np.random.seed(0)  # random makes unstable linalg
    idx = np.arange(len(epochs))
    np.random.shuffle(idx)
    xd.fit(epochs[idx])
    denoise_shfl = xd.apply(epochs)
    assert_array_almost_equal(denoise['cond2']._data,
                              denoise_shfl['cond2']._data)
예제 #9
0
def test_xdawn_apply_transform():
    """Test Xdawn apply and transform."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False)
    n_components = 2
    # Fit Xdawn
    xd = Xdawn(n_components=n_components, correct_overlap='auto')
    xd.fit(epochs)

    # apply on raw
    xd.apply(raw)
    # apply on epochs
    xd.apply(epochs)
    # apply on evoked
    xd.apply(epochs.average())
    # apply on other thing should raise an error
    assert_raises(ValueError, xd.apply, 42)

    # transform on epochs
    xd.transform(epochs)
    # transform on ndarray
    xd.transform(epochs._data)
    # transform on someting else
    assert_raises(ValueError, xd.transform, 42)
예제 #10
0
def test_xdawn_fit():
    """Test Xdawn fit."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False)
    # =========== Basic Fit test =================
    # test base xdawn
    xd = Xdawn(n_components=2, correct_overlap='auto',
               signal_cov=None, reg=None)
    xd.fit(epochs)
    # with this parameters, the overlapp correction must be False
    assert_equal(xd.correct_overlap, False)
    # no overlapp correction should give averaged evoked
    evoked = epochs['cond2'].average()
    assert_array_equal(evoked.data, xd.evokeds_['cond2'].data)

    # ========== with signal cov provided ====================
    # provide covariance object
    signal_cov = compute_raw_covariance(raw, picks=picks)
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=signal_cov, reg=None)
    xd.fit(epochs)
    # provide ndarray
    signal_cov = np.eye(len(picks))
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=signal_cov, reg=None)
    xd.fit(epochs)
    # provide ndarray of bad shape
    signal_cov = np.eye(len(picks) - 1)
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=signal_cov, reg=None)
    assert_raises(ValueError, xd.fit, epochs)
    # provide another type
    signal_cov = 42
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=signal_cov, reg=None)
    assert_raises(ValueError, xd.fit, epochs)
    # fit with baseline correction and ovverlapp correction should throw an
    # error
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=(None, 0), verbose=False)

    xd = Xdawn(n_components=2, correct_overlap=True)
    assert_raises(ValueError, xd.fit, epochs)
예제 #11
0
def test_xdawn_regularization():
    """Test Xdawn with regularization."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False)

    # test xdawn with overlap correction
    xd = Xdawn(n_components=2, correct_overlap=True,
               signal_cov=None, reg=0.1)
    xd.fit(epochs)
    # ========== with cov regularization ====================
    # ledoit-wolf
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=np.eye(len(picks)), reg='ledoit_wolf')
    xd.fit(epochs)
    # oas
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=np.eye(len(picks)), reg='oas')
    xd.fit(epochs)
    # with shrinkage
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=np.eye(len(picks)), reg=0.1)
    xd.fit(epochs)
    # with bad shrinkage
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=np.eye(len(picks)), reg=2)
    assert_raises(ValueError, xd.fit, epochs)
예제 #12
0
def test_xdawn_fit():
    """Test Xdawn fit."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=None,
                    verbose=False)
    # =========== Basic Fit test =================
    # test base xdawn
    xd = Xdawn(n_components=2,
               correct_overlap='auto',
               signal_cov=None,
               reg=None)
    xd.fit(epochs)
    # with this parameters, the overlapp correction must be False
    assert_equal(xd.correct_overlap, False)
    # no overlapp correction should give averaged evoked
    evoked = epochs['cond2'].average()
    assert_array_equal(evoked.data, xd.evokeds_['cond2'].data)

    # ========== with signal cov provided ====================
    # provide covariance object
    signal_cov = compute_raw_covariance(raw, picks=picks)
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=signal_cov,
               reg=None)
    xd.fit(epochs)
    # provide ndarray
    signal_cov = np.eye(len(picks))
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=signal_cov,
               reg=None)
    xd.fit(epochs)
    # provide ndarray of bad shape
    signal_cov = np.eye(len(picks) - 1)
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=signal_cov,
               reg=None)
    assert_raises(ValueError, xd.fit, epochs)
    # provide another type
    signal_cov = 42
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=signal_cov,
               reg=None)
    assert_raises(ValueError, xd.fit, epochs)
    # fit with baseline correction and ovverlapp correction should throw an
    # error
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=(None, 0),
                    verbose=False)

    xd = Xdawn(n_components=2, correct_overlap=True)
    assert_raises(ValueError, xd.fit, epochs)
예제 #13
0
def test_xdawn_apply_transform():
    """Test Xdawn apply and transform."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    preload=True, baseline=None, verbose=False)
    n_components = 2
    # Fit Xdawn
    xd = Xdawn(n_components=n_components, correct_overlap='auto')
    xd.fit(epochs)

    # apply on raw
    xd.apply(raw)
    # apply on epochs
    denoise = xd.apply(epochs)
    # apply on evoked
    xd.apply(epochs.average())
    # apply on other thing should raise an error
    assert_raises(ValueError, xd.apply, 42)

    # transform on epochs
    xd.transform(epochs)
    # transform on ndarray
    xd.transform(epochs._data)
    # transform on someting else
    assert_raises(ValueError, xd.transform, 42)

    # check numerical results with shuffled epochs
    idx = np.arange(len(epochs))
    np.random.shuffle(idx)
    xd.fit(epochs[idx])
    denoise_shfl = xd.apply(epochs)
    assert_array_equal(denoise['cond2']._data, denoise_shfl['cond2']._data)
예제 #14
0
def test_xdawn_regularization():
    """Test Xdawn with regularization."""
    # Get data, this time MEG so we can test proper reg/ch type support
    raw = read_raw_fif(raw_fname, verbose=False, preload=True)
    events = read_events(event_name)
    picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                       ecg=False, eog=False,
                       exclude='bads')[::8]
    raw.pick_channels([raw.ch_names[pick] for pick in picks])
    del picks
    raw.info.normalize_proj()
    epochs = Epochs(raw, events, event_id, tmin, tmax,
                    preload=True, baseline=None, verbose=False)

    # Test with overlapping events.
    # modify events to simulate one overlap
    events = epochs.events
    sel = np.where(events[:, 2] == 2)[0][:2]
    modified_event = events[sel[0]]
    modified_event[0] += 1
    epochs.events[sel[1]] = modified_event
    # Fit and check that overlap was found and applied
    xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas')
    xd.fit(epochs)
    assert xd.correct_overlap_
    evoked = epochs['cond2'].average()
    assert np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data))

    # With covariance regularization
    for reg in [.1, 0.1, 'ledoit_wolf', 'oas']:
        xd = Xdawn(n_components=2, correct_overlap=False,
                   signal_cov=np.eye(len(epochs.ch_names)), reg=reg)
        xd.fit(epochs)
    # With bad shrinkage
    xd = Xdawn(n_components=2, correct_overlap=False,
               signal_cov=np.eye(len(epochs.ch_names)), reg=2)
    with pytest.raises(ValueError, match='shrinkage must be'):
        xd.fit(epochs)
    # With rank-deficient input
    raw = maxwell_filter(raw, int_order=4, ext_order=2)
    xd = Xdawn(correct_overlap=False, reg=None)
    # this is a bit wacky because `epochs` has projectors on from the old raw
    # but it works as a rank-deficient test case
    with pytest.raises(ValueError, match='Could not compute eigenvalues'):
        xd.fit(epochs)
    xd = Xdawn(correct_overlap=False, reg=0.5)
    xd.fit(epochs)
    xd = Xdawn(correct_overlap=False, reg='diagonal_fixed')
    xd.fit(epochs)
예제 #15
0
def test_xdawn_apply_transform():
    """Test Xdawn apply and transform."""
    # Get data
    raw, events, picks = _get_data()
    raw.pick_types(eeg=True, meg=False)
    epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False,
                    preload=True, baseline=None,
                    verbose=False)
    n_components = 2
    # Fit Xdawn
    xd = Xdawn(n_components=n_components, correct_overlap=False)
    xd.fit(epochs)

    # Apply on different types of instances
    for inst in [raw, epochs.average(), epochs]:
        denoise = xd.apply(inst)
    # Apply on other thing should raise an error
    pytest.raises(ValueError, xd.apply, 42)

    # Transform on Epochs
    xd.transform(epochs)
    # Transform on Evoked
    xd.transform(epochs.average())
    # Transform on ndarray
    xd.transform(epochs._data)
    xd.transform(epochs._data[0])
    # Transform on something else
    pytest.raises(ValueError, xd.transform, 42)

    # Check numerical results with shuffled epochs
    np.random.seed(0)  # random makes unstable linalg
    idx = np.arange(len(epochs))
    np.random.shuffle(idx)
    xd.fit(epochs[idx])
    denoise_shfl = xd.apply(epochs)
    assert_array_almost_equal(denoise['cond2']._data,
                              denoise_shfl['cond2']._data)
예제 #16
0
def test_xdawn_regularization():
    """Test Xdawn with regularization."""
    # Get data, this time MEG so we can test proper reg/ch type support
    raw = read_raw_fif(raw_fname, verbose=False, preload=True)
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       eeg=False,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')[::8]
    raw.pick_channels([raw.ch_names[pick] for pick in picks])
    del picks
    raw.info.normalize_proj()
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    preload=True,
                    baseline=None,
                    verbose=False)

    # Test with overlapping events.
    # modify events to simulate one overlap
    events = epochs.events
    sel = np.where(events[:, 2] == 2)[0][:2]
    modified_event = events[sel[0]]
    modified_event[0] += 1
    epochs.events[sel[1]] = modified_event
    # Fit and check that overlap was found and applied
    xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas')
    xd.fit(epochs)
    assert xd.correct_overlap_
    evoked = epochs['cond2'].average()
    assert np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data))

    # With covariance regularization
    for reg in [.1, 0.1, 'ledoit_wolf', 'oas']:
        xd = Xdawn(n_components=2,
                   correct_overlap=False,
                   signal_cov=np.eye(len(epochs.ch_names)),
                   reg=reg)
        xd.fit(epochs)
    # With bad shrinkage
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=np.eye(len(epochs.ch_names)),
               reg=2)
    with pytest.raises(ValueError, match='shrinkage must be'):
        xd.fit(epochs)
    # With rank-deficient input
    # this is a bit wacky because `epochs` has projectors on from the old raw
    # but it works as a rank-deficient test case
    xd = Xdawn(correct_overlap=False, reg=0.5)
    xd.fit(epochs)
    xd = Xdawn(correct_overlap=False, reg='diagonal_fixed')
    xd.fit(epochs)
    bad_eig = (sys.platform.startswith('win')
               and check_version('numpy', '1.16.5')
               and 'mkl_rt' in _get_numpy_libs()
               )  # some problem with MKL on Win
    if bad_eig:
        pytest.skip('Unknown MKL+Windows error fails for eig check')
    xd = Xdawn(correct_overlap=False, reg=None)
    with pytest.raises(ValueError, match='Could not compute eigenvalues'):
        xd.fit(epochs)
예제 #17
0
def test_xdawn():
    """Test init of xdawn."""
    # Init xdawn with good parameters
    Xdawn(n_components=2, correct_overlap='auto', signal_cov=None, reg=None)
    # Init xdawn with bad parameters
    pytest.raises(ValueError, Xdawn, correct_overlap=42)
예제 #18
0
def test_xdawn_apply_transform():
    """Test Xdawn apply and transform."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=None,
                    verbose=False)
    n_components = 2
    # Fit Xdawn
    xd = Xdawn(n_components=n_components, correct_overlap='auto')
    xd.fit(epochs)

    # apply on raw
    xd.apply(raw)
    # apply on epochs
    xd.apply(epochs)
    # apply on evoked
    xd.apply(epochs.average())
    # apply on other thing should raise an error
    assert_raises(ValueError, xd.apply, 42)

    # transform on epochs
    xd.transform(epochs)
    # transform on ndarray
    xd.transform(epochs._data)
    # transform on someting else
    assert_raises(ValueError, xd.transform, 42)
예제 #19
0
def test_xdawn_fit():
    """Test Xdawn fit."""
    # Get data
    raw, events, picks = _get_data()
    raw.del_proj()
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=None,
                    verbose=False)
    # =========== Basic Fit test =================
    # Test base xdawn
    xd = Xdawn(n_components=2, correct_overlap='auto')
    xd.fit(epochs)
    # With these parameters, the overlap correction must be False
    assert not xd.correct_overlap_
    # No overlap correction should give averaged evoked
    evoked = epochs['cond2'].average()
    assert_array_equal(evoked.data, xd.evokeds_['cond2'].data)

    assert_allclose(np.linalg.norm(xd.filters_['cond2'], axis=1), 1)

    # ========== with signal cov provided ====================
    # Provide covariance object
    signal_cov = compute_raw_covariance(raw, picks=picks)
    xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov)
    xd.fit(epochs)
    # Provide ndarray
    signal_cov = np.eye(len(picks))
    xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov)
    xd.fit(epochs)
    # Provide ndarray of bad shape
    signal_cov = np.eye(len(picks) - 1)
    xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov)
    pytest.raises(ValueError, xd.fit, epochs)
    # Provide another type
    signal_cov = 42
    xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov)
    pytest.raises(ValueError, xd.fit, epochs)
    # Fit with baseline correction and overlap correction should throw an
    # error
    # XXX This is a buggy test, the epochs here don't overlap
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=(None, 0),
                    verbose=False)

    xd = Xdawn(n_components=2, correct_overlap=True)
    pytest.raises(ValueError, xd.fit, epochs)
예제 #20
0
def test_xdawn_regularization():
    """Test Xdawn with regularization."""
    # get data
    raw, events, picks = _get_data()
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    preload=True,
                    baseline=None,
                    verbose=False)

    # test xdawn with overlap correction
    xd = Xdawn(n_components=2, correct_overlap=True, signal_cov=None, reg=0.1)
    xd.fit(epochs)
    # ========== with cov regularization ====================
    # ledoit-wolf
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=np.eye(len(picks)),
               reg='ledoit_wolf')
    xd.fit(epochs)
    # oas
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=np.eye(len(picks)),
               reg='oas')
    xd.fit(epochs)
    # with shrinkage
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=np.eye(len(picks)),
               reg=0.1)
    xd.fit(epochs)
    # with bad shrinkage
    xd = Xdawn(n_components=2,
               correct_overlap=False,
               signal_cov=np.eye(len(picks)),
               reg=2)
    assert_raises(ValueError, xd.fit, epochs)