Esempio n. 1
0
def test_map_element_at_ansi_fail(data_gen):
    message = "org.apache.spark.SparkNoSuchElementException" if (not is_before_spark_330() or is_databricks104_or_later()) else "java.util.NoSuchElementException"
    # For 3.3.0+ strictIndexOperator should not affect element_at
    test_conf=copy_and_update(ansi_enabled_conf, {'spark.sql.ansi.strictIndexOperator': 'false'})
    assert_gpu_and_cpu_error(
            lambda spark: unary_op_df(spark, data_gen).selectExpr(
                'element_at(a, "NOT_FOUND")').collect(),
                conf=test_conf,
                error_message=message)
Esempio n. 2
0
def test_array_element_at_ansi_fail_invalid_index(index):
    message = "ArrayIndexOutOfBoundsException" if is_before_spark_330() else "SparkArrayIndexOutOfBoundsException"
    if isinstance(index, int):
        test_func = lambda spark: unary_op_df(spark, ArrayGen(int_gen)).select(
            element_at(col('a'), index)).collect()
    else:
        test_func = lambda spark: two_col_df(spark, ArrayGen(int_gen), index).selectExpr(
            'element_at(a, b)').collect()
    # For 3.3.0+ strictIndexOperator should not affect element_at
    test_conf=copy_and_update(ansi_enabled_conf, {'spark.sql.ansi.strictIndexOperator': 'false'})
    assert_gpu_and_cpu_error(
        test_func,
        conf=test_conf,
        error_message=message)
Esempio n. 3
0
@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_timeadd(data_gen):
    days, seconds = data_gen
    assert_gpu_and_cpu_are_equal_collect(
        # We are starting at year 0005 to make sure we don't go before year 0001
        # and beyond year 10000 while doing TimeAdd
        lambda spark: unary_op_df(
            spark,
            TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc),
                         end=datetime(15, 1, 1, tzinfo=timezone.utc)),
            seed=1).selectExpr("a + (interval {} days {} seconds)".format(
                days, seconds)))


@pytest.mark.skipif(
    is_before_spark_330(),
    reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_timeadd_daytime_column():
    gen_list = [
        # timestamp column max year is 1000
        ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))),
        # max days is 8000 year, so added result will not be out of range
        ('d',
         DayTimeIntervalGen(min_value=timedelta(days=0),
                            max_value=timedelta(days=8000 * 365)))
    ]
    assert_gpu_and_cpu_are_equal_collect(
        lambda spark: gen_df(spark, gen_list).selectExpr(
            "t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))

Esempio n. 4
0
    all_confs = reader_confs.copy()
    all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list,
        'spark.sql.caseSensitive': case_sensitive})
    # This is a hack to get the type in a slightly less verbose way
    extra_struct_gen = StructGen([('nested_col', orc_gen), ("nested_non_existing", orc_gen)])
    extra_gen_list = [("top_pri", orc_gen),
                      ("top_non_existing_mid", orc_gen),
                      ("TOP_AR", ArrayGen(extra_struct_gen, max_length=10)),
                      ("top_ST", extra_struct_gen),
                      ("top_non_existing_end", orc_gen)]
    rs = StructGen(extra_gen_list, nullable=False).data_type
    assert_gpu_and_cpu_are_equal_collect(
            lambda spark : spark.read.schema(rs).orc(data_path),
            conf=all_confs)

@pytest.mark.skipif(is_before_spark_330(), reason='Hidden file metadata columns are a new feature of Spark 330')
@allow_non_gpu(any = True)
@pytest.mark.parametrize('metadata_column', ["file_path", "file_name", "file_size", "file_modification_time"])
def test_orc_scan_with_hidden_metadata_fallback(spark_tmp_path, metadata_column):
    data_path = spark_tmp_path + "/hidden_metadata.orc"
    with_cpu_session(lambda spark : spark.range(10) \
                     .selectExpr("id", "id % 3 as p") \
                     .write \
                     .partitionBy("p") \
                     .mode("overwrite") \
                     .orc(data_path))

    def do_orc_scan(spark):
        df = spark.read.orc(data_path).selectExpr("id", "_metadata.{}".format(metadata_column))
        return df
Esempio n. 5
0
# Merged "test_nested_array_item" with this one since arrays as literals is supported
@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
def test_array_item(data_gen):
    assert_gpu_and_cpu_are_equal_collect(
        lambda spark: two_col_df(spark, data_gen, array_index_gen).selectExpr(
            'a[b]',
            'a[0]',
            'a[1]',
            'a[null]',
            'a[3]',
            'a[50]',
            'a[-1]'))


# No need to test this for multiple data types for array. Only one is enough
@pytest.mark.skipif(is_before_spark_330(), reason="'strictIndexOperator' is introduced from Spark 3.3.0")
@pytest.mark.parametrize('strict_index_enabled', [True, False])
@pytest.mark.parametrize('index', [-2, 100, array_neg_index_gen, array_out_index_gen], ids=idfn)
def test_array_item_with_strict_index(strict_index_enabled, index):
    message = "SparkArrayIndexOutOfBoundsException"
    if isinstance(index, int):
        test_df = lambda spark: unary_op_df(spark, ArrayGen(int_gen)).select(col('a')[index])
    else:
        test_df = lambda spark: two_col_df(spark, ArrayGen(int_gen), index).selectExpr('a[b]')

    test_conf=copy_and_update(
        ansi_enabled_conf, {'spark.sql.ansi.strictIndexOperator': strict_index_enabled})

    if strict_index_enabled:
        assert_gpu_and_cpu_error(
            lambda spark: test_df(spark).collect(),
Esempio n. 6
0
@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn)
def test_write_empty_parquet_round_trip(spark_tmp_path, parquet_gens):
    def create_empty_df(spark, path):
        gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
        return gen_df(spark, gen_list, length=0).write.parquet(path)

    data_path = spark_tmp_path + '/PARQUET_DATA'
    assert_gpu_and_cpu_writes_are_equal_collect(
        create_empty_df,
        lambda spark, path: spark.read.parquet(path),
        data_path,
        conf=writer_confs)


# should fallback when trying to write field ID metadata
@pytest.mark.skipif(is_before_spark_330(),
                    reason='Field ID is not supported before Spark 330')
@allow_non_gpu('DataWritingCommandExec')
def test_parquet_write_field_id(spark_tmp_path):
    data_path = spark_tmp_path + '/PARQUET_DATA'
    schema = StructType([
        StructField("c1", IntegerType(), metadata={'parquet.field.id': 1}),
    ])
    data = [
        (1, ),
        (2, ),
        (3, ),
    ]
    assert_gpu_fallback_write(
        lambda spark, path: spark.createDataFrame(data, schema).coalesce(
            1).write.mode("overwrite").parquet(path),
Esempio n. 7
0
         "spark.rapids.sql.castFloatToString.enabled"       : "true", 
         "spark.sql.legacy.castComplexTypesToString.enabled": legacy}
    )


# The bug SPARK-37451 only affects the following versions
def is_neg_dec_scale_bug_version():
    return ("3.1.1" <= spark_version() < "3.1.3") or ("3.2.0" <= spark_version() < "3.2.1")

@pytest.mark.skipif(is_neg_dec_scale_bug_version(), reason="RAPIDS doesn't support casting string to decimal for negative scale decimal in this version of Spark because of SPARK-37451")
def test_cast_string_to_negative_scale_decimal():
    assert_gpu_and_cpu_are_equal_collect(
        lambda spark: unary_op_df(spark, StringGen("[0-9]{9}")).select(
            f.col('a').cast(DecimalType(8, -3))))

@pytest.mark.skipif(is_before_spark_330(), reason="ansi cast throws exception only in 3.3.0+")
@pytest.mark.parametrize('type', [DoubleType(), FloatType()])
def test_cast_double_to_timestamp(type):
    def fun(spark):
        data=[float("inf"),float("-inf"),float("nan")]
        df = spark.createDataFrame(data, DoubleType())
        return df.select(f.col('value').cast(TimestampType())).collect()
    assert_gpu_and_cpu_error(fun, {"spark.sql.ansi.enabled": True}, "java.time.DateTimeException")

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_cast_day_time_interval_to_string():
    _assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='day', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
    _assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='hour', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
    _assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='minute', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
    _assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
    _assert_cast_to_string_equal(DayTimeIntervalGen(start_field='hour', end_field='hour', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})