def test_autologging_of_datasources_with_different_formats( spark_session, format_to_file_path): mlflow.spark.autolog() for data_format, file_path in format_to_file_path.items(): base_df = (spark_session.read.format(data_format).option( "header", "true").option("inferSchema", "true").load(file_path)) base_df.createOrReplaceTempView("temptable") table_df0 = spark_session.table("temptable") table_df1 = spark_session.sql( "SELECT number1, number2 from temptable LIMIT 5") dfs = [ base_df, table_df0, table_df1, base_df.filter("number1 > 0"), base_df.select("number1"), base_df.limit(2), base_df.filter("number1 > 0").select("number1").limit(2), ] for df in dfs: with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_logged(run=run, path=file_path, data_format=data_format)
def test_autologging_multiple_reads_same_run(spark_session, format_to_file_path): mlflow.spark.autolog() with mlflow.start_run(): for data_format, file_path in format_to_file_path.items(): run_id = mlflow.active_run().info.run_id df = spark_session.read.format(data_format).load(file_path) df.collect() time.sleep(1) run = mlflow.get_run(run_id) assert _SPARK_TABLE_INFO_TAG_NAME in run.data.tags table_info_tag = run.data.tags[_SPARK_TABLE_INFO_TAG_NAME] assert table_info_tag == "\n".join([ _get_expected_table_info_row(path, data_format) for data_format, path in format_to_file_path.items() ])
def test_autologging_slow_api_requests(spark_session, format_to_file_path): import mlflow.utils.rest_utils orig = mlflow.utils.rest_utils.http_request def _slow_api_req_mock(*args, **kwargs): if kwargs.get("method") == "POST": print("Sleeping, %s, %s" % (args, kwargs)) time.sleep(1) return orig(*args, **kwargs) mlflow.spark.autolog() with mlflow.start_run(): # Mock slow API requests to log Spark datasource information with mock.patch( "mlflow.utils.rest_utils.http_request") as http_request_mock: http_request_mock.side_effect = _slow_api_req_mock run_id = mlflow.active_run().info.run_id for data_format, file_path in format_to_file_path.items(): df = (spark_session.read.format(data_format).option( "header", "true").option("inferSchema", "true").load(file_path)) df.collect() # Sleep a bit prior to ending the run to guarantee that the Python process can pick up on # datasource read events (simulate the common case of doing work, e.g. model training, # on the DataFrame after reading from it) time.sleep(1) # Python subscriber threads should pick up the active run at the time they're notified # & make API requests against that run, even if those requests are slow. time.sleep(5) run = mlflow.get_run(run_id) assert _SPARK_TABLE_INFO_TAG_NAME in run.data.tags table_info_tag = run.data.tags[_SPARK_TABLE_INFO_TAG_NAME] assert table_info_tag == "\n".join([ _get_expected_table_info_row(path, data_format) for data_format, path in format_to_file_path.items() ])
def test_autologging_disabled_logging_datasource_with_different_formats( spark_session, format_to_file_path ): mlflow.spark.autolog(disable=True) for data_format, file_path in format_to_file_path.items(): df = ( spark_session.read.format(data_format) .option("header", "true") .option("inferSchema", "true") .load(file_path) ) with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_not_logged(run=run)