def python_random_split(data, ratio=0.75, seed=42): """Pandas random splitter. The splitter randomly splits the input data. Args: data (pandas.DataFrame): Pandas DataFrame to be split. ratio (float or list): Ratio for splitting data. If it is a single float number it splits data into two halves and the ratio argument indicates the ratio of training data set; if it is a list of float numbers, the splitter splits data into several portions corresponding to the split ratios. If a list is provided and the ratios are not summed to 1, they will be normalized. seed (int): Seed. Returns: list: Splits of the input data as pandas.DataFrame. """ multi_split, ratio = process_split_ratio(ratio) if multi_split: splits = split_pandas_data_with_ratios(data, ratio, shuffle=True, seed=seed) splits_new = [x.drop("split_index", axis=1) for x in splits] return splits_new else: return sk_split(data, test_size=None, train_size=ratio, random_state=seed)
def spark_random_split(data, ratio=0.75, seed=42): """Spark random splitter. Randomly split the data into several splits. Args: data (pyspark.sql.DataFrame): Spark DataFrame to be split. ratio (float or list): Ratio for splitting data. If it is a single float number it splits data into two halves and the ratio argument indicates the ratio of training data set; if it is a list of float numbers, the splitter splits data into several portions corresponding to the split ratios. If a list is provided and the ratios are not summed to 1, they will be normalized. seed (int): Seed. Returns: list: Splits of the input data as pyspark.sql.DataFrame. """ multi_split, ratio = process_split_ratio(ratio) if multi_split: return data.randomSplit(ratio, seed=seed) else: return data.randomSplit([ratio, 1 - ratio], seed=seed)
def _do_stratification_spark( data, ratio=0.75, min_rating=1, filter_by="user", is_partitioned=True, is_random=True, seed=42, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, col_timestamp=DEFAULT_TIMESTAMP_COL, ): """Helper function to perform stratified splits. This function splits data in a stratified manner. That is, the same values for the filter_by column are retained in each split, but the corresponding set of entries are divided according to the ratio provided. Args: data (pyspark.sql.DataFrame): Spark DataFrame to be split. ratio (float or list): Ratio for splitting data. If it is a single float number it splits data into two sets and the ratio argument indicates the ratio of training data set; if it is a list of float numbers, the splitter splits data into several portions corresponding to the split ratios. If a list is provided and the ratios are not summed to 1, they will be normalized. min_rating (int): minimum number of ratings for user or item. filter_by (str): either "user" or "item", depending on which of the two is to filter with min_rating. is_partitioned (bool): flag to partition data by filter_by column is_random (bool): flag to make split randomly or use timestamp column seed (int): Seed. col_user (str): column name of user IDs. col_item (str): column name of item IDs. col_timestamp (str): column name of timestamps. Args: Returns: """ # A few preliminary checks. if filter_by not in ["user", "item"]: raise ValueError("filter_by should be either 'user' or 'item'.") if min_rating < 1: raise ValueError( "min_rating should be integer and larger than or equal to 1.") if col_user not in data.columns: raise ValueError("Schema of data not valid. Missing User Col") if col_item not in data.columns: raise ValueError("Schema of data not valid. Missing Item Col") if not is_random: if col_timestamp not in data.columns: raise ValueError("Schema of data not valid. Missing Timestamp Col") if min_rating > 1: data = min_rating_filter_spark( data=data, min_rating=min_rating, filter_by=filter_by, col_user=col_user, col_item=col_item, ) split_by = col_user if filter_by == "user" else col_item partition_by = split_by if is_partitioned else [] order_by = F.rand(seed=seed) if is_random else F.col(col_timestamp) window_count = Window.partitionBy(partition_by) window_spec = Window.partitionBy(partition_by).orderBy(order_by) data = (data.withColumn("_count", F.count(split_by).over(window_count)).withColumn( "_rank", F.row_number().over(window_spec) / F.col("_count")).drop("_count")) multi_split, ratio = process_split_ratio(ratio) ratio = ratio if multi_split else [ratio, 1 - ratio] splits = [] prev_split = None for split in np.cumsum(ratio): condition = F.col("_rank") <= split if prev_split is not None: condition &= F.col("_rank") > prev_split splits.append(data.filter(condition).drop("_rank")) prev_split = split return splits
def _do_stratification( data, ratio=0.75, min_rating=1, filter_by="user", is_random=True, seed=42, col_user=DEFAULT_USER_COL, col_item=DEFAULT_ITEM_COL, col_timestamp=DEFAULT_TIMESTAMP_COL, ): # A few preliminary checks. if not (filter_by == "user" or filter_by == "item"): raise ValueError("filter_by should be either 'user' or 'item'.") if min_rating < 1: raise ValueError("min_rating should be integer and larger than or equal to 1.") if col_user not in data.columns: raise ValueError("Schema of data not valid. Missing User Col") if col_item not in data.columns: raise ValueError("Schema of data not valid. Missing Item Col") if not is_random: if col_timestamp not in data.columns: raise ValueError("Schema of data not valid. Missing Timestamp Col") multi_split, ratio = process_split_ratio(ratio) split_by_column = col_user if filter_by == "user" else col_item ratio = ratio if multi_split else [ratio, 1 - ratio] if min_rating > 1: data = min_rating_filter_pandas( data, min_rating=min_rating, filter_by=filter_by, col_user=col_user, col_item=col_item, ) # Split by each group and aggregate splits together. splits = [] # If it is for chronological splitting, the split will be performed in a random way. df_grouped = ( data.sort_values(col_timestamp).groupby(split_by_column) if is_random is False else data.groupby(split_by_column) ) for _, group in df_grouped: group_splits = split_pandas_data_with_ratios( group, ratio, shuffle=is_random, seed=seed ) # Concatenate the list of split dataframes. concat_group_splits = pd.concat(group_splits) splits.append(concat_group_splits) # Concatenate splits for all the groups together. splits_all = pd.concat(splits) # Take split by split_index splits_list = [ splits_all[splits_all["split_index"] == x].drop("split_index", axis=1) for x in range(len(ratio)) ] return splits_list