def test_write_hive_bucketed_table_fallback(spark_tmp_path, spark_tmp_table_factory, fileFormat): """ fallback because GPU does not support Hive hash partition """ src = spark_tmp_table_factory.get() table = spark_tmp_table_factory.get() def write_hive_table(spark): data = map(lambda i: (i % 13, str(i), i % 5), range(50)) df = spark.createDataFrame(data, ["i", "j", "k"]) df.write.mode("overwrite").partitionBy("k").bucketBy( 8, "i", "j").format(fileFormat).saveAsTable(src) spark.sql(""" create table if not exists {0} using hive options(fileFormat \"{1}\") as select * from {2} """.format(table, fileFormat, src)) data_path = spark_tmp_path + '/HIVE_DATA' assert_gpu_fallback_write( lambda spark, _: write_hive_table(spark), lambda spark, _: spark.sql("SELECT * FROM {}".format(table)), data_path, 'DataWritingCommandExec', conf={ "hive.exec.dynamic.partition": "true", "hive.exec.dynamic.partition.mode": "nonstrict" })
def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory): data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_fallback_write( lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('parquet').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), lambda spark, path: spark.read.parquet(path), data_path, 'DataWritingCommandExec')
def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory): data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_fallback_write( lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('orc').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), lambda spark, path: spark.read.orc(path), data_path, 'DataWritingCommandExec', conf = {'spark.rapids.sql.format.orc.write.enabled': True})
def test_csv_save_as_table_fallback(spark_tmp_path, spark_tmp_table_factory): gen = TimestampGen() data_path = spark_tmp_path + '/CSV_DATA' assert_gpu_fallback_write( lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format( "csv").mode('overwrite').option("path", path).saveAsTable( spark_tmp_table_factory.get()), lambda spark, path: spark.read.csv(path), data_path, 'DataWritingCommandExec')
def test_parquet_writeLegacyFormat_fallback(spark_tmp_path, spark_tmp_table_factory): gen = IntegerGen() data_path = spark_tmp_path + '/PARQUET_DATA' all_confs={'spark.sql.parquet.writeLegacyFormat': 'true'} assert_gpu_fallback_write( lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), lambda spark, path: spark.read.parquet(path), data_path, 'DataWritingCommandExec', conf=all_confs)
def test_orc_write_compression_fallback(spark_tmp_path, codec, spark_tmp_table_factory): gen = TimestampGen() data_path = spark_tmp_path + '/PARQUET_DATA' all_confs={'spark.sql.orc.compression.codec': codec} assert_gpu_fallback_write( lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), lambda spark, path: spark.read.orc(path), data_path, 'DataWritingCommandExec', conf=all_confs)
def test_parquet_write_legacy_fallback(spark_tmp_path, ts_write, ts_rebase, spark_tmp_table_factory): gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) data_path = spark_tmp_path + '/PARQUET_DATA' all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, 'spark.sql.parquet.outputTimestampType': ts_write} assert_gpu_fallback_write( lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), lambda spark, path: spark.read.parquet(path), data_path, 'DataWritingCommandExec', conf=all_confs)
def test_parquet_write_field_id(spark_tmp_path): data_path = spark_tmp_path + '/PARQUET_DATA' schema = StructType([ StructField("c1", IntegerType(), metadata={'parquet.field.id': 1}), ]) data = [ (1, ), (2, ), (3, ), ] assert_gpu_fallback_write( lambda spark, path: spark.createDataFrame(data, schema).coalesce( 1).write.mode("overwrite").parquet(path), lambda spark, path: spark.read.parquet(path), data_path, 'DataWritingCommandExec', conf={"spark.sql.parquet.fieldId.write.enabled": "true"}) # default is true