Ejemplo n.º 1
0
def test_stratified_fold_split():
    df = _setup_data()

    splitted = dataset.stratified_fold_split(dataframe=df,
                                             class_column="class")

    assert int == splitted["fold"].dtype
    assert set(range(5)) == set(splitted["fold"].unique())
    ants_folds = set(splitted[splitted["tag"] == "ants"]["fold"])
    bees_folds = set(splitted[splitted["tag"] == "bees"]["fold"])
    assert ants_folds == bees_folds
Ejemplo n.º 2
0
def test_stratified_fold_split_num_folds():
    df = _setup_data()

    splitted = dataset.stratified_fold_split(df, "class", n_folds=2)

    assert set(range(2)) == set(splitted["fold"].unique())
Ejemplo n.º 3
0
def split_dataframe(
    dataframe: pd.DataFrame,
    train_folds: List[int],
    valid_folds: Optional[List[int]] = None,
    infer_folds: Optional[List[int]] = None,
    tag2class: Optional[Dict[str, int]] = None,
    tag_column: str = None,
    class_column: str = None,
    seed: int = 42,
    n_folds: int = 5
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Split a Pandas DataFrame into folds.

    Args:
        dataframe (pd.DataFrame): input dataframe
        train_folds (List[int]): train folds
        valid_folds (List[int], optional): valid folds.
            If none takes all folds not included in ``train_folds``
        infer_folds (List[int], optional): infer folds.
            If none takes all folds not included in ``train_folds``
            and ``valid_folds``
        tag2class (Dict[str, int], optional): mapping from label names into int
        tag_column (str, optional): column with label names
        class_column (str, optional): column to use for split
        seed (int): seed for split
        n_folds (int): number of folds
    Returns:
        (tuple): tuple with 4 dataframes
            whole dataframe, train part, valid part and infer part
    """

    if args_are_not_none(tag2class, tag_column, class_column):
        dataframe = map_dataframe(dataframe, tag_column, class_column,
                                  tag2class)

    if class_column is not None:
        result_dataframe = stratified_fold_split(dataframe,
                                                 class_column=class_column,
                                                 random_state=seed,
                                                 n_folds=n_folds)
    else:
        result_dataframe = default_fold_split(dataframe,
                                              random_state=seed,
                                              n_folds=n_folds)

    fold_series = result_dataframe["fold"]

    train_folds = folds_to_list(train_folds)
    df_train = result_dataframe[fold_series.isin(train_folds)]

    if valid_folds is None:
        mask = ~fold_series.isin(train_folds)
        valid_folds = result_dataframe[mask]["fold"]

    valid_folds = folds_to_list(valid_folds)
    df_valid = result_dataframe[fold_series.isin(valid_folds)]

    infer_folds = folds_to_list(infer_folds or [])
    df_infer = result_dataframe[fold_series.isin(infer_folds)]

    return result_dataframe, df_train, df_valid, df_infer