コード例 #1
0
def test_class_enumerator_actual_module():
    module = ClassEnumerator(module=cuml.linear_model,
                             exclude_classes=[
                                 cuml.LinearRegression, cuml.MBSGDClassifier,
                                 cuml.MBSGDRegressor
                             ],
                             custom_constructors={
                                 'LogisticRegression':
                                 lambda: cuml.LogisticRegression(handle=1)
                             })
    models = module.get_models()
    ref = {
        'ElasticNet': cuml.ElasticNet,
        'Lasso': cuml.Lasso,
        'LogisticRegression': lambda: cuml.LogisticRegression(handle=1),
        'Ridge': cuml.Ridge
    }

    assert models['LogisticRegression']().handle == ref['LogisticRegression'](
    ).handle
    models.pop('LogisticRegression')
    ref.pop('LogisticRegression')
    assert models == ref
コード例 #2
0
def test_class_enumerator():
    class SomeModule:
        class SomeClass(cuml.Base):
            pass

        class ExcludedClass(cuml.Base):
            pass

        class CustomConstructorClass(cuml.Base):
            def __init__(self, *, some_parameter):
                self.some_parameter = some_parameter

            def __eq__(self, other):
                return self.some_parameter == other.some_parameter

    module = ClassEnumerator(
        module=SomeModule,
        exclude_classes=[SomeModule.ExcludedClass],
        custom_constructors={
            "CustomConstructorClass":
            lambda: SomeModule.CustomConstructorClass(some_parameter=1)
        })

    models = module.get_models()
    ref = {
        "SomeClass":
        SomeModule.SomeClass,
        "CustomConstructorClass":
        lambda: SomeModule.CustomConstructorClass(some_parameter=1)
    }

    # Here we don't do `assert models == ref` because CustomConstructorClass is
    # a lambda.
    assert len(models) == len(ref) == 2
    assert models['SomeClass'] == ref['SomeClass']
    assert models['CustomConstructorClass']() == ref['CustomConstructorClass'](
    )
コード例 #3
0
from cuml import LinearRegression as reg
from cuml import PCA
from cuml.experimental.explainer.common import get_cai_ptr
from cuml.experimental.explainer.common import get_dtype_from_model_func
from cuml.experimental.explainer.common import get_handle_from_cuml_model_func
from cuml.experimental.explainer.common import get_link_fn_from_str_or_fn
from cuml.experimental.explainer.common import get_tag_from_model_func
from cuml.experimental.explainer.common import link_dict
from cuml.experimental.explainer.common import model_func_call
from cuml.test.utils import ClassEnumerator
from cuml.datasets import make_regression
from sklearn.linear_model import LinearRegression as skreg


models_config = ClassEnumerator(module=cuml)
models = models_config.get_models()

_default_tags = [
    'preferred_input_order',
    'X_types_gpu',
    'non_deterministic',
    'requires_positive_X',
    'requires_positive_y',
    'X_types',
    'poor_score',
    'no_validation',
    'multioutput',
    'allow_nan',
    'stateless',
    'multilabel',
    '_skip_test',
コード例 #4
0
import numpy as np
import pickle
import pytest

from cuml.test import test_arima
from cuml.tsa.arima import ARIMA
from cuml.test.utils import array_equal, unit_param, stress_param, \
    ClassEnumerator, get_classes_from_package
from cuml.test.test_svm import compare_svm
from sklearn.base import clone
from sklearn.datasets import load_iris, make_classification, make_regression
from sklearn.manifold.t_sne import trustworthiness
from sklearn.model_selection import train_test_split

regression_config = ClassEnumerator(module=cuml.linear_model)
regression_models = regression_config.get_models()

solver_config = ClassEnumerator(
    module=cuml.solvers,
    # QN uses softmax here because some of the tests uses multiclass
    # logistic regression which requires a softmax loss
    custom_constructors={"QN": lambda: cuml.QN(loss="softmax")})
solver_models = solver_config.get_models()

cluster_config = ClassEnumerator(module=cuml.cluster,
                                 exclude_classes=[cuml.DBSCAN])
cluster_models = cluster_config.get_models()

decomposition_config = ClassEnumerator(module=cuml.decomposition)
decomposition_models = decomposition_config.get_models()