示例#1
0
    def test_predict_forest_classify(self):
        random.seed(0)
        forest = ExtraForest(n_trees=3, n_min=1)
        forest.fit(self.fourths)
        pred = forest.predict(self.fourths[0])

        for tree in forest.trees:
            self.assertEqual(tree.root_node.left, [0.75, 0.25])
            self.assertEqual(tree.root_node.right, [0.25, 0.75])

        self.assertEqual(pred, [1, 0, 0, 0, 0, 1, 1, 1])
示例#2
0
    def test_predict_forest_regression(self):
        xdata = ([0], [0], [0], [0], [1], [1], [1], [1])
        ydata = (2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3)
        forest = ExtraForest()
        forest.fit((xdata, ydata))
        pred = forest.predict(xdata)

        # These are the means of the xdata classes, since after the first split
        # the attributes are uniform
        for zval in pred[:4]:
            self.assertAlmostEqual(zval, 2.15)

        for zval in pred[4:]:
            self.assertAlmostEqual(zval, 3.15)
示例#3
0
    def test_predict(self):
        random.seed(0)
        forest = ExtraForest(n_trees=3, n_min=1)
        forest.fit(self.fourths)
        pred = forest.predict(self.fourths[0])

        # For all trees, the first split will be somewhere inbetween the single
        # attributes of `fourths[0`. There will be no second split as
        # the attributes are uniform after the first split.
        for tree in forest.trees:
            self.assertEqual(tree.root_node.left, [0.75, 0.25])
            self.assertEqual(tree.root_node.right, [0.25, 0.75])

        self.assertEqual(pred, [1, 0, 0, 0, 0, 1, 1, 1])
示例#4
0

# Create data
data = create_dataset()

# Train tree
tree = ExtraTree(n_min=5)
tree.fit(data)

# Train forest
forest = ExtraForest(n_trees=10, n_min=5)
forest.fit(data)

# Predict
pred_tree = tree.predict(data[0])
pred_forest = forest.predict(data[0])

fig, axes = plt.subplots(2, 3)

# Plot data
axes[0][0].set_title('True')
for coords, cls in zip(*data):
    axes[0][0].plot(*coords, marker='o', color=colorize_class(cls))

# Plot tree result
axes[0][1].set_title("Single tree\n (min_samples=5, k=*)")
for idx, cls in enumerate(pred_tree):
    axes[0][1].plot(*data[0][idx], marker='o', color=colorize_class(cls))

# Plot forest result
axes[0][2].set_title("10 trees\n (min_samples=5, k=*)")