Пример #1
0
def test_write_hive_bucketed_table_fallback(spark_tmp_path,
                                            spark_tmp_table_factory,
                                            fileFormat):
    """
    fallback because GPU does not support Hive hash partition
    """
    src = spark_tmp_table_factory.get()
    table = spark_tmp_table_factory.get()

    def write_hive_table(spark):

        data = map(lambda i: (i % 13, str(i), i % 5), range(50))
        df = spark.createDataFrame(data, ["i", "j", "k"])
        df.write.mode("overwrite").partitionBy("k").bucketBy(
            8, "i", "j").format(fileFormat).saveAsTable(src)

        spark.sql("""
            create table if not exists {0} 
            using hive options(fileFormat \"{1}\")
            as select * from {2} 
            """.format(table, fileFormat, src))

    data_path = spark_tmp_path + '/HIVE_DATA'

    assert_gpu_fallback_write(
        lambda spark, _: write_hive_table(spark),
        lambda spark, _: spark.sql("SELECT * FROM {}".format(table)),
        data_path,
        'DataWritingCommandExec',
        conf={
            "hive.exec.dynamic.partition": "true",
            "hive.exec.dynamic.partition.mode": "nonstrict"
        })
Пример #2
0
def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory):
    data_path = spark_tmp_path + '/PARQUET_DATA'
    assert_gpu_fallback_write(
            lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('parquet').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()),
            lambda spark, path: spark.read.parquet(path),
            data_path,
            'DataWritingCommandExec')
Пример #3
0
def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory):
    data_path = spark_tmp_path + '/ORC_DATA'
    assert_gpu_fallback_write(
            lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('orc').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()),
            lambda spark, path: spark.read.orc(path),
            data_path,
            'DataWritingCommandExec',
            conf = {'spark.rapids.sql.format.orc.write.enabled': True})
Пример #4
0
def test_csv_save_as_table_fallback(spark_tmp_path, spark_tmp_table_factory):
    gen = TimestampGen()
    data_path = spark_tmp_path + '/CSV_DATA'
    assert_gpu_fallback_write(
        lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format(
            "csv").mode('overwrite').option("path", path).saveAsTable(
                spark_tmp_table_factory.get()),
        lambda spark, path: spark.read.csv(path), data_path,
        'DataWritingCommandExec')
Пример #5
0
def test_parquet_writeLegacyFormat_fallback(spark_tmp_path, spark_tmp_table_factory):
    gen = IntegerGen()
    data_path = spark_tmp_path + '/PARQUET_DATA'
    all_confs={'spark.sql.parquet.writeLegacyFormat': 'true'}
    assert_gpu_fallback_write(
            lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()),
            lambda spark, path: spark.read.parquet(path),
            data_path,
            'DataWritingCommandExec',
            conf=all_confs)
Пример #6
0
def test_orc_write_compression_fallback(spark_tmp_path, codec, spark_tmp_table_factory):
    gen = TimestampGen()
    data_path = spark_tmp_path + '/PARQUET_DATA'
    all_confs={'spark.sql.orc.compression.codec': codec}
    assert_gpu_fallback_write(
            lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()),
            lambda spark, path: spark.read.orc(path),
            data_path,
            'DataWritingCommandExec',
            conf=all_confs)
Пример #7
0
def test_parquet_write_legacy_fallback(spark_tmp_path, ts_write, ts_rebase, spark_tmp_table_factory):
    gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))
    data_path = spark_tmp_path + '/PARQUET_DATA'
    all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
            'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
            'spark.sql.parquet.outputTimestampType': ts_write}
    assert_gpu_fallback_write(
            lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()),
            lambda spark, path: spark.read.parquet(path),
            data_path,
            'DataWritingCommandExec',
            conf=all_confs)
Пример #8
0
def test_parquet_write_field_id(spark_tmp_path):
    data_path = spark_tmp_path + '/PARQUET_DATA'
    schema = StructType([
        StructField("c1", IntegerType(), metadata={'parquet.field.id': 1}),
    ])
    data = [
        (1, ),
        (2, ),
        (3, ),
    ]
    assert_gpu_fallback_write(
        lambda spark, path: spark.createDataFrame(data, schema).coalesce(
            1).write.mode("overwrite").parquet(path),
        lambda spark, path: spark.read.parquet(path),
        data_path,
        'DataWritingCommandExec',
        conf={"spark.sql.parquet.fieldId.write.enabled":
              "true"})  # default is true