예제 #1
0
def test_delete_default():
    """Assert that the whole pipeline is deleted as default."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR", "LDA"])
    atom.delete()
    assert not (atom.models or atom.metric)
    assert atom.results.empty
예제 #2
0
def test_delete_models_is_sequence():
    """Assert that for a sequence, multiple models are deleted."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR", "Tree", "RF"])
    atom.delete(["Tree", "RF"])
    assert atom.models == "LR"
    assert atom.winner is atom.LR
    assert len(atom.results) == 1
예제 #3
0
def test_plot_partial_dependence(features):
    """Assert that the plot_partial_dependence method work as intended."""
    # For binary classification tasks
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    pytest.raises(NotFittedError, atom.plot_partial_dependence)
    atom.run(["Tree", "LGB"], metric="f1")

    # More than 3 features
    with pytest.raises(ValueError, match=r".*Maximum 3 allowed.*"):
        atom.plot_partial_dependence(features=[0, 1, 2, 3], display=False)

    # Triple feature
    with pytest.raises(ValueError, match=r".*should be single or in pairs.*"):
        atom.lgb.plot_partial_dependence(features=[(0, 1, 2), 2],
                                         display=False)

    # Pair for multi-model
    with pytest.raises(ValueError, match=r".*when plotting multiple models.*"):
        atom.plot_partial_dependence(features=[(0, 2), 2], display=False)

    # Unknown feature
    with pytest.raises(ValueError, match=r".*not found in the dataset.*"):
        atom.plot_partial_dependence(features=["test", 2], display=False)

    # Invalid index
    with pytest.raises(ValueError, match=r".*got index.*"):
        atom.plot_partial_dependence(features=[120, 2], display=False)

    # Different features for multiple models
    atom.branch = "branch_2"
    atom.feature_selection(strategy="pca", n_features=5)
    atom.run(["tree2"])
    with pytest.raises(ValueError, match=r".*models use the same features.*"):
        atom.plot_partial_dependence(features=(0, 1), display=False)

    atom.delete("Tree2")  # Drop model created for test
    atom.branch.delete()  # Drop branch created for test
    atom.plot_partial_dependence(display=False)
    atom.lgb.plot_feature_importance(show=5, display=False)
    atom.lgb.plot_partial_dependence(display=False)

    # For multiclass classification tasks
    atom = ATOMClassifier(X_class, y_class, random_state=1)
    atom.run(["Tree", "LGB"], metric="f1_macro")

    # Invalid target int
    with pytest.raises(ValueError, match=r".*classes, got .*"):
        atom.plot_partial_dependence(target=5, display=False)

    # Invalid target str
    with pytest.raises(ValueError, match=r".*not found in the mapping.*"):
        atom.plot_partial_dependence(target="Yes", display=False)

    atom.lgb.plot_partial_dependence(features,
                                     target=2,
                                     title="title",
                                     display=False)
예제 #4
0
def test_delete_models_is_str():
    """Assert that for a string, a single model is deleted."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR", "Tree"])
    atom.delete("winner")
    assert atom.models == "Tree"
    assert atom.winner is atom.Tree
    assert len(atom.results) == 1
    assert not hasattr(atom, "LR")
예제 #5
0
def test_delete_duplicates():
    """Assert that duplicate models are ignored."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run("LR")
    atom.delete(["LR", "LR"])
    assert not atom.models
예제 #6
0
def test_delete_general_number():
    """Assert that the general number selects all models with that number."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR0", "RF0"])
    atom.delete("0")
    assert not atom.models
예제 #7
0
def test_delete_general_name():
    """Assert that the general name selects all models from that acronym."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR1", "LR2"])
    atom.delete("LR")
    assert not atom.models