def test_transform_data_multiple_branches(): """Assert that the data is transformed with multiple branches.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.prune() atom.branch = "branch_2" atom.balance() atom.feature_generation(strategy="dfs", n_features=5) atom.branch = "branch_3" atom.feature_selection(strategy="sfm", solver="lgb", n_features=20) atom.save(FILE_DIR + "atom_2", save_data=False) atom2 = ATOMLoader(FILE_DIR + "atom_2", data=(X_bin, y_bin), transform_data=True) for branch in atom._branches: assert atom2._branches[branch].data.equals(atom._branches[branch].data)
def test_transform_data(): """Assert that the data is transformed correctly.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.prune(columns=slice(3, 10)) atom.apply(lambda x: x + 2, column="mean radius") atom.feature_generation(strategy="dfs", n_features=5) atom.feature_selection(strategy="sfm", solver="lgb", n_features=10) atom.save(FILE_DIR + "atom", save_data=False) atom2 = ATOMLoader(FILE_DIR + "atom", data=(X_bin, y_bin), transform_data=True) assert atom2.dataset.shape == atom.dataset.shape atom3 = ATOMLoader(FILE_DIR + "atom", data=(X_bin, y_bin), transform_data=False) assert atom3.dataset.shape == merge(X_bin, y_bin).shape
def test_feature_generation_attributes(): """Assert that the attrs from feature_generation are passed to atom.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.feature_generation("GFG", n_features=2, generations=5, population=200) assert hasattr(atom, "symbolic_transformer") assert hasattr(atom, "genetic_features")
def test_feature_generation(): """Assert that the feature_generation method creates extra features.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) atom.feature_generation(n_features=2, generations=5, population=200) assert atom.X.shape[1] == X_bin.shape[1] + 2