fit_intercept), "LogisticRegression": lambda fit_intercept=True: cuml.LogisticRegression(fit_intercept= fit_intercept), "Lasso": lambda fit_intercept=True: cuml.Lasso(fit_intercept=fit_intercept), "Ridge": lambda fit_intercept=True: cuml.Ridge(fit_intercept=fit_intercept), "ElasticNet": lambda fit_intercept=True: cuml.ElasticNet(fit_intercept=fit_intercept) } solver_models = { "CD": lambda: cuml.CD(), "SGD": lambda: cuml.SGD(eta0=0.005), "QN": lambda: cuml.QN(loss="softmax") } cluster_models = {"KMeans": lambda: cuml.KMeans()} decomposition_models = { "PCA": lambda: cuml.PCA(), "TruncatedSVD": lambda: cuml.TruncatedSVD(), } decomposition_models_xfail = { "GaussianRandomProjection": lambda: cuml.GaussianRandomProjection(), "SparseRandomProjection": lambda: cuml.SparseRandomProjection() } neighbor_models = {"NearestNeighbors": lambda: cuml.NearestNeighbors()}
from cuml.test.utils import array_equal, unit_param, stress_param, \ ClassEnumerator, get_classes_from_package from cuml.test.test_svm import compare_svm from sklearn.base import clone from sklearn.datasets import load_iris, make_classification, make_regression from sklearn.manifold.t_sne import trustworthiness from sklearn.model_selection import train_test_split regression_config = ClassEnumerator(module=cuml.linear_model) regression_models = regression_config.get_models() solver_config = ClassEnumerator( module=cuml.solvers, # QN uses softmax here because some of the tests uses multiclass # logistic regression which requires a softmax loss custom_constructors={"QN": lambda: cuml.QN(loss="softmax")}) solver_models = solver_config.get_models() cluster_config = ClassEnumerator(module=cuml.cluster, exclude_classes=[cuml.DBSCAN]) cluster_models = cluster_config.get_models() decomposition_config = ClassEnumerator(module=cuml.decomposition) decomposition_models = decomposition_config.get_models() decomposition_config_xfail = ClassEnumerator(module=cuml.random_projection) decomposition_models_xfail = decomposition_config_xfail.get_models() neighbor_config = ClassEnumerator(module=cuml.neighbors) neighbor_models = neighbor_config.get_models()