def test_enabling_autologging_before_spark_session_works(disable):
    mlflow.spark.autolog(disable=disable)

    # creating spark session AFTER autolog was enabled
    spark_session = _get_or_create_spark_session()

    rows = [Row(100)]
    schema = StructType([StructField("number2", IntegerType())])
    rdd = spark_session.sparkContext.parallelize(rows)
    df = spark_session.createDataFrame(rdd, schema)
    tempdir = tempfile.mkdtemp()
    filepath = os.path.join(tempdir, "test-data")
    df.write.option("header", "true").format("csv").save(filepath)

    read_df = (spark_session.read.format("csv").option(
        "header", "true").option("inferSchema", "true").load(filepath))

    with mlflow.start_run():
        run_id = mlflow.active_run().info.run_id
        read_df.collect()
        time.sleep(1)

    run = mlflow.get_run(run_id)
    if disable:
        _assert_spark_data_not_logged(run=run)
    else:
        _assert_spark_data_logged(run=run, path=filepath, data_format="csv")

    shutil.rmtree(tempdir)
    spark_session.stop()
def test_enabling_autologging_throws_for_missing_jar():
    # pylint: disable=unused-argument
    spark_session = _get_or_create_spark_session(jars="")
    try:
        with pytest.raises(MlflowException) as exc:
            kiwi.spark.autolog()
        assert "Please ensure you have the mlflow-spark JAR attached" in exc.value.message
    finally:
        spark_session.stop()
def spark_session():
    session = _get_or_create_spark_session()
    yield session
    session.stop()