예제 #1
0
def test_validator_parameters_validate_error(params):
    config = {
        "penalty": Union(Enum("l1", "l2", "elasticnet", "none")),
        "dual": TypeOf(bool),
        "tol": Interval(float, 0.0, None, lower_inclusive=False),
        "solver": Enum("newton-cg", "lbfgs", "liblinear", "sag", "saga"),
        "warm_start": TypeOf(bool),
        "n_jobs": Union(Interval(int, 1, None), Const(-1), Const(None))
    }
    with pytest.raises(Exception):
        validate_parameters(config, params)
예제 #2
0
def test_validator_parameters_exist(params):
    config = {
        "penalty": Union(Enum("l1", "l2", "elasticnet", "none")),
        "dual": TypeOf(bool),
        "tol": Interval(float, 0.0, None, lower_inclusive=False),
        "solver": Enum("newton-cg", "lbfgs", "liblinear", "sag", "saga")
    }
    # does not raise
    validate_parameters(config, params)
예제 #3
0
def test_validator_parameters_not_exist(params, error_key):
    config = {
        "penalty": Union(Enum("l1", "l2", "elasticnet", "none")),
        "dual": TypeOf(bool),
        "C": Interval(float, 0.0, None),
        "fit_intercept": True
    }
    msg = "{} is not a valid parameter".format(error_key)
    with pytest.raises(KeyError, match=msg):
        validate_parameters(config, params)
예제 #4
0
def test_validator_conditions_exists(params):
    config = {
        "kernel": Enum("linear", "poly", "rbf", "sigmoid", "percomputed"),
        "degree": Interval(int, 1, None),
        "gamma": Union(Interval(float, 0, None), Const("auto")),
        "coef0": Interval(float, 0.0, None),
        "tol": Interval(float, 0.0, None, lower_inclusive=False),
        "_conditions": {
            "kernel": {
                "poly": ["degree", "gamma", "coef0"],
                "rbf": ["gamma"],
                "sigmoid": ["gamma", "coef0"]
            }
        }
    }
    validate_parameters(config, params)
예제 #5
0
def test_validator_conditions_error(params, msg):
    config = {
        "kernel": Enum("linear", "poly", "rbf", "sigmoid", "percomputed"),
        "degree": Interval(int, 1, None),
        "gamma": Union(Interval(float, 0, None), Const("auto")),
        "coef0": Interval(float, 0.0, None),
        "tol": Interval(float, 0.0, None, lower_inclusive=False),
        "_conditions": {
            "tol": {},
            "kernel": {
                "poly": ["degree", "gamma", "coef0"],
                "rbf": ["gamma"],
                "sigmoid": ["gamma", "coef0"]
            }
        }
    }
    with pytest.raises(ValueError, match=msg):
        validate_parameters(config, params)
예제 #6
0
def test_union_removes_tags():
    union = Union(TypeOf(int, tags=['control']),
                  Enum('a', 'b', tags=['not good']),
                  tags=['deprecated'])
    for params in union.params:
        assert not params.tags
예제 #7
0
def test_union_invalid_params_init(params, msg):
    with pytest.raises(ValueError, match=msg):
        Union(*params)
예제 #8
0
        enum.validate(value, 'tol')


@pytest.mark.parametrize(
    'params, msg',
    [((), 'parameters must have at least one item'),
     (('hello', 'world'), 'all parameters must be of type Parameter'),
     ((TypeOf(int), 3), 'all parameters must be of type Parameter'),
     ((None, Enum('hello')), 'all parameters must be of type Parameter')])
def test_union_invalid_params_init(params, msg):
    with pytest.raises(ValueError, match=msg):
        Union(*params)


@pytest.mark.parametrize('union, value, msg', [
    (Union(TypeOf(int), Enum('hello', 'world')), None,
     r'tol: None is not a int and is not in \[hello, world\]'),
    (Union(TypeOf(int), Enum('hello', 'world')), 0.4, 'tol: 0.4 is not a int'),
])
def test_union_invalid_values(union, value, msg):
    with pytest.raises(ValueError, match=msg):
        union.validate(value, "tol")


@pytest.mark.parametrize('union, value', [
    (Union(TypeOf(int), Enum('hello', 'world')), 'hello'),
    (Union(TypeOf(int), Enum('hello', 'world')), 'world'),
    (Union(TypeOf(int), Enum('hello', 'world')), 10),
    (Union(TypeOf(int), Enum('hello', 'world'), Const(None)), None),
    (Union(TypeOf(float), TypeOf(int)), 10),
    (Union(TypeOf(float), TypeOf(int)), 10.3),