def test_autologging_dedups_multiple_reads_of_same_datasource( spark_session, format_to_file_path): mlflow.spark.autolog() data_format = list(format_to_file_path.keys())[0] file_path = format_to_file_path[data_format] df = (spark_session.read.format(data_format).option( "header", "true").option("inferSchema", "true").load(file_path)) with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.collect() df.filter("number1 > 0").collect() df.limit(2).collect() df.collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_logged(run=run, path=file_path, data_format=data_format) # Test context provider flow df.filter("number1 > 0").collect() df.limit(2).collect() df.collect() with mlflow.start_run(): run_id2 = mlflow.active_run().info.run_id time.sleep(1) run2 = mlflow.get_run(run_id2) _assert_spark_data_logged(run=run2, path=file_path, data_format=data_format)
def test_autologging_disabled_then_enabled(spark_session, format_to_file_path): mlflow.spark.autolog(disable=True) data_format = list(format_to_file_path.keys())[0] file_path = format_to_file_path[data_format] df = ( spark_session.read.format(data_format) .option("header", "true") .option("inferSchema", "true") .load(file_path) ) # Logging is disabled here. with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_not_logged(run=run) # Logging is enabled here. mlflow.spark.autolog(disable=False) with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.filter("number1 > 0").collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_logged(run=run, path=file_path, data_format=data_format)
def test_autologging_disabled_logging_with_or_without_active_run( spark_session, format_to_file_path ): mlflow.spark.autolog(disable=True) data_format = list(format_to_file_path.keys())[0] file_path = format_to_file_path[data_format] df = ( spark_session.read.format(data_format) .option("header", "true") .option("inferSchema", "true") .load(file_path) ) # Reading data source before starting a run df.filter("number1 > 0").collect() df.limit(2).collect() df.collect() # If there was any tag info collected it will be logged here with mlflow.start_run(): run_id = mlflow.active_run().info.run_id time.sleep(1) # Confirm nothing was logged. run = mlflow.get_run(run_id) _assert_spark_data_not_logged(run=run) # Reading data source during an active run with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_not_logged(run=run)
def test_autologging_does_not_start_run(spark_session, format_to_file_path): try: mlflow.spark.autolog() data_format = list(format_to_file_path.keys())[0] file_path = format_to_file_path[data_format] df = (spark_session.read.format(data_format).option( "header", "true").option("inferSchema", "true").load(file_path)) df.collect() time.sleep(1) active_run = mlflow.active_run() assert active_run is None assert len(mlflow.search_runs()) == 0 finally: mlflow.end_run()
def test_autologging_multiple_runs_same_data(spark_session, format_to_file_path): mlflow.spark.autolog() data_format = list(format_to_file_path.keys())[0] file_path = format_to_file_path[data_format] df = (spark_session.read.format(data_format).option( "header", "true").option("inferSchema", "true").load(file_path)) df.collect() for _ in range(2): with mlflow.start_run(): time.sleep(1) run_id = mlflow.active_run().info.run_id run = mlflow.get_run(run_id) _assert_spark_data_logged(run=run, path=file_path, data_format=data_format)
def test_autologging_does_not_throw_on_api_failures(spark_session, format_to_file_path, mlflow_client): # pylint: disable=unused-argument mlflow.spark.autolog() def failing_req_mock(*args, **kwargs): raise Exception("API request failed!") with mlflow.start_run(): with mock.patch( 'mlflow.utils.rest_utils.http_request') as http_request_mock: http_request_mock.side_effect = failing_req_mock data_format = list(format_to_file_path.keys())[0] file_path = format_to_file_path[data_format] df = spark_session.read.format(data_format).option("header", "true"). \ option("inferSchema", "true").load(file_path) df.collect() df.filter("number1 > 0").collect() df.limit(2).collect() df.collect() time.sleep(1)