コード例 #1
0
def bench(train_data, test_data, n_trees):
    """
    Benchmark Extratrees

    Trains a ExtraForest with `n_trees` trees and returns training time,
    evaluation time and testing accuracy.

    Args:
        train_data (extratrees.Datset): Train set
        test_data (extratrees.Datset): Test set
        n_trees (int): Number of trees
    """

    print('-' * 20, "This implementation")
    print("Going to train %d trees" % n_trees)
    forest = ExtraForest(n_trees=n_trees, n_min=10, criterion="entropy")

    print("Training...")
    t_start = time.time()
    forest.fit(train_data)
    t_train = time.time() - t_start
    print("Trained in %.3fs" % t_train)

    print("Testing")
    t_start = time.time()
    test_score = score(forest, test_data)
    t_test = time.time() - t_start
    print("Evaluated in %.3fs" % t_test)
    print("Score: %.3f" % test_score)
コード例 #2
0
ファイル: test_extratrees.py プロジェクト: jwmng/extratrees
    def test_fit(self):
        forest = ExtraForest(n_trees=3)
        forest.fit(self.fourths)

        self.assertTrue(forest._is_classifier)
        self.assertTrue(len(forest.trees), 3)

        for tree in forest.trees:
            self.assertIsInstance(tree, ExtraTree)
            self.assertTrue(tree._is_classifier)
コード例 #3
0
ファイル: test_extratrees.py プロジェクト: jwmng/extratrees
    def test_predict_forest_classify(self):
        random.seed(0)
        forest = ExtraForest(n_trees=3, n_min=1)
        forest.fit(self.fourths)
        pred = forest.predict(self.fourths[0])

        for tree in forest.trees:
            self.assertEqual(tree.root_node.left, [0.75, 0.25])
            self.assertEqual(tree.root_node.right, [0.25, 0.75])

        self.assertEqual(pred, [1, 0, 0, 0, 0, 1, 1, 1])
コード例 #4
0
ファイル: test_extratrees.py プロジェクト: jwmng/extratrees
    def test_predict_forest_regression(self):
        xdata = ([0], [0], [0], [0], [1], [1], [1], [1])
        ydata = (2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3)
        forest = ExtraForest()
        forest.fit((xdata, ydata))
        pred = forest.predict(xdata)

        # These are the means of the xdata classes, since after the first split
        # the attributes are uniform
        for zval in pred[:4]:
            self.assertAlmostEqual(zval, 2.15)

        for zval in pred[4:]:
            self.assertAlmostEqual(zval, 3.15)
コード例 #5
0
ファイル: test_extratrees.py プロジェクト: jwmng/extratrees
    def test_predict(self):
        random.seed(0)
        forest = ExtraForest(n_trees=3, n_min=1)
        forest.fit(self.fourths)
        pred = forest.predict(self.fourths[0])

        # For all trees, the first split will be somewhere inbetween the single
        # attributes of `fourths[0`. There will be no second split as
        # the attributes are uniform after the first split.
        for tree in forest.trees:
            self.assertEqual(tree.root_node.left, [0.75, 0.25])
            self.assertEqual(tree.root_node.right, [0.25, 0.75])

        self.assertEqual(pred, [1, 0, 0, 0, 0, 1, 1, 1])
コード例 #6
0
ファイル: test_extratrees.py プロジェクト: jwmng/extratrees
 def test_init_forest(self):
     forest = ExtraForest(n_trees=5, k_value=3, n_min=1)
     self.assertEqual(forest.n_trees, 5)
     self.assertEqual(forest.k_value, 3)
     self.assertEqual(forest.n_min, 1)
     self.assertFalse(forest._is_classifier)
     self.assertEqual(forest.trees, [])
コード例 #7
0
def get_train_time(train_data):
    """
    Benchmark Extratrees

    Trains a ExtraTree trees and returns training time.

    Args:
        train_data (extratrees.Datset): Train set
    """

    forest = ExtraForest(n_trees=10, n_min=10, criterion="entropy")

    t_start = time.time()
    forest.fit(train_data)
    t_train = (time.time() - t_start) / 10
    return t_train
コード例 #8
0
ファイル: test_extratrees.py プロジェクト: jwmng/extratrees
 def test_repr_no_fit(self):
     forest = ExtraForest()
     self.assertEqual(str(forest), "<Forest (10 trees), not fitted>")
コード例 #9
0
ファイル: compare.py プロジェクト: jwmng/extratrees
""" compare.py - Compare training time of extratrees and sklearn """

import time
import matplotlib.pyplot as plt

from sklearn.ensemble import ExtraTreesClassifier

from src.extratrees import ExtraForest
from docs.helpers import load_data, TRAIN_FILE, TEST_FILE

if __name__ == '__main__':
    TRAIN_SET = load_data(TRAIN_FILE)
    TEST_SET = load_data(TEST_FILE)

    this_clf = ExtraForest(n_trees=1, n_min=10, criterion="gini")
    skl_clf = ExtraTreesClassifier(criterion="gini", min_samples_split=10)

    sizes = (100, 1000, 2000, 5000, 10000, 20000, 50000)
    times_this = [0] * len(sizes)
    times_skl = [0] * len(sizes)

    for idx, size in enumerate(sizes):

        # This clf
        t0 = time.time()
        this_clf.fit((TRAIN_SET[0][:size], TRAIN_SET[1][:size]))
        times_this[idx] = time.time() - t0

        # Sklearn
        t0 = time.time()
        skl_clf.fit(TRAIN_SET[0][:size], TRAIN_SET[1][:size])
コード例 #10
0
ファイル: example.py プロジェクト: jwmng/extratrees
            grid_pred = classifier.predict([[x, y]])
            col = (grid_pred[0]), 0, (1 - grid_pred[0])
            ax.add_patch(
                patches.Rectangle((x, y), (2 / gridsize), (2 / gridsize),
                                  facecolor=col))


# Create data
data = create_dataset()

# Train tree
tree = ExtraTree(n_min=5)
tree.fit(data)

# Train forest
forest = ExtraForest(n_trees=10, n_min=5)
forest.fit(data)

# Predict
pred_tree = tree.predict(data[0])
pred_forest = forest.predict(data[0])

fig, axes = plt.subplots(2, 3)

# Plot data
axes[0][0].set_title('True')
for coords, cls in zip(*data):
    axes[0][0].plot(*coords, marker='o', color=colorize_class(cls))

# Plot tree result
axes[0][1].set_title("Single tree\n (min_samples=5, k=*)")