def test_diff_between_canonical_variables(self): """Test if the scaled canonical variables are almost same.""" rho, w_x, w_y = calculate_cca(self.dat_x, self.dat_y) cv_x = apply_spatial_filter(self.dat_x, w_x) cv_y = apply_spatial_filter(self.dat_y, w_y) def scale(x): tmp = x.data - x.data.mean() return tmp / tmp[np.argmax(np.abs(tmp))] diff = scale(cv_x) - scale(cv_y) diff = np.sum(np.abs(diff)) / self.SAMPLES self.assertTrue(diff < 0.1)
def test_diagonal(self): """The whitened data should have all 1s on the covariance matrix.""" a = calculate_whitening_matrix(self.cnt) dat2 = apply_spatial_filter(self.cnt, a) vals = np.diag(np.cov(dat2.data.T)) np.testing.assert_array_almost_equal(vals, [1. for i in range(len(vals))])
def test_zeros(self): """The whitened data should have all 0s on the non-diagonals of the covariance matrix.""" a = calculate_whitening_matrix(self.cnt) dat2 = apply_spatial_filter(self.cnt, a) cov = np.cov(dat2.data.T) # substract the diagonals cov -= np.diag(np.diag(cov)) self.assertAlmostEqual(np.sum(cov), 0)
def test_spatial_filter_cnt(self): """Spatial filtering should work with cnt.""" cnt_f = apply_spatial_filter(self.cnt, self.w) # chan 0 np.testing.assert_array_equal(cnt_f.data[:, 0], self.cnt.data[:, 2] - self.cnt.data[:, 1]) # chan 1 np.testing.assert_array_equal(cnt_f.data[:, 1], 0.5 * np.sum(self.cnt.data, axis=-1)) # chan 2 np.testing.assert_array_equal(cnt_f.data[:, 2], self.cnt.data[:, 0])
def test_spatial_filter_cnt(self): """Spatial filtering should work with cnt.""" cnt_f = apply_spatial_filter(self.cnt, self.w) # chan 0 np.testing.assert_array_equal( cnt_f.data[:, 0], self.cnt.data[:, 2] - self.cnt.data[:, 1]) # chan 1 np.testing.assert_array_equal(cnt_f.data[:, 1], 0.5 * np.sum(self.cnt.data, axis=-1)) # chan 2 np.testing.assert_array_equal(cnt_f.data[:, 2], self.cnt.data[:, 0])
def apply_csp(data, return_as='filtered', time_axis=1, columns_to_apply=(0, 1, -2, -1)): """Calculates and applies CSP""" w, a, d = calculate_csp(data) if return_as == 'patterns': return w, a w = w[:, list(columns_to_apply)] filtered = apply_spatial_filter(data, w) if return_as == 'logvar': filtered.data = np.log(np.var(filtered.data, axis=time_axis)) filtered.axes[1] = np.arange(filtered.data.shape[0]) return filtered else: return filtered
def test_prefix_and_postfix(self): """Prefix and Postfix are mutual exclusive.""" with self.assertRaises(ValueError): apply_spatial_filter(self.cnt, self.w, prefix='foo', postfix='bar')
def test_prefix(self): """Apply prefix correctly.""" cnt_f = apply_spatial_filter(self.cnt, self.w, prefix='foo') self.assertEqual(cnt_f.axes[-1], ['foo'+str(i) for i in range(CHANS)])
def test_apply_spatial_filter_copy(self): """apply_spatial_filter must not modify arguments.""" cpy = self.cnt.copy() apply_spatial_filter(self.cnt, self.w) self.assertEqual(self.cnt, cpy)
def test_postfix_w_wrong_type(self): """Raise TypeError if postfix is neither None or str.""" with self.assertRaises(TypeError): apply_spatial_filter(self.cnt, self.w, postfix=1)
def test_postfix(self): """Apply postfix correctly.""" cnt_f = apply_spatial_filter(self.cnt, self.w, postfix='foo') self.assertEqual(cnt_f.axes[-1], [c + 'foo' for c in self.cnt.axes[-1]])
def test_spatial_filter_epo(self): """Spatial filtering should work with epo.""" cnt_f = apply_spatial_filter(self.cnt, self.w) epo_f = apply_spatial_filter(self.epo, self.w) for i in range(EPOS): np.testing.assert_array_equal(epo_f.data[i, ...], cnt_f.data)
def test_shape(self): """The spatial filtered data should keep its shape.""" cnt_f = apply_spatial_filter(self.cnt, self.w) epo_f = apply_spatial_filter(self.epo, self.w) self.assertEqual(self.cnt.data.shape, cnt_f.data.shape) self.assertEqual(self.epo.data.shape, epo_f.data.shape)
def test_postfix(self): """Apply postfix correctly.""" cnt_f = apply_spatial_filter(self.cnt, self.w, postfix='foo') self.assertEqual(cnt_f.axes[-1], [c+'foo' for c in self.cnt.axes[-1]])
def test_prefix(self): """Apply prefix correctly.""" cnt_f = apply_spatial_filter(self.cnt, self.w, prefix='foo') self.assertEqual(cnt_f.axes[-1], ['foo' + str(i) for i in range(CHANS)])
def test_apply_spatial_filter_swapaxes(self): """apply_spatial_filter must work with nonstandard chanaxis.""" epo_f = apply_spatial_filter(swapaxes(self.epo, 1, -1), self.w, chanaxis=1) epo_f = swapaxes(epo_f, 1, -1) epo_f2 = apply_spatial_filter(self.epo, self.w) self.assertEqual(epo_f, epo_f2)