Esempio n. 1
0
    def repartition(self, df: DataFrame,
                    partition_spec: PartitionSpec) -> DataFrame:
        def _persist_and_count(df: DataFrame) -> int:
            df = self.persist(df)
            return df.count()

        df = self.to_df(df)
        num_funcs = {KEYWORD_ROWCOUNT: lambda: _persist_and_count(df)}
        num = partition_spec.get_num_partitions(**num_funcs)

        if partition_spec.algo == "hash":
            sdf = hash_repartition(self.spark_session, df.native, num,
                                   partition_spec.partition_by)
        elif partition_spec.algo == "rand":
            sdf = rand_repartition(self.spark_session, df.native, num,
                                   partition_spec.partition_by)
        elif partition_spec.algo == "even":
            df = self.persist(df)
            sdf = even_repartition(self.spark_session, df.native, num,
                                   partition_spec.partition_by)
        else:  # pragma: no cover
            raise NotImplementedError(partition_spec.algo +
                                      " is not supported")
        sorts = partition_spec.get_sorts(df.schema)
        if len(sorts) > 0:
            sdf = sdf.sortWithinPartitions(*sorts.keys(),
                                           ascending=list(sorts.values()))
        return self.to_df(sdf, df.schema, df.metadata)
Esempio n. 2
0
def test_even_repartition_no_cols(spark_session):
    df = _df([[0, 1], [0, 2], [0, 3], [0, 4], [1, 1], [1, 2], [1, 3]],
             "a:int,b:int")
    res = even_repartition(spark_session, df, 0, []).collect()
    assert 7 == len(res)
    res = even_repartition(spark_session, df, 1,
                           []).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 7 == len([x for x in res if x[2] == 7])
    res = even_repartition(spark_session, df, 6,
                           []).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 5 == len([x for x in res if x[2] == 1])
    assert 2 == len([x for x in res if x[2] == 2])
    res = even_repartition(spark_session, df, 7,
                           []).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert all(x[2] == 1 for x in res)
    res = even_repartition(spark_session, df, 8,
                           []).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert all(x[2] == 1 for x in res)
Esempio n. 3
0
def test_even_repartition_with_cols(spark_session):
    df = _df([[0, 1], [0, 2], [0, 3], [0, 4], [1, 1], [1, 2], [1, 3]],
             "a:int,b:int")
    res = even_repartition(spark_session, df, 0,
                           ["a"]).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 4 == len([x for x in res if x[2] == 4])
    assert 3 == len([x for x in res if x[2] == 3])
    res = (even_repartition(spark_session, df, 0,
                            ["a", "b"]).rdd.mapPartitions(_pc).collect())
    assert 7 == len(res)
    assert 7 == len([x for x in res if x[2] == 1])
    res = (even_repartition(spark_session, df, 1,
                            ["a", "b"]).rdd.mapPartitions(_pc).collect())
    assert 7 == len(res)
    assert 7 == len([x for x in res if x[2] == 7])
    res = even_repartition(spark_session, df, 3,
                           ["a"]).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 4 == len([x for x in res if x[2] == 4])
    assert 3 == len([x for x in res if x[2] == 3])
    res = even_repartition(spark_session, df, 0,
                           ["b"]).rdd.mapPartitions(_pc).collect()
    assert 7 == len(res)
    assert 6 == len([x for x in res if x[2] == 2])
    assert 1 == len([x for x in res if x[2] == 1])

    # test with multiple keys and that are not the first positions
    df = _df(
        [
            [1, "a", 1],
            [2, "b", 2],
            [3, "c", 3],
            [4, "d", 4],
            [5, "e", 1],
            [6, "f", 2],
            [7, "g", 3],
        ],
        "z:int,a:str,b:int",
    )
    res = (even_repartition(spark_session, df, 0,
                            ["b", "z"]).rdd.mapPartitions(_pc).collect())
    assert sorted(res) == sorted([
        [1, "a", 1, 1],
        [2, "b", 2, 1],
        [3, "c", 3, 1],
        [4, "d", 4, 1],
        [5, "e", 1, 1],
        [6, "f", 2, 1],
        [7, "g", 3, 1],
    ])