def test_calculate_cca_swapaxes(self): """caluclate_cca must work with nonstandard timeaxis.""" res1 = calculate_cca(swapaxes(self.dat_x, 0, 1), swapaxes(self.dat_y, 0, 1), timeaxis=1) res2 = calculate_cca(self.dat_x, self.dat_y) np.testing.assert_array_equal(res1[0], res2[0]) np.testing.assert_array_equal(res1[1], res2[1]) np.testing.assert_array_equal(res1[2], res2[2])
def test_segment_dat_swapaxes(self): """Segmentation must work with nonstandard axes.""" epo = segment_dat(swapaxes(self.dat, 0, 1), self.mrk_def, [-400, 400], timeaxis=-1) # segment_dat added a new dimension epo = swapaxes(epo, 1, 2) epo2 = segment_dat(self.dat, self.mrk_def, [-400, 400]) self.assertEqual(epo, epo2)
def test_correct_for_baseline_swapaxes(self): """Correct for baseline must work with nonstandard timeaxis.""" dat = correct_for_baseline(swapaxes(self.dat, 0, 1), [-1000, 0], timeaxis=0) dat = swapaxes(dat, 0, 1) dat2 = correct_for_baseline(self.dat, [-1000, 0]) self.assertEqual(dat, dat2)
def test_apply_spatial_filter_swapaxes(self): """apply_spatial_filter must work with nonstandard chanaxis.""" epo_f = apply_spatial_filter(swapaxes(self.epo, 1, -1), self.w, chanaxis=1) epo_f = swapaxes(epo_f, 1, -1) epo_f2 = apply_spatial_filter(self.epo, self.w) self.assertEqual(epo_f, epo_f2)
def test_append_cnt_swapaxes(self): """append_cnt must work with nonstandard timeaxis.""" dat = append_cnt(swapaxes(self.dat, 0, 1), swapaxes(self.dat, 0, 1), timeaxis=1) dat = swapaxes(dat, 0, 1) dat2 = append_cnt(self.dat, self.dat) self.assertEqual(dat, dat2)
def test_append_swapaxes(self): """append must work with nonstandard timeaxis.""" dat = append(swapaxes(self.dat, 0, 2), swapaxes(self.dat, 0, 2), axis=2) dat = swapaxes(dat, 0, 2) dat2 = append(self.dat, self.dat) self.assertEqual(dat, dat2)
def test_filtfilt_swapaxes(self): """filtfilt must work with nonstandard timeaxis.""" fn = self.dat.fs / 2 b, a = butter(4, [6 / fn, 8 / fn], btype='band') dat = filtfilt(swapaxes(self.dat, 0, 1), b, a, timeaxis=1) dat = swapaxes(dat, 0, 1) dat2 = filtfilt(self.dat, b, a) self.assertEqual(dat, dat2)
def test_create_feature_vectors_swapaxes(self): """create_feature_vectors must work with nonstandard classaxis.""" # keep in mind that create_feature_vectors already swaps the # axes internally to move the classaxis to 0 dat = create_feature_vectors(swapaxes(self.dat, 0, 2), classaxis=2) dat2 = create_feature_vectors(self.dat) self.assertEqual(dat, dat2)
def test_calculate_signed_r_square_swapaxes(self): """caluclate_r_square must work with nonstandard classaxis.""" dat = calculate_signed_r_square(swapaxes(self.dat, 0, 2), classaxis=2) # the class-axis just dissapears during # calculate_signed_r_square, so axis 2 becomes axis 1 dat = dat.swapaxes(0, 1) dat2 = calculate_signed_r_square(self.dat) np.testing.assert_array_equal(dat, dat2)
def test_swapaxes(self): """Swapping axes.""" new = swapaxes(self.dat, 0, 1) self.assertTrue((new.axes[0] == self.dat.axes[1]).all()) self.assertTrue((new.axes[1] == self.dat.axes[0]).all()) self.assertEqual(new.names[0], self.dat.names[1]) self.assertEqual(new.names[1], self.dat.names[0]) self.assertEqual(new.units[0], self.dat.units[1]) self.assertEqual(new.units[1], self.dat.units[0]) self.assertEqual(new.data.shape[::-1], self.dat.data.shape) np.testing.assert_array_equal(new.data.swapaxes(0, 1), self.dat.data)
def test_calculate_signed_r_square_swapaxes(self): """caluclate_r_square must work with nonstandard classaxis.""" dat = calculate_signed_r_square(swapaxes(self.dat, 0, 2), classaxis=2) # the class-axis just dissapears during # calculate_signed_r_square, so axis 2 becomes axis 1 dat = dat.swapaxes(0, 1) dat2 = calculate_signed_r_square(self.dat) # this used to work with numpy 1.8, but with 1.9 the arrays # differ slightly after e-15, I don't see why this change # happened, but it is barely noticeable wo we check for almost # equality # np.testing.assert_array_equal(dat, dat2) np.testing.assert_array_almost_equal(dat, dat2)
def test_rereference_swapaxes(self): """rereference must work with nonstandard chanaxis.""" dat = rereference(swapaxes(self.epo, 1, 2), 'chan0', chanaxis=1) dat = swapaxes(dat, 1, 2) dat2 = rereference(self.epo, 'chan0') self.assertEqual(dat, dat2)
def test_spectrum_swapaxes(self): """spectrum must work with nonstandard timeaxis.""" dat = spectrum_welch(swapaxes(self.dat, 0, 1), timeaxis=1) dat = swapaxes(dat, 0, 1) dat2 = spectrum_welch(self.dat) self.assertEqual(dat, dat2)
def test_select_ival_swapaxes(self): """select_ival must work with nonstandard timeaxis.""" dat = select_ival(swapaxes(self.dat, 0, 1), [-500, 0], timeaxis=0) dat = swapaxes(dat, 0, 1) dat2 = select_ival(self.dat, [-500, 0]) self.assertEqual(dat, dat2)
def test_clear_markes_swapaxes(self): """clear_markers must work with nonstandard timeaxis.""" dat = clear_markers(swapaxes(self.dat, 1, 2), timeaxis=2) dat = swapaxes(dat, 1, 2) dat2 = clear_markers(self.dat) self.assertEqual(dat, dat2)
def test_swapaxes_twice(self): """Swapping the same axes twice must result in original.""" dat = swapaxes(self.dat, 0, 1) dat = swapaxes(dat, 0, 1) self.assertEqual(dat, self.dat)
def test_jumping_means_swapaxes(self): """jumping means must work with nonstandard timeaxis.""" dat = jumping_means(swapaxes(self.dat, 1, 2), [[0, 1000]], timeaxis=2) dat = swapaxes(dat, 1, 2) dat2 = jumping_means(self.dat, [[0, 1000]]) self.assertEqual(dat, dat2)
def test_sort_channels_swapaxis(self): """sort_channels must workt with nonstandard chanaxis.""" sorted_ = sort_channels(swapaxes(self.dat, 1, -1), 1) sorted_ = swapaxes(sorted_, 1, -1) sorted2 = sort_channels(self.dat) self.assertEqual(sorted_, sorted2)
def test_band_pass_swapaxes(self): """band_pass must work with nonstandard timeaxis.""" dat = band_pass(swapaxes(self.dat, 0, 1), 6, 8, timeaxis=1) dat = swapaxes(dat, 0, 1) dat2 = band_pass(self.dat, 6, 8) self.assertEqual(dat, dat2)
def test_subsample_swapaxes(self): """subsample must work with nonstandard timeaxis.""" dat = subsample(swapaxes(self.dat, 0, 1), 10, timeaxis=1) dat = swapaxes(dat, 0, 1) dat2 = subsample(self.dat, 10) self.assertEqual(dat, dat2)
def test_variance_swapaxes(self): """variance must work with nonstandard timeaxis.""" dat = variance(swapaxes(self.dat, 1, 2), timeaxis=2) # we don't swap back here as variance removes the timeaxis dat2 = variance(self.dat) self.assertEqual(dat, dat2)
def test_select_channels_swapaxis(self): """Select channels works with non default chanaxis.""" dat1 = select_channels(swapaxes(self.dat, 0, 1), ["ca.*"], chanaxis=0) dat1 = swapaxes(dat1, 0, 1) dat2 = select_channels(self.dat, ["ca.*"]) self.assertEqual(dat1, dat2)
def test_select_classes_swapaxes(self): """Select classes must work with nonstandard classaxis.""" dat = select_classes(swapaxes(self.dat, 0, 2), [0], classaxis=2) dat = swapaxes(dat, 0, 2) dat2 = select_classes(self.dat, [0]) self.assertEqual(dat, dat2)
def test_remove_channels_swapaxis(self): """Remove channels works with non default chanaxis.""" dat1 = remove_channels(swapaxes(self.dat, 0, 1), ['ca.*'], chanaxis=0) dat1 = swapaxes(dat1, 0, 1) dat2 = remove_channels(self.dat, ['ca.*']) self.assertEqual(dat1, dat2)
def test_swapaxes_copy(self): """Swapaxes must not modify argument.""" cpy = self.dat.copy() swapaxes(self.dat, 0, 1) self.assertEqual(cpy, self.dat)
def test_remove_classes_swapaxes(self): """Remove epochs must work with nonstandard classaxis.""" dat = remove_classes(swapaxes(self.dat, 0, 2), [0, 2], classaxis=2) dat = swapaxes(dat, 0, 2) dat2 = remove_classes(self.dat, [0, 2]) self.assertEqual(dat, dat2)
def test_remove_epochs_swapaxes(self): """Remove epochs must work with nonstandard classaxis.""" dat = remove_epochs(swapaxes(self.dat, 0, 2), [0, 1], classaxis=2) dat = swapaxes(dat, 0, 2) dat2 = remove_epochs(self.dat, [0, 1]) self.assertEqual(dat, dat2)