def test_has_same_base_dtype(): arr_int32 = np.array([1, 2, 3], dtype=np.int32) arr_int64 = np.array([1, 2, 3], dtype=np.int64) arr_float32 = np.array([1, 2, 3], dtype=np.float32) arr_float64 = np.array([1, 2, 3], dtype=np.float64) arr_str = ["a", "b", "c"] df_1 = pd.DataFrame(dict(a=arr_int32, b=arr_int64)) df_2 = pd.DataFrame(dict(a=arr_int64, b=arr_int32)) df_3 = pd.DataFrame(dict(a=arr_float32, b=arr_int32)) df_4 = pd.DataFrame(dict(a=arr_float64, b=arr_float64)) df_5 = pd.DataFrame(dict(a=arr_float64, b=arr_float64, c=arr_float64)) df_6 = pd.DataFrame(dict(a=arr_str)) # all columns match assert has_same_base_dtype(df_1, df_2) # specific column matches assert has_same_base_dtype(df_3, df_4, columns=["a"]) # some column types do not match assert not has_same_base_dtype(df_3, df_4) # column types do not match assert not has_same_base_dtype(df_1, df_3, columns=["a"]) # all columns are not shared assert not has_same_base_dtype(df_4, df_5) # column types do not match assert not has_same_base_dtype(df_5, df_6, columns=["a"]) # assert string columns match assert has_same_base_dtype(df_6, df_6)
def test_has_same_base_dtype(): arr_int32 = np.array([1, 2, 3], dtype=np.int32) arr_int64 = np.array([1, 2, 3], dtype=np.int64) arr_float32 = np.array([1, 2, 3], dtype=np.float32) arr_float64 = np.array([1, 2, 3], dtype=np.float64) arr_str = ['a', 'b', 'c'] df_1 = pd.DataFrame(dict(a=arr_int32, b=arr_int64)) df_2 = pd.DataFrame(dict(a=arr_int64, b=arr_int32)) df_3 = pd.DataFrame(dict(a=arr_float32, b=arr_int32)) df_4 = pd.DataFrame(dict(a=arr_float64, b=arr_float64)) df_5 = pd.DataFrame(dict(a=arr_float64, b=arr_float64, c=arr_float64)) df_6 = pd.DataFrame(dict(a=arr_str)) # all columns match assert has_same_base_dtype(df_1, df_2) # specific column matches assert has_same_base_dtype(df_3, df_4, columns=['a']) # some column types do not match assert not has_same_base_dtype(df_3, df_4) # column types do not match assert not has_same_base_dtype(df_1, df_3, columns=['a']) # all columns are not shared assert not has_same_base_dtype(df_4, df_5) # column types do not match assert not has_same_base_dtype(df_5, df_6, columns=['a']) # assert string columns match assert has_same_base_dtype(df_6, df_6)
def check_column_dtypes_wrapper( rating_true, rating_pred, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, col_rating=DEFAULT_RATING_COL, col_prediction=DEFAULT_PREDICTION_COL, *args, **kwargs ): """Check columns of DataFrame inputs Args: rating_true (pd.DataFrame): True data rating_pred (pd.DataFrame): Predicted data col_user (str): column name for user col_item (str): column name for item col_rating (str): column name for rating col_prediction (str): column name for prediction """ if not has_columns(rating_true, [col_user, col_item, col_rating]): raise ValueError("Missing columns in true rating DataFrame") if not has_columns(rating_pred, [col_user, col_item, col_prediction]): raise ValueError("Missing columns in predicted rating DataFrame") if not has_same_base_dtype( rating_true, rating_pred, columns=[col_user, col_item] ): raise ValueError("Columns in provided DataFrames are not the same datatype") return func( rating_true=rating_true, rating_pred=rating_pred, col_user=col_user, col_item=col_item, col_rating=col_rating, col_prediction=col_prediction, *args, **kwargs )