Exemple #1
0
def test_attrs_are_passed():
    """Assert that the attributes from the parent are passed."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.balance()
    atom.branch = "branch_2"
    assert atom.branch_2.idx is not atom.master.idx
    assert atom.branch_2.adasyn is atom.master.adasyn
Exemple #2
0
def test_stacking_models_from_branch():
    """Assert that only the models from the current branch are passed."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.run(["LR", "LGB"])
    atom.branch = "branch_2"
    atom.balance()
    atom.run(["RF", "ET"])
    atom.stacking()
    assert atom.stack.models == ["RF", "ET"]
Exemple #3
0
def test_transform_data_multiple_branches():
    """Assert that the data is transformed with multiple branches."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.prune()
    atom.branch = "branch_2"
    atom.balance()
    atom.feature_generation(strategy="dfs", n_features=5)
    atom.branch = "branch_3"
    atom.feature_selection(strategy="sfm", solver="lgb", n_features=20)
    atom.save(FILE_DIR + "atom_2", save_data=False)

    atom2 = ATOMLoader(FILE_DIR + "atom_2",
                       data=(X_bin, y_bin),
                       transform_data=True)
    for branch in atom._branches:
        assert atom2._branches[branch].data.equals(atom._branches[branch].data)
Exemple #4
0
def test_balance_attribute():
    """Assert that Balancer's estimator is attached to ATOM."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.balance(strategy="NearMiss")
    assert atom.nearmiss.__class__.__name__ == "NearMiss"
Exemple #5
0
def test_balance_mapping():
    """Assert that the balance method gets the mapping attribute from atom."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    atom.balance()
    assert atom.pipeline[0].mapping == atom.mapping
Exemple #6
0
def test_balance():
    """Assert that the balance method balances the training set."""
    atom = ATOMClassifier(X_bin, y_bin, random_state=1)
    length = (atom.y_train == 0).sum()
    atom.balance()
    assert (atom.y_train == 0).sum() != length