import pytest

from sklearn.utils.estimator_checks import check_estimator

from skltemplate import TemplateEstimator
from skltemplate import TemplateClassifier
from skltemplate import TemplateTransformer


@pytest.mark.parametrize(
    "estimator",
    [TemplateEstimator(), TemplateTransformer(), TemplateClassifier()]
)
def test_all_estimators(estimator):
    return check_estimator(estimator)
"""
===========================
Plotting Template Estimator
===========================

An example plot of TemplateEstimator
"""
import numpy as np
from skltemplate import TemplateEstimator
from matplotlib import pyplot as plt

X = np.arange(100).reshape(100, 1)
y = np.zeros((100, ))
estimator = TemplateEstimator()
estimator.fit(X, y)
plt.plot(estimator.predict(X))
plt.show()
def test_demo():
    X = np.random.random((100, 10))
    estimator = TemplateEstimator()
    estimator.fit(X, X[:, 0])
    assert_almost_equal(estimator.predict(X), X[:, 0]**2)