Пример #1
0
 def test_threading_equiv(self):
     """
     Make sure that threading does not influence the
     structure of the trees.
     """
     inputs, outputs = regression_data()
     tree1 = build_tree(inputs, outputs, min_leaf=6, max_depth=3, num_threads=1)
     tree2 = build_tree(inputs, outputs, min_leaf=6, max_depth=3, num_threads=4)
     tree3 = build_tree(inputs, outputs, min_leaf=6, max_depth=3, num_threads=12)
     self.assertTrue(_trees_equivalent(tree1, tree2))
     self.assertTrue(_trees_equivalent(tree2, tree3))
Пример #2
0
 def test_two_branches(self):
     """
     Test building with two branches.
     """
     inputs, outputs = uint8_regression_data()
     tree = build_tree(inputs, outputs, min_leaf=6, max_depth=2)
     self.assertIsInstance(tree, TreeBranch)
     self.assertIsInstance(tree.less_than, TreeBranch)
     self.assertIsInstance(tree.greater_equal, TreeBranch)
     self.assertIsInstance(tree.greater_equal.less_than, TreeLeaf)
     self.assertIsInstance(tree.greater_equal.greater_equal, TreeLeaf)
     self.assertIsInstance(tree.less_than.less_than, TreeLeaf)
     self.assertIsInstance(tree.less_than.greater_equal, TreeLeaf)
     self.assertEqual(tree.split_feature, 5)
     self.assertTrue(np.allclose(tree.threshold, 126.5))
     self.assertEqual(tree.less_than.split_feature, 0)
     self.assertTrue(np.allclose(tree.less_than.threshold, 153.5))
     self.assertEqual(tree.greater_equal.split_feature, 6)
     self.assertTrue(np.allclose(tree.greater_equal.threshold, 157))
     self.assertTrue(np.allclose(tree.less_than.less_than.output,
                                 [0.05919572, 0.0693915, 0.65720719]))
     self.assertTrue(np.allclose(tree.less_than.greater_equal.output,
                                 [1.13262165, -0.10673304, 0.12196609]))
     self.assertTrue(np.allclose(tree.greater_equal.less_than.output,
                                 [0.05945089, -1.83148921, -0.04628649]))
     self.assertTrue(np.allclose(tree.greater_equal.greater_equal.output,
                                 [-0.77123165, -0.0870612, -0.13615353]))
Пример #3
0
 def test_leaf(self):
     """
     Test building with no branches.
     """
     inputs, outputs = regression_data()
     tree = build_tree(inputs, outputs, min_leaf=3, max_depth=0)
     self.assertIsInstance(tree, TreeLeaf)
     self.assertTrue(np.allclose(tree.output, np.mean(outputs, axis=0)))
Пример #4
0
 def test_two_branches_permuted(self):
     """
     Test that the features from test_two_branches
     remain valid when the data is permuted.
     """
     inputs, outputs = regression_data()
     inputs[:, 3] = inputs[:, 6]
     inputs[:, 6] = inputs[:, 1]
     tree = build_tree(inputs, outputs, min_leaf=6, max_depth=2)
     self.assertEqual(tree.split_feature, 3)
     self.assertEqual(tree.less_than.split_feature, 3)
Пример #5
0
 def test_one_branch(self):
     """
     Test building with one branch.
     """
     inputs, outputs = regression_data()
     tree = build_tree(inputs, outputs, min_leaf=3, max_depth=1)
     self.assertIsInstance(tree, TreeBranch)
     self.assertIsInstance(tree.less_than, TreeLeaf)
     self.assertIsInstance(tree.greater_equal, TreeLeaf)
     self.assertEqual(tree.split_feature, 6)
     self.assertTrue(np.allclose(tree.threshold, 0.696461959))
     self.assertTrue(np.allclose(tree.less_than.output,
                                 [0.15382855, -0.82658136, -0.04233113]))
     self.assertTrue(np.allclose(tree.greater_equal.output,
                                 [0.0946953, 0.30802816, 0.39473894]))
Пример #6
0
 def test_predictable_split(self):
     """
     Test that a predictable split actually occurs.
     """
     inputs, outputs = _predictable_data(20000)
     node = build_tree(inputs, outputs, min_leaf=5, max_depth=1)
     expected_leaves = [[-10, 7], [-7, -10]]
     self.assertIsInstance(node, TreeBranch)
     self.assertIsInstance(node.less_than, TreeLeaf)
     self.assertIsInstance(node.greater_equal, TreeLeaf)
     self.assertEqual(node.split_feature, 1)
     self.assertTrue(np.allclose(node.threshold, 5, rtol=1e-2, atol=1e-2))
     for child, expected in zip([node.less_than, node.greater_equal], expected_leaves):
         self.assertTrue(np.allclose(child.output, expected,
                                     rtol=1e-2, atol=1e-2))
Пример #7
0
 def test_one_branch(self):
     """
     Test building with one branch.
     """
     inputs, outputs = uint8_regression_data()
     tree = build_tree(inputs, outputs, min_leaf=3, max_depth=1)
     self.assertIsInstance(tree, TreeBranch)
     self.assertIsInstance(tree.less_than, TreeLeaf)
     self.assertIsInstance(tree.greater_equal, TreeLeaf)
     self.assertEqual(tree.split_feature, 5)
     self.assertTrue(np.allclose(tree.threshold, 126.5))
     self.assertTrue(np.allclose(tree.less_than.output,
                                 [0.70325124, -0.03628322, 0.33606252]))
     self.assertTrue(np.allclose(tree.greater_equal.output,
                                 [-0.43895862, -0.78483254, -0.10020673]))
Пример #8
0
 def test_two_branches(self):
     """
     Test building with two branches.
     """
     inputs, outputs = regression_data()
     tree = build_tree(inputs, outputs, min_leaf=6, max_depth=2)
     self.assertIsInstance(tree, TreeBranch)
     self.assertIsInstance(tree.less_than, TreeBranch)
     self.assertIsInstance(tree.greater_equal, TreeLeaf)
     self.assertIsInstance(tree.less_than.less_than, TreeLeaf)
     self.assertIsInstance(tree.less_than.greater_equal, TreeLeaf)
     self.assertEqual(tree.split_feature, 6)
     self.assertTrue(np.allclose(tree.threshold, 0.696461959))
     self.assertTrue(np.allclose(tree.greater_equal.output,
                                 [0.0946953, 0.30802816, 0.39473894]))
     self.assertEqual(tree.less_than.split_feature, 6)
     self.assertTrue(np.allclose(tree.less_than.threshold, -0.6259399115000001))
     self.assertTrue(np.allclose(tree.less_than.less_than.output,
                                 [0.66664267, -0.3186112, -0.53174299]))
     self.assertTrue(np.allclose(tree.less_than.greater_equal.output,
                                 [-0.14531307, -1.12289727, 0.24315912]))