def test_branch_setter_change(): """Assert that we can change to an old branch.""" atom = ATOMClassifier(X10_nan, y10, random_state=1) atom.branch = "branch_2" atom.clean() atom.branch = "master" assert atom.pipeline.empty # Has no clean estimator
def test_branch_setter_from_valid(): """Assert that we can create a new branch, not from the current one.""" atom = ATOMClassifier(X10_nan, y10, random_state=1) atom.branch = "branch_2" atom.impute() atom.branch = "branch_3_from_master" assert atom.branch.name == "branch_3" assert atom.n_nans > 0
def test_attrs_are_passed(): """Assert that the attributes from the parent are passed.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.balance() atom.branch = "branch_2" assert atom.branch_2.idx is not atom.master.idx assert atom.branch_2.adasyn is atom.master.adasyn
def test_transform_data_multiple_branches(): """Assert that the data is transformed with multiple branches.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.prune() atom.branch = "branch_2" atom.balance() atom.feature_generation(strategy="dfs", n_features=5) atom.branch = "branch_3" atom.feature_selection(strategy="sfm", solver="lgb", n_features=20) atom.save(FILE_DIR + "atom_2", save_data=False) atom2 = ATOMLoader(FILE_DIR + "atom_2", data=(X_bin, y_bin), transform_data=True) for branch in atom._branches: assert atom2._branches[branch].data.equals(atom._branches[branch].data)
def test_repr(): """Assert that the __repr__ method visualizes the pipeline(s).""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.scale() assert "Branches: master" in str(atom) atom.branch = "branch_2" assert "Branches:\n >>> master\n >>> branch_2 !" in str(atom)
def test_plot_partial_dependence(features): """Assert that the plot_partial_dependence method work as intended.""" # For binary classification tasks atom = ATOMClassifier(X_bin, y_bin, random_state=1) pytest.raises(NotFittedError, atom.plot_partial_dependence) atom.run(["Tree", "LGB"], metric="f1") # More than 3 features with pytest.raises(ValueError, match=r".*Maximum 3 allowed.*"): atom.plot_partial_dependence(features=[0, 1, 2, 3], display=False) # Triple feature with pytest.raises(ValueError, match=r".*should be single or in pairs.*"): atom.lgb.plot_partial_dependence(features=[(0, 1, 2), 2], display=False) # Pair for multi-model with pytest.raises(ValueError, match=r".*when plotting multiple models.*"): atom.plot_partial_dependence(features=[(0, 2), 2], display=False) # Unknown feature with pytest.raises(ValueError, match=r".*not found in the dataset.*"): atom.plot_partial_dependence(features=["test", 2], display=False) # Invalid index with pytest.raises(ValueError, match=r".*got index.*"): atom.plot_partial_dependence(features=[120, 2], display=False) # Different features for multiple models atom.branch = "branch_2" atom.feature_selection(strategy="pca", n_features=5) atom.run(["tree2"]) with pytest.raises(ValueError, match=r".*models use the same features.*"): atom.plot_partial_dependence(features=(0, 1), display=False) atom.delete("Tree2") # Drop model created for test atom.branch.delete() # Drop branch created for test atom.plot_partial_dependence(display=False) atom.lgb.plot_feature_importance(show=5, display=False) atom.lgb.plot_partial_dependence(display=False) # For multiclass classification tasks atom = ATOMClassifier(X_class, y_class, random_state=1) atom.run(["Tree", "LGB"], metric="f1_macro") # Invalid target int with pytest.raises(ValueError, match=r".*classes, got .*"): atom.plot_partial_dependence(target=5, display=False) # Invalid target str with pytest.raises(ValueError, match=r".*not found in the mapping.*"): atom.plot_partial_dependence(target="Yes", display=False) atom.lgb.plot_partial_dependence(features, target=2, title="title", display=False)
def test_stacking_models_from_branch(): """Assert that only the models from the current branch are passed.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.run(["LR", "LGB"]) atom.branch = "branch_2" atom.balance() atom.run(["RF", "ET"]) atom.stacking() assert atom.stack.models == ["RF", "ET"]
def test_vote_branch_transformation(pipeline): """Assert that the branches transform every estimator only once.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.clean() atom.impute() atom.branch = "branch_2" atom.encode() atom.run(models=["Tree", "LGB"]) atom.voting() assert isinstance(atom.vote.predict(X_bin, pipeline=pipeline), np.ndarray)
def test_vote_prediction_methods(): """Assert that the prediction methods work as intended.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.clean() atom.run(models=["Tree"]) atom.branch = "branch_2" atom.impute(strat_num="mean", strat_cat="most_frequent") atom.run(["LGB"]) atom.voting(models=["Tree", "LGB"]) pytest.raises(AttributeError, atom.vote.decision_function, X_bin) assert isinstance(atom.vote.predict(X_bin), np.ndarray) assert isinstance(atom.vote.predict_proba(X_bin), np.ndarray) assert isinstance(atom.vote.score(X_bin, y_bin), np.float64)
def test_branch_delete_current(): """Assert that we can delete the current branch.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.branch = "branch_2" atom.branch.delete() assert "branch_2" not in atom._branches
def test_branch_delete_depending_models(): """Assert that an error is raised when the branch has depending models.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.branch = "branch_2" atom.run("LR") pytest.raises(PermissionError, atom.branch.delete)
def test_getattr_branch(): """Assert that branches can be called from the trainer.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.branch = "branch_2" assert atom.branch_2 is atom._branches["branch_2"]
def test_branch_setter_from_invalid(): """Assert that an error is raised when the from branch doesn't exist.""" atom = ATOMClassifier(X10_nan, y10, random_state=1) with pytest.raises(ValueError, match=r".*branch to split from does not exist.*"): atom.branch = "new_branch_from_invalid"
def test_branch_setter_new(): """Assert that we can create a new branch.""" atom = ATOMClassifier(X10_nan, y10, random_state=1) atom.clean() atom.branch = "branch_2" assert list(atom._branches.keys()) == ["master", "branch_2"]
def test_branch_delete_invalid_name(): """Assert that an error is raised when the name is invalid.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.branch = "branch_2" pytest.raises(ValueError, atom.branch.delete, "invalid")
def test_delattr_branch(): """Assert that branches can be deleted through del.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.branch = "branch_2" del atom.branch assert list(atom._branches.keys()) == ["master"]
def test_branch_delete_not_current(): """Assert that we can delete any branch.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.branch = "branch_2" atom.branch.delete("master") assert "master" not in atom._branches
def test_branch_setter_empty(): """Assert that an error is raised when the name is empty.""" atom = ATOMClassifier(X10_nan, y10, random_state=1) with pytest.raises(ValueError, match=r".*Can't create a branch.*"): atom.branch = ""