import pytest from sklearn.utils.estimator_checks import check_estimator from skltemplate import TemplateEstimator from skltemplate import TemplateClassifier from skltemplate import TemplateTransformer @pytest.mark.parametrize( "estimator", [TemplateEstimator(), TemplateTransformer(), TemplateClassifier()] ) def test_all_estimators(estimator): return check_estimator(estimator)
""" =========================== Plotting Template Estimator =========================== An example plot of TemplateEstimator """ import numpy as np from skltemplate import TemplateEstimator from matplotlib import pyplot as plt X = np.arange(100).reshape(100, 1) y = np.zeros((100, )) estimator = TemplateEstimator() estimator.fit(X, y) plt.plot(estimator.predict(X)) plt.show()
def test_demo(): X = np.random.random((100, 10)) estimator = TemplateEstimator() estimator.fit(X, X[:, 0]) assert_almost_equal(estimator.predict(X), X[:, 0]**2)