예제 #1
0
def test_autologging_of_datasources_with_different_formats(
        spark_session, format_to_file_path):
    mlflow.spark.autolog()
    for data_format, file_path in format_to_file_path.items():
        base_df = (spark_session.read.format(data_format).option(
            "header", "true").option("inferSchema", "true").load(file_path))
        base_df.createOrReplaceTempView("temptable")
        table_df0 = spark_session.table("temptable")
        table_df1 = spark_session.sql(
            "SELECT number1, number2 from temptable LIMIT 5")
        dfs = [
            base_df,
            table_df0,
            table_df1,
            base_df.filter("number1 > 0"),
            base_df.select("number1"),
            base_df.limit(2),
            base_df.filter("number1 > 0").select("number1").limit(2),
        ]

        for df in dfs:
            with mlflow.start_run():
                run_id = mlflow.active_run().info.run_id
                df.collect()
                time.sleep(1)
            run = mlflow.get_run(run_id)
            _assert_spark_data_logged(run=run,
                                      path=file_path,
                                      data_format=data_format)
예제 #2
0
def test_autologging_multiple_reads_same_run(spark_session,
                                             format_to_file_path):
    mlflow.spark.autolog()
    with mlflow.start_run():
        for data_format, file_path in format_to_file_path.items():
            run_id = mlflow.active_run().info.run_id
            df = spark_session.read.format(data_format).load(file_path)
            df.collect()
            time.sleep(1)
        run = mlflow.get_run(run_id)
        assert _SPARK_TABLE_INFO_TAG_NAME in run.data.tags
        table_info_tag = run.data.tags[_SPARK_TABLE_INFO_TAG_NAME]
        assert table_info_tag == "\n".join([
            _get_expected_table_info_row(path, data_format)
            for data_format, path in format_to_file_path.items()
        ])
예제 #3
0
def test_autologging_slow_api_requests(spark_session, format_to_file_path):
    import mlflow.utils.rest_utils

    orig = mlflow.utils.rest_utils.http_request

    def _slow_api_req_mock(*args, **kwargs):
        if kwargs.get("method") == "POST":
            print("Sleeping, %s, %s" % (args, kwargs))
            time.sleep(1)
        return orig(*args, **kwargs)

    mlflow.spark.autolog()
    with mlflow.start_run():
        # Mock slow API requests to log Spark datasource information
        with mock.patch(
                "mlflow.utils.rest_utils.http_request") as http_request_mock:
            http_request_mock.side_effect = _slow_api_req_mock
            run_id = mlflow.active_run().info.run_id
            for data_format, file_path in format_to_file_path.items():
                df = (spark_session.read.format(data_format).option(
                    "header", "true").option("inferSchema",
                                             "true").load(file_path))
                df.collect()
        # Sleep a bit prior to ending the run to guarantee that the Python process can pick up on
        # datasource read events (simulate the common case of doing work, e.g. model training,
        # on the DataFrame after reading from it)
        time.sleep(1)

    # Python subscriber threads should pick up the active run at the time they're notified
    # & make API requests against that run, even if those requests are slow.
    time.sleep(5)
    run = mlflow.get_run(run_id)
    assert _SPARK_TABLE_INFO_TAG_NAME in run.data.tags
    table_info_tag = run.data.tags[_SPARK_TABLE_INFO_TAG_NAME]
    assert table_info_tag == "\n".join([
        _get_expected_table_info_row(path, data_format)
        for data_format, path in format_to_file_path.items()
    ])
예제 #4
0
def test_autologging_disabled_logging_datasource_with_different_formats(
    spark_session, format_to_file_path
):
    mlflow.spark.autolog(disable=True)
    for data_format, file_path in format_to_file_path.items():
        df = (
            spark_session.read.format(data_format)
            .option("header", "true")
            .option("inferSchema", "true")
            .load(file_path)
        )

        with mlflow.start_run():
            run_id = mlflow.active_run().info.run_id
            df.collect()
            time.sleep(1)
        run = mlflow.get_run(run_id)
        _assert_spark_data_not_logged(run=run)