Exemple #1
0
class TestClearMarkers(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones * 2, axis=0)
        cnt = np.append(cnt, ones * 3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 2000, 30, endpoint=False)
        self.good_markers = [[-1000, 'a'], [-999, 'b'], [0, 'c'],
                             [1999.9999999, 'd']]
        bad_markers = [[-1001, 'x'], [2000, 'x']]
        markers = self.good_markers[:]
        markers.extend(bad_markers)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.markers = markers
        self.dat.fs = 10

    def test_clear_markers(self):
        """Clear markers."""
        dat = clear_markers(self.dat)
        self.assertEqual(dat.markers, self.good_markers)

    def test_clear_emtpy_markers(self):
        """Clearing emtpy markers has no effect."""
        dat = self.dat.copy()
        dat.markers = []
        dat2 = clear_markers(dat)
        self.assertEqual(dat, dat2)

    def test_clear_nonexisting_markers(self):
        """Clearing emtpy markers has no effect."""
        dat = self.dat.copy()
        del dat.markers
        dat2 = clear_markers(dat)
        self.assertEqual(dat, dat2)

    def test_clear_markers_w_empty_data(self):
        """Clearing emtpy dat should remove all markers."""
        dat = self.dat.copy()
        dat.data = np.array([])
        dat2 = clear_markers(dat)
        self.assertEqual(dat2.markers, [])

    def test_clear_markes_swapaxes(self):
        """clear_markers must work with nonstandard timeaxis."""
        dat = clear_markers(swapaxes(self.dat, 1, 2), timeaxis=2)
        dat = swapaxes(dat, 1, 2)
        dat2 = clear_markers(self.dat)
        self.assertEqual(dat, dat2)

    def test_clear_markers_copy(self):
        """clear_markers must not modify argument."""
        cpy = self.dat.copy()
        clear_markers(self.dat)
        self.assertEqual(self.dat, cpy)
class TestClearMarkers(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones*2, axis=0)
        cnt = np.append(cnt, ones*3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 2000, 30, endpoint=False)
        self.good_markers = [[-1000, 'a'], [-999, 'b'], [0, 'c'], [1999.9999999, 'd']]
        bad_markers = [[-1001, 'x'], [2000, 'x']]
        markers = self.good_markers[:]
        markers.extend(bad_markers)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.markers = markers
        self.dat.fs = 10

    def test_clear_markers(self):
        """Clear markers."""
        dat = clear_markers(self.dat)
        self.assertEqual(dat.markers, self.good_markers)

    def test_clear_emtpy_markers(self):
        """Clearing emtpy markers has no effect."""
        dat = self.dat.copy()
        dat.markers = []
        dat2 = clear_markers(dat)
        self.assertEqual(dat, dat2)

    def test_clear_nonexisting_markers(self):
        """Clearing emtpy markers has no effect."""
        dat = self.dat.copy()
        del dat.markers
        dat2 = clear_markers(dat)
        self.assertEqual(dat, dat2)

    def test_clear_markers_w_empty_data(self):
        """Clearing emtpy dat should remove all markers."""
        dat = self.dat.copy()
        dat.data = np.array([])
        dat2 = clear_markers(dat)
        self.assertEqual(dat2.markers, [])

    def test_clear_markes_swapaxes(self):
        """clear_markers must work with nonstandard timeaxis."""
        dat = clear_markers(swapaxes(self.dat, 1, 2), timeaxis=2)
        dat = swapaxes(dat, 1, 2)
        dat2 = clear_markers(self.dat)
        self.assertEqual(dat, dat2)

    def test_clear_markers_copy(self):
        """clear_markers must not modify argument."""
        cpy = self.dat.copy()
        clear_markers(self.dat)
        self.assertEqual(self.dat, cpy)
Exemple #3
0
 def test_copy(self):
     """Copy must work."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     d2 = d1.copy()
     self.assertEqual(d1, d2)
     # we can't really check of all references to be different in
     # depth recursively, so we only check on the first level
     for k in d1.__dict__:
         self.assertNotEqual(id(getattr(d1, k)), id(getattr(d2, k)))
     d2 = d1.copy(foo='bar')
     self.assertEqual(d2.foo, 'bar')
Exemple #4
0
 def test_copy(self):
     """Copy must work."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     d2 = d1.copy()
     self.assertEqual(d1, d2)
     # we can't really check of all references to be different in
     # depth recursively, so we only check on the first level
     for k in d1.__dict__:
         self.assertNotEqual(id(getattr(d1, k)), id(getattr(d2, k)))
     d2 = d1.copy(foo='bar')
     self.assertEqual(d2.foo, 'bar')
Exemple #5
0
class TestAppendEpo(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones*2, axis=0)
        cnt = np.append(cnt, ones*3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = ['zero', 'one', 'two']

    def test_append_epo(self):
        """append_epo."""
        dat = append_epo(self.dat, self.dat)
        self.assertEqual(dat.data.shape[0], 2*self.dat.data.shape[0])
        self.assertEqual(len(dat.axes[0]), 2*len(self.dat.axes[0]))
        np.testing.assert_array_equal(dat.data, np.concatenate([self.dat.data, self.dat.data], axis=0))
        np.testing.assert_array_equal(dat.axes[0], np.concatenate([self.dat.axes[0], self.dat.axes[0]]))
        self.assertEqual(dat.class_names, self.dat.class_names)

    def test_append_epo_with_extra(self):
        """append_epo with extra must work with list and ndarrays."""
        self.dat.a = list(range(10))
        self.dat.b = np.arange(10)
        dat = append_epo(self.dat, self.dat, extra=['a', 'b'])
        self.assertEqual(dat.a, list(range(10)) + list(range(10)))
        np.testing.assert_array_equal(dat.b, np.concatenate([np.arange(10), np.arange(10)]))

    def test_append_epo_with_different_class_names(self):
        """test_append must raise a ValueError if class_names are different."""
        a = self.dat.copy()
        a.class_names = a.class_names[:-1]
        with self.assertRaises(ValueError):
            append_epo(a, self.dat)
            append_epo(self.dat, a)

    def test_append_epo_swapaxes(self):
        """append_epo must work with nonstandard timeaxis."""
        dat = append_epo(swapaxes(self.dat, 0, 2), swapaxes(self.dat, 0, 2), classaxis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = append_epo(self.dat, self.dat)
        self.assertEqual(dat, dat2)

    def test_append_epo_copy(self):
        """append_epo means must not modify argument."""
        cpy = self.dat.copy()
        append_epo(self.dat, self.dat)
        self.assertEqual(self.dat, cpy)
class TestCorrectForBaseline(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        classes = [0, 0, 0]
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 0, 10, endpoint=False)
        # three cnts: 1s, -1s, and 0s
        data = np.array([ones, ones * -1, ones * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channels'], ['#', 'ms', '#'])

    def test_correct_for_baseline_epo(self):
        """Test baselineing w/ epo like."""
        # normal case
        dat = correct_for_baseline(self.dat, [-500, 0])
        np.testing.assert_array_equal(np.zeros((3, 10, 5)), dat.data)
        # the full dat interval
        dat = correct_for_baseline(self.dat, [dat.axes[-2][0], dat.axes[-2][-1]])
        np.testing.assert_array_equal(np.zeros((3, 10, 5)), dat.data)

    def test_correct_for_baseline_cnt(self):
        """Test baselineing w/ cnt like."""
        data = self.dat.data.reshape(30, 5)
        axes = [np.linspace(-1000, 2000, 30, endpoint=False), self.dat.axes[-1]]
        units = self.dat.units[1:]
        names = self.dat.names[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat2 = correct_for_baseline(dat, [-1000, 0])
        np.testing.assert_array_equal(dat2.data, dat.data - 1)

    def test_ival_checks(self):
        """Test for malformed ival parameter."""
        with self.assertRaises(AssertionError):
            correct_for_baseline(self.dat, [0, -1])
        with self.assertRaises(AssertionError):
            correct_for_baseline(self.dat, [self.dat.axes[-2][0]-1, 0])
        with self.assertRaises(AssertionError):
            correct_for_baseline(self.dat, [0, self.dat.axes[-2][1]+1])

    def test_correct_for_baseline_copy(self):
        """Correct for baseline must not modify dat argument."""
        cpy = self.dat.copy()
        correct_for_baseline(self.dat, [-1000, 0])
        self.assertEqual(cpy, self.dat)

    def test_correct_for_baseline_swapaxes(self):
        """Correct for baseline must work with nonstandard timeaxis."""
        dat = correct_for_baseline(swapaxes(self.dat, 0, 1), [-1000, 0], timeaxis=0)
        dat = swapaxes(dat, 0, 1)
        dat2 = correct_for_baseline(self.dat, [-1000, 0])
        self.assertEqual(dat, dat2)
class TestSwapaxes(unittest.TestCase):
    def setUp(self):
        raw = np.arange(2000).reshape(-1, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 4000, 400, endpoint=False)
        fs = 100
        marker = [[100, 'foo'], [200, 'bar']]
        self.dat = Data(raw, [time, channels], ['time', 'channels'],
                        ['ms', '#'])
        self.dat.fs = fs
        self.dat.markers = marker

    def test_swapaxes(self):
        """Swapping axes."""
        new = swapaxes(self.dat, 0, 1)
        self.assertTrue((new.axes[0] == self.dat.axes[1]).all())
        self.assertTrue((new.axes[1] == self.dat.axes[0]).all())
        self.assertEqual(new.names[0], self.dat.names[1])
        self.assertEqual(new.names[1], self.dat.names[0])
        self.assertEqual(new.units[0], self.dat.units[1])
        self.assertEqual(new.units[1], self.dat.units[0])
        self.assertEqual(new.data.shape[::-1], self.dat.data.shape)
        np.testing.assert_array_equal(new.data.swapaxes(0, 1), self.dat.data)

    def test_swapaxes_copy(self):
        """Swapaxes must not modify argument."""
        cpy = self.dat.copy()
        swapaxes(self.dat, 0, 1)
        self.assertEqual(cpy, self.dat)

    def test_swapaxes_twice(self):
        """Swapping the same axes twice must result in original."""
        dat = swapaxes(self.dat, 0, 1)
        dat = swapaxes(dat, 0, 1)
        self.assertEqual(dat, self.dat)
class TestSelectChannels(unittest.TestCase):
    def setUp(self):
        raw = np.arange(20).reshape(4, 5)
        channels = ["ca1", "ca2", "cb1", "cb2", "cc1"]
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ["time", "channels"], ["ms", "#"])

    def test_select_channels(self):
        """Selecting channels with an array of regexes."""
        channels = self.dat.data.copy()
        self.dat = select_channels(self.dat, ["ca.*", "cc1"])
        np.testing.assert_array_equal(self.dat.axes[-1], np.array(["ca1", "ca2", "cc1"]))
        np.testing.assert_array_equal(self.dat.data, channels[:, np.array([0, 1, -1])])

    def test_select_channels_inverse(self):
        """Removing channels with an array of regexes."""
        channels = self.dat.data.copy()
        self.dat = select_channels(self.dat, ["ca.*", "cc1"], invert=True)
        np.testing.assert_array_equal(self.dat.axes[-1], np.array(["cb1", "cb2"]))
        np.testing.assert_array_equal(self.dat.data, channels[:, np.array([2, 3])])

    def test_select_channels_copy(self):
        """Select channels must not change the original parameter."""
        cpy = self.dat.copy()
        select_channels(self.dat, ["ca.*"])
        self.assertEqual(cpy, self.dat)

    def test_select_channels_swapaxis(self):
        """Select channels works with non default chanaxis."""
        dat1 = select_channels(swapaxes(self.dat, 0, 1), ["ca.*"], chanaxis=0)
        dat1 = swapaxes(dat1, 0, 1)
        dat2 = select_channels(self.dat, ["ca.*"])
        self.assertEqual(dat1, dat2)
Exemple #9
0
class TestSwapaxes(unittest.TestCase):

    def setUp(self):
        raw = np.arange(2000).reshape(-1, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 4000, 400, endpoint=False)
        fs = 100
        marker = [[100, 'foo'], [200, 'bar']]
        self.dat = Data(raw, [time, channels], ['time', 'channels'], ['ms', '#'])
        self.dat.fs = fs
        self.dat.markers = marker

    def test_swapaxes(self):
        """Swapping axes."""
        new = swapaxes(self.dat, 0, 1)
        self.assertTrue((new.axes[0] == self.dat.axes[1]).all())
        self.assertTrue((new.axes[1] == self.dat.axes[0]).all())
        self.assertEqual(new.names[0], self.dat.names[1])
        self.assertEqual(new.names[1], self.dat.names[0])
        self.assertEqual(new.units[0], self.dat.units[1])
        self.assertEqual(new.units[1], self.dat.units[0])
        self.assertEqual(new.data.shape[::-1], self.dat.data.shape)
        np.testing.assert_array_equal(new.data.swapaxes(0, 1), self.dat.data)

    def test_swapaxes_copy(self):
        """Swapaxes must not modify argument."""
        cpy = self.dat.copy()
        swapaxes(self.dat, 0, 1)
        self.assertEqual(cpy, self.dat)

    def test_swapaxes_twice(self):
        """Swapping the same axes twice must result in original."""
        dat = swapaxes(self.dat, 0, 1)
        dat = swapaxes(dat, 0, 1)
        self.assertEqual(dat, self.dat)
Exemple #10
0
 def test_eq_and_ne(self):
     """Check if __ne__ is properly implemented."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     d2 = d1.copy()
     # if __eq__ is implemented and __ne__ is not, this evaluates to
     # True!
     self.assertFalse(d1 == d2 and d1 != d2)
Exemple #11
0
 def test_eq_and_ne(self):
     """Check if __ne__ is properly implemented."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     d2 = d1.copy()
     # if __eq__ is implemented and __ne__ is not, this evaluates to
     # True!
     self.assertFalse(d1 == d2 and d1 != d2)
Exemple #12
0
 def test_equality(self):
     """Test the various (in)equalities."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     # known extra attributes
     d1.markers = [[123, 'foo'], [234, 'bar']]
     d1.fs = 100
     # unknown extra attribute
     d1.foo = 'bar'
     # so far, so equal
     d2 = d1.copy()
     self.assertEqual(d1, d2)
     # different shape
     d2 = d1.copy()
     d2.data = np.arange(20).reshape(5, 4)
     self.assertNotEqual(d1, d2)
     # different data
     d2 = d1.copy()
     d2.data[0, 0] = 42
     self.assertNotEqual(d1, d2)
     # different axes
     d2 = d1.copy()
     d2.axes[0] = np.arange(100)
     self.assertNotEqual(d1, d2)
     # different names
     d2 = d1.copy()
     d2.names[0] = 'baz'
     self.assertNotEqual(d1, d2)
     # different untis
     d2 = d1.copy()
     d2.units[0] = 'u3'
     self.assertNotEqual(d1, d2)
     # different known extra attribute
     d2 = d1.copy()
     d2.markers[0] = [123, 'baz']
     self.assertNotEqual(d1, d2)
     # different known extra attribute
     d2 = d1.copy()
     d2.fs = 10
     self.assertNotEqual(d1, d2)
     # different unknown extra attribute
     d2 = d1.copy()
     d2.baz = 'baz'
     self.assertNotEqual(d1, d2)
     # different new unknown extra attribute
     d2 = d1.copy()
     d2.bar = 42
     self.assertNotEqual(d1, d2)
Exemple #13
0
 def test_equality(self):
     """Test the various (in)equalities."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     # known extra attributes
     d1.markers = [[123, 'foo'], [234, 'bar']]
     d1.fs = 100
     # unknown extra attribute
     d1.foo = 'bar'
     # so far, so equal
     d2 = d1.copy()
     self.assertEqual(d1, d2)
     # different shape
     d2 = d1.copy()
     d2.data = np.arange(20).reshape(5, 4)
     self.assertNotEqual(d1, d2)
     # different data
     d2 = d1.copy()
     d2.data[0, 0] = 42
     self.assertNotEqual(d1, d2)
     # different axes
     d2 = d1.copy()
     d2.axes[0] = np.arange(100)
     self.assertNotEqual(d1, d2)
     # different names
     d2 = d1.copy()
     d2.names[0] = 'baz'
     self.assertNotEqual(d1, d2)
     # different untis
     d2 = d1.copy()
     d2.units[0] = 'u3'
     self.assertNotEqual(d1, d2)
     # different known extra attribute
     d2 = d1.copy()
     d2.markers[0] = [123, 'baz']
     self.assertNotEqual(d1, d2)
     # different known extra attribute
     d2 = d1.copy()
     d2.fs = 10
     self.assertNotEqual(d1, d2)
     # different unknown extra attribute
     d2 = d1.copy()
     d2.baz = 'baz'
     self.assertNotEqual(d1, d2)
     # different new unknown extra attribute
     d2 = d1.copy()
     d2.bar = 42
     self.assertNotEqual(d1, d2)
class TestSelectEpochs(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 1000, 10, endpoint=False)
        classes = [0, 1, 2, 1]
        class_names = ['zeros', 'ones', 'twoes']
        # four cnts: 0s, 1s, -1s, and 0s
        data = np.array([ones * 0, ones * 1, ones * 2, ones * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = class_names

    def test_select_epochs(self):
        """Selecting Epochs."""
        # normal case
        dat = select_epochs(self.dat, [0])
        self.assertEqual(dat.data.shape[0], 1)
        np.testing.assert_array_equal(dat.data, self.dat.data[[0]])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][0])
        # normal every second
        dat = select_epochs(self.dat, [0, 2])
        self.assertEqual(dat.data.shape[0], 2)
        np.testing.assert_array_equal(dat.data, self.dat.data[::2])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][::2])
        # the full epo
        dat = select_epochs(self.dat, list(range(self.dat.data.shape[0])))
        np.testing.assert_array_equal(dat.data, self.dat.data)
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0])
        # remove one
        dat = select_epochs(self.dat, [0], invert=True)
        self.assertEqual(dat.data.shape[0], 3)
        np.testing.assert_array_equal(dat.data, self.dat.data[1:])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][1:])
        # remove every second
        dat = select_epochs(self.dat, [0, 2], invert=True)
        self.assertEqual(dat.data.shape[0], 2)
        np.testing.assert_array_equal(dat.data, self.dat.data[1::2])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][1::2])

    def test_select_epochs_with_cnt(self):
        """Select epochs must raise an exception if called with cnt argument."""
        del(self.dat.class_names)
        with self.assertRaises(AssertionError):
            select_epochs(self.dat, [0, 1])

    def test_select_epochs_swapaxes(self):
        """Select epochs must work with nonstandard classaxis."""
        dat = select_epochs(swapaxes(self.dat, 0, 2), [0, 1], classaxis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = select_epochs(self.dat, [0, 1])
        self.assertEqual(dat, dat2)

    def test_select_epochs_copy(self):
        """Select Epochs must not modify argument."""
        cpy = self.dat.copy()
        select_epochs(self.dat, [0, 1])
        self.assertEqual(self.dat, cpy)
Exemple #15
0
class TestRereference(unittest.TestCase):
    def setUp(self):
        dat = np.zeros((SAMPLES, CHANS))
        # [-10, -9, ... 20)
        dat[:, 0] = np.arange(SAMPLES) - SAMPLES / 2
        channels = ['chan{i}'.format(i=i) for i in range(CHANS)]
        time = np.arange(SAMPLES)
        self.cnt = Data(dat, [time, channels], ['time', 'channels'],
                        ['ms', '#'])
        # construct epo
        epo_dat = np.array([dat + i for i in range(EPOS)])
        classes = ['class{i}'.format(i=i) for i in range(EPOS)]
        self.epo = Data(epo_dat, [classes, time, channels],
                        ['class', 'time', 'channels'], ['#', 'ms', '#'])

    def test_rereference_cnt(self):
        """Rereference channels (cnt)."""
        cnt_r = rereference(self.cnt, 'chan0')
        dat_r = np.linspace(SAMPLES / 2, -SAMPLES / 2, SAMPLES, endpoint=False)
        dat_r = [dat_r for i in range(CHANS)]
        dat_r = np.array(dat_r).T
        dat_r[:, 0] = 0
        np.testing.assert_array_equal(cnt_r.data, dat_r)

    def test_rereference_epo(self):
        """Rereference channels (epo)."""
        epo_r = rereference(self.epo, 'chan0')
        dat_r = np.linspace(SAMPLES / 2, -SAMPLES / 2, SAMPLES, endpoint=False)
        dat_r = [dat_r for i in range(CHANS)]
        dat_r = np.array(dat_r).T
        dat_r[:, 0] = 0
        dat_r = np.array([dat_r for i in range(EPOS)])
        np.testing.assert_array_equal(epo_r.data, dat_r)

    def test_raise_value_error(self):
        """Raise ValueError if channel not found."""
        with self.assertRaises(ValueError):
            rereference(self.cnt, 'foo')

    def test_case_insensitivity(self):
        """rereference should not care about case."""
        try:
            rereference(self.cnt, 'ChAN0')
        except ValueError:
            self.fail()

    def test_rereference_copy(self):
        """rereference must not modify arguments."""
        cpy = self.cnt.copy()
        rereference(self.cnt, 'chan0')
        self.assertEqual(self.cnt, cpy)

    def test_rereference_swapaxes(self):
        """rereference must work with nonstandard chanaxis."""
        dat = rereference(swapaxes(self.epo, 1, 2), 'chan0', chanaxis=1)
        dat = swapaxes(dat, 1, 2)
        dat2 = rereference(self.epo, 'chan0')
        self.assertEqual(dat, dat2)
Exemple #16
0
class TestRereference(unittest.TestCase):

    def setUp(self):
        dat = np.zeros((SAMPLES, CHANS))
        # [-10, -9, ... 20)
        dat[:, 0] = np.arange(SAMPLES) - SAMPLES/2
        channels = ['chan{i}'.format(i=i) for i in range(CHANS)]
        time = np.arange(SAMPLES)
        self.cnt = Data(dat, [time, channels], ['time', 'channels'], ['ms', '#'])
        # construct epo
        epo_dat = np.array([dat + i for i in range(EPOS)])
        classes = ['class{i}'.format(i=i) for i in range(EPOS)]
        self.epo = Data(epo_dat, [classes, time, channels], ['class', 'time', 'channels'], ['#', 'ms', '#'])

    def test_rereference_cnt(self):
        """Rereference channels (cnt)."""
        cnt_r = rereference(self.cnt, 'chan0')
        dat_r = np.linspace(SAMPLES/2, -SAMPLES/2, SAMPLES, endpoint=False)
        dat_r = [dat_r for i in range(CHANS)]
        dat_r = np.array(dat_r).T
        dat_r[:, 0] = 0
        np.testing.assert_array_equal(cnt_r.data, dat_r)

    def test_rereference_epo(self):
        """Rereference channels (epo)."""
        epo_r = rereference(self.epo, 'chan0')
        dat_r = np.linspace(SAMPLES/2, -SAMPLES/2, SAMPLES, endpoint=False)
        dat_r = [dat_r for i in range(CHANS)]
        dat_r = np.array(dat_r).T
        dat_r[:, 0] = 0
        dat_r = np.array([dat_r for i in range(EPOS)])
        np.testing.assert_array_equal(epo_r.data, dat_r)

    def test_raise_value_error(self):
        """Raise ValueError if channel not found."""
        with self.assertRaises(ValueError):
            rereference(self.cnt, 'foo')

    def test_case_insensitivity(self):
        """rereference should not care about case."""
        try:
            rereference(self.cnt, 'ChAN0')
        except ValueError:
            self.fail()

    def test_rereference_copy(self):
        """rereference must not modify arguments."""
        cpy = self.cnt.copy()
        rereference(self.cnt, 'chan0')
        self.assertEqual(self.cnt, cpy)

    def test_rereference_swapaxes(self):
        """rereference must work with nonstandard chanaxis."""
        dat = rereference(swapaxes(self.epo, 1, 2), 'chan0', chanaxis=1)
        dat = swapaxes(dat, 1, 2)
        dat2 = rereference(self.epo, 'chan0')
        self.assertEqual(dat, dat2)
Exemple #17
0
class TestVariance(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # epo with 0, 1, 2
        data = np.array([0*ones, ones, 2*ones])
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 1000, 10, endpoint=False)
        classes = [0, 1, 2]
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])

    def test_variance(self):
        """Variance."""
        dat = variance(self.dat)
        # test the resulting dat has one axis less (the middle one)
        self.assertEqual(dat.data.shape, self.dat.data.shape[::2])
        # each epoch should have a variance of zero, test if the var of
        # all epochs is 0
        self.assertEqual(dat.data.var(), 0)
        self.assertEqual(len(dat.axes), len(self.dat.axes)-1)

    def test_variance_with_cnt(self):
        """variance must work with cnt argument."""
        data = self.dat.data[1]
        axes = self.dat.axes[1:]
        names = self.dat.names[1:]
        units = self.dat.units[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat = variance(dat)
        self.assertEqual(dat.data.var(), 0)
        self.assertEqual(len(dat.axes), len(self.dat.axes)-2)

    def test_variance_swapaxes(self):
        """variance must work with nonstandard timeaxis."""
        dat = variance(swapaxes(self.dat, 1, 2), timeaxis=2)
        # we don't swap back here as variance removes the timeaxis
        dat2 = variance(self.dat)
        self.assertEqual(dat, dat2)

    def test_variance_copy(self):
        """variance must not modify argument."""
        cpy = self.dat.copy()
        variance(self.dat)
        self.assertEqual(self.dat, cpy)
class TestVariance(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        # epo with 0, 1, 2
        data = np.array([0 * ones, ones, 2 * ones])
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 1000, 10, endpoint=False)
        classes = [0, 1, 2]
        self.dat = Data(data, [classes, time, channels],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])

    def test_variance(self):
        """Variance."""
        dat = variance(self.dat)
        # test the resulting dat has one axis less (the middle one)
        self.assertEqual(dat.data.shape, self.dat.data.shape[::2])
        # each epoch should have a variance of zero, test if the var of
        # all epochs is 0
        self.assertEqual(dat.data.var(), 0)
        self.assertEqual(len(dat.axes), len(self.dat.axes) - 1)

    def test_variance_with_cnt(self):
        """variance must work with cnt argument."""
        data = self.dat.data[1]
        axes = self.dat.axes[1:]
        names = self.dat.names[1:]
        units = self.dat.units[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat = variance(dat)
        self.assertEqual(dat.data.var(), 0)
        self.assertEqual(len(dat.axes), len(self.dat.axes) - 2)

    def test_variance_swapaxes(self):
        """variance must work with nonstandard timeaxis."""
        dat = variance(swapaxes(self.dat, 1, 2), timeaxis=2)
        # we don't swap back here as variance removes the timeaxis
        dat2 = variance(self.dat)
        self.assertEqual(dat, dat2)

    def test_variance_copy(self):
        """variance must not modify argument."""
        cpy = self.dat.copy()
        variance(self.dat)
        self.assertEqual(self.dat, cpy)
Exemple #19
0
class TestSelectIval(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 0, 10, endpoint=False)
        classes = [0, 0, 0]
        # three cnts: 1s, -1s, and 0s
        data = np.array([ones, ones * -1, ones * 0])
        self.dat = Data(data, [classes, time, channels],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.fs = 10

    def test_select_ival(self):
        """Selecting Intervals."""
        # normal case
        dat = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat.axes[1][0], -500)
        self.assertEqual(dat.axes[1][-1], -100)
        # the full dat interval
        dat = select_ival(self.dat,
                          [self.dat.axes[1][0], self.dat.axes[1][-1] + 1])
        self.assertEqual(dat.axes[1][0], self.dat.axes[1][0])
        self.assertEqual(dat.axes[1][-1], self.dat.axes[1][-1])
        np.testing.assert_array_equal(dat.data, self.dat.data)

    def test_select_ival_with_markers(self):
        """Selecting Intervals with markers."""
        # normal case
        good_markers = [[-499, 99, 'x'], [-500, 'x'], [-0.0001, 'x']]
        bad_markers = [[501, 'y'], [0, 'y'], [1, 'y']]
        self.dat.markers = good_markers[:]
        self.dat.markers.extend(bad_markers)
        dat = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat.markers, good_markers)

    def test_ival_checks(self):
        """Test for malformed ival parameter."""
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [0, -1])
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [self.dat.axes[1][0] - 1, 0])
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [0, self.dat.axes[1][-1] + 1])

    def test_select_ival_copy(self):
        """Select_ival must not modify the argument."""
        cpy = self.dat.copy()
        select_ival(cpy, [-500, 0])
        self.assertEqual(cpy, self.dat)

    def test_select_ival_swapaxes(self):
        """select_ival must work with nonstandard timeaxis."""
        dat = select_ival(swapaxes(self.dat, 0, 1), [-500, 0], timeaxis=0)
        dat = swapaxes(dat, 0, 1)
        dat2 = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat, dat2)
Exemple #20
0
class TestRectifytChannels(unittest.TestCase):

    def setUp(self):
        raw = np.arange(20).reshape(4, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ['time', 'channels'], ['ms', '#'])

    def test_rectify_channels(self):
        """Rectify channels of positive and negative data must be equal."""
        dat = rectify_channels(self.dat.copy(data=-self.dat.data))
        dat2 = rectify_channels(self.dat)
        self.assertEqual(dat, dat2)

    def test_rectify_channels_copy(self):
        """Rectify channels must not change the original parameter."""
        cpy = self.dat.copy()
        rectify_channels(self.dat)
        self.assertEqual(cpy, self.dat)
Exemple #21
0
class TestSelectIval(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 0, 10, endpoint=False)
        classes = [0, 0, 0]
        # three cnts: 1s, -1s, and 0s
        data = np.array([ones, ones * -1, ones * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.fs = 10

    def test_select_ival(self):
        """Selecting Intervals."""
        # normal case
        dat = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat.axes[1][0], -500)
        self.assertEqual(dat.axes[1][-1],-100)
        # the full dat interval
        dat = select_ival(self.dat, [self.dat.axes[1][0], self.dat.axes[1][-1] + 1])
        self.assertEqual(dat.axes[1][0], self.dat.axes[1][0])
        self.assertEqual(dat.axes[1][-1], self.dat.axes[1][-1])
        np.testing.assert_array_equal(dat.data, self.dat.data)

    def test_select_ival_with_markers(self):
        """Selecting Intervals with markers."""
        # normal case
        good_markers = [[-499,99, 'x'], [-500, 'x'], [-0.0001, 'x']]
        bad_markers = [[501, 'y'], [0, 'y'], [1, 'y']]
        self.dat.markers = good_markers[:]
        self.dat.markers.extend(bad_markers)
        dat = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat.markers, good_markers)

    def test_ival_checks(self):
        """Test for malformed ival parameter."""
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [0, -1])
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [self.dat.axes[1][0]-1, 0])
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [0, self.dat.axes[1][-1]+1])

    def test_select_ival_copy(self):
        """Select_ival must not modify the argument."""
        cpy = self.dat.copy()
        select_ival(cpy, [-500, 0])
        self.assertEqual(cpy, self.dat)

    def test_select_ival_swapaxes(self):
        """select_ival must work with nonstandard timeaxis."""
        dat = select_ival(swapaxes(self.dat, 0, 1), [-500, 0], timeaxis=0)
        dat = swapaxes(dat, 0, 1)
        dat2 = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat, dat2)
class TestAppendCnt(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones * 2, axis=0)
        cnt = np.append(cnt, ones * 3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        self.dat = Data(cnt, [time, channels], ['time', 'channel'],
                        ['ms', '#'])
        self.dat.markers = [[0, 'a'], [1, 'b']]
        self.dat.fs = 10

    def test_append_cnt(self):
        """append_cnt."""
        dat = append_cnt(self.dat, self.dat)
        self.assertEqual(dat.data.shape[0], 2 * self.dat.data.shape[0])
        self.assertEqual(len(dat.axes[0]), 2 * len(self.dat.axes[0]))
        np.testing.assert_array_equal(
            dat.data, np.concatenate([self.dat.data, self.dat.data], axis=0))
        np.testing.assert_array_equal(dat.axes[0],
                                      np.linspace(0, 6000, 60, endpoint=False))
        self.assertEqual(
            dat.markers,
            self.dat.markers + [[x[0] + 3000, x[1]] for x in self.dat.markers])

    def test_append_cnt_with_extra(self):
        """append_cnt with extra must work with list and ndarrays."""
        self.dat.a = list(range(10))
        self.dat.b = np.arange(10)
        dat = append_cnt(self.dat, self.dat, extra=['a', 'b'])
        self.assertEqual(dat.a, list(range(10)) + list(range(10)))
        np.testing.assert_array_equal(
            dat.b, np.concatenate([np.arange(10), np.arange(10)]))
        self.assertEqual(
            dat.markers,
            self.dat.markers + [[x[0] + 3000, x[1]] for x in self.dat.markers])

    def test_append_cnt_swapaxes(self):
        """append_cnt must work with nonstandard timeaxis."""
        dat = append_cnt(swapaxes(self.dat, 0, 1),
                         swapaxes(self.dat, 0, 1),
                         timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = append_cnt(self.dat, self.dat)
        self.assertEqual(dat, dat2)

    def test_append_cnt_copy(self):
        """append_cnt means must not modify argument."""
        cpy = self.dat.copy()
        append_cnt(self.dat, self.dat)
        self.assertEqual(self.dat, cpy)
Exemple #23
0
class TestSpectrumWelch(unittest.TestCase):
    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        self.amps = [30, 10, 2]
        t = np.linspace(0, dt, fs * dt)
        data = np.sum([
            a * np.sin(2 * np.pi * t * f)
            for a, f in zip(self.amps, self.freqs)
        ],
                      axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_spectrum_welch(self):
        """Calculate the spectrum."""
        nperseg = 100
        dat = spectrum_welch(self.dat, nperseg=nperseg)
        # check that the amplitudes are almost correct
        for idx, freq in enumerate(self.freqs):
            for chan in range(dat.data.shape[1]):
                self.assertAlmostEqual(
                    dat.data[np.argmin(np.abs(dat.axes[0] - freq)), chan],
                    self.amps[idx]**2 / 2,
                    delta=.15)
        # check that the max freq is <= self.dat.fs / 2, and min freq >= 0
        self.assertGreaterEqual(min(dat.axes[0]), 0)
        self.assertLessEqual(max(dat.axes[0]), self.dat.fs / 2)

    def test_spectrum_has_no_fs(self):
        """A spectrum has no sampling freq."""
        dat = spectrum_welch(self.dat)
        self.assertFalse(hasattr(dat, 'fs'))

    def test_spectrum_copy(self):
        """spectrum must not modify argument."""
        cpy = self.dat.copy()
        spectrum_welch(self.dat)
        self.assertEqual(cpy, self.dat)

    def test_spectrum_swapaxes(self):
        """spectrum must work with nonstandard timeaxis."""
        dat = spectrum_welch(swapaxes(self.dat, 0, 1), timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = spectrum_welch(self.dat)
        self.assertEqual(dat, dat2)
Exemple #24
0
class TestSpectrum(unittest.TestCase):

    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        self.amps = [30, 10, 2]
        t = np.linspace(0, dt, fs*dt)
        data = np.sum([a * np.sin(2*np.pi*t*f) for a, f in zip(self.amps, self.freqs)], axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_spectrum(self):
        """Calculate the spectrum."""
        dat = spectrum(self.dat)
        # check that the amplitudes are almost correct
        for idx, freq in enumerate(self.freqs):
            for chan in range(dat.data.shape[1]):
                self.assertAlmostEqual(dat.data[dat.axes[0] == freq, chan], self.amps[idx], delta=.15)
        # check the amplitudes for the remaining freqs are almost zero
        mask = (dat.axes[0] != self.freqs[0]) & (dat.axes[0] != self.freqs[1]) & (dat.axes[0] != self.freqs[2])
        self.assertFalse((dat.data[mask] > .8).any())
        # check that the max freq is < self.dat.fs / 2, and min freq > 0
        self.assertGreater(min(dat.axes[0]), 0)
        self.assertLess(max(dat.axes[0]), self.dat.fs / 2)

    def test_spectrum_has_no_fs(self):
        """A spectrum has no sampling freq."""
        dat = spectrum(self.dat)
        self.assertFalse(hasattr(dat, 'fs'))

    def test_spectrum_copy(self):
        """spectrum must not modify argument."""
        cpy = self.dat.copy()
        spectrum(self.dat)
        self.assertEqual(cpy, self.dat)

    def test_spectrum_swapaxes(self):
        """spectrum must work with nonstandard timeaxis."""
        dat = spectrum(swapaxes(self.dat, 0, 1), timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = spectrum(self.dat)
        self.assertEqual(dat, dat2)
Exemple #25
0
class TestLFilter(unittest.TestCase):
    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        amps = [30, 10, 2]
        t = np.linspace(0, dt, fs * dt)
        data = np.sum(
            [a * np.sin(2 * np.pi * t * f) for a, f in zip(amps, self.freqs)],
            axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_bandpass(self):
        """Band pass filtering."""
        # bandpass around the middle frequency
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        ans = lfilter(self.dat, b, a)
        # check if the desired band is not damped
        dat = spectrum(ans)
        mask = dat.axes[0] == 7
        self.assertTrue((dat.data[mask] > 6.5).all())
        # check if the outer freqs are damped close to zero
        mask = (dat.axes[0] <= 6) & (dat.axes[0] > 8)
        self.assertTrue((dat.data[mask] < .5).all())

    def test_lfilter_copy(self):
        """lfilter must not modify argument."""
        cpy = self.dat.copy()
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        lfilter(self.dat, b, a)
        self.assertEqual(cpy, self.dat)

    def test_lfilter_swapaxes(self):
        """lfilter must work with nonstandard timeaxis."""
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        dat = lfilter(swapaxes(self.dat, 0, 1), b, a, timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = lfilter(self.dat, b, a)
        self.assertEqual(dat, dat2)
Exemple #26
0
class TestFiltFilt(unittest.TestCase):

    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        amps = [30, 10, 2]
        t = np.linspace(0, dt, fs*dt)
        data = np.sum([a * np.sin(2*np.pi*t*f) for a, f in zip(amps, self.freqs)], axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_bandpass(self):
        """Band pass filtering."""
        # bandpass around the middle frequency
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        ans = filtfilt(self.dat, b, a)
        # check if the desired band is not damped
        dat = spectrum(ans)
        mask = dat.axes[0] == 7
        self.assertTrue((dat.data[mask] > 6.5).all())
        # check if the outer freqs are damped close to zero
        mask = (dat.axes[0] <= 6) & (dat.axes[0] > 8)
        self.assertTrue((dat.data[mask] < .5).all())

    def test_filtfilt_copy(self):
        """filtfilt must not modify argument."""
        cpy = self.dat.copy()
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        filtfilt(self.dat, b, a)
        self.assertEqual(cpy, self.dat)

    def test_filtfilt_swapaxes(self):
        """filtfilt must work with nonstandard timeaxis."""
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        dat = filtfilt(swapaxes(self.dat, 0, 1), b, a, timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = filtfilt(self.dat, b, a)
        self.assertEqual(dat, dat2)
class TestRemoveEpochs(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        channels = ["ca1", "ca2", "cb1", "cb2", "cc1"]
        time = np.linspace(0, 1000, 10, endpoint=False)
        classes = [0, 1, 2, 1]
        class_names = ["zeros", "ones", "twoes"]
        # four cnts: 0s, 1s, -1s, and 0s
        data = np.array([ones * 0, ones * 1, ones * 2, ones * 0])
        self.dat = Data(data, [classes, time, channels], ["class", "time", "channel"], ["#", "ms", "#"])
        self.dat.class_names = class_names

    def test_remove_epochs(self):
        """Removing Epochs."""
        # normal case
        dat = remove_epochs(self.dat, [0])
        self.assertEqual(dat.data.shape[0], 3)
        np.testing.assert_array_equal(dat.data, self.dat.data[1:])
        # normal every second
        dat = remove_epochs(self.dat, [0, 2])
        self.assertEqual(dat.data.shape[0], 2)
        np.testing.assert_array_equal(dat.data, self.dat.data[1::2])
        # the full epo
        dat = remove_epochs(self.dat, list(range(self.dat.data.shape[0])))
        np.testing.assert_array_equal(dat.data.shape[0], 0)

    def test_remove_epochs_with_cnt(self):
        """Remove epochs must raise an exception if called with cnt argument."""
        del (self.dat.class_names)
        with self.assertRaises(AssertionError):
            remove_epochs(self.dat, [0, 1])

    def test_remove_epochs_swapaxes(self):
        """Remove epochs must work with nonstandard classaxis."""
        dat = remove_epochs(swapaxes(self.dat, 0, 2), [0, 1], classaxis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = remove_epochs(self.dat, [0, 1])
        self.assertEqual(dat, dat2)

    def test_remove_epochs_copy(self):
        """Remove Epochs must not modify argument."""
        cpy = self.dat.copy()
        remove_epochs(self.dat, [0, 1])
        self.assertEqual(self.dat, cpy)
Exemple #28
0
class TestBandpass(unittest.TestCase):

    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        amps = [30, 10, 2]
        t = np.linspace(0, dt, fs*dt)
        data = np.sum([a * np.sin(2*np.pi*t*f) for a, f in zip(amps, self.freqs)], axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_band_pass(self):
        """Band pass filtering."""
        # bandpass around the middle frequency
        ans = band_pass(self.dat, 6, 8)
        # the amplitudes
        fourier = np.abs(rfft(ans.data, axis=0) * 2 / self.dat.data.shape[0])
        ffreqs = rfftfreq(ans.data.shape[0], 1/ans.fs)
        # check if the outer freqs are damped close to zero
        # freqs...
        for i in self.freqs[0], self.freqs[-1]:
            # buckets for freqs
            for j in fourier[ffreqs == i]:
                # channels
                for k in j:
                    self.assertAlmostEqual(k, 0., delta=.1)

    def test_band_pass_copy(self):
        """band_pass must not modify argument."""
        cpy = self.dat.copy()
        band_pass(self.dat, 2, 3, timeaxis=0)
        self.assertEqual(cpy, self.dat)

    def test_band_pass_swapaxes(self):
        """band_pass must work with nonstandard timeaxis."""
        dat = band_pass(swapaxes(self.dat, 0, 1), 6, 8, timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = band_pass(self.dat, 6, 8)
        self.assertEqual(dat, dat2)
Exemple #29
0
class TestAppendCnt(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones*2, axis=0)
        cnt = np.append(cnt, ones*3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        self.dat = Data(cnt, [time, channels], ['time', 'channel'], ['ms', '#'])
        self.dat.markers = [[0, 'a'], [1, 'b']]
        self.dat.fs = 10

    def test_append_cnt(self):
        """append_cnt."""
        dat = append_cnt(self.dat, self.dat)
        self.assertEqual(dat.data.shape[0], 2*self.dat.data.shape[0])
        self.assertEqual(len(dat.axes[0]), 2*len(self.dat.axes[0]))
        np.testing.assert_array_equal(dat.data, np.concatenate([self.dat.data, self.dat.data], axis=0))
        np.testing.assert_array_equal(dat.axes[0], np.linspace(0, 6000, 60, endpoint=False))
        self.assertEqual(dat.markers, self.dat.markers + [[x[0] + 3000, x[1]] for x in self.dat.markers])

    def test_append_cnt_with_extra(self):
        """append_cnt with extra must work with list and ndarrays."""
        self.dat.a = list(range(10))
        self.dat.b = np.arange(10)
        dat = append_cnt(self.dat, self.dat, extra=['a', 'b'])
        self.assertEqual(dat.a, list(range(10)) + list(range(10)))
        np.testing.assert_array_equal(dat.b, np.concatenate([np.arange(10), np.arange(10)]))
        self.assertEqual(dat.markers, self.dat.markers + [[x[0] + 3000, x[1]] for x in self.dat.markers])

    def test_append_cnt_swapaxes(self):
        """append_cnt must work with nonstandard timeaxis."""
        dat = append_cnt(swapaxes(self.dat, 0, 1), swapaxes(self.dat, 0, 1), timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = append_cnt(self.dat, self.dat)
        self.assertEqual(dat, dat2)

    def test_append_cnt_copy(self):
        """append_cnt means must not modify argument."""
        cpy = self.dat.copy()
        append_cnt(self.dat, self.dat)
        self.assertEqual(self.dat, cpy)
Exemple #30
0
class TestSquare(unittest.TestCase):
    def setUp(self):
        raw = np.arange(1, 21).reshape(4, 5)
        channels = ["ca1", "ca2", "cb1", "cb2", "cc1"]
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ["time", "channels"], ["ms", "#"])

    def test_square(self):
        """Square basics must work."""
        dat = square(self.dat)
        # works elementwise (does not alter the shape)
        self.assertEqual(self.dat.data.shape, dat.data.shape)
        # actual square was computed
        np.testing.assert_array_almost_equal(dat.data, np.square(self.dat.data))

    def test_square_copy(self):
        """Square must not change the original parameter."""
        cpy = self.dat.copy()
        square(self.dat)
        self.assertEqual(cpy, self.dat)
class TestCreateFeatureVectors(unittest.TestCase):

    def setUp(self):
        # create epoched data with only 0s in class0, 1s in class1 and
        # 2s in class2
        cnt = np.ones((10, 3))
        epo = np.array([0*cnt, 1*cnt, 2*cnt])
        time = np.arange(10)
        channels = np.array(['ch1', 'ch2', 'ch3'])
        classes = np.arange(3)
        axes = ['class', 'time', 'channel']
        units = ['#', 'ms', '#']
        self.dat = Data(epo, [classes, time, channels], axes, units)

    def test_create_feature_vectors(self):
        """Create Feature Vectors."""
        dat = create_feature_vectors(self.dat)
        self.assertTrue(all(dat.data[0] == 0))
        self.assertTrue(all(dat.data[1] == 1))
        self.assertTrue(all(dat.data[2] == 2))
        self.assertEqual(dat.data.ndim, 2)
        self.assertEqual(len(dat.axes), 2)
        self.assertEqual(len(dat.names), 2)
        self.assertEqual(len(dat.units), 2)
        self.assertEqual(dat.names[-1], 'feature vector')
        self.assertEqual(dat.units[-1], 'dl')
        np.testing.assert_array_equal(dat.axes[-1], np.arange(dat.data.shape[-1]))

    def test_create_feature_vectors_copy(self):
        """create_feature_vectors must not modify argument."""
        cpy = self.dat.copy()
        create_feature_vectors(self.dat)
        self.assertEqual(cpy, self.dat)

    def test_create_feature_vectors_swapaxes(self):
        """create_feature_vectors must work with nonstandard classaxis."""
        # keep in mind that create_feature_vectors already swaps the
        # axes internally to move the classaxis to 0
        dat = create_feature_vectors(swapaxes(self.dat, 0, 2), classaxis=2)
        dat2 = create_feature_vectors(self.dat)
        self.assertEqual(dat, dat2)
Exemple #32
0
class TestCreateFeatureVectors(unittest.TestCase):
    def setUp(self):
        # create epoched data with only 0s in class0, 1s in class1 and
        # 2s in class2
        cnt = np.ones((10, 3))
        epo = np.array([0 * cnt, 1 * cnt, 2 * cnt])
        time = np.arange(10)
        channels = np.array(['ch1', 'ch2', 'ch3'])
        classes = np.arange(3)
        axes = ['class', 'time', 'channel']
        units = ['#', 'ms', '#']
        self.dat = Data(epo, [classes, time, channels], axes, units)

    def test_create_feature_vectors(self):
        """Create Feature Vectors."""
        dat = create_feature_vectors(self.dat)
        self.assertTrue(all(dat.data[0] == 0))
        self.assertTrue(all(dat.data[1] == 1))
        self.assertTrue(all(dat.data[2] == 2))
        self.assertEqual(dat.data.ndim, 2)
        self.assertEqual(len(dat.axes), 2)
        self.assertEqual(len(dat.names), 2)
        self.assertEqual(len(dat.units), 2)
        self.assertEqual(dat.names[-1], 'feature vector')
        self.assertEqual(dat.units[-1], 'dl')
        np.testing.assert_array_equal(dat.axes[-1],
                                      np.arange(dat.data.shape[-1]))

    def test_create_feature_vectors_copy(self):
        """create_feature_vectors must not modify argument."""
        cpy = self.dat.copy()
        create_feature_vectors(self.dat)
        self.assertEqual(cpy, self.dat)

    def test_create_feature_vectors_swapaxes(self):
        """create_feature_vectors must work with nonstandard classaxis."""
        # keep in mind that create_feature_vectors already swaps the
        # axes internally to move the classaxis to 0
        dat = create_feature_vectors(swapaxes(self.dat, 0, 2), classaxis=2)
        dat2 = create_feature_vectors(self.dat)
        self.assertEqual(dat, dat2)
Exemple #33
0
class TestLogarithm(unittest.TestCase):

    def setUp(self):
        raw = np.arange(1, 21).reshape(4, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ['time', 'channels'], ['ms', '#'])

    def test_logarithm(self):
        """logarithm basics must work."""
        dat = logarithm(self.dat)
        # works elementwise (does not alter the shape)
        self.assertEqual(self.dat.data.shape, dat.data.shape)
        # actual log was computed
        np.testing.assert_array_almost_equal(np.e**dat.data, self.dat.data)

    def test_logarithm_copy(self):
        """Rectify channels must not change the original parameter."""
        cpy = self.dat.copy()
        logarithm(self.dat)
        self.assertEqual(cpy, self.dat)
class TestLogarithm(unittest.TestCase):
    def setUp(self):
        raw = np.arange(1, 21).reshape(4, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ['time', 'channels'],
                        ['ms', '#'])

    def test_logarithm(self):
        """logarithm basics must work."""
        dat = logarithm(self.dat)
        # works elementwise (does not alter the shape)
        self.assertEqual(self.dat.data.shape, dat.data.shape)
        # actual log was computed
        np.testing.assert_array_almost_equal(np.e**dat.data, self.dat.data)

    def test_logarithm_copy(self):
        """Rectify channels must not change the original parameter."""
        cpy = self.dat.copy()
        logarithm(self.dat)
        self.assertEqual(cpy, self.dat)
Exemple #35
0
class TestSquare(unittest.TestCase):
    def setUp(self):
        raw = np.arange(1, 21).reshape(4, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ['time', 'channels'],
                        ['ms', '#'])

    def test_square(self):
        """Square basics must work."""
        dat = square(self.dat)
        # works elementwise (does not alter the shape)
        self.assertEqual(self.dat.data.shape, dat.data.shape)
        # actual square was computed
        np.testing.assert_array_almost_equal(dat.data,
                                             np.square(self.dat.data))

    def test_square_copy(self):
        """Square must not change the original parameter."""
        cpy = self.dat.copy()
        square(self.dat)
        self.assertEqual(cpy, self.dat)
class TestCalculateClasswiseAverage(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 2))
        twoes = ones * 2
        # 7 epochs
        data = np.array([ones, ones, twoes, twoes, ones, twoes, twoes])
        channels = ['c1', 'c2']
        time = np.linspace(0, 1000, 10)
        classes = [0, 0, 1, 1, 0, 1, 1]
        class_names = ['ones', 'twoes']
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = class_names

    def test_calculate_classwise_average(self):
        """Calculate classwise average."""
        avg_dat = calculate_classwise_average(self.dat)
        # check for two datches (one for each class)
        self.assertEqual(avg_dat.data.shape[0], 2)
        # check if the data is correct
        self.assertEqual(np.average(avg_dat.data[0]), 1)
        self.assertEqual(np.average(avg_dat.data[1]), 2)
        # check if we have as many classes on axes as we have in data
        self.assertEqual(avg_dat.data.shape[0], len(avg_dat.axes[0]))
        #
        self.assertEqual(avg_dat.class_names, self.dat.class_names)

    def test_calculate_classwise_average_with_cnt(self):
        """Calculate classwise avg must raise an error if called with continouos data."""
        del(self.dat.class_names)
        with self.assertRaises(AssertionError):
            calculate_classwise_average(self.dat)

    def test_calculate_classwise_average_copy(self):
        """Calculate classwise avg must not modify the argument."""
        cpy = self.dat.copy()
        calculate_classwise_average(self.dat)
        self.assertEqual(self.dat, cpy)
class TestSortChannels(unittest.TestCase):
    def setUp(self):
        self.sorted_channels = np.array([name for name, pos in CHANNEL_10_20])
        channels = self.sorted_channels.copy()
        random.shuffle(channels)
        raw = np.random.random((5, 10, len(channels)))
        time = np.linspace(0, 1000, 10, endpoint=False)
        epochs = np.array([0, 1, 0, 1, 0])
        fs = 100
        marker = [[100, 'foo'], [200, 'bar']]
        self.dat = Data(raw, [epochs, time, channels],
                        ['class', 'time', 'channels'], ['#', 'ms', '#'])
        self.dat.fs = fs
        self.dat.markers = marker

    def test_sort_channels(self):
        """sort_channels must sort correctly."""
        dat = sort_channels(self.dat)
        np.testing.assert_array_equal(dat.axes[-1], self.sorted_channels)

    def test_sort_channels_with_unknown_channel(self):
        """Unknown channels move to the back."""
        self.dat.axes[-1][7] = 'XX'
        dat = sort_channels(self.dat)
        self.assertEqual(dat.axes[-1][-1], 'XX')

    def test_sort_channels_swapaxis(self):
        """sort_channels must workt with nonstandard chanaxis."""
        sorted_ = sort_channels(swapaxes(self.dat, 1, -1), 1)
        sorted_ = swapaxes(sorted_, 1, -1)
        sorted2 = sort_channels(self.dat)
        self.assertEqual(sorted_, sorted2)

    def test_sort_channels_copy(self):
        """sort_channels must not modify argument."""
        cpy = self.dat.copy()
        sort_channels(self.dat)
        self.assertEqual(self.dat, cpy)
class TestSelectChannels(unittest.TestCase):
    def setUp(self):
        raw = np.arange(20).reshape(4, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ['time', 'channels'],
                        ['ms', '#'])

    def test_select_channels(self):
        """Selecting channels with an array of regexes."""
        channels = self.dat.data.copy()
        self.dat = select_channels(self.dat, ['ca.*', 'cc1'])
        np.testing.assert_array_equal(self.dat.axes[-1],
                                      np.array(['ca1', 'ca2', 'cc1']))
        np.testing.assert_array_equal(self.dat.data,
                                      channels[:, np.array([0, 1, -1])])

    def test_select_channels_inverse(self):
        """Removing channels with an array of regexes."""
        channels = self.dat.data.copy()
        self.dat = select_channels(self.dat, ['ca.*', 'cc1'], invert=True)
        np.testing.assert_array_equal(self.dat.axes[-1],
                                      np.array(['cb1', 'cb2']))
        np.testing.assert_array_equal(self.dat.data,
                                      channels[:, np.array([2, 3])])

    def test_select_channels_copy(self):
        """Select channels must not change the original parameter."""
        cpy = self.dat.copy()
        select_channels(self.dat, ['ca.*'])
        self.assertEqual(cpy, self.dat)

    def test_select_channels_swapaxis(self):
        """Select channels works with non default chanaxis."""
        dat1 = select_channels(swapaxes(self.dat, 0, 1), ['ca.*'], chanaxis=0)
        dat1 = swapaxes(dat1, 0, 1)
        dat2 = select_channels(self.dat, ['ca.*'])
        self.assertEqual(dat1, dat2)
Exemple #39
0
class TestCalculateClasswiseAverage(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 2))
        twoes = ones * 2
        # 7 epochs
        data = np.array([ones, ones, twoes, twoes, ones, twoes, twoes])
        channels = ['c1', 'c2']
        time = np.linspace(0, 1000, 10)
        classes = [0, 0, 1, 1, 0, 1, 1]
        class_names = ['ones', 'twoes']
        self.dat = Data(data, [classes, time, channels],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = class_names

    def test_calculate_classwise_average(self):
        """Calculate classwise average."""
        avg_dat = calculate_classwise_average(self.dat)
        # check for two datches (one for each class)
        self.assertEqual(avg_dat.data.shape[0], 2)
        # check if the data is correct
        self.assertEqual(np.average(avg_dat.data[0]), 1)
        self.assertEqual(np.average(avg_dat.data[1]), 2)
        # check if we have as many classes on axes as we have in data
        self.assertEqual(avg_dat.data.shape[0], len(avg_dat.axes[0]))
        #
        self.assertEqual(avg_dat.class_names, self.dat.class_names)

    def test_calculate_classwise_average_with_cnt(self):
        """Calculate classwise avg must raise an error if called with continouos data."""
        del (self.dat.class_names)
        with self.assertRaises(AssertionError):
            calculate_classwise_average(self.dat)

    def test_calculate_classwise_average_copy(self):
        """Calculate classwise avg must not modify the argument."""
        cpy = self.dat.copy()
        calculate_classwise_average(self.dat)
        self.assertEqual(self.dat, cpy)
Exemple #40
0
class TestSortChannels(unittest.TestCase):

    def setUp(self):
        self.sorted_channels = np.array([name for name, pos in CHANNEL_10_20])
        channels = self.sorted_channels.copy()
        random.shuffle(channels)
        raw = np.random.random((5, 10, len(channels)))
        time = np.linspace(0, 1000, 10, endpoint=False)
        epochs = np.array([0, 1, 0, 1, 0])
        fs = 100
        marker = [[100, 'foo'], [200, 'bar']]
        self.dat = Data(raw, [epochs, time, channels], ['class', 'time', 'channels'], ['#', 'ms', '#'])
        self.dat.fs = fs
        self.dat.markers = marker

    def test_sort_channels(self):
        """sort_channels must sort correctly."""
        dat = sort_channels(self.dat)
        np.testing.assert_array_equal(dat.axes[-1], self.sorted_channels)

    def test_sort_channels_with_unknown_channel(self):
        """Unknown channels move to the back."""
        self.dat.axes[-1][7] = 'XX'
        dat = sort_channels(self.dat)
        self.assertEqual(dat.axes[-1][-1], 'XX')

    def test_sort_channels_swapaxis(self):
        """sort_channels must workt with nonstandard chanaxis."""
        sorted_ = sort_channels(swapaxes(self.dat, 1, -1), 1)
        sorted_ = swapaxes(sorted_, 1, -1)
        sorted2 = sort_channels(self.dat)
        self.assertEqual(sorted_, sorted2)

    def test_sort_channels_copy(self):
        """sort_channels must not modify argument."""
        cpy = self.dat.copy()
        sort_channels(self.dat)
        self.assertEqual(self.dat, cpy)
class TestCalculateWhitentingMatrix(unittest.TestCase):
    def setUp(self):
        data = np.random.randn(SAMPLES, CHANS)
        data[:, 1] += 0.5 * data[:, 0]
        data[:, 2] -= 0.5 * data[:, 0]
        t = np.arange(SAMPLES)
        chans = ['chan{i}'.format(i=i) for i in range(CHANS)]
        self.cnt = Data(data, [t, chans], ['time', 'channels'], ['ms', '#'])

    def test_shape(self):
        """The whitening filter should have the shape: CHANSxCHANS."""
        a = calculate_whitening_matrix(self.cnt)
        self.assertEqual(a.shape, (CHANS, CHANS))

    def test_diagonal(self):
        """The whitened data should have all 1s on the covariance matrix."""
        a = calculate_whitening_matrix(self.cnt)
        dat2 = np.dot(self.cnt.data, a)
        vals = np.diag(np.cov(dat2.T))
        np.testing.assert_array_almost_equal(vals,
                                             [1. for i in range(len(vals))])

    def test_zeros(self):
        """The whinened data should have all 0s on the non-diagonals of the covariance matrix."""
        a = calculate_whitening_matrix(self.cnt)
        dat2 = np.dot(self.cnt.data, a)
        cov = np.cov(dat2.T)
        # substract the diagonals
        cov -= np.diag(np.diag(cov))
        self.assertAlmostEqual(np.sum(cov), 0)

    def test_calculate_whitening_matrix_copy(self):
        """calculate_whitening_matrix must not modify arguments."""
        cpy = self.cnt.copy()
        calculate_whitening_matrix(self.cnt)
        self.assertEqual(self.cnt, cpy)
class TestCalculateWhitentingMatrix(unittest.TestCase):

    def setUp(self):
        data = np.random.randn(SAMPLES, CHANS)
        data[:, 1] += 0.5 * data[:, 0]
        data[:, 2] -= 0.5 * data[:, 0]
        t = np.arange(SAMPLES)
        chans = ['chan{i}'.format(i=i) for i in range(CHANS)]
        self.cnt = Data(data, [t, chans], ['time', 'channels'], ['ms', '#'])

    def test_shape(self):
        """The whitening filter should have the shape: CHANSxCHANS."""
        a = calculate_whitening_matrix(self.cnt)
        self.assertEqual(a.shape, (CHANS, CHANS))

    def test_diagonal(self):
        """The whitened data should have all 1s on the covariance matrix."""
        a = calculate_whitening_matrix(self.cnt)
        dat2 = apply_spatial_filter(self.cnt, a)
        vals = np.diag(np.cov(dat2.data.T))
        np.testing.assert_array_almost_equal(vals, [1. for i in range(len(vals))])

    def test_zeros(self):
        """The whitened data should have all 0s on the non-diagonals of the covariance matrix."""
        a = calculate_whitening_matrix(self.cnt)
        dat2 = apply_spatial_filter(self.cnt, a)
        cov = np.cov(dat2.data.T)
        # substract the diagonals
        cov -= np.diag(np.diag(cov))
        self.assertAlmostEqual(np.sum(cov), 0)

    def test_calculate_whitening_matrix_copy(self):
        """calculate_whitening_matrix must not modify arguments."""
        cpy = self.cnt.copy()
        calculate_whitening_matrix(self.cnt)
        self.assertEqual(self.cnt, cpy)
class TestCalculateCSP(unittest.TestCase):

    EPOCHS = 50
    SAMPLES = 100
    SOURCES = 2
    CHANNELS = 10

    def setUp(self):
        # create a random noise signal with 50 epochs, 100 samples, and
        # 2 sources
        # even epochs and source 0: *= 5
        # odd epochs and source 1: *= 5
        self.s = np.random.randn(self.EPOCHS, self.SAMPLES, self.SOURCES)
        self.s[ ::2, :, 0] *= 5
        self.s[1::2, :, 1] *= 5
        # the mixmatrix which converts our sources to channels
        # X = As + noise
        self.A = np.random.randn(self.CHANNELS, self.SOURCES)
        # our 'signal' which 50 epochs, 100 samples and 10 channels
        self.X = np.empty((self.EPOCHS, self.SAMPLES, self.CHANNELS))
        for i in range(self.EPOCHS):
            self.X[i] = np.dot(self.A, self.s[i].T).T
        noise = np.random.randn(self.EPOCHS, self.SAMPLES, self.CHANNELS) * 0.01
        self.X += noise

        a = np.array([1 for i in range(self.X.shape[0])])
        a[0::2] = 0
        axes = [a, np.arange(self.X.shape[1]), np.arange(self.X.shape[2])]
        self.epo = Data(self.X, axes=axes, names=['class', 'time', 'channel'], units=['#', 'ms', '#'])
        self.epo.class_names = ['foo', 'bar']

    def test_d(self):
        """Test if the first lambda is almost 1 and the last one almost -1."""
        W, A_est, d = calculate_csp(self.epo)
        epsilon = 0.1
        self.assertAlmostEqual(d[0], 1, delta=epsilon)
        self.assertAlmostEqual(d[-1], -1, delta=epsilon)

    def test_A(self):
        """Test if A_est is elementwise almost equal A."""
        W, A_est, d = calculate_csp(self.epo)
        # A and A_est can have a different scaling, after normalizing
        # and correcting for sign, they should be almost equal
        # normalize (we're only interested in the first and last column)
        for i in 0, -1:
            idx = np.argmax(np.abs(A_est[:, i]))
            A_est[:, i] /= A_est[idx, i]
            idx = np.argmax(np.abs(self.A[:, i]))
            self.A[:, i] /= self.A[idx, i]
        # for i in 0, -1:
        #   check elementwise if A[:, i] almost A_est[:, i]
        epsilon = 0.01
        for i in 0, -1:
            diff = self.A[:, i] - A_est[:, i]
            diff = np.abs(diff)
            diff = np.sum(diff) / self.A.shape[0]
            self.assertTrue(diff < epsilon)

    def test_s(self):
        """Test if s_est is elementwise almost equal s."""
        W, A_est, d = calculate_csp(self.epo)
        # applying the filter to X gives us s_est which should be almost
        # equal s
        s_est = np.empty(self.s.shape)
        for i in range(self.EPOCHS):
            s_est[i] = np.dot(self.X[i], W[:, [0, -1]])
        # correct for scaling, and sign
        self.s = self.s.reshape(-1, self.SOURCES)
        s_est2 = s_est.reshape(-1, self.SOURCES)
        epsilon = 0.01
        for i in range(self.SOURCES):
            idx = np.argmax(np.abs(s_est2[:, i]))
            s_est2[:, i] /= s_est2[idx, i]
            idx = np.argmax(np.abs(self.s[:, i]))
            self.s[:, i] /= self.s[idx, i]
            diff = np.sum(np.abs(self.s[:, i] - s_est2[:, i])) / self.s.shape[0]
            self.assertTrue(diff < epsilon)

    def test_manual_class_selection(self):
        """Manual class indices selection must work."""
        w, a, d = calculate_csp(self.epo)
        w2, a2, d2 = calculate_csp(self.epo, [0, 1])
        np.testing.assert_array_equal(w, w2)
        np.testing.assert_array_equal(a, a2)
        np.testing.assert_array_equal(d, d2)
        w2, a2, d2 = calculate_csp(self.epo, [1, 0])
        np.testing.assert_array_almost_equal(np.abs(w), np.abs(w2[:, ::-1]))
        np.testing.assert_array_almost_equal(np.abs(a), np.abs(a2[:, ::-1]))
        np.testing.assert_array_almost_equal(np.abs(d), np.abs(d2[::-1]))

    def test_raise_error_on_wrong_manual_classes(self):
        """Raise error if classes not in epo."""
        with self.assertRaises(AssertionError):
            calculate_csp(self.epo, [0, 2])
            calculate_csp(self.epo, [0, -1])

    def test_raise_error_with_automatic_classes(self):
        """Raise error if not enough classes in epo."""
        self.epo.axes[0][:] = 0
        with self.assertRaises(AssertionError):
            calculate_csp(self.epo)


    #def test_calculate_csp_swapaxes(self):
    #    """caluclate_csp must work with nonstandard classaxis."""
    #    dat = calculate_csp(swapaxes(self.epo, 0, 2), classaxis=2, chanaxis=0)
    #    dat2 = calculate_csp(self.epo)
    #    np.testing.assert_array_equal(dat[0], dat2[0])

    def test_calculate_csp_copy(self):
        """caluclate_csp must not modify argument."""
        cpy = self.epo.copy()
        calculate_csp(self.epo)
        self.assertEqual(self.epo, cpy)
class TestSelectClasses(unittest.TestCase):

    def setUp(self):
        # create noisy data [40 epochs, 100 samples, 64 channels] with
        # values 0..1
        dat = np.random.uniform(size=(40, 100, 64))
        # every second epoch belongs to class 0 and 1 alterning
        # for class 1 add 1 in interval 40..80
        # for class 2 add 1 in interval 20..60
        #
        #        * .* .
        #
        # .* .* .      * .*
        # ------------------>
        # 0  20 40 60 80 100
        dat[::2,40:80:,:] += 1
        dat[1::2,20:60:,:] += 1
        time = np.arange(dat.shape[1])
        classes = np.zeros(dat.shape[0])
        classes[::2] = 1
        chans = np.arange(64)
        self.dat = Data(dat, [classes, time, chans], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = 'one', 'two'

    def test_calculate_signed_r_square(self):
        """Calculating signed r**2."""
        dat = calculate_signed_r_square(self.dat)
        self.assertEqual(dat.ndim + 1, self.dat.data.ndim)
        # average over channels (one could also take just one channel)
        dat = dat.mean(axis=1)
        # check the intervals
        self.assertTrue(all(dat[0:20] < .2))
        self.assertTrue(all(dat[20:40] > .5))
        self.assertTrue(all(dat[40:60] < .2))
        self.assertTrue(all(dat[60:80] < .5))
        self.assertTrue(all(dat[80:100] < .2))

    def test_calculate_signed_r_square_min_max(self):
        """Min and max values must be in [-1, 1]."""
        dat = calculate_signed_r_square(self.dat)
        self.assertTrue(-1 <= np.min(dat) <= 1)
        self.assertTrue(-1 <= np.max(dat) <= 1)

    def test_calculate_signed_r_square_with_cnt(self):
        """Select epochs must raise an exception if called with cnt argument."""
        del(self.dat.class_names)
        with self.assertRaises(AssertionError):
            calculate_signed_r_square(self.dat)

    def test_calculate_signed_r_square_swapaxes(self):
        """caluclate_r_square must work with nonstandard classaxis."""
        dat = calculate_signed_r_square(swapaxes(self.dat, 0, 2), classaxis=2)
        # the class-axis just dissapears during
        # calculate_signed_r_square, so axis 2 becomes axis 1
        dat = dat.swapaxes(0, 1)
        dat2 = calculate_signed_r_square(self.dat)
        np.testing.assert_array_equal(dat, dat2)

    def test_calculate_signed_r_square_copy(self):
        """caluclate_r_square must not modify argument."""
        cpy = self.dat.copy()
        calculate_signed_r_square(self.dat)
        self.assertEqual(self.dat, cpy)
class TestApplyCSP(unittest.TestCase):

    EPOCHS = 50
    SAMPLES = 100
    SOURCES = 2
    CHANNELS = 10

    def setUp(self):
        # create a random noise signal with 50 epochs, 100 samples, and
        # 2 sources
        # even epochs and source 0: *= 5
        # odd epochs and source 1: *= 5
        self.s = np.random.randn(self.EPOCHS, self.SAMPLES, self.SOURCES)
        self.s[::2, :, 0] *= 5
        self.s[1::2, :, 1] *= 5
        # the mixmatrix which converts our sources to channels
        # X = As + noise
        self.A = np.random.randn(self.CHANNELS, self.SOURCES)
        # our 'signal' which 50 epochs, 100 samples and 10 channels
        self.X = np.empty((self.EPOCHS, self.SAMPLES, self.CHANNELS))
        for i in range(self.EPOCHS):
            self.X[i] = np.dot(self.A, self.s[i].T).T
        noise = np.random.randn(self.EPOCHS, self.SAMPLES,
                                self.CHANNELS) * 0.01
        self.X += noise

        a = np.array([1 for i in range(self.X.shape[0])])
        a[0::2] = 0
        axes = [a, np.arange(self.X.shape[1]), np.arange(self.X.shape[2])]
        self.epo = Data(self.X,
                        axes=axes,
                        names=['class', 'time', 'channel'],
                        units=['#', 'ms', '#'])
        self.epo.class_names = ['foo', 'bar']

        self.filter = np.random.random((self.CHANNELS, self.CHANNELS))

    def test_apply_csp(self):
        """apply_csp."""
        dat = apply_csp(self.epo, self.filter)
        # reduce the channels down to 2, the rest of the shape should
        # stay the same
        self.assertEqual(self.epo.data.shape[0], dat.data.shape[0])
        self.assertEqual(self.epo.data.shape[1], dat.data.shape[1])
        self.assertEqual(2, dat.data.shape[2])
        # new name for csp axis
        self.assertEqual(dat.names[-1], 'CSP Channel')
        # check if the dot product was calculated correctly
        d = np.array([
            np.dot(self.epo.data[i], self.filter[:, [0, -1]])
            for i in range(self.epo.data.shape[0])
        ])
        np.testing.assert_array_equal(d, dat.data)

    #def test_apply_csp_swapaxes(self):
    #    """apply_csp must work with nonstandard classaxis."""
    #    dat = apply_csp(swapaxes(self.epo, 0, 1), self.filter.T, classaxis=1)
    #    #dat = swapaxes(dat, 0, 1)
    #    print dat.data.shape
    #    dat2 = apply_csp(self.epo, self.filter)
    #    print
    #    print dat2.data.shape
    #    self.assertEqual(dat, dat2)

    def test_apply_csp_copy(self):
        """apply_csp must not modify argument."""
        cpy = self.epo.copy()
        apply_csp(self.epo, self.filter)
        self.assertEqual(self.epo, cpy)
class TestApplySpatialFilter(unittest.TestCase):
    def setUp(self):
        data = np.random.randn(SAMPLES, CHANS)
        data[:, 1] += 0.5 * data[:, 0]
        data[:, 2] -= 0.5 * data[:, 0]
        t = np.arange(SAMPLES)
        chans = ['chan{i}'.format(i=i) for i in range(CHANS)]
        self.cnt = Data(data, [t, chans], ['time', 'channels'], ['ms', '#'])

        # construct epo
        epo_dat = np.array([data for i in range(EPOS)])
        classes = ['class{i}'.format(i=i) for i in range(EPOS)]
        self.epo = Data(epo_dat, [classes, t, chans],
                        ['class', 'time', 'channels'], ['#', 'ms', '#'])

        # my little spatial filter
        self.w = np.array([[0, 0.5, 1], [-1, 0.5, 0], [1, 0.5, 0]])

    def test_shape(self):
        """The spatial filtered data should keep its shape."""
        cnt_f = apply_spatial_filter(self.cnt, self.w)
        epo_f = apply_spatial_filter(self.epo, self.w)
        self.assertEqual(self.cnt.data.shape, cnt_f.data.shape)
        self.assertEqual(self.epo.data.shape, epo_f.data.shape)

    def test_spatial_filter_cnt(self):
        """Spatial filtering should work with cnt."""
        cnt_f = apply_spatial_filter(self.cnt, self.w)
        # chan 0
        np.testing.assert_array_equal(
            cnt_f.data[:, 0], self.cnt.data[:, 2] - self.cnt.data[:, 1])
        # chan 1
        np.testing.assert_array_equal(cnt_f.data[:, 1],
                                      0.5 * np.sum(self.cnt.data, axis=-1))
        # chan 2
        np.testing.assert_array_equal(cnt_f.data[:, 2], self.cnt.data[:, 0])

    def test_spatial_filter_epo(self):
        """Spatial filtering should work with epo."""
        cnt_f = apply_spatial_filter(self.cnt, self.w)
        epo_f = apply_spatial_filter(self.epo, self.w)
        for i in range(EPOS):
            np.testing.assert_array_equal(epo_f.data[i, ...], cnt_f.data)

    def test_prefix_and_postfix(self):
        """Prefix and Postfix are mutual exclusive."""
        with self.assertRaises(ValueError):
            apply_spatial_filter(self.cnt, self.w, prefix='foo', postfix='bar')

    def test_prefix(self):
        """Apply prefix correctly."""
        cnt_f = apply_spatial_filter(self.cnt, self.w, prefix='foo')
        self.assertEqual(cnt_f.axes[-1],
                         ['foo' + str(i) for i in range(CHANS)])

    def test_prefix_w_wrong_type(self):
        """Raise TypeError if prefix is neither None or str."""
        with self.assertRaises(TypeError):
            apply_spatial_filter(self.cnt, self.w, prefix=1)

    def test_postfix(self):
        """Apply postfix correctly."""
        cnt_f = apply_spatial_filter(self.cnt, self.w, postfix='foo')
        self.assertEqual(cnt_f.axes[-1],
                         [c + 'foo' for c in self.cnt.axes[-1]])

    def test_postfix_w_wrong_type(self):
        """Raise TypeError if postfix is neither None or str."""
        with self.assertRaises(TypeError):
            apply_spatial_filter(self.cnt, self.w, postfix=1)

    def test_apply_spatial_filter_copy(self):
        """apply_spatial_filter must not modify arguments."""
        cpy = self.cnt.copy()
        apply_spatial_filter(self.cnt, self.w)
        self.assertEqual(self.cnt, cpy)

    def test_apply_spatial_filter_swapaxes(self):
        """apply_spatial_filter must work with nonstandard chanaxis."""
        epo_f = apply_spatial_filter(swapaxes(self.epo, 1, -1),
                                     self.w,
                                     chanaxis=1)
        epo_f = swapaxes(epo_f, 1, -1)
        epo_f2 = apply_spatial_filter(self.epo, self.w)
        self.assertEqual(epo_f, epo_f2)
Exemple #47
0
class TestJumpingMeans(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones * 2, axis=0)
        cnt = np.append(cnt, ones * 3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])

    def test_jumping_means(self):
        """Jumping means."""
        # with several ivals
        dat = jumping_means(self.dat, [[0, 1000], [1000, 2000], [2000, 3000]])
        newshape = list(self.dat.data.shape)
        newshape[1] = 3
        self.assertEqual(list(dat.data.shape), newshape)
        # first epo (0)
        self.assertEqual(dat.data[0, 0, 0], 0)
        self.assertEqual(dat.data[0, 1, 0], 0)
        self.assertEqual(dat.data[0, 2, 0], 0)
        # second epo (1)
        self.assertEqual(dat.data[1, 0, 0], 1)
        self.assertEqual(dat.data[1, 1, 0], 2)
        self.assertEqual(dat.data[1, 2, 0], 3)
        # third epo (2)
        self.assertEqual(dat.data[2, 0, 0], 2)
        self.assertEqual(dat.data[2, 1, 0], 4)
        self.assertEqual(dat.data[2, 2, 0], 6)
        # fourth epo (0)
        self.assertEqual(dat.data[3, 0, 0], 0)
        self.assertEqual(dat.data[3, 1, 0], 0)
        self.assertEqual(dat.data[3, 2, 0], 0)
        # with one ival
        dat = jumping_means(self.dat, [[0, 1000]])
        newshape = list(self.dat.data.shape)
        newshape[1] = 1
        self.assertEqual(list(dat.data.shape), newshape)
        # first epo (0)
        self.assertEqual(dat.data[0, 0, 0], 0)
        # second epo (1)
        self.assertEqual(dat.data[1, 0, 0], 1)
        # third epo (2)
        self.assertEqual(dat.data[2, 0, 0], 2)
        # fourth epo (0)
        self.assertEqual(dat.data[3, 0, 0], 0)

    def test_jumping_means_with_cnt(self):
        """jumping_means must work with cnt argument."""
        data = self.dat.data[1]
        axes = self.dat.axes[1:]
        names = self.dat.names[1:]
        units = self.dat.units[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat = jumping_means(dat, [[0, 1000], [1000, 2000]])
        self.assertEqual(dat.data[0, 0], 1)
        self.assertEqual(dat.data[1, 0], 2)

    def test_jumping_means_swapaxes(self):
        """jumping means must work with nonstandard timeaxis."""
        dat = jumping_means(swapaxes(self.dat, 1, 2), [[0, 1000]], timeaxis=2)
        dat = swapaxes(dat, 1, 2)
        dat2 = jumping_means(self.dat, [[0, 1000]])
        self.assertEqual(dat, dat2)

    def test_jumping_means_copy(self):
        """jumping means must not modify argument."""
        cpy = self.dat.copy()
        jumping_means(self.dat, [[0, 1000]])
        self.assertEqual(self.dat, cpy)
Exemple #48
0
class TestSelectClasses(unittest.TestCase):
    def setUp(self):
        # create noisy data [40 epochs, 100 samples, 64 channels] with
        # values 0..1
        dat = np.random.uniform(size=(40, 100, 64))
        # every second epoch belongs to class 0 and 1 alterning
        # for class 1 add 1 in interval 40..80
        # for class 2 add 1 in interval 20..60
        #
        #        * .* .
        #
        # .* .* .      * .*
        # ------------------>
        # 0  20 40 60 80 100
        dat[::2, 40:80:, :] += 1
        dat[1::2, 20:60:, :] += 1
        time = np.arange(dat.shape[1])
        classes = np.zeros(dat.shape[0])
        classes[::2] = 1
        chans = np.arange(64)
        self.dat = Data(dat, [classes, time, chans],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = 'one', 'two'

    def test_calculate_signed_r_square(self):
        """Calculating signed r**2."""
        dat = calculate_signed_r_square(self.dat)
        self.assertEqual(dat.ndim + 1, self.dat.data.ndim)
        # average over channels (one could also take just one channel)
        dat = dat.mean(axis=1)
        # check the intervals
        self.assertTrue(all(dat[0:20] < .2))
        self.assertTrue(all(dat[20:40] > .5))
        self.assertTrue(all(dat[40:60] < .2))
        self.assertTrue(all(dat[60:80] < .5))
        self.assertTrue(all(dat[80:100] < .2))

    def test_calculate_signed_r_square_min_max(self):
        """Min and max values must be in [-1, 1]."""
        dat = calculate_signed_r_square(self.dat)
        self.assertTrue(-1 <= np.min(dat) <= 1)
        self.assertTrue(-1 <= np.max(dat) <= 1)

    def test_calculate_signed_r_square_with_cnt(self):
        """Select epochs must raise an exception if called with cnt argument."""
        del (self.dat.class_names)
        with self.assertRaises(AssertionError):
            calculate_signed_r_square(self.dat)

    def test_calculate_signed_r_square_swapaxes(self):
        """caluclate_r_square must work with nonstandard classaxis."""
        dat = calculate_signed_r_square(swapaxes(self.dat, 0, 2), classaxis=2)
        # the class-axis just dissapears during
        # calculate_signed_r_square, so axis 2 becomes axis 1
        dat = dat.swapaxes(0, 1)
        dat2 = calculate_signed_r_square(self.dat)
        # this used to work with numpy 1.8, but with 1.9 the arrays
        # differ slightly after e-15, I don't see why this change
        # happened, but it is barely noticeable wo we check for almost
        # equality
        # np.testing.assert_array_equal(dat, dat2)
        np.testing.assert_array_almost_equal(dat, dat2)

    def test_calculate_signed_r_square_copy(self):
        """caluclate_r_square must not modify argument."""
        cpy = self.dat.copy()
        calculate_signed_r_square(self.dat)
        self.assertEqual(self.dat, cpy)
class TestCalculateSpoc(unittest.TestCase):

    EPOCHS = 50
    SAMPLES = 100
    SOURCES = 2
    CHANNELS = 10

    def setUp(self):
        # generate sources with independent variance modulations, the
        # first source will be the target source
        z = np.abs(np.random.randn(self.EPOCHS, self.SOURCES))
        for i in range(self.SOURCES):
            z[:, i] /= z[:, i].std()
        self.s = np.random.randn(self.EPOCHS, self.SAMPLES, self.SOURCES)
        for i in range(self.SOURCES):
            for j in range(self.EPOCHS):
                self.s[j, :, i] *= z[j, i]
        # the mixmatrix which converts our sources to channels
        # X = As + noise
        self.A = np.random.randn(self.CHANNELS, self.SOURCES)
        # our 'signal' which 50 epochs, 100 samples and 10 channels
        self.X = np.empty((self.EPOCHS, self.SAMPLES, self.CHANNELS))
        for i in range(self.EPOCHS):
            self.X[i] = np.dot(self.A, self.s[i].T).T
        noise = np.random.randn(self.EPOCHS, self.SAMPLES,
                                self.CHANNELS) * 0.01
        self.X += noise
        # convert to epo
        axes = [
            z[:, 0],
            np.arange(self.X.shape[1]),
            np.arange(self.X.shape[2])
        ]
        self.epo = Data(self.X,
                        axes=axes,
                        names=['target_variable', 'time', 'channel'],
                        units=['#', 'ms', '#'])

    def test_d(self):
        """Test if the list of lambdas is reverse-sorted and the first one > 0."""
        W, A_est, d = calculate_spoc(self.epo)
        self.assertTrue(d[0] > 0)
        self.assertTrue(np.all(d == np.sort(d)[::-1]))

    def test_A(self):
        """Test if A_est is elementwise almost equal A."""
        W, A_est, d = calculate_spoc(self.epo)
        # A and A_est can have a different scaling, after normalizing
        # and correcting for sign, the first pattern should be almost
        # equal the source pattern
        idx = np.argmax(np.abs(A_est[:, 0]))
        A_est[:, 0] /= A_est[idx, 0]
        idx = np.argmax(np.abs(self.A[:, 0]))
        self.A[:, 0] /= self.A[idx, 0]
        # check elementwise if A[:, 0] almost A_est[:, 0]
        epsilon = 0.01
        diff = self.A[:, 0] - A_est[:, 0]
        diff = np.abs(diff)
        diff = np.sum(diff) / self.A.shape[0]
        self.assertTrue(diff < epsilon)

    def test_s(self):
        """Test if s_est is elementwise almost equal s."""
        W, A_est, d = calculate_spoc(self.epo)
        # applying the filter to X gives us s_est which should be almost
        # equal s
        s_est = np.empty(self.s.shape[:2])
        for i in range(self.EPOCHS):
            s_est[i] = np.dot(self.X[i], W[:, 0])
        s_true = self.s[..., 0]
        epsilon = 0.001

        # correct for scale
        s_true /= s_true.std()
        s_est /= s_est.std()

        # correct for sign
        s_true = np.abs(s_true)
        s_est = np.abs(s_est)

        diff = np.sum(s_true - s_est) / (self.s.shape[0] * self.s.shape[1])
        self.assertTrue(diff < epsilon)

    #def test_calculate_signed_r_square_swapaxes(self):
    #    """caluclate_r_square must work with nonstandard classaxis."""
    #    dat = calculate_signed_r_square(swapaxes(self.dat, 0, 2), classaxis=2)
    #    # the class-axis just dissapears during
    #    # calculate_signed_r_square, so axis 2 becomes axis 1
    #    dat = dat.swapaxes(0, 1)
    #    dat2 = calculate_signed_r_square(self.dat)
    #    np.testing.assert_array_equal(dat, dat2)

    def test_calculate_signed_r_square_copy(self):
        """caluclate_r_square must not modify argument."""
        cpy = self.epo.copy()
        calculate_spoc(self.epo)
        self.assertEqual(self.epo, cpy)
Exemple #50
0
class TestAppend(unittest.TestCase):
    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones * 2, axis=0)
        cnt = np.append(cnt, ones * 3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels],
                        ['class', 'time', 'channel'], ['#', 'ms', '#'])

    def test_append(self):
        """Append."""
        dat = append(self.dat, self.dat)
        self.assertEqual(dat.data.shape[0], 2 * self.dat.data.shape[0])
        self.assertEqual(len(dat.axes[0]), 2 * len(self.dat.axes[0]))
        np.testing.assert_array_equal(
            dat.data, np.concatenate([self.dat.data, self.dat.data], axis=0))
        np.testing.assert_array_equal(
            dat.axes[0], np.concatenate([self.dat.axes[0], self.dat.axes[0]]))

    def test_append_with_extra(self):
        """append with extra must work with list and ndarrays."""
        self.dat.a = list(range(10))
        self.dat.b = np.arange(10)
        dat = append(self.dat, self.dat, extra=['a', 'b'])
        self.assertEqual(dat.a, list(range(10)) + list(range(10)))
        np.testing.assert_array_equal(
            dat.b, np.concatenate([np.arange(10), np.arange(10)]))

    def test_append_with_different_extra_types(self):
        """append must throw a TypeError if extra-types don't match."""
        a = self.dat.copy()
        b = self.dat.copy()
        a.a = list(range(10))
        b.a = np.arange(10)
        with self.assertRaises(TypeError):
            append(a, b, extra=['a'])

    def test_append_with_wrong_known_attributes(self):
        # .data dimensions must match
        a = self.dat.copy()
        a.data = a.data[np.newaxis, ...]
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # .data.shape must match for all axes except the ones beeing
        # appended
        a = self.dat.copy()
        a.data = a.data[..., :-1]
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # .axes must be equal for all axes except the ones beeing
        # appended
        a = self.dat.copy()
        a.axes[-1][0] = 'foo'
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # names must be equal
        a = self.dat.copy()
        a.names[0] = 'foo'
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # units must be equal
        a = self.dat.copy()
        a.units[0] = 'foo'
        with self.assertRaises(AssertionError):
            append(self.dat, a)

    def test_append_with_unsupported_extra_types(self):
        """append must trhow a TypeError if extra-type is unsupported."""
        self.dat.a = {'foo': 'bar'}
        with self.assertRaises(TypeError):
            append(self.dat, self.dat, extra=['a'])

    def test_append_with_cnt(self):
        """append must work with cnt argument."""
        data = self.dat.data[1]
        axes = self.dat.axes[1:]
        names = self.dat.names[1:]
        units = self.dat.units[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat2 = append(dat, dat)
        self.assertEqual(dat2.data.shape[0], 2 * dat.data.shape[0])
        self.assertEqual(len(dat2.axes[0]), 2 * len(dat.axes[0]))
        np.testing.assert_array_equal(
            dat2.data, np.concatenate([dat.data, dat.data], axis=0))
        np.testing.assert_array_equal(
            dat2.axes[0], np.concatenate([dat.axes[0], dat.axes[0]], axis=0))

    def test_append_with_negative_axis(self):
        """Append must work correctly with a negative axis."""
        dat2 = self.dat.copy()
        dat2.data = dat2.data[:-1, ...]
        dat2.axes[0] = dat2.axes[2][:-1]
        a = append(self.dat, dat2, axis=0)
        b = append(self.dat, dat2, axis=-3)
        self.assertEqual(a, b)

    def test_append_swapaxes(self):
        """append must work with nonstandard timeaxis."""
        dat = append(swapaxes(self.dat, 0, 2),
                     swapaxes(self.dat, 0, 2),
                     axis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = append(self.dat, self.dat)
        self.assertEqual(dat, dat2)

    def test_append_copy(self):
        """append means must not modify argument."""
        cpy = self.dat.copy()
        append(self.dat, self.dat)
        self.assertEqual(self.dat, cpy)
Exemple #51
0
class TestCalculateCCA(unittest.TestCase):

    SAMPLES = 1000
    CHANNELS_X = 10
    CHANNELS_Y = 5
    NOISE_LEVEL = 0.1

    def setUp(self):
        # X is a random mixture matrix of random variables
        Sx = randn(self.SAMPLES, self.CHANNELS_X)
        Ax = randn(self.CHANNELS_X, self.CHANNELS_X)
        X = np.dot(Sx, Ax)
        # Y is a random mixture matrix of random variables except the
        # first component
        Sy = randn(self.SAMPLES, self.CHANNELS_Y)
        Sy[:, 0] = Sx[:, 0] + self.NOISE_LEVEL * randn(self.SAMPLES)
        Ay = randn(self.CHANNELS_Y, self.CHANNELS_Y)
        Y = np.dot(Sy, Ay)
        # generate Data object
        axes_x = [np.arange(X.shape[0]), np.arange(X.shape[1])]
        axes_y = [np.arange(Y.shape[0]), np.arange(Y.shape[1])]
        self.dat_x = Data(X, axes=axes_x, names=['time', 'channel'], units=['ms', '#'])
        self.dat_y = Data(Y, axes=axes_y, names=['time', 'channel'], units=['ms', '#'])

    def test_rho(self):
        """Test if the canonical correlation coefficient almost equals 1."""
        rho, w_x, w_y = calculate_cca(self.dat_x, self.dat_y)
        self.assertAlmostEqual(rho, 1.0, delta=0.01)

    def test_diff_between_canonical_variables(self):
        """Test if the scaled canonical variables are almost same."""
        rho, w_x, w_y = calculate_cca(self.dat_x, self.dat_y)
        cv_x = apply_spatial_filter(self.dat_x, w_x)
        cv_y = apply_spatial_filter(self.dat_y, w_y)

        def scale(x):
            tmp = x.data - x.data.mean()
            return tmp / tmp[np.argmax(np.abs(tmp))]

        diff = scale(cv_x) - scale(cv_y)
        diff = np.sum(np.abs(diff)) / self.SAMPLES
        self.assertTrue(diff < 0.1)

    def test_raise_error_with_non_continuous_data(self):
        """Raise error if ``dat_x`` is not continuous Data object."""
        dat = Data(randn(2, self.SAMPLES, self.CHANNELS_X),
                   axes=[[0, 1], self.dat_x.axes[0], self.dat_x.axes[1]],
                   names=['class', 'time', 'channel'],
                   units=['#', 'ms', '#'])
        with self.assertRaises(AssertionError):
            calculate_cca(dat, self.dat_x)

    def test_raise_error_with_different_length_data(self):
        """Raise error if the length of ``dat_x`` and ``dat_y`` is different."""
        dat = append(self.dat_x, self.dat_x)
        with self.assertRaises(AssertionError):
            calculate_cca(dat, self.dat_y)

    def test_calculate_cca_swapaxes(self):
        """caluclate_cca must work with nonstandard timeaxis."""
        res1 = calculate_cca(swapaxes(self.dat_x, 0, 1), swapaxes(self.dat_y, 0, 1), timeaxis=1)
        res2 = calculate_cca(self.dat_x, self.dat_y)
        np.testing.assert_array_equal(res1[0], res2[0])
        np.testing.assert_array_equal(res1[1], res2[1])
        np.testing.assert_array_equal(res1[2], res2[2])

    def test_calculate_cca_copy(self):
        """caluclate_cca must not modify argument."""
        cpy_x = self.dat_x.copy()
        cpy_y = self.dat_y.copy()
        calculate_cca(self.dat_x, self.dat_y)
        self.assertEqual(self.dat_x, cpy_x)
        self.assertEqual(self.dat_y, cpy_y)
class TestSubsample(unittest.TestCase):
    def setUp(self):
        raw = np.arange(2000).reshape(-1, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 4000, 400, endpoint=False)
        fs = 100
        marker = [[100, 'foo'], [200, 'bar']]
        self.dat = Data(raw, [time, channels], ['time', 'channels'],
                        ['ms', '#'])
        self.dat.fs = fs
        self.dat.markers = marker

    def test_subsampling(self):
        """Subsampling to 10Hz."""
        dat = subsample(self.dat, 10)
        # check if the new fs is correct
        self.assertEqual(dat.fs, 10.)
        # check if data and axes have the same length
        self.assertEqual(len(dat.axes[-1]), dat.data.shape[-1])
        # no channels must have been deleted
        np.testing.assert_array_equal(self.dat.axes[-1], dat.axes[-1])
        self.assertEqual(self.dat.data.shape[-1], dat.data.shape[-1])
        # markers must not have been modified
        self.assertEqual(self.dat.markers, dat.markers)
        # no marker must have been deleted
        self.assertEqual(len(self.dat.markers), len(dat.markers))
        # check the actual data
        # after subsampling, data should look like:
        # [[0,   1,  2,  3,  4,  5]
        #  [50, 51, 52, 53, 54, 55]
        #  [...]]
        # so the first column of the resampled data should be all
        # multiples of 50.
        zeros = dat.data[:, 0] % 50
        self.assertFalse(np.any(zeros))

    def test_subsample_with_epo(self):
        """subsample must work with epoched data."""
        data = np.array([self.dat.data, self.dat.data, self.dat.data])
        axes = [np.arange(3), self.dat.axes[0], self.dat.axes[1]]
        names = ['class', 'time', 'channel']
        units = ['#', 'ms', '#']
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat = subsample(dat, 10)
        self.assertEqual(dat.fs, 10)
        self.assertEqual(dat.data.ndim, 3)
        self.assertEqual(len(dat.axes[1]), dat.data.shape[1])
        self.assertEqual(dat.data.shape[1], self.dat.data.shape[0] / 10)

    def test_whole_number_divisor_check(self):
        """Freq must be a whole number divisor of dat.fs"""
        with self.assertRaises(AssertionError):
            subsample(self.dat, 33)
        with self.assertRaises(AssertionError):
            subsample(self.dat, 101)

    def test_has_fs_check(self):
        """subsample must raise an exception if .fs attribute is not found."""
        with self.assertRaises(AssertionError):
            del (self.dat.fs)
            subsample(self.dat, 10)

    def test_axes_and_data_have_same_len_check(self):
        """subsample must raise an error if the timeaxis and data have not the same lengh."""
        with self.assertRaises(AssertionError):
            self.dat.axes[-2] = self.dat.axes[-2][1:]
            subsample(self.dat, 10)

    def test_subsample_copy(self):
        """Subsample must not modify argument."""
        cpy = self.dat.copy()
        subsample(self.dat, 10)
        self.assertEqual(cpy, self.dat)

    def test_subsample_swapaxes(self):
        """subsample must work with nonstandard timeaxis."""
        dat = subsample(swapaxes(self.dat, 0, 1), 10, timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = subsample(self.dat, 10)
        self.assertEqual(dat, dat2)
Exemple #53
0
class TestCalculateCCA(unittest.TestCase):

    SAMPLES = 1000
    CHANNELS_X = 10
    CHANNELS_Y = 5
    NOISE_LEVEL = 0.1

    def setUp(self):
        # X is a random mixture matrix of random variables
        Sx = randn(self.SAMPLES, self.CHANNELS_X)
        Ax = randn(self.CHANNELS_X, self.CHANNELS_X)
        self.X = np.dot(Sx, Ax)
        # Y is a random mixture matrix of random variables except the
        # first component
        Sy = randn(self.SAMPLES, self.CHANNELS_Y)
        Sy[:, 0] = Sx[:, 0] + self.NOISE_LEVEL * randn(self.SAMPLES)
        Ay = randn(self.CHANNELS_Y, self.CHANNELS_Y)
        self.Y = np.dot(Sy, Ay)
        # generate Data object
        axes_x = [np.arange(self.X.shape[0]), np.arange(self.X.shape[1])]
        axes_y = [np.arange(self.Y.shape[0]), np.arange(self.Y.shape[1])]
        self.dat_x = Data(self.X, axes=axes_x, names=['time', 'channel'], units=['ms', '#'])
        self.dat_y = Data(self.Y, axes=axes_y, names=['time', 'channel'], units=['ms', '#'])

    def test_rho(self):
        """Test if the canonical correlation coefficient almost equals 1."""
        rho, w_x, w_y = calculate_cca(self.dat_x, self.dat_y)
        self.assertAlmostEqual(rho, 1.0, delta=0.01)

    def test_diff_between_canonical_variables(self):
        """Test if the scaled canonical variables are almost same."""
        rho, w_x, w_y = calculate_cca(self.dat_x, self.dat_y)
        cv_x = np.dot(self.X, w_x)
        cv_y = np.dot(self.Y, w_y)

        def scale(x):
            tmp = x - x.mean()
            return tmp / tmp[np.argmax(np.abs(tmp))]

        diff = scale(cv_x) - scale(cv_y)
        diff = np.sum(np.abs(diff)) / self.SAMPLES
        self.assertTrue(diff < 0.1)

    def test_raise_error_with_non_continuous_data(self):
        """Raise error if ``dat_x`` is not continuous Data object."""
        dat = Data(randn(2, self.SAMPLES, self.CHANNELS_X),
                   axes=[[0, 1], self.dat_x.axes[0], self.dat_x.axes[1]],
                   names=['class', 'time', 'channel'],
                   units=['#', 'ms', '#'])
        with self.assertRaises(AssertionError):
            calculate_cca(dat, self.dat_x)

    def test_raise_error_with_different_length_data(self):
        """Raise error if the length of ``dat_x`` and ``dat_y`` is different."""
        dat = append(self.dat_x, self.dat_x)
        with self.assertRaises(AssertionError):
            calculate_cca(dat, self.dat_y)

    def test_calculate_cca_swapaxes(self):
        """caluclate_cca must work with nonstandard timeaxis."""
        res1 = calculate_cca(swapaxes(self.dat_x, 0, 1), swapaxes(self.dat_y, 0, 1), timeaxis=1)
        res2 = calculate_cca(self.dat_x, self.dat_y)
        np.testing.assert_array_equal(res1[0], res2[0])
        np.testing.assert_array_equal(res1[1], res2[1])
        np.testing.assert_array_equal(res1[2], res2[2])

    def test_calculate_cca_copy(self):
        """caluclate_cca must not modify argument."""
        cpy_x = self.dat_x.copy()
        cpy_y = self.dat_y.copy()
        calculate_cca(self.dat_x, self.dat_y)
        self.assertEqual(self.dat_x, cpy_x)
        self.assertEqual(self.dat_y, cpy_y)
Exemple #54
0
class TestAppend(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones*2, axis=0)
        cnt = np.append(cnt, ones*3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])

    def test_append(self):
        """Append."""
        dat = append(self.dat, self.dat)
        self.assertEqual(dat.data.shape[0], 2*self.dat.data.shape[0])
        self.assertEqual(len(dat.axes[0]), 2*len(self.dat.axes[0]))
        np.testing.assert_array_equal(dat.data, np.concatenate([self.dat.data, self.dat.data], axis=0))
        np.testing.assert_array_equal(dat.axes[0], np.concatenate([self.dat.axes[0], self.dat.axes[0]]))

    def test_append_with_extra(self):
        """append with extra must work with list and ndarrays."""
        self.dat.a = range(10)
        self.dat.b = np.arange(10)
        dat = append(self.dat, self.dat, extra=['a', 'b'])
        self.assertEqual(dat.a, range(10) + range(10))
        np.testing.assert_array_equal(dat.b, np.concatenate([np.arange(10), np.arange(10)]))

    def test_append_with_different_extra_types(self):
        """append must throw a TypeError if extra-types don't match."""
        a = self.dat.copy()
        b = self.dat.copy()
        a.a = range(10)
        b.a = np.arange(10)
        with self.assertRaises(TypeError):
            append(a, b, extra=['a'])

    def test_append_with_wrong_known_attributes(self):
        # .data dimensions must match
        a = self.dat.copy()
        a.data = a.data[np.newaxis, ...]
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # .data.shape must match for all axes except the ones beeing
        # appended
        a = self.dat.copy()
        a.data = a.data[..., :-1]
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # .axes must be equal for all axes except the ones beeing
        # appended
        a = self.dat.copy()
        a.axes[-1][0] = 'foo'
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # names must be equal
        a = self.dat.copy()
        a.names[0] = 'foo'
        with self.assertRaises(AssertionError):
            append(self.dat, a)
        # units must be equal
        a = self.dat.copy()
        a.units[0] = 'foo'
        with self.assertRaises(AssertionError):
            append(self.dat, a)

    def test_append_with_unsupported_extra_types(self):
        """append must trhow a TypeError if extra-type is unsupported."""
        self.dat.a = {'foo' : 'bar'}
        with self.assertRaises(TypeError):
            append(self.dat, self.dat, extra=['a'])

    def test_append_with_cnt(self):
        """append must work with cnt argument."""
        data = self.dat.data[1]
        axes = self.dat.axes[1:]
        names = self.dat.names[1:]
        units = self.dat.units[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat2 = append(dat, dat)
        self.assertEqual(dat2.data.shape[0], 2*dat.data.shape[0])
        self.assertEqual(len(dat2.axes[0]), 2*len(dat.axes[0]))
        np.testing.assert_array_equal(dat2.data, np.concatenate([dat.data, dat.data], axis=0))
        np.testing.assert_array_equal(dat2.axes[0], np.concatenate([dat.axes[0], dat.axes[0]], axis=0))

    def test_append_swapaxes(self):
        """append must work with nonstandard timeaxis."""
        dat = append(swapaxes(self.dat, 0, 2), swapaxes(self.dat, 0, 2), axis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = append(self.dat, self.dat)
        self.assertEqual(dat, dat2)

    def test_append_copy(self):
        """append means must not modify argument."""
        cpy = self.dat.copy()
        append(self.dat, self.dat)
        self.assertEqual(self.dat, cpy)
class TestCalculateSpoc(unittest.TestCase):

    EPOCHS = 50
    SAMPLES = 100
    SOURCES = 2
    CHANNELS = 10

    def setUp(self):
        # generate sources with independent variance modulations, the
        # first source will be the target source
        z = np.abs(np.random.randn(self.EPOCHS, self.SOURCES))
        for i in range(self.SOURCES):
            z[:, i] /= z[:, i].std()
        self.s = np.random.randn(self.EPOCHS, self.SAMPLES, self.SOURCES)
        for i in range(self.SOURCES):
            for j in range(self.EPOCHS):
                self.s[j, :, i] *= z[j, i]
        # the mixmatrix which converts our sources to channels
        # X = As + noise
        self.A = np.random.randn(self.CHANNELS, self.SOURCES)
        # our 'signal' which 50 epochs, 100 samples and 10 channels
        self.X = np.empty((self.EPOCHS, self.SAMPLES, self.CHANNELS))
        for i in range(self.EPOCHS):
            self.X[i] = np.dot(self.A, self.s[i].T).T
        noise = np.random.randn(self.EPOCHS, self.SAMPLES, self.CHANNELS) * 0.01
        self.X += noise
        # convert to epo
        axes = [z[:, 0], np.arange(self.X.shape[1]), np.arange(self.X.shape[2])]
        self.epo = Data(self.X,
                axes=axes,
                names=['target_variable', 'time', 'channel'],
                units=['#', 'ms', '#'])

    def test_d(self):
        """Test if the list of lambdas is reverse-sorted and the first one > 0."""
        W, A_est, d = calculate_spoc(self.epo)
        self.assertTrue(d[0] > 0)
        self.assertTrue(np.all(d == np.sort(d)[::-1]))

    def test_A(self):
        """Test if A_est is elementwise almost equal A."""
        W, A_est, d = calculate_spoc(self.epo)
        # A and A_est can have a different scaling, after normalizing
        # and correcting for sign, the first pattern should be almost
        # equal the source pattern
        idx = np.argmax(np.abs(A_est[:, 0]))
        A_est[:, 0] /= A_est[idx, 0]
        idx = np.argmax(np.abs(self.A[:, 0]))
        self.A[:, 0] /= self.A[idx, 0]
        # check elementwise if A[:, 0] almost A_est[:, 0]
        epsilon = 0.01
        diff = self.A[:, 0] - A_est[:, 0]
        diff = np.abs(diff)
        diff = np.sum(diff) / self.A.shape[0]
        self.assertTrue(diff < epsilon)

    def test_s(self):
        """Test if s_est is elementwise almost equal s."""
        W, A_est, d = calculate_spoc(self.epo)
        # applying the filter to X gives us s_est which should be almost
        # equal s
        s_est = np.empty(self.s.shape[:2])
        for i in range(self.EPOCHS):
            s_est[i] = np.dot(self.X[i], W[:, 0])
        s_true = self.s[..., 0]
        epsilon = 0.001

        # correct for scale
        s_true /= s_true.std()
        s_est /= s_est.std()

        # correct for sign
        s_true = np.abs(s_true)
        s_est = np.abs(s_est)

        diff = np.sum(s_true - s_est) / (self.s.shape[0] * self.s.shape[1])
        self.assertTrue(diff < epsilon)

    #def test_calculate_signed_r_square_swapaxes(self):
    #    """caluclate_r_square must work with nonstandard classaxis."""
    #    dat = calculate_signed_r_square(swapaxes(self.dat, 0, 2), classaxis=2)
    #    # the class-axis just dissapears during
    #    # calculate_signed_r_square, so axis 2 becomes axis 1
    #    dat = dat.swapaxes(0, 1)
    #    dat2 = calculate_signed_r_square(self.dat)
    #    np.testing.assert_array_equal(dat, dat2)

    def test_calculate_signed_r_square_copy(self):
        """caluclate_r_square must not modify argument."""
        cpy = self.epo.copy()
        calculate_spoc(self.epo)
        self.assertEqual(self.epo, cpy)
class TestSegmentDat(unittest.TestCase):

    def setUp(self):
        # create 100 samples and tree channels data
        ones = np.ones((100, 3))
        data = np.array([ones, ones*2, ones*3]).reshape(-1, 3)
        time = np.linspace(0, 3000, 300, endpoint=False)
        channels = ['a', 'b', 'c']
        markers = [[500, 'M1'], [1500, 'M2'], [2500, 'M3']]
        self.dat = Data(data, [time, channels], ['time', 'channels'], ['ms', '#'])
        self.dat.markers = markers
        self.dat.fs = 100
        self.mrk_def = {'class 1': ['M1'],
                        'class 2': ['M2', 'M3']
                       }

    def test_segment_dat(self):
        """Test conversion from Continuous to Epoched data."""
        epo = segment_dat(self.dat, self.mrk_def, [-400, 400])
        # test if basic info was transferred from cnt
        self.assertEqual(self.dat.markers, epo.markers)
        self.assertEqual(self.dat.fs, epo.fs)
        np.testing.assert_array_equal(self.dat.axes[-1], epo.axes[-1])
        # test if the actual data is correct
        self.assertEqual(list(epo.axes[0]), [0, 1, 1])
        np.testing.assert_array_equal(epo.class_names, np.array(['class 1', 'class 2']))
        self.assertEqual(epo.data.shape, (3, 80, 3))
        for i in range(3):
            e = epo.data[i, ...]
            self.assertEqual(np.average(e), i+1)
        # test if the epo.ival is the same we cut out
        self.assertEqual(epo.axes[-2][0], -400)
        self.assertEqual(epo.axes[-2][-1], 390)

    def test_segment_dat_with_nonexisting_markers(self):
        """Segmentation without result should return empty .data"""
        mrk_def = {'class 1': ['FUU1'],
                   'class 2': ['FUU2', 'FUU3']
                  }
        epo = segment_dat(self.dat, mrk_def, [-400, 400])
        self.assertEqual(epo.data.shape[0], 0)

    def test_segment_dat_with_unequally_sized_data(self):
        """Segmentation must ignore too short or too long chunks in the result."""
        # We create a marker that is too close to the beginning of the
        # data, so its cnt will not bee of length [-400, 400] ms. It
        # should not appear in the resulting epo
        self.dat.markers.append([100, 'M1'])
        epo = segment_dat(self.dat, self.mrk_def, [-400, 400])
        self.assertEqual(epo.data.shape[0], 3)

    # the following tests
    # (test_segment_dat_with_restriction_to_new_data_ival...) work very
    # similar but test slightly different conditions. The basic idea is
    # always: we create a small cnt with three markers directly next to
    # each other, the only thing changing between the tests is the
    # interval. We test all possible combinations of the marker position
    # relative to the interval:
    #   [M----], M [---], [--M--], [----M], [---] M
    # we check in each test that with increasing number of new samples
    # the correct number of epochs is returned.
    # WARNING: This is fairly complicated to get right, if you want to
    # change something please make sure you fully understand the problem
    def test_segment_dat_with_restriction_to_new_data_ival_zero_pos(self):
        """Online Segmentation with ival 0..+something must work correctly."""
        data = np.ones((9, 3))
        time = np.linspace(0, 900, 9, endpoint=False)
        channels = 'a', 'b', 'c'
        markers = [[100, 'x'], [200, 'x'], [300, 'x']]
        dat = Data(data, [time, channels], ['time', 'channels'], ['ms', '#'])
        dat.fs = 10
        dat.markers = markers
        mrk_def = {'class 1': ['x']}
        # each tuple has (number of new samples, expected epocs)
        samples_epos = [(0, 0), (1, 0), (2, 1), (3, 2), (4, 3), (5, 3)]
        for s, e in samples_epos:
            epo = segment_dat(dat, mrk_def, [0, 500], newsamples=s)
            self.assertEqual(epo.data.shape[0], e)

    def test_segment_dat_with_restriction_to_new_data_ival_pos_pos(self):
        """Online Segmentation with ival +something..+something must work correctly."""
        data = np.ones((9, 3))
        time = np.linspace(0, 900, 9, endpoint=False)
        channels = 'a', 'b', 'c'
        markers = [[100, 'x'], [200, 'x'], [300, 'x']]
        dat = Data(data, [time, channels], ['time', 'channels'], ['ms', '#'])
        dat.fs = 10
        dat.markers = markers
        mrk_def = {'class 1': ['x']}
        # each tuple has (number of new samples, expected epocs)
        samples_epos = [(0, 0), (1, 0), (2, 1), (3, 2), (4, 3), (5, 3)]
        for s, e in samples_epos:
            epo = segment_dat(dat, mrk_def, [100, 500], newsamples=s)
            self.assertEqual(epo.data.shape[0], e)

    def test_segment_dat_with_restriction_to_new_data_ival_neg_pos(self):
        """Online Segmentation with ival -something..+something must work correctly."""
        data = np.ones((9, 3))
        time = np.linspace(0, 900, 9, endpoint=False)
        channels = 'a', 'b', 'c'
        markers = [[400, 'x'], [500, 'x'], [600, 'x']]
        dat = Data(data, [time, channels], ['time', 'channels'], ['ms', '#'])
        dat.fs = 10
        dat.markers = markers
        mrk_def = {'class 1': ['x']}
        # each tuple has (number of new samples, expected epocs)
        samples_epos = [(0, 0), (1, 0), (2, 1), (3, 2), (4, 3), (5, 3)]
        for s, e in samples_epos:
            epo = segment_dat(dat, mrk_def, [-300, 200], newsamples=s)
            self.assertEqual(epo.data.shape[0], e)

    def test_segment_dat_with_restriction_to_new_data_ival_neg_zero(self):
        """Online Segmentation with ival -something..0 must work correctly."""
        data = np.ones((9, 3))
        time = np.linspace(0, 900, 9, endpoint=False)
        channels = 'a', 'b', 'c'
        markers = [[500, 'x'], [600, 'x'], [700, 'x']]
        dat = Data(data, [time, channels], ['time', 'channels'], ['ms', '#'])
        dat.fs = 10
        dat.markers = markers
        mrk_def = {'class 1': ['x']}
        # each tuple has (number of new samples, expected epocs)
        samples_epos = [(0, 0), (1, 0), (2, 1), (3, 2), (4, 3), (5, 3)]
        for s, e in samples_epos:
            epo = segment_dat(dat, mrk_def, [-400, 0], newsamples=s)
            self.assertEqual(epo.data.shape[0], e)

    def test_segment_dat_with_restriction_to_new_data_ival_neg_neg(self):
        """Online Segmentation with ival -something..-something must work correctly."""
        data = np.ones((9, 3))
        time = np.linspace(0, 900, 9, endpoint=False)
        channels = 'a', 'b', 'c'
        markers = [[500, 'x'], [600, 'x'], [700, 'x']]
        dat = Data(data, [time, channels], ['time', 'channels'], ['ms', '#'])
        dat.fs = 10
        dat.markers = markers
        mrk_def = {'class 1': ['x']}
        # each tuple has (number of new samples, expected epocs)
        samples_epos = [(0, 0), (1, 0), (2, 1), (3, 2), (4, 3), (5, 3)]
        for s, e in samples_epos:
            epo = segment_dat(dat, mrk_def, [-400, -100], newsamples=s)
            self.assertEqual(epo.data.shape[0], e)

    def test_segment_dat_with_negative_newsamples(self):
        """Raise an error when newsamples is not positive or None."""
        with self.assertRaises(AssertionError):
            segment_dat(self.dat, self.mrk_def, [-400, 400], newsamples=-1)

    def test_segment_dat_copy(self):
        """segment_dat must not modify arguments."""
        cpy = self.dat.copy()
        segment_dat(self.dat, self.mrk_def, [-400, 400])
        self.assertEqual(cpy, self.dat)

    def test_segment_dat_swapaxes(self):
        """Segmentation must work with nonstandard axes."""
        epo = segment_dat(swapaxes(self.dat, 0, 1), self.mrk_def, [-400, 400], timeaxis=-1)
        # segment_dat added a new dimension
        epo = swapaxes(epo, 1, 2)
        epo2 = segment_dat(self.dat, self.mrk_def, [-400, 400])
        self.assertEqual(epo, epo2)