def test_parameter_validation():
    """Test parameter validation checks for consistent assignment."""
    test_cases = [
        dict(
            prediction_depth="nmlnp",
            stopping_criteria=None,
        ),
        dict(
            prediction_depth="nmlnp",
            stopping_criteria="not_a_float_or_a_callable",
        ),
        dict(
            prediction_depth="mlnp",
            stopping_criteria=123.4,
        ),
        dict(prediction_depth="some_invalid_prediction_depth_value", ),
        dict(
            algorithm="lcn",
            training_strategy=None,
        ),
        dict(
            algorithm="lcn",
            training_strategy="some_invalid_training_strategy",
        ),
        dict(
            algorithm="lcpn",
            training_strategy="exclusive",
        ),
        dict(algorithm="some_invalid_algorithm_value", ),
    ]

    for classifier_kwargs in test_cases:
        clf, (X, y) = make_classifier_and_data(**classifier_kwargs)
        assert_that(calling(clf.fit).with_args(X=X, y=y), raises(TypeError))
def test_trivial_hierarchy_classification():
    """Test that a trivial (degenerate) hierarchy behaves as expected."""
    clf, (X, y) = make_classifier_and_data(n_classes=5)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.30,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)

    assert_that(accuracy, is_(close_to(1., delta=0.05)))
def test_fitted_attributes():
    """Test classifier attributes are set correctly after fitting."""
    n_classes = 10
    clf, (X, y) = make_classifier_and_data(n_classes=n_classes)

    clf.fit(X, y)

    assert_that(DiGraph(clf.class_hierarchy_),
                matches_graph(DiGraph(clf.class_hierarchy)))
    assert_that(clf.graph_, matches_graph(DiGraph(clf.class_hierarchy)))
    assert_that(clf.classes_, contains_inanyorder(*range(n_classes)))
    assert_that(clf.n_classes_, is_(equal_to(n_classes)))
    assert_that(
        clf.graph_.nodes[ROOT],
        has_entries(metafeatures=has_entries(
            n_samples=X.shape[0],
            n_targets=n_classes,
        ), ),
    )