def test_ssd_epoched_data(): """Test Common Spatial Patterns algorithm on epoched data. Compare the outputs when raw data is used. """ X, A, S = simulate_data(n_trials=100, n_channels=20, n_samples=500) sf = 250 n_channels = X.shape[0] info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') n_components_true = 5 # Build epochs as sliding windows over the continuous raw file # Epoch length is 1 second X_e = np.reshape(X, (100, 20, 500)) # Fit filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], l_trans_bandwidth=4, h_trans_bandwidth=4) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=4, h_trans_bandwidth=4) # ssd on epochs ssd_e = SSD(info, filt_params_signal, filt_params_noise) ssd_e.fit(X_e) # ssd on raw ssd = SSD(info, filt_params_signal, filt_params_noise) ssd.fit(X) # Check if the 5 first 5 components are the same for both _, sorter_spec_e = ssd_e.get_spectral_ratio(ssd_e.transform(X_e)) _, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) assert_array_equal(sorter_spec_e[:n_components_true], sorter_spec[:n_components_true])
def test_sorting(): """Test sorting learning during training.""" X, _, _ = simulate_data(n_trials=100, n_channels=20, n_samples=500) # Epoch length is 1 second X = np.reshape(X, (100, 20, 500)) # split data Xtr, Xte = X[:80], X[80:] sf = 250 n_channels = Xtr.shape[1] info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], l_trans_bandwidth=4, h_trans_bandwidth=4) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=4, h_trans_bandwidth=4) # check sort_by_spectral_ratio set to False ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=None, sort_by_spectral_ratio=False) ssd.fit(Xtr) _, sorter_tr = ssd.get_spectral_ratio(ssd.transform(Xtr)) _, sorter_te = ssd.get_spectral_ratio(ssd.transform(Xte)) assert any(sorter_tr != sorter_te) # check sort_by_spectral_ratio set to True ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=None, sort_by_spectral_ratio=True) ssd.fit(Xtr) # check sorters sorter_in = ssd.sorter_spec ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=None, sort_by_spectral_ratio=False) ssd.fit(Xtr) _, sorter_out = ssd.get_spectral_ratio(ssd.transform(Xtr)) assert all(sorter_in == sorter_out)
def test_ssd(): """Test Common Spatial Patterns algorithm on raw data.""" X, A, S = simulate_data() sf = 250 n_channels = X.shape[0] info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') n_components_true = 5 # Init filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=1, h_trans_bandwidth=1) ssd = SSD(info, filt_params_signal, filt_params_noise) # freq no int freq = 'foo' filt_params_signal = dict(l_freq=freq, h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=1, h_trans_bandwidth=1) with pytest.raises(TypeError, match='must be an instance '): ssd = SSD(info, filt_params_signal, filt_params_noise) # Wrongly specified noise band freq = 2 filt_params_signal = dict(l_freq=freq, h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=1, h_trans_bandwidth=1) with pytest.raises(ValueError, match='Wrongly specified '): ssd = SSD(info, filt_params_signal, filt_params_noise) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise with pytest.raises(ValueError, match='must be defined'): ssd = SSD(info, filt_params_signal, filt_params_noise) # Data type filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=1, h_trans_bandwidth=1) ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) pytest.raises(TypeError, ssd.fit, raw) # More than 1 channel type ch_types = np.reshape([['mag'] * 10, ['eeg'] * 10], n_channels) info_2 = create_info(ch_names=n_channels, sfreq=sf, ch_types=ch_types) with pytest.raises(ValueError, match='At this point SSD'): ssd = SSD(info_2, filt_params_signal, filt_params_noise) # Number of channels info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types='eeg') ssd = SSD(info_3, filt_params_signal, filt_params_noise) pytest.raises(ValueError, ssd.fit, X) # Fit n_components = 10 ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components) # Call transform before fit pytest.raises(AttributeError, ssd.transform, X) # Check outputs ssd.fit(X) assert (ssd.filters_.shape == (n_channels, n_channels)) assert (ssd.patterns_.shape == (n_channels, n_channels)) # Transform X_ssd = ssd.fit_transform(X) assert (X_ssd.shape[0] == n_components) # back and forward ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=None, sort_by_spectral_ratio=False) ssd.fit(X) X_denoised = ssd.inverse_transform(X) assert_array_almost_equal(X_denoised, X) # Power ratio ordering spec_ratio, _ = ssd.get_spectral_ratio(ssd.transform(X)) # since we now that the number of true components is 5, the relative # difference should be low for the first 5 components and then increases index_diff = np.argmax(-np.diff(spec_ratio)) assert index_diff == n_components_true - 1 # Check detected peaks # fit ssd n_components = n_components_true filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1) filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], l_trans_bandwidth=1, h_trans_bandwidth=1) ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components, sort_by_spectral_ratio=False) ssd.fit(X) out = ssd.transform(X) psd_out, _ = psd_array_welch(out[0], sfreq=250, n_fft=250) psd_S, _ = psd_array_welch(S[0], sfreq=250, n_fft=250) corr = np.abs(np.corrcoef((psd_out, psd_S))[0, 1]) assert np.abs(corr) > 0.95 # Check pattern estimation # Since there is no exact ordering of the recovered patterns # a pair-wise greedy search will be done error = list() for ii in range(n_channels): corr = np.abs(np.corrcoef(ssd.patterns_[ii, :].T, A[:, 0])[0, 1]) error.append(1 - corr) min_err = np.min(error) assert min_err < 0.3 # threshold taken from SSD original paper