def test_has_same_base_dtype():
    arr_int32 = np.array([1, 2, 3], dtype=np.int32)
    arr_int64 = np.array([1, 2, 3], dtype=np.int64)
    arr_float32 = np.array([1, 2, 3], dtype=np.float32)
    arr_float64 = np.array([1, 2, 3], dtype=np.float64)
    arr_str = ["a", "b", "c"]

    df_1 = pd.DataFrame(dict(a=arr_int32, b=arr_int64))
    df_2 = pd.DataFrame(dict(a=arr_int64, b=arr_int32))
    df_3 = pd.DataFrame(dict(a=arr_float32, b=arr_int32))
    df_4 = pd.DataFrame(dict(a=arr_float64, b=arr_float64))
    df_5 = pd.DataFrame(dict(a=arr_float64, b=arr_float64, c=arr_float64))
    df_6 = pd.DataFrame(dict(a=arr_str))

    # all columns match
    assert has_same_base_dtype(df_1, df_2)
    # specific column matches
    assert has_same_base_dtype(df_3, df_4, columns=["a"])
    # some column types do not match
    assert not has_same_base_dtype(df_3, df_4)
    # column types do not match
    assert not has_same_base_dtype(df_1, df_3, columns=["a"])
    # all columns are not shared
    assert not has_same_base_dtype(df_4, df_5)
    # column types do not match
    assert not has_same_base_dtype(df_5, df_6, columns=["a"])
    # assert string columns match
    assert has_same_base_dtype(df_6, df_6)
def test_has_same_base_dtype():
    arr_int32 = np.array([1, 2, 3], dtype=np.int32)
    arr_int64 = np.array([1, 2, 3], dtype=np.int64)
    arr_float32 = np.array([1, 2, 3], dtype=np.float32)
    arr_float64 = np.array([1, 2, 3], dtype=np.float64)
    arr_str = ['a', 'b', 'c']

    df_1 = pd.DataFrame(dict(a=arr_int32, b=arr_int64))
    df_2 = pd.DataFrame(dict(a=arr_int64, b=arr_int32))
    df_3 = pd.DataFrame(dict(a=arr_float32, b=arr_int32))
    df_4 = pd.DataFrame(dict(a=arr_float64, b=arr_float64))
    df_5 = pd.DataFrame(dict(a=arr_float64, b=arr_float64, c=arr_float64))
    df_6 = pd.DataFrame(dict(a=arr_str))

    # all columns match
    assert has_same_base_dtype(df_1, df_2)
    # specific column matches
    assert has_same_base_dtype(df_3, df_4, columns=['a'])
    # some column types do not match
    assert not has_same_base_dtype(df_3, df_4)
    # column types do not match
    assert not has_same_base_dtype(df_1, df_3, columns=['a'])
    # all columns are not shared
    assert not has_same_base_dtype(df_4, df_5)
    # column types do not match
    assert not has_same_base_dtype(df_5, df_6, columns=['a'])
    # assert string columns match
    assert has_same_base_dtype(df_6, df_6)
    def check_column_dtypes_wrapper(
        rating_true,
        rating_pred,
        col_user=DEFAULT_USER_COL,
        col_item=DEFAULT_ITEM_COL,
        col_rating=DEFAULT_RATING_COL,
        col_prediction=DEFAULT_PREDICTION_COL,
        *args,
        **kwargs
    ):
        """Check columns of DataFrame inputs

        Args:
            rating_true (pd.DataFrame): True data
            rating_pred (pd.DataFrame): Predicted data
            col_user (str): column name for user
            col_item (str): column name for item
            col_rating (str): column name for rating
            col_prediction (str): column name for prediction
        """

        if not has_columns(rating_true, [col_user, col_item, col_rating]):
            raise ValueError("Missing columns in true rating DataFrame")
        if not has_columns(rating_pred, [col_user, col_item, col_prediction]):
            raise ValueError("Missing columns in predicted rating DataFrame")
        if not has_same_base_dtype(
            rating_true, rating_pred, columns=[col_user, col_item]
        ):
            raise ValueError("Columns in provided DataFrames are not the same datatype")

        return func(
            rating_true=rating_true,
            rating_pred=rating_pred,
            col_user=col_user,
            col_item=col_item,
            col_rating=col_rating,
            col_prediction=col_prediction,
            *args,
            **kwargs
        )
Esempio n. 4
0
    def check_column_dtypes_wrapper(
        rating_true,
        rating_pred,
        col_user=DEFAULT_USER_COL,
        col_item=DEFAULT_ITEM_COL,
        col_rating=DEFAULT_RATING_COL,
        col_prediction=DEFAULT_PREDICTION_COL,
        *args,
        **kwargs
    ):
        """Check columns of DataFrame inputs

        Args:
            rating_true (pd.DataFrame): True data
            rating_pred (pd.DataFrame): Predicted data
            col_user (str): column name for user
            col_item (str): column name for item
            col_rating (str): column name for rating
            col_prediction (str): column name for prediction
        """

        if not has_columns(rating_true, [col_user, col_item, col_rating]):
            raise ValueError("Missing columns in true rating DataFrame")
        if not has_columns(rating_pred, [col_user, col_item, col_prediction]):
            raise ValueError("Missing columns in predicted rating DataFrame")
        if not has_same_base_dtype(
            rating_true, rating_pred, columns=[col_user, col_item]
        ):
            raise ValueError("Columns in provided DataFrames are not the same datatype")

        return func(
            rating_true=rating_true,
            rating_pred=rating_pred,
            col_user=col_user,
            col_item=col_item,
            col_rating=col_rating,
            col_prediction=col_prediction,
            *args,
            **kwargs
        )