def test_attrs_are_passed(): """Assert that the attributes from the parent are passed.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.balance() atom.branch = "branch_2" assert atom.branch_2.idx is not atom.master.idx assert atom.branch_2.adasyn is atom.master.adasyn
def test_stacking_models_from_branch(): """Assert that only the models from the current branch are passed.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.run(["LR", "LGB"]) atom.branch = "branch_2" atom.balance() atom.run(["RF", "ET"]) atom.stacking() assert atom.stack.models == ["RF", "ET"]
def test_transform_data_multiple_branches(): """Assert that the data is transformed with multiple branches.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.prune() atom.branch = "branch_2" atom.balance() atom.feature_generation(strategy="dfs", n_features=5) atom.branch = "branch_3" atom.feature_selection(strategy="sfm", solver="lgb", n_features=20) atom.save(FILE_DIR + "atom_2", save_data=False) atom2 = ATOMLoader(FILE_DIR + "atom_2", data=(X_bin, y_bin), transform_data=True) for branch in atom._branches: assert atom2._branches[branch].data.equals(atom._branches[branch].data)
def test_balance_attribute(): """Assert that Balancer's estimator is attached to ATOM.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.balance(strategy="NearMiss") assert atom.nearmiss.__class__.__name__ == "NearMiss"
def test_balance_mapping(): """Assert that the balance method gets the mapping attribute from atom.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.balance() assert atom.pipeline[0].mapping == atom.mapping
def test_balance(): """Assert that the balance method balances the training set.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) length = (atom.y_train == 0).sum() atom.balance() assert (atom.y_train == 0).sum() != length