def test_get_autologging_config_returns_configured_values_or_defaults_as_expected( ): assert get_autologging_config("nonexistent_integration", "foo") is None @autologging_integration("test_integration_for_config") def autolog(foo="bar", t=7, disable=False): pass # Before `autolog()` has been invoked, config values should not be available assert get_autologging_config("test_integration_for_config", "foo") is None assert get_autologging_config("test_integration_for_config", "disable") is None assert get_autologging_config("test_integration_for_config", "t", 10) == 10 autolog() assert get_autologging_config("test_integration_for_config", "foo") == "bar" assert get_autologging_config("test_integration_for_config", "disable") is False assert get_autologging_config("test_integration_for_config", "t", 10) == 7 assert get_autologging_config("test_integration_for_config", "nonexistent") is None autolog(foo="baz") assert get_autologging_config("test_integration_for_config", "foo") == "baz"
def test_autologging_integrations_expose_configs_and_support_disablement(integration): for integration in AUTOLOGGING_INTEGRATIONS_TO_TEST: integration.autolog(disable=False) assert not autologging_is_disabled(integration.FLAVOR_NAME) assert not get_autologging_config(integration.FLAVOR_NAME, "disable", True) integration.autolog(disable=True) assert autologging_is_disabled(integration.FLAVOR_NAME) assert get_autologging_config(integration.FLAVOR_NAME, "disable", False)
def test_autolog_obeys_disabled(): from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS mlflow.autolog(disable=True) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert get_autologging_config("sklearn", "disable") mlflow.autolog() mlflow.utils.import_hooks.notify_module_loaded(sklearn) mlflow.autolog(disable=True) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert get_autologging_config("sklearn", "disable") mlflow.autolog(disable=False) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert not get_autologging_config("sklearn", "disable") mlflow.sklearn.autolog(disable=True) assert get_autologging_config("sklearn", "disable") AUTOLOGGING_INTEGRATIONS.clear() mlflow.autolog(disable_for_unsupported_versions=False) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert not get_autologging_config("sklearn", "disable_for_unsupported_versions") mlflow.autolog(disable_for_unsupported_versions=True) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert get_autologging_config("sklearn", "disable_for_unsupported_versions") mlflow.sklearn.autolog(disable_for_unsupported_versions=False) assert not get_autologging_config("sklearn", "disable_for_unsupported_versions") mlflow.sklearn.autolog(disable_for_unsupported_versions=True) assert get_autologging_config("sklearn", "disable_for_unsupported_versions")
def wrapper_fit(original, self, *args, **kwargs): should_autolog = False if AutologHelpers.should_autolog: AutologHelpers.should_autolog = False should_autolog = True try: if should_autolog: # This may generate warnings due to collisions in already-logged param names log_fn_args_as_params(original, args, kwargs) # training model model = original(self, *args, **kwargs) if should_autolog: # Log the model if get_autologging_config(FLAVOR_NAME, "log_models", True): try_mlflow_log(log_model, model, artifact_path="model") # Log the most common metrics if isinstance(model, statsmodels.base.wrapper.ResultsWrapper): metrics_dict = results_to_dict(model) try_mlflow_log(mlflow.log_metrics, metrics_dict) return model finally: # Clean the shared flag for future calls in case it had been set here ... if should_autolog: AutologHelpers.should_autolog = True
def test_universal_autolog_calls_specific_autologs_correctly( library, mlflow_module): integrations_with_additional_config = [xgboost, lightgbm, sklearn] args_to_test = { "log_models": False, "disable": True, "exclusive": True, "disable_for_unsupported_versions": True, "silent": True, } if library in integrations_with_additional_config: args_to_test.update({ "log_input_examples": True, "log_model_signatures": True }) mlflow.autolog(**args_to_test) if mlflow_module == mlflow.tensorflow and Version( tensorflow.__version__) >= Version("2.6.0"): # NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent # autologging enablement logic relies on this import behavior. mlflow.utils.import_hooks.notify_module_loaded(keras) mlflow.utils.import_hooks.notify_module_loaded(tensorflow) else: mlflow.utils.import_hooks.notify_module_loaded(library) for arg_key, arg_value in args_to_test.items(): assert (get_autologging_config(mlflow_module.autolog.integration_name, arg_key, None) == arg_value)
def on_train_end(self, logs=None): # pylint: disable=unused-argument if self.log_models: registered_model_name = get_autologging_config( mlflow.tensorflow.FLAVOR_NAME, "registered_model_name", None) mlflow.keras.log_model(self.model, artifact_path="model", registered_model_name=registered_model_name)
def train_end(self, estimator, *args, **kwargs): if isinstance(estimator.net, HybridSequential) and self.log_models: registered_model_name = get_autologging_config( mlflow.gluon.FLAVOR_NAME, "registered_model_name", None) mlflow.gluon.log_model(estimator.net, artifact_path="model", registered_model_name=registered_model_name)
def wrapper_fit(original, self, *args, **kwargs): should_autolog = False if AutologHelpers.should_autolog: AutologHelpers.should_autolog = False should_autolog = True try: if should_autolog: # This may generate warnings due to collisions in already-logged param names log_fn_args_as_params(original, args, kwargs) # training model model = original(self, *args, **kwargs) if should_autolog: # Log the model if get_autologging_config(FLAVOR_NAME, "log_models", True): global _save_model_called_from_autolog _save_model_called_from_autolog = True registered_model_name = get_autologging_config( FLAVOR_NAME, "registered_model_name", None ) try: log_model( model, artifact_path="model", registered_model_name=registered_model_name, ) finally: _save_model_called_from_autolog = False # Log the most common metrics if isinstance(model, statsmodels.base.wrapper.ResultsWrapper): metrics_dict = _get_autolog_metrics(model) mlflow.log_metrics(metrics_dict) model_summary = model.summary().as_text() mlflow.log_text(model_summary, "model_summary.txt") return model finally: # Clean the shared flag for future calls in case it had been set here ... if should_autolog: AutologHelpers.should_autolog = True
def on_train_end(self, logs=None): if log_models: registered_model_name = get_autologging_config( FLAVOR_NAME, "registered_model_name", None) log_model( self.model, artifact_path="model", registered_model_name=registered_model_name, )
def patched_fit(original, self, *args, **kwargs): run_id = mlflow.active_run().info.run_id tracking_uri = mlflow.get_tracking_uri() client = MlflowAutologgingQueueingClient(tracking_uri) metrics_logger = BatchMetricsLogger(run_id, tracking_uri) log_models = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_models", True) log_every_n_epoch = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_every_n_epoch", 1) early_stop_callback = None mlflow_callback = __MLflowPaddleCallback(client, metrics_logger, run_id, log_models, log_every_n_epoch) if "callbacks" in kwargs: callbacks = kwargs["callbacks"] for callback in callbacks: if isinstance(callback, paddle.callbacks.EarlyStopping): early_stop_callback = callback _log_early_stop_params(early_stop_callback, client, run_id) break kwargs["callbacks"].append(mlflow_callback) else: kwargs["callbacks"] = [mlflow_callback] client.flush(synchronous=False) result = original(self, *args, **kwargs) if early_stop_callback is not None: _log_early_stop_metrics(early_stop_callback, client, run_id) mlflow.log_text(str(self.summary()), "model_summary.txt") if log_models: mlflow.paddle.log_model(pd_model=self, artifact_path="model") client.flush(synchronous=True) return result
def test_universal_autolog_calls_specific_autologs_correctly(library, mlflow_module): integrations_with_additional_config = [xgboost, lightgbm, sklearn] args_to_test = { "log_models": False, "disable": True, "exclusive": True, } if library in integrations_with_additional_config: args_to_test.update({"log_input_examples": True, "log_model_signatures": True}) mlflow.autolog(**args_to_test) mlflow.utils.import_hooks.notify_module_loaded(library) for arg_key, arg_value in args_to_test.items(): assert get_autologging_config(mlflow_module.FLAVOR_NAME, arg_key, None) == arg_value
def after_fit(self): from fastai.callback.all import SaveModelCallback # Do not log model in case of predicting if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"): return # Workaround to log model from SaveModelCallback # Use this till able to set order between SaveModelCallback and EarlyStoppingCallback for cb in self.cbs: if isinstance(cb, SaveModelCallback): cb("after_fit") if self.log_models: registered_model_name = get_autologging_config( mlflow.fastai.FLAVOR_NAME, "registered_model_name", None ) log_model( self.learn, artifact_path="model", registered_model_name=registered_model_name )
def safe_patch_function(*args, **kwargs): """ A safe wrapper around the specified `patch_function` implementation designed to handle exceptions thrown during the execution of `patch_function`. This wrapper distinguishes exceptions thrown from the underlying / original function (`<destination>.<function_name>`) from exceptions thrown from other parts of `patch_function`. This distinction is made by passing an augmented version of the underlying / original function to `patch_function` that uses nonlocal state to track whether or not it has been executed and whether or not it threw an exception. Exceptions thrown from the underlying / original function are propagated to the caller, while exceptions thrown from other parts of `patch_function` are caught and logged as warnings. """ # Reroute warnings encountered during the patch function implementation to an MLflow event # logger, and enforce silent mode if applicable (i.e. if the corresponding autologging # integration was called with `silent=True`), hiding MLflow event logging statements and # hiding all warnings in the autologging preamble and postamble (i.e. the code surrounding # the user's original / underlying ML function). Non-MLflow warnings are enabled during the # execution of the original / underlying ML function # # Note that we've opted *not* to apply this context manager as a decorator on # `safe_patch_function` because the context-manager-as-decorator pattern uses # `contextlib.ContextDecorator`, which creates generator expressions that cannot be pickled # during model serialization by ML frameworks such as scikit-learn is_silent_mode = get_autologging_config(autologging_integration, "silent", False) with set_mlflow_events_and_warnings_behavior_globally( # MLflow warnings emitted during autologging training sessions are likely not # actionable and result from the autologging implementation invoking another MLflow # API. Accordingly, we reroute these warnings to the MLflow event logger with level # WARNING For reference, see recommended warning and event logging behaviors from # https://docs.python.org/3/howto/logging.html#when-to-use-logging reroute_warnings=True, disable_event_logs=is_silent_mode, disable_warnings=is_silent_mode, ), set_non_mlflow_warnings_behavior_for_current_thread( # non-MLflow Warnings emitted during the autologging preamble (before the original / # underlying ML function is called) and postamble (after the original / underlying ML # function is called) are likely not actionable and result from the autologging # implementation invoking an API from a dependent library. Accordingly, we reroute # these warnings to the MLflow event logger with level WARNING. For reference, see # recommended warning and event logging behaviors from # https://docs.python.org/3/howto/logging.html#when-to-use-logging reroute_warnings=True, disable_warnings=is_silent_mode, ): if is_testing(): preexisting_run_for_testing = mlflow.active_run() # Whether or not to exclude autologged content from user-created fluent runs # (i.e. runs created manually via `mlflow.start_run()`) exclusive = get_autologging_config(autologging_integration, "exclusive", False) user_created_fluent_run_is_active = ( mlflow.active_run() and not _AutologgingSessionManager.active_session() ) active_session_failed = ( _AutologgingSessionManager.active_session() is not None and _AutologgingSessionManager.active_session().state == "failed" ) if ( active_session_failed or autologging_is_disabled(autologging_integration) or (user_created_fluent_run_is_active and exclusive) or mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED ): # If the autologging integration associated with this patch is disabled, # or if the current autologging integration is in exclusive mode and a user-created # fluent run is active, call the original function and return. Restore the original # warning behavior during original function execution, since autologging is being # skipped with set_non_mlflow_warnings_behavior_for_current_thread( disable_warnings=False, reroute_warnings=False, ): return original(*args, **kwargs) # Whether or not the original / underlying function has been called during the # execution of patched code original_has_been_called = False # The value returned by the call to the original / underlying function during # the execution of patched code original_result = None # Whether or not an exception was raised from within the original / underlying function # during the execution of patched code failed_during_original = False # The active MLflow run (if any) associated with patch code execution patch_function_run_for_testing = None def try_log_autologging_event(log_fn, *args): try: log_fn(*args) except Exception as e: _logger.debug( "Failed to log autologging event via '%s'. Exception: %s", log_fn, e, ) with _AutologgingSessionManager.start_session(autologging_integration) as session: try: def call_original(*og_args, **og_kwargs): try: try_log_autologging_event( AutologgingEventLogger.get_logger().log_original_function_start, session, destination, function_name, og_args, og_kwargs, ) if is_testing(): _validate_args(args, kwargs, og_args, og_kwargs) # By the time `original` is called by the patch implementation, we # assume that either: 1. the patch implementation has already # created an MLflow run or 2. the patch code will not create an # MLflow run during the current execution. Here, we capture a # reference to the active run, which we will use later on to # determine whether or not the patch implementation created # a run and perform validation if necessary nonlocal patch_function_run_for_testing patch_function_run_for_testing = mlflow.active_run() nonlocal original_has_been_called original_has_been_called = True nonlocal original_result # Show all non-MLflow warnings as normal (i.e. not as event logs) # during original function execution, even if silent mode is enabled # (`silent=True`), since these warnings originate from the ML framework # or one of its dependencies and are likely relevant to the caller with set_non_mlflow_warnings_behavior_for_current_thread( disable_warnings=False, reroute_warnings=False, ): original_result = original(*og_args, **og_kwargs) try_log_autologging_event( AutologgingEventLogger.get_logger().log_original_function_success, session, destination, function_name, og_args, og_kwargs, ) return original_result except Exception as e: try_log_autologging_event( AutologgingEventLogger.get_logger().log_original_function_error, session, destination, function_name, og_args, og_kwargs, e, ) nonlocal failed_during_original failed_during_original = True raise # Apply the name, docstring, and signature of `original` to `call_original`. # This is important because several autologging patch implementations inspect # the signature of the `original` argument during execution call_original = update_wrapper_extended(call_original, original) try_log_autologging_event( AutologgingEventLogger.get_logger().log_patch_function_start, session, destination, function_name, args, kwargs, ) if patch_is_class: patch_function.call(call_original, *args, **kwargs) else: patch_function(call_original, *args, **kwargs) session.state = "succeeded" try_log_autologging_event( AutologgingEventLogger.get_logger().log_patch_function_success, session, destination, function_name, args, kwargs, ) except Exception as e: session.state = "failed" # Exceptions thrown during execution of the original function should be # propagated to the caller. Additionally, exceptions encountered during test # mode should be reraised to detect bugs in autologging implementations if failed_during_original or is_testing(): raise try_log_autologging_event( AutologgingEventLogger.get_logger().log_patch_function_error, session, destination, function_name, args, kwargs, e, ) _logger.warning( "Encountered unexpected error during %s autologging: %s", autologging_integration, e, ) if is_testing() and not preexisting_run_for_testing: # If an MLflow run was created during the execution of patch code, verify that # it is no longer active and that it contains expected autologging tags assert not mlflow.active_run(), ( "Autologging integration %s leaked an active run" % autologging_integration ) if patch_function_run_for_testing: _validate_autologging_run( autologging_integration, patch_function_run_for_testing.info.run_id ) if original_has_been_called: return original_result else: return original(*args, **kwargs)
def patched_fit(original, self, *args, **kwargs): """ A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the following parameters, metrics and artifacts: - Training epochs - Optimizer parameters - `EarlyStoppingCallback`_ parameters - Metrics stored in `trainer.callback_metrics` - Model checkpoints - Trained model .. _EarlyStoppingCallback: https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html """ run_id = mlflow.active_run().info.run_id tracking_uri = mlflow.get_tracking_uri() client = MlflowAutologgingQueueingClient(tracking_uri) metrics_logger = BatchMetricsLogger(run_id, tracking_uri) log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_models", True) log_every_n_epoch = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_every_n_epoch", 1) log_every_n_step = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_every_n_step", None) early_stop_callback = None for callback in self.callbacks: if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping): early_stop_callback = callback _log_early_stop_params(early_stop_callback, client, run_id) if not any( isinstance(callbacks, __MLflowPLCallback) for callbacks in self.callbacks): self.callbacks += [ __MLflowPLCallback(client, metrics_logger, run_id, log_models, log_every_n_epoch, log_every_n_step) ] client.flush(synchronous=False) result = original(self, *args, **kwargs) if early_stop_callback is not None: _log_early_stop_metrics(early_stop_callback, client, run_id) if Version(pl.__version__) < Version("1.4.0"): summary = str(ModelSummary(self.model, mode="full")) else: summary = str(ModelSummary(self.model, max_depth=-1)) tempdir = tempfile.mkdtemp() try: summary_file = os.path.join(tempdir, "model_summary.txt") with open(summary_file, "w") as f: f.write(summary) mlflow.log_artifact(local_path=summary_file) finally: shutil.rmtree(tempdir) if log_models: registered_model_name = get_autologging_config( mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None) mlflow.pytorch.log_model( pytorch_model=self.model, artifact_path="model", registered_model_name=registered_model_name, ) if early_stop_callback is not None and self.checkpoint_callback.best_model_path: mlflow.log_artifact( local_path=self.checkpoint_callback.best_model_path, artifact_path="restored_model_checkpoint", ) client.flush(synchronous=True) return result
def train(_log_models, original, *args, **kwargs): def record_eval_results(eval_results, metrics_logger): """ Create a callback function that records evaluation results. """ # TODO: Remove `replace("SNAPSHOT", "dev")` once the following issue is addressed: # https://github.com/dmlc/xgboost/issues/6984 from mlflow.xgboost._autolog import IS_TRAINING_CALLBACK_SUPPORTED if IS_TRAINING_CALLBACK_SUPPORTED: from mlflow.xgboost._autolog import AutologCallback # In xgboost >= 1.3.0, user-defined callbacks should inherit # `xgboost.callback.TrainingCallback`: # https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback return AutologCallback(metrics_logger, eval_results) else: from mlflow.xgboost._autolog import autolog_callback return picklable_exception_safe_function( functools.partial(autolog_callback, metrics_logger=metrics_logger, eval_results=eval_results)) def log_feature_importance_plot(features, importance, importance_type): """ Log feature importance plot. """ import matplotlib.pyplot as plt from cycler import cycler features = np.array(features) # Structure the supplied `importance` values as a `num_features`-by-`num_classes` matrix importances_per_class_by_feature = np.array(importance) if importances_per_class_by_feature.ndim <= 1: # In this case, the supplied `importance` values are not given per class. Rather, # one importance value is given per feature. For consistency with the assumed # `num_features`-by-`num_classes` matrix structure, we coerce the importance # values to a `num_features`-by-1 matrix indices = np.argsort(importance) # Sort features and importance values by magnitude during transformation to a # `num_features`-by-`num_classes` matrix features = features[indices] importances_per_class_by_feature = np.array([[ importance ] for importance in importances_per_class_by_feature[indices]]) # In this case, do not include class labels on the feature importance plot because # only one importance value has been provided per feature, rather than an # one importance value for each class per feature label_classes_on_plot = False else: importance_value_magnitudes = np.abs( importances_per_class_by_feature).sum(axis=1) indices = np.argsort(importance_value_magnitudes) features = features[indices] importances_per_class_by_feature = importances_per_class_by_feature[ indices] label_classes_on_plot = True num_classes = importances_per_class_by_feature.shape[1] num_features = len(features) # If num_features > 10, increase the figure height to prevent the plot # from being too dense. w, h = [6.4, 4.8] # matplotlib's default figure size h = h + 0.1 * num_features if num_features > 10 else h h = h + 0.1 * num_classes if num_classes > 1 else h fig, ax = plt.subplots(figsize=(w, h)) # When importance values are provided for each class per feature, we want to ensure # that the same color is used for all bars in the bar chart that have the same class colors_to_cycle = plt.rcParams["axes.prop_cycle"].by_key( )["color"][:num_classes] color_cycler = cycler(color=colors_to_cycle) ax.set_prop_cycle(color_cycler) # The following logic operates on one feature at a time, adding a bar to the bar chart # for each class that reflects the importance of the feature to predictions of that # class feature_ylocs = np.arange(num_features) # Define offsets on the y-axis that are used to evenly space the bars for each class # around the y-axis position of each feature offsets_per_yloc = np.linspace( -0.5, 0.5, num_classes) / 2 if num_classes > 1 else [0] for feature_idx, (feature_yloc, importances_per_class) in enumerate( zip(feature_ylocs, importances_per_class_by_feature)): for class_idx, (offset, class_importance) in enumerate( zip(offsets_per_yloc, importances_per_class)): (bar, ) = ax.barh( feature_yloc + offset, class_importance, align="center", # Set the bar height such that importance value bars for a particular # feature are spaced properly relative to each other (no overlap or gaps) # and relative to importance value bars for other features height=(0.5 / max(num_classes - 1, 1)), ) if label_classes_on_plot and feature_idx == 0: # Only set a label the first time a bar for a particular class is plotted to # avoid duplicate legend entries. If we were to set a label for every bar, # the legend would contain `num_features` labels for each class. bar.set_label("Class {}".format(class_idx)) ax.set_yticks(feature_ylocs) ax.set_yticklabels(features) ax.set_xlabel("Importance") ax.set_title("Feature Importance ({})".format(importance_type)) if label_classes_on_plot: ax.legend() fig.tight_layout() tmpdir = tempfile.mkdtemp() try: # pylint: disable=undefined-loop-variable filepath = os.path.join( tmpdir, "feature_importance_{}.png".format(imp_type)) fig.savefig(filepath) mlflow.log_artifact(filepath) finally: plt.close(fig) shutil.rmtree(tmpdir) autologging_client = MlflowAutologgingQueueingClient() # logging booster params separately to extract key/value pairs and make it easier to # compare them across runs. booster_params = args[0] if len(args) > 0 else kwargs["params"] autologging_client.log_params(run_id=mlflow.active_run().info.run_id, params=booster_params) unlogged_params = [ "params", "dtrain", "evals", "obj", "feval", "evals_result", "xgb_model", "callbacks", "learning_rates", ] params_to_log_for_fn = get_mlflow_run_params_for_fn_args( original, args, kwargs, unlogged_params) autologging_client.log_params(run_id=mlflow.active_run().info.run_id, params=params_to_log_for_fn) param_logging_operations = autologging_client.flush(synchronous=False) all_arg_names = _get_arg_names(original) num_pos_args = len(args) # adding a callback that records evaluation results. eval_results = [] callbacks_index = all_arg_names.index("callbacks") run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: callback = record_eval_results(eval_results, metrics_logger) if num_pos_args >= callbacks_index + 1: tmp_list = list(args) tmp_list[callbacks_index] += [callback] args = tuple(tmp_list) elif "callbacks" in kwargs and kwargs["callbacks"] is not None: kwargs["callbacks"] += [callback] else: kwargs["callbacks"] = [callback] # training model model = original(*args, **kwargs) # If early_stopping_rounds is present, logging metrics at the best iteration # as extra metrics with the max step + 1. early_stopping_index = all_arg_names.index("early_stopping_rounds") early_stopping = num_pos_args >= early_stopping_index + 1 or kwargs.get( "early_stopping_rounds") if early_stopping: extra_step = len(eval_results) autologging_client.log_metrics( run_id=mlflow.active_run().info.run_id, metrics={ "stopped_iteration": extra_step - 1, "best_iteration": model.best_iteration, }, ) autologging_client.log_metrics( run_id=mlflow.active_run().info.run_id, metrics=eval_results[model.best_iteration], step=extra_step, ) early_stopping_logging_operations = autologging_client.flush( synchronous=False) # logging feature importance as artifacts. for imp_type in importance_types: imp = None try: imp = model.get_score(importance_type=imp_type) features, importance = zip(*imp.items()) log_feature_importance_plot(features, importance, imp_type) except Exception: _logger.exception( "Failed to log feature importance plot. XGBoost autologging " "will ignore the failure and continue. Exception: ") if imp is not None: tmpdir = tempfile.mkdtemp() try: filepath = os.path.join( tmpdir, "feature_importance_{}.json".format(imp_type)) with open(filepath, "w") as f: json.dump(imp, f) mlflow.log_artifact(filepath) finally: shutil.rmtree(tmpdir) # dtrain must exist as the original train function already ran successfully dtrain = args[1] if len(args) > 1 else kwargs.get("dtrain") # it is possible that the dataset was constructed before the patched # constructor was applied, so we cannot assume the input_example_info exists input_example_info = getattr(dtrain, "input_example_info", None) def get_input_example(): if input_example_info is None: raise Exception(ENSURE_AUTOLOGGING_ENABLED_TEXT) if input_example_info.error_msg is not None: raise Exception(input_example_info.error_msg) return input_example_info.input_example def infer_model_signature(input_example): model_output = model.predict(xgboost.DMatrix(input_example)) model_signature = infer_signature(input_example, model_output) return model_signature # Only log the model if the autolog() param log_models is set to True. if _log_models: # Will only resolve `input_example` and `signature` if `log_models` is `True`. input_example, signature = resolve_input_example_and_signature( get_input_example, infer_model_signature, log_input_examples, log_model_signatures, _logger, ) registered_model_name = get_autologging_config( FLAVOR_NAME, "registered_model_name", None) log_model( model, artifact_path="model", signature=signature, input_example=input_example, registered_model_name=registered_model_name, ) param_logging_operations.await_completion() if early_stopping: early_stopping_logging_operations.await_completion() return model