Пример #1
0
def init_spark_session(conf=None, app_name="rikai", rikai_version=None):
    from pyspark.sql import SparkSession

    if not rikai_version:
        rikai_version = get_default_jar_version(use_snapshot=True)
    builder = (
        SparkSession.builder.appName(app_name)
        .config(
            "spark.jars.packages", "ai.eto:rikai_2.12:{}".format(rikai_version)
        )
        .config(
            "spark.sql.extensions",
            "ai.eto.rikai.sql.spark.RikaiSparkSessionExtensions",
        )
        .config(
            "spark.driver.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        )
        .config(
            "spark.executor.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        )
    )
    conf = conf or {}
    for k, v in conf.items():
        builder = builder.config(k, v)
    session = builder.master("local[2]").getOrCreate()
    init(session)
    return session
Пример #2
0
def spark(tmp_path_factory) -> SparkSession:
    version = get_default_jar_version(use_snapshot=True)
    session = (
        SparkSession.builder.appName("spark-test")
        .config("spark.jars.packages", f"ai.eto:rikai_2.12:{version}")
        .config(
            "spark.sql.extensions",
            "ai.eto.rikai.sql.spark.RikaiSparkSessionExtensions",
        )
        .config(
            "rikai.sql.ml.registry.test.impl",
            "ai.eto.rikai.sql.model.testing.TestRegistry",
        )
        .config(
            "rikai.sql.ml.registry.file.impl",
            "ai.eto.rikai.sql.model.fs.FileSystemRegistry",
        )
        .config(
            "rikai.sql.ml.registry.mlflow.impl",
            "ai.eto.rikai.sql.model.mlflow.MlflowRegistry",
        )
        .config(
            "spark.driver.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        )
        .config(
            "spark.executor.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        )
        .master("local[2]")
        .getOrCreate()
    )
    init(session)
    return session
Пример #3
0
def test_model_codegen_registered(spark: SparkSession):
    init(spark, True)

    spark.sql(
        """CREATE MODEL foo_dynamic OPTIONS (foo="str",bar=True,max_score=1.23)
         USING 'test://model/a/b/c'"""
    ).count()

    init(spark, False)

    spark.sql(
        """CREATE MODEL foo_static OPTIONS (foo="str",bar=True,max_score=1.23)
         USING 'test://model/a/b/c'"""
    ).count()
Пример #4
0
def spark() -> SparkSession:
    session = (SparkSession.builder.appName("spark-test").config(
        "spark.jars.packages", "ai.eto:rikai_2.12:0.0.3-SNAPSHOT").config(
            "spark.sql.extensions",
            "ai.eto.rikai.sql.spark.RikaiSparkSessionExtensions",
        ).config(
            "rikai.sql.ml.registry.test.impl",
            "ai.eto.rikai.sql.model.testing.TestRegistry",
        ).config(
            "rikai.sql.ml.registry.file.impl",
            "ai.eto.rikai.sql.model.fs.FileSystemRegistry",
        ).config(
            "spark.driver.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        ).config(
            "spark.executor.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        ).master("local[2]").getOrCreate())
    init(session)
    return session