def test_predict_forest_classify(self): random.seed(0) forest = ExtraForest(n_trees=3, n_min=1) forest.fit(self.fourths) pred = forest.predict(self.fourths[0]) for tree in forest.trees: self.assertEqual(tree.root_node.left, [0.75, 0.25]) self.assertEqual(tree.root_node.right, [0.25, 0.75]) self.assertEqual(pred, [1, 0, 0, 0, 0, 1, 1, 1])
def test_predict_forest_regression(self): xdata = ([0], [0], [0], [0], [1], [1], [1], [1]) ydata = (2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3) forest = ExtraForest() forest.fit((xdata, ydata)) pred = forest.predict(xdata) # These are the means of the xdata classes, since after the first split # the attributes are uniform for zval in pred[:4]: self.assertAlmostEqual(zval, 2.15) for zval in pred[4:]: self.assertAlmostEqual(zval, 3.15)
def test_predict(self): random.seed(0) forest = ExtraForest(n_trees=3, n_min=1) forest.fit(self.fourths) pred = forest.predict(self.fourths[0]) # For all trees, the first split will be somewhere inbetween the single # attributes of `fourths[0`. There will be no second split as # the attributes are uniform after the first split. for tree in forest.trees: self.assertEqual(tree.root_node.left, [0.75, 0.25]) self.assertEqual(tree.root_node.right, [0.25, 0.75]) self.assertEqual(pred, [1, 0, 0, 0, 0, 1, 1, 1])
# Create data data = create_dataset() # Train tree tree = ExtraTree(n_min=5) tree.fit(data) # Train forest forest = ExtraForest(n_trees=10, n_min=5) forest.fit(data) # Predict pred_tree = tree.predict(data[0]) pred_forest = forest.predict(data[0]) fig, axes = plt.subplots(2, 3) # Plot data axes[0][0].set_title('True') for coords, cls in zip(*data): axes[0][0].plot(*coords, marker='o', color=colorize_class(cls)) # Plot tree result axes[0][1].set_title("Single tree\n (min_samples=5, k=*)") for idx, cls in enumerate(pred_tree): axes[0][1].plot(*data[0][idx], marker='o', color=colorize_class(cls)) # Plot forest result axes[0][2].set_title("10 trees\n (min_samples=5, k=*)")