def test_get_columns_remove_duplicates(): """Assert that duplicate columns are ignored.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) assert atom._get_columns([0, 1, 0]) == ["mean radius", "mean texture"]
def test_get_columns_by_name(): """Assert that columns can be retrieved by name.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) pytest.raises(ValueError, atom._get_columns, "invalid") assert atom._get_columns("mean radius") == ["mean radius"]
def test_get_columns_slice(): """Assert that a slice of columns is returned.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) assert len(atom._get_columns(slice(2, 6))) == 4
def test_get_columns_by_index(): """Assert that columns can be retrieved by index.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1) pytest.raises(ValueError, atom._get_columns, 40) assert atom._get_columns(0) == ["mean radius"]
def test_get_columns_is_None(): """Assert that all or only numerical columns are returned.""" atom = ATOMClassifier(X10_str, y10, random_state=1) atom._get_columns(columns=None, only_numerical=True) assert len(atom._get_columns(columns=None, only_numerical=True)) == 3 assert len(atom._get_columns(columns=None, only_numerical=False)) == 4