def test_XdawnTransformer(): """Test _XdawnTransformer.""" # Get data raw, events, picks = _get_data() raw.del_proj() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) X = epochs._data y = epochs.events[:, -1] # Fit xdt = _XdawnTransformer() xdt.fit(X, y) pytest.raises(ValueError, xdt.fit, X, y[1:]) pytest.raises(ValueError, xdt.fit, 'foo') # Provide covariance object signal_cov = compute_raw_covariance(raw, picks=picks) xdt = _XdawnTransformer(signal_cov=signal_cov) xdt.fit(X, y) # Provide ndarray signal_cov = np.eye(len(picks)) xdt = _XdawnTransformer(signal_cov=signal_cov) xdt.fit(X, y) # Provide ndarray of bad shape signal_cov = np.eye(len(picks) - 1) xdt = _XdawnTransformer(signal_cov=signal_cov) pytest.raises(ValueError, xdt.fit, X, y) # Provide another type signal_cov = 42 xdt = _XdawnTransformer(signal_cov=signal_cov) pytest.raises(ValueError, xdt.fit, X, y) # Fit with y as None xdt = _XdawnTransformer() xdt.fit(X) # Compare xdawn and _XdawnTransformer xd = Xdawn(correct_overlap=False) xd.fit(epochs) xdt = _XdawnTransformer() xdt.fit(X, y) assert_array_almost_equal(xd.filters_['cond2'][:2, :], xdt.filters_.reshape(2, 2, 8)[0]) # Transform testing xdt.transform(X[1:, ...]) # different number of epochs xdt.transform(X[:, :, 1:]) # different number of time pytest.raises(ValueError, xdt.transform, X[:, 1:, :]) Xt = xdt.transform(X) pytest.raises(ValueError, xdt.transform, 42) # Inverse transform testing Xinv = xdt.inverse_transform(Xt) assert Xinv.shape == X.shape xdt.inverse_transform(Xt[1:, ...]) xdt.inverse_transform(Xt[:, :, 1:]) # should raise an error if not correct number of components pytest.raises(ValueError, xdt.inverse_transform, Xt[:, 1:, :]) pytest.raises(ValueError, xdt.inverse_transform, 42)
def test_XdawnTransformer(): """Test _XdawnTransformer.""" # Get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False, add_eeg_ref=False) X = epochs._data y = epochs.events[:, -1] # Fit xdt = _XdawnTransformer() xdt.fit(X, y) assert_raises(ValueError, xdt.fit, X, y[1:]) assert_raises(ValueError, xdt.fit, 'foo') # Provide covariance object signal_cov = compute_raw_covariance(raw, picks=picks) xdt = _XdawnTransformer(signal_cov=signal_cov) xdt.fit(X, y) # Provide ndarray signal_cov = np.eye(len(picks)) xdt = _XdawnTransformer(signal_cov=signal_cov) xdt.fit(X, y) # Provide ndarray of bad shape signal_cov = np.eye(len(picks) - 1) xdt = _XdawnTransformer(signal_cov=signal_cov) assert_raises(ValueError, xdt.fit, X, y) # Provide another type signal_cov = 42 xdt = _XdawnTransformer(signal_cov=signal_cov) assert_raises(ValueError, xdt.fit, X, y) # Fit with y as None xdt = _XdawnTransformer() xdt.fit(X) # Compare xdawn and _XdawnTransformer xd = Xdawn(correct_overlap=False) xd.fit(epochs) xdt = _XdawnTransformer() xdt.fit(X, y) assert_array_almost_equal(xd.filters_['cond2'][:, :2], xdt.filters_.reshape(2, 2, 8)[0].T) # Transform testing xdt.transform(X[1:, ...]) # different number of epochs xdt.transform(X[:, :, 1:]) # different number of time assert_raises(ValueError, xdt.transform, X[:, 1:, :]) Xt = xdt.transform(X) assert_raises(ValueError, xdt.transform, 42) # Inverse transform testing Xinv = xdt.inverse_transform(Xt) assert_equal(Xinv.shape, X.shape) xdt.inverse_transform(Xt[1:, ...]) xdt.inverse_transform(Xt[:, :, 1:]) # should raise an error if not correct number of components assert_raises(ValueError, xdt.inverse_transform, Xt[:, 1:, :]) assert_raises(ValueError, xdt.inverse_transform, 42)
def test_xdawn_picks(): """Test picking with Xdawn.""" data = np.random.RandomState(0).randn(10, 2, 10) info = create_info(2, 1000., ('eeg', 'misc')) epochs = EpochsArray(data, info) xd = Xdawn(correct_overlap=False) xd.fit(epochs) epochs_out = xd.apply(epochs)['1'] assert epochs_out.info['ch_names'] == epochs.ch_names assert not (epochs_out.get_data()[:, 0] != data[:, 0]).any() assert_array_equal(epochs_out.get_data()[:, 1], data[:, 1])
def test_xdawn_regularization(): """Test Xdawn with regularization.""" # Get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) # Test with overlapping events. # modify events to simulate one overlap events = epochs.events sel = np.where(events[:, 2] == 2)[0][:2] modified_event = events[sel[0]] modified_event[0] += 1 epochs.events[sel[1]] = modified_event # Fit and check that overlap was found and applied xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas') xd.fit(epochs) assert_equal(xd.correct_overlap_, True) evoked = epochs['cond2'].average() assert_true(np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data))) # With covariance regularization for reg in [.1, 0.1, 'ledoit_wolf', 'oas']: xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg=reg) xd.fit(epochs) # With bad shrinkage xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg=2) assert_raises(ValueError, xd.fit, epochs)
def test_xdawn_regularization(): """Test Xdawn with regularization.""" # Get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False, add_eeg_ref=False) # Test with overlapping events. # modify events to simulate one overlap events = epochs.events sel = np.where(events[:, 2] == 2)[0][:2] modified_event = events[sel[0]] modified_event[0] += 1 epochs.events[sel[1]] = modified_event # Fit and check that overlap was found and applied xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas') xd.fit(epochs) assert_equal(xd.correct_overlap_, True) evoked = epochs['cond2'].average() assert_true(np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data))) # With covariance regularization for reg in [.1, 0.1, 'ledoit_wolf', 'oas']: xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg=reg) xd.fit(epochs) # With bad shrinkage xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg=2) assert_raises(ValueError, xd.fit, epochs)
def test_xdawn_decoding_performance(): """Test decoding performance and extracted pattern on synthetic data.""" from sklearn.model_selection import KFold from sklearn.pipeline import make_pipeline from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import MinMaxScaler from sklearn.metrics import accuracy_score n_xdawn_comps = 3 expected_accuracy = 0.98 epochs, mixing_mat = _simulate_erplike_mixed_data(n_epochs=100) y = epochs.events[:, 2] # results of Xdawn and _XdawnTransformer should match xdawn_pipe = make_pipeline( Xdawn(n_components=n_xdawn_comps), Vectorizer(), MinMaxScaler(), LogisticRegression(solver='liblinear')) xdawn_trans_pipe = make_pipeline( _XdawnTransformer(n_components=n_xdawn_comps), Vectorizer(), MinMaxScaler(), LogisticRegression(solver='liblinear')) cv = KFold(n_splits=3, shuffle=False) for pipe, X in ( (xdawn_pipe, epochs), (xdawn_trans_pipe, epochs.get_data())): predictions = np.empty_like(y, dtype=float) for train, test in cv.split(X, y): pipe.fit(X[train], y[train]) predictions[test] = pipe.predict(X[test]) cv_accuracy_xdawn = accuracy_score(y, predictions) assert_allclose(cv_accuracy_xdawn, expected_accuracy, atol=0.01) # for both event types, the first component should "match" the mixing fitted_xdawn = pipe.steps[0][1] if isinstance(fitted_xdawn, Xdawn): relev_patterns = np.concatenate( [comps[[0]] for comps in fitted_xdawn.patterns_.values()]) else: relev_patterns = fitted_xdawn.patterns_[::n_xdawn_comps] for i in range(len(relev_patterns)): r, _ = stats.pearsonr(relev_patterns[i, :], mixing_mat[0, :]) assert np.abs(r) > 0.99
def test_xdawn_apply_transform(): """Test Xdawn apply and transform.""" # Get data raw, events, picks = _get_data() raw.pick_types(eeg=True, meg=False) epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False, preload=True, baseline=None, verbose=False) n_components = 2 # Fit Xdawn xd = Xdawn(n_components=n_components, correct_overlap=False) xd.fit(epochs) # Apply on different types of instances for inst in [raw, epochs.average(), epochs]: denoise = xd.apply(inst) # Apply on other thing should raise an error pytest.raises(ValueError, xd.apply, 42) # Transform on Epochs xd.transform(epochs) # Transform on Evoked xd.transform(epochs.average()) # Transform on ndarray xd.transform(epochs._data) xd.transform(epochs._data[0]) # Transform on something else pytest.raises(ValueError, xd.transform, 42) # Check numerical results with shuffled epochs np.random.seed(0) # random makes unstable linalg idx = np.arange(len(epochs)) np.random.shuffle(idx) xd.fit(epochs[idx]) denoise_shfl = xd.apply(epochs) assert_array_almost_equal(denoise['cond2']._data, denoise_shfl['cond2']._data)
def test_xdawn_apply_transform(): """Test Xdawn apply and transform.""" # get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) n_components = 2 # Fit Xdawn xd = Xdawn(n_components=n_components, correct_overlap='auto') xd.fit(epochs) # apply on raw xd.apply(raw) # apply on epochs xd.apply(epochs) # apply on evoked xd.apply(epochs.average()) # apply on other thing should raise an error assert_raises(ValueError, xd.apply, 42) # transform on epochs xd.transform(epochs) # transform on ndarray xd.transform(epochs._data) # transform on someting else assert_raises(ValueError, xd.transform, 42)
def test_xdawn_fit(): """Test Xdawn fit.""" # get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) # =========== Basic Fit test ================= # test base xdawn xd = Xdawn(n_components=2, correct_overlap='auto', signal_cov=None, reg=None) xd.fit(epochs) # with this parameters, the overlapp correction must be False assert_equal(xd.correct_overlap, False) # no overlapp correction should give averaged evoked evoked = epochs['cond2'].average() assert_array_equal(evoked.data, xd.evokeds_['cond2'].data) # ========== with signal cov provided ==================== # provide covariance object signal_cov = compute_raw_covariance(raw, picks=picks) xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov, reg=None) xd.fit(epochs) # provide ndarray signal_cov = np.eye(len(picks)) xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov, reg=None) xd.fit(epochs) # provide ndarray of bad shape signal_cov = np.eye(len(picks) - 1) xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov, reg=None) assert_raises(ValueError, xd.fit, epochs) # provide another type signal_cov = 42 xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov, reg=None) assert_raises(ValueError, xd.fit, epochs) # fit with baseline correction and ovverlapp correction should throw an # error epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=(None, 0), verbose=False) xd = Xdawn(n_components=2, correct_overlap=True) assert_raises(ValueError, xd.fit, epochs)
def test_xdawn_regularization(): """Test Xdawn with regularization.""" # get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) # test xdawn with overlap correction xd = Xdawn(n_components=2, correct_overlap=True, signal_cov=None, reg=0.1) xd.fit(epochs) # ========== with cov regularization ==================== # ledoit-wolf xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg='ledoit_wolf') xd.fit(epochs) # oas xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg='oas') xd.fit(epochs) # with shrinkage xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg=0.1) xd.fit(epochs) # with bad shrinkage xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(picks)), reg=2) assert_raises(ValueError, xd.fit, epochs)
def test_xdawn_apply_transform(): """Test Xdawn apply and transform.""" # get data raw, events, picks = _get_data() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) n_components = 2 # Fit Xdawn xd = Xdawn(n_components=n_components, correct_overlap='auto') xd.fit(epochs) # apply on raw xd.apply(raw) # apply on epochs denoise = xd.apply(epochs) # apply on evoked xd.apply(epochs.average()) # apply on other thing should raise an error assert_raises(ValueError, xd.apply, 42) # transform on epochs xd.transform(epochs) # transform on ndarray xd.transform(epochs._data) # transform on someting else assert_raises(ValueError, xd.transform, 42) # check numerical results with shuffled epochs idx = np.arange(len(epochs)) np.random.shuffle(idx) xd.fit(epochs[idx]) denoise_shfl = xd.apply(epochs) assert_array_equal(denoise['cond2']._data, denoise_shfl['cond2']._data)
def test_xdawn_regularization(): """Test Xdawn with regularization.""" # Get data, this time MEG so we can test proper reg/ch type support raw = read_raw_fif(raw_fname, verbose=False, preload=True) events = read_events(event_name) picks = pick_types(raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude='bads')[::8] raw.pick_channels([raw.ch_names[pick] for pick in picks]) del picks raw.info.normalize_proj() epochs = Epochs(raw, events, event_id, tmin, tmax, preload=True, baseline=None, verbose=False) # Test with overlapping events. # modify events to simulate one overlap events = epochs.events sel = np.where(events[:, 2] == 2)[0][:2] modified_event = events[sel[0]] modified_event[0] += 1 epochs.events[sel[1]] = modified_event # Fit and check that overlap was found and applied xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas') xd.fit(epochs) assert xd.correct_overlap_ evoked = epochs['cond2'].average() assert np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data)) # With covariance regularization for reg in [.1, 0.1, 'ledoit_wolf', 'oas']: xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(epochs.ch_names)), reg=reg) xd.fit(epochs) # With bad shrinkage xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(epochs.ch_names)), reg=2) with pytest.raises(ValueError, match='shrinkage must be'): xd.fit(epochs) # With rank-deficient input raw = maxwell_filter(raw, int_order=4, ext_order=2) xd = Xdawn(correct_overlap=False, reg=None) # this is a bit wacky because `epochs` has projectors on from the old raw # but it works as a rank-deficient test case with pytest.raises(ValueError, match='Could not compute eigenvalues'): xd.fit(epochs) xd = Xdawn(correct_overlap=False, reg=0.5) xd.fit(epochs) xd = Xdawn(correct_overlap=False, reg='diagonal_fixed') xd.fit(epochs)
def test_xdawn_regularization(): """Test Xdawn with regularization.""" # Get data, this time MEG so we can test proper reg/ch type support raw = read_raw_fif(raw_fname, verbose=False, preload=True) events = read_events(event_name) picks = pick_types(raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude='bads')[::8] raw.pick_channels([raw.ch_names[pick] for pick in picks]) del picks raw.info.normalize_proj() epochs = Epochs(raw, events, event_id, tmin, tmax, preload=True, baseline=None, verbose=False) # Test with overlapping events. # modify events to simulate one overlap events = epochs.events sel = np.where(events[:, 2] == 2)[0][:2] modified_event = events[sel[0]] modified_event[0] += 1 epochs.events[sel[1]] = modified_event # Fit and check that overlap was found and applied xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas') xd.fit(epochs) assert xd.correct_overlap_ evoked = epochs['cond2'].average() assert np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data)) # With covariance regularization for reg in [.1, 0.1, 'ledoit_wolf', 'oas']: xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(epochs.ch_names)), reg=reg) xd.fit(epochs) # With bad shrinkage xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=np.eye(len(epochs.ch_names)), reg=2) with pytest.raises(ValueError, match='shrinkage must be'): xd.fit(epochs) # With rank-deficient input # this is a bit wacky because `epochs` has projectors on from the old raw # but it works as a rank-deficient test case xd = Xdawn(correct_overlap=False, reg=0.5) xd.fit(epochs) xd = Xdawn(correct_overlap=False, reg='diagonal_fixed') xd.fit(epochs) bad_eig = (sys.platform.startswith('win') and check_version('numpy', '1.16.5') and 'mkl_rt' in _get_numpy_libs() ) # some problem with MKL on Win if bad_eig: pytest.skip('Unknown MKL+Windows error fails for eig check') xd = Xdawn(correct_overlap=False, reg=None) with pytest.raises(ValueError, match='Could not compute eigenvalues'): xd.fit(epochs)
def test_xdawn(): """Test init of xdawn.""" # Init xdawn with good parameters Xdawn(n_components=2, correct_overlap='auto', signal_cov=None, reg=None) # Init xdawn with bad parameters pytest.raises(ValueError, Xdawn, correct_overlap=42)
def test_xdawn_fit(): """Test Xdawn fit.""" # Get data raw, events, picks = _get_data() raw.del_proj() epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=None, verbose=False) # =========== Basic Fit test ================= # Test base xdawn xd = Xdawn(n_components=2, correct_overlap='auto') xd.fit(epochs) # With these parameters, the overlap correction must be False assert not xd.correct_overlap_ # No overlap correction should give averaged evoked evoked = epochs['cond2'].average() assert_array_equal(evoked.data, xd.evokeds_['cond2'].data) assert_allclose(np.linalg.norm(xd.filters_['cond2'], axis=1), 1) # ========== with signal cov provided ==================== # Provide covariance object signal_cov = compute_raw_covariance(raw, picks=picks) xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov) xd.fit(epochs) # Provide ndarray signal_cov = np.eye(len(picks)) xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov) xd.fit(epochs) # Provide ndarray of bad shape signal_cov = np.eye(len(picks) - 1) xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov) pytest.raises(ValueError, xd.fit, epochs) # Provide another type signal_cov = 42 xd = Xdawn(n_components=2, correct_overlap=False, signal_cov=signal_cov) pytest.raises(ValueError, xd.fit, epochs) # Fit with baseline correction and overlap correction should throw an # error # XXX This is a buggy test, the epochs here don't overlap epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=(None, 0), verbose=False) xd = Xdawn(n_components=2, correct_overlap=True) pytest.raises(ValueError, xd.fit, epochs)