def test_split_threshold(self): xs = [0, 0, 0] self.assertEqual( [1, 0, 0], WTNetwork.split_threshold([1, -1, 0], xs)) self.assertEqual([1, 0, 0], xs) xs = [1, 1, 1] self.assertEqual( [1, 0, 1], WTNetwork.split_threshold([1, -1, 0], xs)) self.assertEqual([1, 0, 1], xs)
def test_split_threshold_scalar(self): test = { (1, 0): 1, (0, 0): 0, (-1, 0): 0, (1, 1): 1, (0, 1): 1, (-1, 1): 0, } for x, s in test: self.assertEqual( test[(x, s)], WTNetwork.split_threshold(x, s))