コード例 #1
0
    def repartition(self, df: DataFrame,
                    partition_spec: PartitionSpec) -> DataFrame:
        def _persist_and_count(df: DataFrame) -> int:
            df = self.persist(df)
            return df.count()

        df = self.to_df(df)
        num_funcs = {KEYWORD_ROWCOUNT: lambda: _persist_and_count(df)}
        num = partition_spec.get_num_partitions(**num_funcs)

        if partition_spec.algo == "hash":
            sdf = hash_repartition(self.spark_session, df.native, num,
                                   partition_spec.partition_by)
        elif partition_spec.algo == "rand":
            sdf = rand_repartition(self.spark_session, df.native, num,
                                   partition_spec.partition_by)
        elif partition_spec.algo == "even":
            df = self.persist(df)
            sdf = even_repartition(self.spark_session, df.native, num,
                                   partition_spec.partition_by)
        else:  # pragma: no cover
            raise NotImplementedError(partition_spec.algo +
                                      " is not supported")
        sorts = partition_spec.get_sorts(df.schema)
        if len(sorts) > 0:
            sdf = sdf.sortWithinPartitions(*sorts.keys(),
                                           ascending=list(sorts.values()))
        return self.to_df(sdf, df.schema, df.metadata)
コード例 #2
0
def test_hash_repartition(spark_session):
    df = _df([[0, 1], [0, 2], [0, 3], [0, 4], [1, 1], [1, 2], [1, 3]],
             "a:int,b:int")
    res = hash_repartition(spark_session, df, 0, []).collect()
    assert 7 == len(res)
    res = hash_repartition(spark_session, df, 1,
                           []).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 7 == len([x for x in res if x[2] == 7])
    res = hash_repartition(spark_session, df, 1,
                           ["a"]).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 7 == len([x for x in res if x[2] == 7])
    res = hash_repartition(spark_session, df, 0,
                           ["a"]).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    res = (hash_repartition(spark_session, df, 10,
                            ["a"]).rdd.mapPartitions(_pc).collect())
    assert 7 == len(res)
    assert 4 == len([x for x in res if x[2] == 4])
    assert 3 == len([x for x in res if x[2] == 3])