def test_depth(self): tree = BinaryDecisionTree(n_features=1) split = BinaryDecisionTreeSplit(feature_id=0, value=0.0) for split_count in range(15): tree.split_node(tree.leaves()[0], split) print(tree) self.assertEqual(1, tree.depth(0)) self.assertEqual(2, tree.depth(1)) self.assertEqual(2, tree.depth(2)) self.assertEqual(3, tree.depth(3)) self.assertEqual(3, tree.depth(4)) self.assertEqual(3, tree.depth(5)) self.assertEqual(3, tree.depth(6)) self.assertEqual(4, tree.depth(7)) self.assertEqual(4, tree.depth(8)) self.assertEqual(4, tree.depth(9)) self.assertEqual(4, tree.depth(10)) self.assertEqual(4, tree.depth(11)) self.assertEqual(4, tree.depth(12)) self.assertEqual(4, tree.depth(13)) self.assertEqual(4, tree.depth(14))
def test_multiple_splits(self): tree = BinaryDecisionTree(n_features=1) split = BinaryDecisionTreeSplit(feature_id=0, value=0.0) for split_count in range(1, 10): tree.split_node(tree.leaves()[0], split) self.assertEqual(tree.num_of_leaves(), split_count + 1) print(tree)