Exemple #1
0
def test_branch_setter_change():
    """Assert that we can change to an old branch."""
    atom = ATOMClassifier(X10_nan, y10, random_state=1)
    atom.branch = "branch_2"
    atom.clean()
    atom.branch = "master"
    assert atom.pipeline.empty  # Has no clean estimator
Exemple #2
0
def test_branch_setter_from_valid():
    """Assert that we can create a new branch, not from the current one."""
    atom = ATOMClassifier(X10_nan, y10, random_state=1)
    atom.branch = "branch_2"
    atom.impute()
    atom.branch = "branch_3_from_master"
    assert atom.branch.name == "branch_3"
    assert atom.n_nans > 0
Exemple #3
0
def test_attrs_are_passed():
    """Assert that the attributes from the parent are passed."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.balance()
    atom.branch = "branch_2"
    assert atom.branch_2.idx is not atom.master.idx
    assert atom.branch_2.adasyn is atom.master.adasyn
Exemple #4
0
def test_transform_data_multiple_branches():
    """Assert that the data is transformed with multiple branches."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.prune()
    atom.branch = "branch_2"
    atom.balance()
    atom.feature_generation(strategy="dfs", n_features=5)
    atom.branch = "branch_3"
    atom.feature_selection(strategy="sfm", solver="lgb", n_features=20)
    atom.save(FILE_DIR + "atom_2", save_data=False)

    atom2 = ATOMLoader(FILE_DIR + "atom_2",
                       data=(X_bin, y_bin),
                       transform_data=True)
    for branch in atom._branches:
        assert atom2._branches[branch].data.equals(atom._branches[branch].data)
Exemple #5
0
def test_repr():
    """Assert that the __repr__ method visualizes the pipeline(s)."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.scale()
    assert "Branches: master" in str(atom)
    atom.branch = "branch_2"
    assert "Branches:\n   >>> master\n   >>> branch_2 !" in str(atom)
Exemple #6
0
def test_plot_partial_dependence(features):
    """Assert that the plot_partial_dependence method work as intended."""
    # For binary classification tasks
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    pytest.raises(NotFittedError, atom.plot_partial_dependence)
    atom.run(["Tree", "LGB"], metric="f1")

    # More than 3 features
    with pytest.raises(ValueError, match=r".*Maximum 3 allowed.*"):
        atom.plot_partial_dependence(features=[0, 1, 2, 3], display=False)

    # Triple feature
    with pytest.raises(ValueError, match=r".*should be single or in pairs.*"):
        atom.lgb.plot_partial_dependence(features=[(0, 1, 2), 2],
                                         display=False)

    # Pair for multi-model
    with pytest.raises(ValueError, match=r".*when plotting multiple models.*"):
        atom.plot_partial_dependence(features=[(0, 2), 2], display=False)

    # Unknown feature
    with pytest.raises(ValueError, match=r".*not found in the dataset.*"):
        atom.plot_partial_dependence(features=["test", 2], display=False)

    # Invalid index
    with pytest.raises(ValueError, match=r".*got index.*"):
        atom.plot_partial_dependence(features=[120, 2], display=False)

    # Different features for multiple models
    atom.branch = "branch_2"
    atom.feature_selection(strategy="pca", n_features=5)
    atom.run(["tree2"])
    with pytest.raises(ValueError, match=r".*models use the same features.*"):
        atom.plot_partial_dependence(features=(0, 1), display=False)

    atom.delete("Tree2")  # Drop model created for test
    atom.branch.delete()  # Drop branch created for test
    atom.plot_partial_dependence(display=False)
    atom.lgb.plot_feature_importance(show=5, display=False)
    atom.lgb.plot_partial_dependence(display=False)

    # For multiclass classification tasks
    atom = ATOMClassifier(X_class, y_class, random_state=1)
    atom.run(["Tree", "LGB"], metric="f1_macro")

    # Invalid target int
    with pytest.raises(ValueError, match=r".*classes, got .*"):
        atom.plot_partial_dependence(target=5, display=False)

    # Invalid target str
    with pytest.raises(ValueError, match=r".*not found in the mapping.*"):
        atom.plot_partial_dependence(target="Yes", display=False)

    atom.lgb.plot_partial_dependence(features,
                                     target=2,
                                     title="title",
                                     display=False)
Exemple #7
0
def test_stacking_models_from_branch():
    """Assert that only the models from the current branch are passed."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR", "LGB"])
    atom.branch = "branch_2"
    atom.balance()
    atom.run(["RF", "ET"])
    atom.stacking()
    assert atom.stack.models == ["RF", "ET"]
Exemple #8
0
def test_vote_branch_transformation(pipeline):
    """Assert that the branches transform every estimator only once."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.clean()
    atom.impute()
    atom.branch = "branch_2"
    atom.encode()
    atom.run(models=["Tree", "LGB"])
    atom.voting()
    assert isinstance(atom.vote.predict(X_bin, pipeline=pipeline), np.ndarray)
Exemple #9
0
def test_vote_prediction_methods():
    """Assert that the prediction methods work as intended."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.clean()
    atom.run(models=["Tree"])
    atom.branch = "branch_2"
    atom.impute(strat_num="mean", strat_cat="most_frequent")
    atom.run(["LGB"])
    atom.voting(models=["Tree", "LGB"])
    pytest.raises(AttributeError, atom.vote.decision_function, X_bin)
    assert isinstance(atom.vote.predict(X_bin), np.ndarray)
    assert isinstance(atom.vote.predict_proba(X_bin), np.ndarray)
    assert isinstance(atom.vote.score(X_bin, y_bin), np.float64)
Exemple #10
0
def test_branch_delete_current():
    """Assert that we can delete the current branch."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.branch = "branch_2"
    atom.branch.delete()
    assert "branch_2" not in atom._branches
Exemple #11
0
def test_branch_delete_depending_models():
    """Assert that an error is raised when the branch has depending models."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.branch = "branch_2"
    atom.run("LR")
    pytest.raises(PermissionError, atom.branch.delete)
Exemple #12
0
def test_getattr_branch():
    """Assert that branches can be called from the trainer."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.branch = "branch_2"
    assert atom.branch_2 is atom._branches["branch_2"]
Exemple #13
0
def test_branch_setter_from_invalid():
    """Assert that an error is raised when the from branch doesn't exist."""
    atom = ATOMClassifier(X10_nan, y10, random_state=1)
    with pytest.raises(ValueError,
                       match=r".*branch to split from does not exist.*"):
        atom.branch = "new_branch_from_invalid"
Exemple #14
0
def test_branch_setter_new():
    """Assert that we can create a new branch."""
    atom = ATOMClassifier(X10_nan, y10, random_state=1)
    atom.clean()
    atom.branch = "branch_2"
    assert list(atom._branches.keys()) == ["master", "branch_2"]
Exemple #15
0
def test_branch_delete_invalid_name():
    """Assert that an error is raised when the name is invalid."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.branch = "branch_2"
    pytest.raises(ValueError, atom.branch.delete, "invalid")
Exemple #16
0
def test_delattr_branch():
    """Assert that branches can be deleted through del."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.branch = "branch_2"
    del atom.branch
    assert list(atom._branches.keys()) == ["master"]
Exemple #17
0
def test_branch_delete_not_current():
    """Assert that we can delete any branch."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.branch = "branch_2"
    atom.branch.delete("master")
    assert "master" not in atom._branches
Exemple #18
0
def test_branch_setter_empty():
    """Assert that an error is raised when the name is empty."""
    atom = ATOMClassifier(X10_nan, y10, random_state=1)
    with pytest.raises(ValueError, match=r".*Can't create a branch.*"):
        atom.branch = ""