def _log_pretraining_metadata(estimator, *args, **kwargs): # pylint: disable=unused-argument """ Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ # Deep parameter logging includes parameters from children of a given # estimator. For some meta estimators (e.g., pipelines), recording # these parameters is desirable. For parameter search estimators, # however, child estimators act as seeds for the parameter search # process; accordingly, we avoid logging initial, untuned parameters # for these seed estimators. should_log_params_deeply = not _is_parameter_search_estimator( estimator) # Chunk model parameters to avoid hitting the log_batch API limit for chunk in _chunk_dict( estimator.get_params(deep=should_log_params_deeply), chunk_size=MAX_PARAMS_TAGS_PER_BATCH, ): truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH) try_mlflow_log(mlflow.log_params, truncated) try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator))
def fit_mlflow(self, func_name, *args, **kwargs): should_start_run = mlflow.active_run() is None if should_start_run: try_mlflow_log(mlflow.start_run) # TODO: We should not log nested estimator parameters for # parameter search estimators (GridSearchCV, RandomizedSearchCV) # Chunk and truncate model parameters to avoid hitting the log_batch API limit for chunk in _chunk_dict(self.get_params(deep=True), chunk_size=MAX_PARAMS_TAGS_PER_BATCH): truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH) try_mlflow_log(mlflow.log_params, truncated) try_mlflow_log( mlflow.set_tags, { "estimator_name": self.__class__.__name__, "estimator_class": self.__class__.__module__ + "." + self.__class__.__name__, }, ) original_fit = gorilla.get_original_attribute(self, func_name) try: fit_output = original_fit(*args, **kwargs) except Exception as e: if should_start_run: try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FAILED)) raise e if hasattr(self, "score"): try: score_args = _get_args_for_score(self.score, self.fit, args, kwargs) training_score = self.score(*score_args) except Exception as e: # pylint: disable=broad-except msg = ( self.score.__qualname__ + " failed. The 'training_score' metric will not be recorded. Scoring error: " + str(e)) _logger.warning(msg) else: try_mlflow_log(mlflow.log_metric, "training_score", training_score) try_mlflow_log(log_model, self, artifact_path="model") if should_start_run: try_mlflow_log(mlflow.end_run) return fit_output