def test_rlike_fallback_null_pattern(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'a rlike "a\u0000"'), 'RLike', conf=_regexp_conf)
def test_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list, ansi_enabled, time_parser_policy): gen = StructGen([('a', DateGen())], nullable=False) data_path = spark_tmp_path + '/CSV_DATA' schema = gen.data_type updated_conf = copy_and_update( _enable_all_types_conf, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, 'spark.sql.ansi.enabled': ansi_enabled, 'spark.sql.legacy.timeParserPolicy': time_parser_policy }) with_cpu_session( lambda spark : gen_df(spark, gen).write\ .option('dateFormat', date_format)\ .csv(data_path)) if time_parser_policy == 'LEGACY': expected_class = 'FileSourceScanExec' if v1_enabled_list == '': expected_class = 'BatchScanExec' assert_gpu_fallback_collect( lambda spark : spark.read \ .schema(schema) \ .option('dateFormat', date_format) \ .csv(data_path), expected_class, conf=updated_conf) else: assert_gpu_and_cpu_are_equal_collect( lambda spark : spark.read\ .schema(schema)\ .option('dateFormat', date_format)\ .csv(data_path), conf=updated_conf)
def test_round_robin_sort_fallback(data_gen): from pyspark.sql.functions import lit assert_gpu_fallback_collect( # Add a computed column to avoid shuffle being optimized back to a CPU shuffle like in test_repartition_df lambda spark: gen_df(spark, data_gen).withColumn('extra', lit(1)). repartition(13), 'ShuffleExchangeExec')
def test_rlike_fallback_possessive_quantifier(): gen = mk_str_gen('(\u20ac|\\w){0,3}a[|b*.$\r\n]{0,2}c\\w{0,3}') assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'a rlike "a*+"'), 'RLike', conf=_regexp_conf)
def test_rlike_fallback_empty_group(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'a rlike "a()?"'), 'RLike', conf=_regexp_conf)
def test_single_nested_orderby_fallback_for_nullorder(data_gen, order): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).orderBy(order), "SortExec", conf={ **allow_negative_scale_of_decimal_conf, })
def test_map_expr_multi_non_literal_keys_fallback(): data_gen = [('a', StringGen(nullable=False)), ('b', StringGen(nullable=False))] assert_gpu_fallback_collect( lambda spark: gen_df(spark, data_gen).selectExpr( 'map(a, b, b, a) as m1'), 'ProjectExec', conf={'spark.rapids.sql.createMap.enabled': False})
def test_sortmerge_join_struct_as_key_fallback(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 500) return left.join(right, left.a == right.r_a, join_type) assert_gpu_fallback_collect(do_join, 'SortMergeJoinExec', conf=_sortmerge_join_conf)
def test_explain_only_sortmerge_join(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 500) return left.join(right, left.a == right.r_a, join_type) assert_gpu_fallback_collect(do_join, 'SortMergeJoinExec', conf=_explain_mode_conf)
def test_nested_part_fallback(part_gen): data_gen = [('a', RepeatSeqGen(part_gen, length=20)), ('b', LongRangeGen()), ('c', int_gen)] window_spec = Window.partitionBy('a').orderBy('b').rowsBetween(-5, 5) def do_it(spark): return gen_df(spark, data_gen, length=2048) \ .withColumn('rn', f.count('c').over(window_spec)) assert_gpu_fallback_collect(do_it, 'WindowExec')
def test_broadcast_nested_loop_with_conditionals_build_right_fallback( data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 50, 25) return left.join(broadcast(right), (left.b >= right.r_b), join_type) conf = allow_negative_scale_of_decimal_conf assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec', conf=conf)
def test_map_expr_dupe_keys_fallback(): data_gen = [('a', StringGen(nullable=False)), ('b', StringGen(nullable=False))] assert_gpu_fallback_collect(lambda spark: gen_df(spark, data_gen). selectExpr('map("key1", b, "key1", a) as m1'), 'ProjectExec', conf={ 'spark.rapids.sql.createMap.enabled': True, 'spark.sql.mapKeyDedupPolicy': 'LAST_WIN' })
def test_broadcast_nested_loop_join_with_condition_fallback( data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 50, 25) # AST does not support cast or logarithm yet return broadcast(left).join(right, left.a > f.log(right.r_a), join_type) conf = allow_negative_scale_of_decimal_conf assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec', conf=conf)
def test_parquet_fallback(spark_tmp_path, read_func, disable_conf): data_gens =[string_gen, byte_gen, short_gen, int_gen, long_gen, boolean_gen] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/PARQUET_DATA' reader = read_func(data_path) with_cpu_session( lambda spark : gen_df(spark, gen).write.parquet(data_path)) assert_gpu_fallback_collect( lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')), 'FileSourceScanExec', conf={disable_conf: 'false'})
def test_json_read_valid_dates(std_input_path, filename, schema, read_func, ansi_enabled, time_parser_policy, spark_tmp_table_factory): updated_conf = copy_and_update( _enable_all_types_conf, { 'spark.sql.ansi.enabled': ansi_enabled, 'spark.sql.legacy.timeParserPolicy': time_parser_policy }) f = read_func(std_input_path + '/' + filename, schema, spark_tmp_table_factory, {}) if time_parser_policy == 'LEGACY' and ansi_enabled == 'true': assert_gpu_fallback_collect(f, 'FileSourceScanExec', conf=updated_conf) else: assert_gpu_and_cpu_are_equal_collect(f, conf=updated_conf)
def test_csv_fallback(spark_tmp_path, read_func, disable_conf): data_gens = [ StringGen('(\\w| |\t|\ud720){0,10}', nullable=False), byte_gen, short_gen, int_gen, long_gen, boolean_gen, date_gen ] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/CSV_DATA' schema = gen.data_type reader = read_func(data_path, schema, False, ',') with_cpu_session(lambda spark: gen_df(spark, gen).write.csv(data_path)) assert_gpu_fallback_collect( lambda spark: reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')), 'FileSourceScanExec', conf={disable_conf: 'false'})
def test_json_read_invalid_dates(std_input_path, filename, schema, read_func, ansi_enabled, time_parser_policy, spark_tmp_table_factory): updated_conf = copy_and_update( _enable_all_types_conf, { 'spark.sql.ansi.enabled': ansi_enabled, 'spark.sql.legacy.timeParserPolicy': time_parser_policy }) f = read_func(std_input_path + '/' + filename, schema, spark_tmp_table_factory, {}) if time_parser_policy == 'EXCEPTION': assert_gpu_and_cpu_error(df_fun=lambda spark: f(spark).collect(), conf=updated_conf, error_message='DateTimeException') elif time_parser_policy == 'LEGACY' and ansi_enabled == 'true': assert_gpu_fallback_collect(f, 'FileSourceScanExec', conf=updated_conf) else: assert_gpu_and_cpu_are_equal_collect(f, conf=updated_conf)
def test_csv_fallback(spark_tmp_path, read_func, disable_conf): data_gens = [ StringGen('(\\w| |\t|\ud720){0,10}', nullable=False), byte_gen, short_gen, int_gen, long_gen, boolean_gen, date_gen ] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/CSV_DATA' schema = gen.data_type updated_conf = _enable_all_types_conf.copy() updated_conf[disable_conf] = 'false' reader = read_func(data_path, schema) with_cpu_session(lambda spark: gen_df(spark, gen).write.csv(data_path)) assert_gpu_fallback_collect( lambda spark: reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')), # TODO add support for lists cpu_fallback_class_name=get_non_gpu_allowed()[0], conf=updated_conf)
def test_parquet_read_field_id(spark_tmp_path): data_path = spark_tmp_path + '/PARQUET_DATA' schema = StructType([ StructField("c1", IntegerType(), metadata={'parquet.field.id': 1}), ]) data = [ (1, ), (2, ), (3, ), ] # write parquet with field IDs with_cpu_session(lambda spark: spark.createDataFrame(data, schema). coalesce(1).write.mode("overwrite").parquet(data_path)) readSchema = StructType([ StructField("mapped_name_xxx", IntegerType(), metadata={'parquet.field.id': 1}), ]) assert_gpu_fallback_collect( lambda spark: spark.read.schema(readSchema).parquet(data_path), 'FileSourceScanExec', {"spark.sql.parquet.fieldId.read.enabled": "true"}) # default is false
def test_broadcast_join_right_struct_as_key(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 50) return left.join(broadcast(right), left.a == right.r_a, join_type) assert_gpu_fallback_collect(do_join, 'BroadcastHashJoinExec', conf=allow_negative_scale_of_decimal_conf)
def test_cast_string_date_fallback(): assert_gpu_fallback_collect( # Cast back to String because this goes beyond what python can support for years lambda spark : unary_op_df(spark, StringGen('([0-9]|-|\\+){4,12}')).select(f.col('a').cast(DateType()).cast(StringType())), 'Cast')
def test_sort_binary_fallback(data_gen, order): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).orderBy(order), "SortExec")
def test_single_nested_orderby_with_limit_fallback(data_gen, order): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).orderBy(order).limit(100), "TakeOrderedAndProjectExec", conf={'spark.rapids.allowCpuRangePartitioning': False})
def test_single_nested_orderby_fallback_for_nullorder(data_gen, order): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).orderBy(order), "SortExec")
def test_cast_string_timestamp_fallback(): assert_gpu_fallback_collect( # Cast back to String because this goes beyond what python can support for years lambda spark : unary_op_df(spark, StringGen('([0-9]|-|\\+){4,12}')).select(f.col('a').cast(TimestampType()).cast(StringType())), 'Cast', conf = {'spark.rapids.sql.castStringToTimestamp.enabled': 'true'})
def test_transform_keys_last_win_fallback(data_gen): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( 'transform_keys(a, (key, value) -> 1)'), 'TransformKeys', conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'})
def test_array_max_fallback(data_gen): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, ArrayGen(data_gen)).selectExpr( 'array_max(a)'), "ArrayMax")
def test_date_format_f_incompat(data_gen, date_format): # note that we can't support it even with incompatibleDateFormats enabled conf = {"spark.rapids.sql.incompatibleDateFormats.enabled": "true"} assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( "date_format(a, '{}')".format(date_format)), 'ProjectExec', conf)
def test_date_format_maybe(data_gen, date_format): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( "date_format(a, '{}')".format(date_format)), 'ProjectExec')
def test_cast_nested_fallback(data_gen, to_type): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, data_gen).select( f.col('a').cast(to_type)), 'Cast')