def repartition(self, df: DataFrame, partition_spec: PartitionSpec) -> DataFrame: def _persist_and_count(df: DataFrame) -> int: df = self.persist(df) return df.count() df = self.to_df(df) num_funcs = {KEYWORD_ROWCOUNT: lambda: _persist_and_count(df)} num = partition_spec.get_num_partitions(**num_funcs) if partition_spec.algo == "hash": sdf = hash_repartition(self.spark_session, df.native, num, partition_spec.partition_by) elif partition_spec.algo == "rand": sdf = rand_repartition(self.spark_session, df.native, num, partition_spec.partition_by) elif partition_spec.algo == "even": df = self.persist(df) sdf = even_repartition(self.spark_session, df.native, num, partition_spec.partition_by) else: # pragma: no cover raise NotImplementedError(partition_spec.algo + " is not supported") sorts = partition_spec.get_sorts(df.schema) if len(sorts) > 0: sdf = sdf.sortWithinPartitions(*sorts.keys(), ascending=list(sorts.values())) return self.to_df(sdf, df.schema, df.metadata)
def test_rand_repartition(spark_session): df = _df([[0, 1], [0, 2], [0, 3], [0, 4], [1, 1], [1, 2], [1, 3]], "a:int,b:int") for i in [0, 1, 2, 3]: for p in [[], ["a"]]: res = (rand_repartition(spark_session, df, i, p).rdd.mapPartitions(_pc).collect()) assert 7 == len(res)