def test_csp(): """Test Common Spatial Patterns algorithm on epochs """ raw = io.read_raw_fif(raw_fname, preload=False) events = read_events(event_name) picks = pick_types(raw.info, meg=True, stim=False, ecg=False, eog=False, exclude='bads') picks = picks[2:9:3] # subselect channels -> disable proj! epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True, proj=False) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_components = 3 csp = CSP(n_components=n_components) csp.fit(epochs_data, epochs.events[:, -1]) y = epochs.events[:, -1] X = csp.fit_transform(epochs_data, y) assert_true(csp.filters_.shape == (n_channels, n_channels)) assert_true(csp.patterns_.shape == (n_channels, n_channels)) assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) # test init exception assert_raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) assert_raises(ValueError, csp.fit, epochs, y) assert_raises(ValueError, csp.transform, epochs, y) csp.n_components = n_components sources = csp.transform(epochs_data) assert_true(sources.shape[1] == n_components) epochs.pick_types(meg='mag', copy=False) # test plot patterns components = np.arange(n_components) csp.plot_patterns(epochs.info, components=components, res=12, show=False) # test plot filters csp.plot_filters(epochs.info, components=components, res=12, show=False) # test covariance estimation methods (results should be roughly equal) csp_epochs = CSP(cov_est="epoch") csp_epochs.fit(epochs_data, y) for attr in ('filters_', 'patterns_'): corr = np.corrcoef(getattr(csp, attr).ravel(), getattr(csp_epochs, attr).ravel())[0, 1] assert_true(corr >= 0.95, msg='%s < 0.95' % corr) # make sure error is raised for undefined estimation method csp_fail = CSP(cov_est="undefined") assert_raises(ValueError, csp_fail.fit, epochs_data, y)
def test_csp_twoclass_symmetry(): """Test that CSP is symmetric when swapping classes.""" x, y = deterministic_toy_data(['class_a', 'class_b']) csp = CSP(norm_trace=False, transform_into='average_power', log=True) log_power = csp.fit_transform(x, y) log_power_ratio_ab = log_power[0] - log_power[1] x, y = deterministic_toy_data(['class_b', 'class_a']) csp = CSP(norm_trace=False, transform_into='average_power', log=True) log_power = csp.fit_transform(x, y) log_power_ratio_ba = log_power[0] - log_power[1] assert_array_almost_equal(log_power_ratio_ab, log_power_ratio_ba)
def test_regularized_csp(): """Test Common Spatial Patterns algorithm using regularized covariance """ raw = io.Raw(raw_fname, preload=False) events = read_events(event_name) picks = pick_types(raw.info, meg=True, stim=False, ecg=False, eog=False, exclude='bads') picks = picks[1:13:3] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_components = 3 reg_cov = [None, 0.05, 'lws', 'oas'] for reg in reg_cov: csp = CSP(n_components=n_components, reg=reg) csp.fit(epochs_data, epochs.events[:, -1]) y = epochs.events[:, -1] X = csp.fit_transform(epochs_data, y) assert_true(csp.filters_.shape == (n_channels, n_channels)) assert_true(csp.patterns_.shape == (n_channels, n_channels)) assert_array_almost_equal(csp.fit(epochs_data, y). transform(epochs_data), X) # test init exception assert_raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) assert_raises(ValueError, csp.fit, epochs, y) assert_raises(ValueError, csp.transform, epochs, y) csp.n_components = n_components sources = csp.transform(epochs_data) assert_true(sources.shape[1] == n_components)
def test_csp_pipeline(): """Test if CSP works in a pipeline.""" from sklearn.svm import SVC from sklearn.pipeline import Pipeline csp = CSP(reg=1, norm_trace=False) svc = SVC() pipe = Pipeline([("CSP", csp), ("SVC", svc)]) pipe.set_params(CSP__reg=0.2) assert (pipe.get_params()["CSP__reg"] == 0.2)
def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" x, y = deterministic_toy_data(['class_a', 'class_b']) pytest.raises(ValueError, CSP, component_order='invalid') # component_order='alternate' only works with two classes csp = CSP(component_order='alternate') with pytest.raises(ValueError): csp.fit(np.zeros((3, 0, 0)), ['a', 'b', 'c']) p_alt = CSP(component_order='alternate').fit(x, y).patterns_ p_mut = CSP(component_order='mutual_info').fit(x, y).patterns_ # This permutation of p_alt and p_mut is explained by the particular # eigenvalues of the toy data: [0.06, 0.1, 0.5, 0.8]. # p_alt arranges them to [0.8, 0.06, 0.5, 0.1] # p_mut arranges them to [0.06, 0.1, 0.8, 0.5] assert_array_almost_equal(p_alt, p_mut[[2, 0, 3, 1]])
def test_csp(): """Test Common Spatial Patterns algorithm on epochs """ raw = io.Raw(raw_fname, preload=False) events = read_events(event_name) picks = pick_types(raw.info, meg=True, stim=False, ecg=False, eog=False, exclude='bads') picks = picks[2:9:3] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_components = 3 csp = CSP(n_components=n_components) csp.fit(epochs_data, epochs.events[:, -1]) y = epochs.events[:, -1] X = csp.fit_transform(epochs_data, y) assert_true(csp.filters_.shape == (n_channels, n_channels)) assert_true(csp.patterns_.shape == (n_channels, n_channels)) assert_array_almost_equal( csp.fit(epochs_data, y).transform(epochs_data), X) # test init exception assert_raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) assert_raises(ValueError, csp.fit, epochs, y) assert_raises(ValueError, csp.transform, epochs, y) csp.n_components = n_components sources = csp.transform(epochs_data) assert_true(sources.shape[1] == n_components) epochs.pick_types(meg='mag', copy=False) # test plot patterns components = np.arange(n_components) csp.plot_patterns(epochs.info, components=components, res=12, show=False) # test plot filters csp.plot_filters(epochs.info, components=components, res=12, show=False)
def test_csp(): """Test Common Spatial Patterns algorithm on epochs.""" raw = io.read_raw_fif(raw_fname, preload=False) events = read_events(event_name) picks = pick_types(raw.info, meg=True, stim=False, ecg=False, eog=False, exclude='bads') picks = picks[2:12:3] # subselect channels -> disable proj! raw.add_proj([], remove_existing=True) epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True, proj=False) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] y = epochs.events[:, -1] # Init pytest.raises(ValueError, CSP, n_components='foo', norm_trace=False) for reg in ['foo', -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) for reg in ['oas', 'ledoit_wolf', 0, 0.5, 1.]: CSP(reg=reg, norm_trace=False) for cov_est in ['foo', None]: pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) pytest.raises(ValueError, CSP, norm_trace='foo') for cov_est in ['concat', 'epoch']: CSP(cov_est=cov_est, norm_trace=False) n_components = 3 # Fit for norm_trace in [True, False]: csp = CSP(n_components=n_components, norm_trace=norm_trace) csp.fit(epochs_data, epochs.events[:, -1]) assert_equal(len(csp.mean_), n_components) assert_equal(len(csp.std_), n_components) # Transform X = csp.fit_transform(epochs_data, y) sources = csp.transform(epochs_data) assert (sources.shape[1] == n_components) assert (csp.filters_.shape == (n_channels, n_channels)) assert (csp.patterns_.shape == (n_channels, n_channels)) assert_array_almost_equal(sources, X) # Test data exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) pytest.raises(ValueError, csp.fit, epochs, y) pytest.raises(ValueError, csp.transform, epochs) # Test plots epochs.pick_types(meg='mag') cmap = ('RdBu', True) components = np.arange(n_components) for plot in (csp.plot_patterns, csp.plot_filters): plot(epochs.info, components=components, res=12, show=False, cmap=cmap) # Test with more than 2 classes epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, picks=picks, event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4), baseline=(None, 0), proj=False, preload=True) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_channels = epochs_data.shape[1] for cov_est in ['concat', 'epoch']: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) assert_equal(len(csp._classes), 4) assert_array_equal(csp.filters_.shape, [n_channels, n_channels]) assert_array_equal(csp.patterns_.shape, [n_channels, n_channels]) # Test average power transform n_components = 2 assert (csp.transform_into == 'average_power') feature_shape = [len(epochs_data), n_components] X_trans = dict() for log in (None, True, False): csp = CSP(n_components=n_components, log=log, norm_trace=False) assert (csp.log is log) Xt = csp.fit_transform(epochs_data, epochs.events[:, 2]) assert_array_equal(Xt.shape, feature_shape) X_trans[str(log)] = Xt # log=None => log=True assert_array_almost_equal(X_trans['None'], X_trans['True']) # Different normalization return different transform assert (np.sum((X_trans['True'] - X_trans['False'])**2) > 1.) # Check wrong inputs pytest.raises(ValueError, CSP, transform_into='average_power', log='foo') # Test csp space transform csp = CSP(transform_into='csp_space', norm_trace=False) assert (csp.transform_into == 'csp_space') for log in ('foo', True, False): pytest.raises(ValueError, CSP, transform_into='csp_space', log=log, norm_trace=False) n_components = 2 csp = CSP(n_components=n_components, transform_into='csp_space', norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) feature_shape = [len(epochs_data), n_components, epochs_data.shape[2]] assert_array_equal(Xt.shape, feature_shape) # Check mixing matrix on simulated data y = np.array([100] * 50 + [1] * 50) X, A = simulate_data(y) for cov_est in ['concat', 'epoch']: # fit csp csp = CSP(n_components=1, cov_est=cov_est, norm_trace=False) csp.fit(X, y) # check the first pattern match the mixing matrix # the sign might change corr = np.abs(np.corrcoef(csp.patterns_[0, :].T, A[:, 0])[0, 1]) assert np.abs(corr) > 0.99 # check output out = csp.transform(X) corr = np.abs(np.corrcoef(out[:, 0], y)[0, 1]) assert np.abs(corr) > 0.95
def test_csp(): """Test Common Spatial Patterns algorithm on epochs """ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False) events = read_events(event_name) picks = pick_types(raw.info, meg=True, stim=False, ecg=False, eog=False, exclude='bads') picks = picks[2:12:3] # subselect channels -> disable proj! raw.add_proj([], remove_existing=True) epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True, proj=False, add_eeg_ref=False) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] y = epochs.events[:, -1] # Init assert_raises(ValueError, CSP, n_components='foo') for reg in ['foo', -0.1, 1.1]: assert_raises(ValueError, CSP, reg=reg) for reg in ['oas', 'ledoit_wolf', 0, 0.5, 1.]: CSP(reg=reg) for cov_est in ['foo', None]: assert_raises(ValueError, CSP, cov_est=cov_est) for cov_est in ['concat', 'epoch']: CSP(cov_est=cov_est) n_components = 3 csp = CSP(n_components=n_components) # Fit csp.fit(epochs_data, epochs.events[:, -1]) assert_equal(len(csp.mean_), n_components) assert_equal(len(csp.std_), n_components) # Transform X = csp.fit_transform(epochs_data, y) sources = csp.transform(epochs_data) assert_true(sources.shape[1] == n_components) assert_true(csp.filters_.shape == (n_channels, n_channels)) assert_true(csp.patterns_.shape == (n_channels, n_channels)) assert_array_almost_equal(sources, X) # Test data exception assert_raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) assert_raises(ValueError, csp.fit, epochs, y) assert_raises(ValueError, csp.transform, epochs) # Test plots epochs.pick_types(meg='mag') cmap = ('RdBu', True) components = np.arange(n_components) for plot in (csp.plot_patterns, csp.plot_filters): plot(epochs.info, components=components, res=12, show=False, cmap=cmap) # Test covariance estimation methods (results should be roughly equal) np.random.seed(0) csp_epochs = CSP(cov_est="epoch") csp_epochs.fit(epochs_data, y) for attr in ('filters_', 'patterns_'): corr = np.corrcoef( getattr(csp, attr).ravel(), getattr(csp_epochs, attr).ravel())[0, 1] assert_true(corr >= 0.94) # Test with more than 2 classes epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, picks=picks, event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4), baseline=(None, 0), proj=False, preload=True, add_eeg_ref=False) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_channels = epochs_data.shape[1] for cov_est in ['concat', 'epoch']: csp = CSP(n_components=n_components, cov_est=cov_est) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) assert_equal(len(csp._classes), 4) assert_array_equal(csp.filters_.shape, [n_channels, n_channels]) assert_array_equal(csp.patterns_.shape, [n_channels, n_channels]) # Test average power transform n_components = 2 assert_true(csp.transform_into == 'average_power') feature_shape = [len(epochs_data), n_components] X_trans = dict() for log in (None, True, False): csp = CSP(n_components=n_components, log=log) assert_true(csp.log is log) Xt = csp.fit_transform(epochs_data, epochs.events[:, 2]) assert_array_equal(Xt.shape, feature_shape) X_trans[str(log)] = Xt # log=None => log=True assert_array_almost_equal(X_trans['None'], X_trans['True']) # Different normalization return different transform assert_true(np.sum((X_trans['True'] - X_trans['False'])**2) > 1.) # Check wrong inputs assert_raises(ValueError, CSP, transform_into='average_power', log='foo') # Test csp space transform csp = CSP(transform_into='csp_space') assert_true(csp.transform_into == 'csp_space') for log in ('foo', True, False): assert_raises(ValueError, CSP, transform_into='csp_space', log=log) n_components = 2 csp = CSP(n_components=n_components, transform_into='csp_space') Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) feature_shape = [len(epochs_data), n_components, epochs_data.shape[2]] assert_array_equal(Xt.shape, feature_shape)