class ScikitTrestle(object):
    def __init__(self, params=None):
        if params is None:
            self.tree = TrestleTree()
        else:
            self.tree = TrestleTree(**params)

    def ifit(self, x, y):
        x = deepcopy(x)
        x['_y_label'] = "%i" % y
        self.tree.ifit(x)

    def fit(self, X, y):
        X = deepcopy(X)
        for i, x in enumerate(X):
            x['_y_label'] = "%i" % y[i]
        self.tree.fit(X, randomize_first=False)

    def predict(self, X):
        return [int(self.tree.categorize(x).predict('_y_label')) for x in X]
示例#2
0
class ScikitTrestle(object):

    def __init__(self, **kwargs):
        self.tree = TrestleTree(**kwargs)
        self.state_format = "variablized_state"

    def ifit(self, x, y):
        x = deepcopy(x)
        x['_y_label'] = float(y)
        self.tree.ifit(x)

    def fit(self, X, y):
        X = deepcopy(X)
        for i, x in enumerate(X):
            x['_y_label'] = float(y)
        self.tree.fit(X, randomize_first=False)

    def skill_info(self, X):
        raise NotImplementedError("Not implemented Erik H. says there is a way \
             to serialize this -> TODO")

    def predict(self, X):
        return [self.tree.categorize(x).predict('_y_label') for x in X]
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))

# Fit regression models (Decision Tree and TRESTLE)
# For TRESTLE the y attribute is hidden, so only the X is used to make
# predictions.
dtree = DecisionTreeRegressor(max_depth=3)
dtree.fit(X, y)
ttree = TrestleTree()
training_data = [{
    'x': float(X[i][0]),
    '_y': float(y[i])
} for i, v in enumerate(X)]
ttree.fit(training_data, iterations=1)

# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_dtree = dtree.predict(X_test)
y_trestle = [ttree.categorize({'x': float(v)}).predict('_y') for v in X_test]

# Plot the results
plt.figure()
plt.scatter(X, y, c="k", label="Data")
plt.plot(X_test, y_trestle, c="g", label="TRESTLE", linewidth=2)
plt.plot(X_test, y_dtree, c="r", label="Decison Tree (Depth=3)", linewidth=2)
plt.xlabel("Data")
plt.ylabel("Target")
plt.title("TRESTLE/Decision Tree Regression")
plt.legend(loc=3)
plt.show()