def _listen_for_spark_activity(spark_context): global _spark_table_info_listener if _get_current_listener() is not None: return if _get_spark_major_version(spark_context) < 3: raise MlflowException( "Spark autologging unsupported for Spark versions < 3") gw = spark_context._gateway params = gw.callback_server_parameters callback_server_params = CallbackServerParameters( address=params.address, port=params.port, daemonize=True, daemonize_connections=True, eager_load=params.eager_load, ssl_context=params.ssl_context, accept_timeout=params.accept_timeout, read_timeout=params.read_timeout, auth_token=params.auth_token, ) callback_server_started = gw.start_callback_server(callback_server_params) try: event_publisher = _get_jvm_event_publisher() event_publisher.init(1) _spark_table_info_listener = PythonSubscriber() event_publisher.register(_spark_table_info_listener) except Exception as e: if callback_server_started: try: gw.shutdown_callback_server() except Exception as e: _logger.warning( "Failed to shut down Spark callback server for autologging: %s", str(e)) _spark_table_info_listener = None raise MlflowException( "Exception while attempting to initialize JVM-side state for " "Spark datasource autologging. Please create a new Spark session " "and ensure you have the mlflow-spark JAR attached to your Spark " "session as described in " "http://mlflow.org/docs/latest/tracking.html#" "automatic-logging-from-spark-experimental. " "Exception:\n%s" % e) # Register context provider for Spark autologging from mlflow.tracking.context.registry import _run_context_provider_registry _run_context_provider_registry.register(SparkAutologgingContext) _logger.info("Autologging successfully enabled for spark.")
def autolog(): """Implementation of Spark datasource autologging""" global _spark_table_info_listener if _get_current_listener() is None: active_session = _get_active_spark_session() if active_session is None: raise MlflowException( "No active SparkContext found, refusing to enable Spark datasource " "autologging. Please create a SparkSession e.g. via " "SparkSession.builder.getOrCreate() (see API docs at " "https://spark.apache.org/docs/latest/api/python/" "pyspark.sql.html#pyspark.sql.SparkSession) " "before attempting to enable autologging" ) # We know SparkContext exists here already, so get it sc = SparkContext.getOrCreate() if _get_spark_major_version(sc) < 3: raise MlflowException("Spark autologging unsupported for Spark versions < 3") gw = active_session.sparkContext._gateway params = gw.callback_server_parameters callback_server_params = CallbackServerParameters( address=params.address, port=params.port, daemonize=True, daemonize_connections=True, eager_load=params.eager_load, ssl_context=params.ssl_context, accept_timeout=params.accept_timeout, read_timeout=params.read_timeout, auth_token=params.auth_token, ) gw.start_callback_server(callback_server_params) event_publisher = _get_jvm_event_publisher() try: event_publisher.init(1) _spark_table_info_listener = PythonSubscriber() _spark_table_info_listener.register() except Exception as e: raise MlflowException( "Exception while attempting to initialize JVM-side state for " "Spark datasource autologging. Please ensure you have the " "mlflow-spark JAR attached to your Spark session as described " "in http://mlflow.org/docs/latest/tracking.html#" "automatic-logging-from-spark-experimental. Exception:\n%s" % e ) # Register context provider for Spark autologging from mlflow.tracking.context.registry import _run_context_provider_registry _run_context_provider_registry.register(SparkAutologgingContext)