def autolog( log_input_examples=False, log_model_signatures=True, log_models=True, disable=False, exclusive=False, ): # pylint: disable=unused-argument """ Enables (or disables) and configures autologging for scikit-learn estimators. **When is autologging performed?** Autologging is performed when you call: - ``estimator.fit()`` - ``estimator.fit_predict()`` - ``estimator.fit_transform()`` **Logged information** **Parameters** - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params`` is called with ``deep=True``. This means when you fit a meta estimator that chains a series of estimators, the parameters of these child estimators are also logged. **Metrics** - A training score obtained by ``estimator.score``. Note that the training score is computed using parameters given to ``fit()``. - Common metrics for classifier: - `precision score`_ .. _precision score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html - `recall score`_ .. _recall score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html - `f1 score`_ .. _f1 score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html - `accuracy score`_ .. _accuracy score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html If the classifier has method ``predict_proba``, we additionally log: - `log loss`_ .. _log loss: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html - `roc auc score`_ .. _roc auc score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html - Common metrics for regressor: - `mean squared error`_ .. _mean squared error: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html - root mean squared error - `mean absolute error`_ .. _mean absolute error: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html - `r2 score`_ .. _r2 score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html **Tags** - An estimator class name (e.g. "LinearRegression"). - A fully qualified estimator class name (e.g. "sklearn.linear_model._base.LinearRegression"). **Artifacts** - An MLflow Model with the :py:mod:`mlflow.sklearn` flavor containing a fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`). The Model also contains the :py:mod:`mlflow.pyfunc` flavor when the scikit-learn estimator defines `predict()`. **How does autologging work for meta estimators?** When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent ``fit()`` calls. **Parameter search** In addition to recording the information discussed above, autologging for parameter search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs with metrics for each set of explored parameters, as well as artifacts and parameters for the best model (if available). **Supported estimators** - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators). - `Pipeline`_ - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) .. _sklearn.utils.all_estimators: https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html .. _Pipeline: https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html .. _GridSearchCV: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html .. _RandomizedSearchCV: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html **Example** `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_ .. code-block:: python from pprint import pprint import numpy as np from sklearn.linear_model import LinearRegression import mlflow def fetch_logged_data(run_id): client = mlflow.tracking.MlflowClient() data = client.get_run(run_id).data tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} artifacts = [f.path for f in client.list_artifacts(run_id, "model")] return data.params, data.metrics, tags, artifacts # enable autologging mlflow.sklearn.autolog() # prepare training data X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) y = np.dot(X, np.array([1, 2])) + 3 # train a model model = LinearRegression() with mlflow.start_run() as run: model.fit(X, y) # fetch logged data params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id) pprint(params) # {'copy_X': 'True', # 'fit_intercept': 'True', # 'n_jobs': 'None', # 'normalize': 'False'} pprint(metrics) # {'training_score': 1.0, 'training_mae': 2.220446049250313e-16, 'training_mse': 1.9721522630525295e-31, 'training_r2_score': 1.0, 'training_rmse': 4.440892098500626e-16} pprint(tags) # {'estimator_class': 'sklearn.linear_model._base.LinearRegression', # 'estimator_name': 'LinearRegression'} pprint(artifacts) # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl'] :param log_input_examples: If ``True``, input examples from training datasets are collected and logged along with scikit-learn model artifacts during training. If ``False``, input examples are not logged. Note: Input examples are MLflow model attributes and are only collected if ``log_models`` is also ``True``. :param log_model_signatures: If ``True``, :py:class:`ModelSignatures <mlflow.models.ModelSignature>` describing model inputs and outputs are collected and logged along with scikit-learn model artifacts during training. If ``False``, signatures are not logged. Note: Model signatures are MLflow model attributes and are only collected if ``log_models`` is also ``True``. :param log_models: If ``True``, trained models are logged as MLflow model artifacts. If ``False``, trained models are not logged. Input examples and model signatures, which are attributes of MLflow models, are also omitted when ``log_models`` is ``False``. :param disable: If ``True``, disables the scikit-learn autologging integration. If ``False``, enables the scikit-learn autologging integration. :param exclusive: If ``True``, autologged content is not logged to user-created fluent runs. If ``False``, autologged content is logged to the active fluent run, which may be user-created. """ import pandas as pd import sklearn from mlflow.models import infer_signature from mlflow.sklearn.utils import ( _MIN_SKLEARN_VERSION, _is_supported_version, _chunk_dict, _get_args_for_score, _log_specialized_estimator_content, _get_Xy, _all_estimators, _truncate_dict, _get_arg_names, _get_estimator_info_tags, _get_meta_estimators_for_autologging, _is_parameter_search_estimator, _log_parameter_search_results_as_artifact, _create_child_runs_for_parameter_search, ) from mlflow.tracking.context import registry as context_registry from mlflow.utils.validation import ( MAX_PARAMS_TAGS_PER_BATCH, MAX_PARAM_VAL_LENGTH, MAX_ENTITY_KEY_LENGTH, ) if not _is_supported_version(): warnings.warn( "Autologging utilities may not work properly on scikit-learn < {} " .format(_MIN_SKLEARN_VERSION) + "(current version: {})".format(sklearn.__version__), stacklevel=2, ) def fit_mlflow(original, self, *args, **kwargs): """ Autologging function that performs model training by executing the training method referred to be `func_name` on the instance of `clazz` referred to by `self` & records MLflow parameters, metrics, tags, and artifacts to a corresponding MLflow Run. """ _log_pretraining_metadata(self, *args, **kwargs) fit_output = original(self, *args, **kwargs) _log_posttraining_metadata(self, *args, **kwargs) return fit_output def _log_pretraining_metadata(estimator, *args, **kwargs): # pylint: disable=unused-argument """ Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ # Deep parameter logging includes parameters from children of a given # estimator. For some meta estimators (e.g., pipelines), recording # these parameters is desirable. For parameter search estimators, # however, child estimators act as seeds for the parameter search # process; accordingly, we avoid logging initial, untuned parameters # for these seed estimators. should_log_params_deeply = not _is_parameter_search_estimator( estimator) # Chunk model parameters to avoid hitting the log_batch API limit for chunk in _chunk_dict( estimator.get_params(deep=should_log_params_deeply), chunk_size=MAX_PARAMS_TAGS_PER_BATCH, ): truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH) try_mlflow_log(mlflow.log_params, truncated) try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator)) def _log_posttraining_metadata(estimator, *args, **kwargs): """ Records metadata for a scikit-learn estimator after training has completed. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ if hasattr(estimator, "score"): try: score_args = _get_args_for_score(estimator.score, estimator.fit, args, kwargs) training_score = estimator.score(*score_args) except Exception as e: msg = ( estimator.score.__qualname__ + " failed. The 'training_score' metric will not be recorded. Scoring error: " + str(e)) _logger.warning(msg) else: try_mlflow_log(mlflow.log_metric, "training_score", training_score) # log common metrics and artifacts for estimators (classifier, regressor) _log_specialized_estimator_content(estimator, mlflow.active_run().info.run_id, args, kwargs) def get_input_example(): # Fetch an input example using the first several rows of the array-like # training data supplied to the training routine (e.g., `fit()`) fit_arg_names = _get_arg_names(estimator.fit) X_var_name, y_var_name = fit_arg_names[:2] input_example = _get_Xy(args, kwargs, X_var_name, y_var_name)[0][:INPUT_EXAMPLE_SAMPLE_ROWS] return input_example def infer_model_signature(input_example): if not hasattr(estimator, "predict"): raise Exception( "the trained model does not specify a `predict` function, " + "which is required in order to infer the signature") return infer_signature(input_example, estimator.predict(input_example)) if log_models: # Will only resolve `input_example` and `signature` if `log_models` is `True`. input_example, signature = resolve_input_example_and_signature( get_input_example, infer_model_signature, log_input_examples, log_model_signatures, _logger, ) try_mlflow_log( log_model, estimator, artifact_path="model", signature=signature, input_example=input_example, ) if _is_parameter_search_estimator(estimator): if hasattr(estimator, "best_estimator_") and log_models: try_mlflow_log( log_model, estimator.best_estimator_, artifact_path="best_estimator", signature=signature, input_example=input_example, ) if hasattr(estimator, "best_score_"): try_mlflow_log(mlflow.log_metric, "best_cv_score", estimator.best_score_) if hasattr(estimator, "best_params_"): best_params = { "best_{param_name}".format(param_name=param_name): param_value for param_name, param_value in estimator.best_params_.items() } try_mlflow_log(mlflow.log_params, best_params) if hasattr(estimator, "cv_results_"): try: # Fetch environment-specific tags (e.g., user and source) to ensure that lineage # information is consistent with the parent run child_tags = context_registry.resolve_tags() child_tags.update({MLFLOW_AUTOLOGGING: FLAVOR_NAME}) _create_child_runs_for_parameter_search( cv_estimator=estimator, parent_run=mlflow.active_run(), child_tags=child_tags, ) except Exception as e: msg = ( "Encountered exception during creation of child runs for parameter search." " Child runs may be missing. Exception: {}".format( str(e))) _logger.warning(msg) try: cv_results_df = pd.DataFrame.from_dict( estimator.cv_results_) _log_parameter_search_results_as_artifact( cv_results_df, mlflow.active_run().info.run_id) except Exception as e: msg = ( "Failed to log parameter search results as an artifact." " Exception: {}".format(str(e))) _logger.warning(msg) def patched_fit(original, self, *args, **kwargs): """ Autologging patch function to be applied to a sklearn model class that defines a `fit` method and inherits from `BaseEstimator` (thereby defining the `get_params()` method) :param clazz: The scikit-learn model class to which this patch function is being applied for autologging (e.g., `sklearn.linear_model.LogisticRegression`) :param func_name: The function name on the specified `clazz` that this patch is overriding for autologging (e.g., specify "fit" in order to indicate that `sklearn.linear_model.LogisticRegression.fit()` is being patched) """ with _SklearnTrainingSession(clazz=self.__class__, allow_children=False) as t: if t.should_log(): return fit_mlflow(original, self, *args, **kwargs) else: return original(self, *args, **kwargs) _, estimators_to_patch = zip(*_all_estimators()) # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected # for patching if they are not already included in the output of `all_estimators()` estimators_to_patch = set(estimators_to_patch).union( set(_get_meta_estimators_for_autologging())) # Exclude certain preprocessing & feature manipulation estimators from patching. These # estimators represent data manipulation routines (e.g., normalization, label encoding) # rather than ML algorithms. Accordingly, we should not create MLflow runs and log # parameters / metrics for these routines, unless they are captured as part of an ML pipeline # (via `sklearn.pipeline.Pipeline`) excluded_module_names = [ "sklearn.preprocessing", "sklearn.impute", "sklearn.feature_extraction", "sklearn.feature_selection", ] estimators_to_patch = [ estimator for estimator in estimators_to_patch if not any([ estimator.__module__.startswith(excluded_module_name) for excluded_module_name in excluded_module_names ]) ] for class_def in estimators_to_patch: for func_name in ["fit", "fit_transform", "fit_predict"]: if hasattr(class_def, func_name): original = getattr(class_def, func_name) # A couple of estimators use property methods to return fitting functions, # rather than defining the fitting functions on the estimator class directly. # # Example: https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/neighbors/_lof.py#L183 # noqa # # We currently exclude these property fitting methods from patching because # it's challenging to patch them correctly. # # Excluded fitting methods: # - sklearn.cluster._agglomerative.FeatureAgglomeration.fit_predict # - sklearn.neighbors._lof.LocalOutlierFactor.fit_predict # # You can list property fitting methods by inserting "print(class_def, func_name)" # in the if clause below. if isinstance(original, property): continue safe_patch( FLAVOR_NAME, class_def, func_name, patched_fit, manage_run=True, )
def autolog(): """ Enables autologging for scikit-learn estimators. **When is autologging performed?** Autologging is performed when you call: - ``estimator.fit()`` - ``estimator.fit_predict()`` - ``estimator.fit_transform()`` **Logged information** **Parameters** - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params`` is called with ``deep=True``. This means when you fit a meta estimator that chains a series of estimators, the parameters of these child estimators are also logged. **Metrics** - A training score obtained by ``estimator.score``. Note that the training score is computed using parameters given to ``fit()``. **Tags** - An estimator class name (e.g. "LinearRegression"). - A fully qualified estimator class name (e.g. "sklearn.linear_model._base.LinearRegression"). **Artifacts** - A fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`). **How does autologging work for meta estimators?** When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent ``fit()`` calls. **Parameter search** In addition to recording the information discussed above, autologging for parameter search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs with metrics for each set of explored parameters, as well as artifacts and parameters for the best model (if available). **Supported estimators** - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators). - `Pipeline`_ - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) .. _sklearn.utils.all_estimators: https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html .. _Pipeline: https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html .. _GridSearchCV: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html .. _RandomizedSearchCV: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html **Example** .. code-block:: python from pprint import pprint import numpy as np import sklearn.linear_model import mlflow # enable autologging mlflow.sklearn.autolog() # prepare training data X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) y = np.dot(X, np.array([1, 2])) + 3 # train a model with mlflow.start_run() as run: reg = sklearn.linear_model.LinearRegression().fit(X, y) def fetch_logged_data(run_id): client = mlflow.tracking.MlflowClient() data = client.get_run(run_id).data tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} artifacts = [f.path for f in client.list_artifacts(run_id, "model")] # fetch logged data params, metrics, tags, artifacts = fetch_logged_data(run._info.run_id) pprint(params) # {'copy_X': 'True', # 'fit_intercept': 'True', # 'n_jobs': 'None', # 'normalize': 'False'} pprint(metrics) # {'training_score': 1.0} pprint(tags) # {'estimator_class': 'sklearn.linear_model._base.LinearRegression', # 'estimator_name': 'LinearRegression'} pprint(artifacts) # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl'] """ import pandas as pd import sklearn from mlflow.models import infer_signature from mlflow.sklearn.utils import ( _MIN_SKLEARN_VERSION, _is_supported_version, _chunk_dict, _get_args_for_score, _get_Xy, _all_estimators, _truncate_dict, _get_arg_names, _get_estimator_info_tags, _get_meta_estimators_for_autologging, _is_parameter_search_estimator, _log_parameter_search_results_as_artifact, _create_child_runs_for_parameter_search, ) from mlflow.tracking.context import registry as context_registry from mlflow.utils.validation import ( MAX_PARAMS_TAGS_PER_BATCH, MAX_PARAM_VAL_LENGTH, MAX_ENTITY_KEY_LENGTH, ) if not _is_supported_version(): warnings.warn( "Autologging utilities may not work properly on scikit-learn < {} ".format( _MIN_SKLEARN_VERSION ) + "(current version: {})".format(sklearn.__version__), stacklevel=2, ) def fit_mlflow(self, func_name, *args, **kwargs): should_start_run = mlflow.active_run() is None if should_start_run: try_mlflow_log(mlflow.start_run) _log_pretraining_metadata(self, *args, **kwargs) original_fit = gorilla.get_original_attribute(self, func_name) try: fit_output = original_fit(*args, **kwargs) except Exception as e: if should_start_run: try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FAILED)) raise e _log_posttraining_metadata(self, *args, **kwargs) if should_start_run: try_mlflow_log(mlflow.end_run) return fit_output def _log_pretraining_metadata(estimator, *args, **kwargs): # pylint: disable=unused-argument """ Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ # Deep parameter logging includes parameters from children of a given # estimator. For some meta estimators (e.g., pipelines), recording # these parameters is desirable. For parameter search estimators, # however, child estimators act as seeds for the parameter search # process; accordingly, we avoid logging initial, untuned parameters # for these seed estimators. should_log_params_deeply = not _is_parameter_search_estimator(estimator) # Chunk model parameters to avoid hitting the log_batch API limit for chunk in _chunk_dict( estimator.get_params(deep=should_log_params_deeply), chunk_size=MAX_PARAMS_TAGS_PER_BATCH, ): truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH) try_mlflow_log(mlflow.log_params, truncated) try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator)) def _log_posttraining_metadata(estimator, *args, **kwargs): """ Records metadata for a scikit-learn estimator after training has completed. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ if hasattr(estimator, "score"): try: score_args = _get_args_for_score(estimator.score, estimator.fit, args, kwargs) training_score = estimator.score(*score_args) except Exception as e: # pylint: disable=broad-except msg = ( estimator.score.__qualname__ + " failed. The 'training_score' metric will not be recorded. Scoring error: " + str(e) ) _logger.warning(msg) else: try_mlflow_log(mlflow.log_metric, "training_score", training_score) input_example = None signature = None if hasattr(estimator, "predict"): try: # Fetch an input example using the first several rows of the array-like # training data supplied to the training routine (e.g., `fit()`) SAMPLE_ROWS = 5 fit_arg_names = _get_arg_names(estimator.fit) X_var_name, y_var_name = fit_arg_names[:2] input_example = _get_Xy(args, kwargs, X_var_name, y_var_name)[0][:SAMPLE_ROWS] model_output = estimator.predict(input_example) signature = infer_signature(input_example, model_output) except Exception as e: # pylint: disable=broad-except input_example = None msg = "Failed to infer an input example and model signature: " + str(e) _logger.warning(msg) try_mlflow_log( log_model, estimator, artifact_path="model", signature=signature, input_example=input_example, ) if _is_parameter_search_estimator(estimator): if hasattr(estimator, "best_estimator_"): try_mlflow_log(log_model, estimator.best_estimator_, artifact_path="best_estimator") if hasattr(estimator, "best_params_"): best_params = { "best_{param_name}".format(param_name=param_name): param_value for param_name, param_value in estimator.best_params_.items() } try_mlflow_log(mlflow.log_params, best_params) if hasattr(estimator, "cv_results_"): try: # Fetch environment-specific tags (e.g., user and source) to ensure that lineage # information is consistent with the parent run environment_tags = context_registry.resolve_tags() _create_child_runs_for_parameter_search( cv_estimator=estimator, parent_run=mlflow.active_run(), child_tags=environment_tags, ) except Exception as e: # pylint: disable=broad-except msg = ( "Encountered exception during creation of child runs for parameter search." " Child runs may be missing. Exception: {}".format(str(e)) ) _logger.warning(msg) try: cv_results_df = pd.DataFrame.from_dict(estimator.cv_results_) _log_parameter_search_results_as_artifact( cv_results_df, mlflow.active_run().info.run_id ) except Exception as e: # pylint: disable=broad-except msg = ( "Failed to log parameter search results as an artifact." " Exception: {}".format(str(e)) ) _logger.warning(msg) def patched_fit(self, func_name, *args, **kwargs): """ To be applied to a sklearn model class that defines a `fit` method and inherits from `BaseEstimator` (thereby defining the `get_params()` method) """ with _SklearnTrainingSession(clazz=self.__class__, allow_children=False) as t: if t.should_log(): return fit_mlflow(self, func_name, *args, **kwargs) else: original_fit = gorilla.get_original_attribute(self, func_name) return original_fit(*args, **kwargs) def create_patch_func(func_name): def f(self, *args, **kwargs): return patched_fit(self, func_name, *args, **kwargs) return f patch_settings = gorilla.Settings(allow_hit=True, store_hit=True) _, estimators_to_patch = zip(*_all_estimators()) # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected # for patching if they are not already included in the output of `all_estimators()` estimators_to_patch = set(estimators_to_patch).union( set(_get_meta_estimators_for_autologging()) ) for class_def in estimators_to_patch: for func_name in ["fit", "fit_transform", "fit_predict"]: if hasattr(class_def, func_name): original = getattr(class_def, func_name) # A couple of estimators use property methods to return fitting functions, # rather than defining the fitting functions on the estimator class directly. # # Example: https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/neighbors/_lof.py#L183 # noqa # # We currently exclude these property fitting methods from patching because # it's challenging to patch them correctly. # # Excluded fitting methods: # - sklearn.cluster._agglomerative.FeatureAgglomeration.fit_predict # - sklearn.neighbors._lof.LocalOutlierFactor.fit_predict # # You can list property fitting methods by inserting "print(class_def, func_name)" # in the if clause below. if isinstance(original, property): continue patch_func = create_patch_func(func_name) # TODO(harupy): Package this wrap & patch routine into a utility function so we can # reuse it in other autologging integrations. # preserve original function attributes patch_func = functools.wraps(original)(patch_func) patch = gorilla.Patch(class_def, func_name, patch_func, settings=patch_settings) gorilla.apply(patch)