Ejemplo n.º 1
0
def test_get_autologging_config_returns_configured_values_or_defaults_as_expected(
):

    assert get_autologging_config("nonexistent_integration", "foo") is None

    @autologging_integration("test_integration_for_config")
    def autolog(foo="bar", t=7, disable=False):
        pass

    # Before `autolog()` has been invoked, config values should not be available
    assert get_autologging_config("test_integration_for_config", "foo") is None
    assert get_autologging_config("test_integration_for_config",
                                  "disable") is None
    assert get_autologging_config("test_integration_for_config", "t", 10) == 10

    autolog()

    assert get_autologging_config("test_integration_for_config",
                                  "foo") == "bar"
    assert get_autologging_config("test_integration_for_config",
                                  "disable") is False
    assert get_autologging_config("test_integration_for_config", "t", 10) == 7
    assert get_autologging_config("test_integration_for_config",
                                  "nonexistent") is None

    autolog(foo="baz")

    assert get_autologging_config("test_integration_for_config",
                                  "foo") == "baz"
def test_autologging_integrations_expose_configs_and_support_disablement(integration):
    for integration in AUTOLOGGING_INTEGRATIONS_TO_TEST:
        integration.autolog(disable=False)

    assert not autologging_is_disabled(integration.FLAVOR_NAME)
    assert not get_autologging_config(integration.FLAVOR_NAME, "disable", True)

    integration.autolog(disable=True)

    assert autologging_is_disabled(integration.FLAVOR_NAME)
    assert get_autologging_config(integration.FLAVOR_NAME, "disable", False)
def test_autolog_obeys_disabled():
    from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS

    mlflow.autolog(disable=True)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert get_autologging_config("sklearn", "disable")

    mlflow.autolog()
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    mlflow.autolog(disable=True)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert get_autologging_config("sklearn", "disable")

    mlflow.autolog(disable=False)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert not get_autologging_config("sklearn", "disable")
    mlflow.sklearn.autolog(disable=True)
    assert get_autologging_config("sklearn", "disable")

    AUTOLOGGING_INTEGRATIONS.clear()
    mlflow.autolog(disable_for_unsupported_versions=False)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert not get_autologging_config("sklearn", "disable_for_unsupported_versions")
    mlflow.autolog(disable_for_unsupported_versions=True)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert get_autologging_config("sklearn", "disable_for_unsupported_versions")

    mlflow.sklearn.autolog(disable_for_unsupported_versions=False)
    assert not get_autologging_config("sklearn", "disable_for_unsupported_versions")
    mlflow.sklearn.autolog(disable_for_unsupported_versions=True)
    assert get_autologging_config("sklearn", "disable_for_unsupported_versions")
Ejemplo n.º 4
0
    def wrapper_fit(original, self, *args, **kwargs):

        should_autolog = False
        if AutologHelpers.should_autolog:
            AutologHelpers.should_autolog = False
            should_autolog = True

        try:
            if should_autolog:
                # This may generate warnings due to collisions in already-logged param names
                log_fn_args_as_params(original, args, kwargs)

            # training model
            model = original(self, *args, **kwargs)

            if should_autolog:
                # Log the model
                if get_autologging_config(FLAVOR_NAME, "log_models", True):
                    try_mlflow_log(log_model, model, artifact_path="model")

                # Log the most common metrics
                if isinstance(model, statsmodels.base.wrapper.ResultsWrapper):
                    metrics_dict = results_to_dict(model)
                    try_mlflow_log(mlflow.log_metrics, metrics_dict)

            return model

        finally:
            # Clean the shared flag for future calls in case it had been set here ...
            if should_autolog:
                AutologHelpers.should_autolog = True
Ejemplo n.º 5
0
def test_universal_autolog_calls_specific_autologs_correctly(
        library, mlflow_module):
    integrations_with_additional_config = [xgboost, lightgbm, sklearn]
    args_to_test = {
        "log_models": False,
        "disable": True,
        "exclusive": True,
        "disable_for_unsupported_versions": True,
        "silent": True,
    }
    if library in integrations_with_additional_config:
        args_to_test.update({
            "log_input_examples": True,
            "log_model_signatures": True
        })

    mlflow.autolog(**args_to_test)

    if mlflow_module == mlflow.tensorflow and Version(
            tensorflow.__version__) >= Version("2.6.0"):
        # NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
        # autologging enablement logic relies on this import behavior.
        mlflow.utils.import_hooks.notify_module_loaded(keras)
        mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
    else:
        mlflow.utils.import_hooks.notify_module_loaded(library)

    for arg_key, arg_value in args_to_test.items():
        assert (get_autologging_config(mlflow_module.autolog.integration_name,
                                       arg_key, None) == arg_value)
Ejemplo n.º 6
0
 def on_train_end(self, logs=None):  # pylint: disable=unused-argument
     if self.log_models:
         registered_model_name = get_autologging_config(
             mlflow.tensorflow.FLAVOR_NAME, "registered_model_name", None)
         mlflow.keras.log_model(self.model,
                                artifact_path="model",
                                registered_model_name=registered_model_name)
Ejemplo n.º 7
0
 def train_end(self, estimator, *args, **kwargs):
     if isinstance(estimator.net, HybridSequential) and self.log_models:
         registered_model_name = get_autologging_config(
             mlflow.gluon.FLAVOR_NAME, "registered_model_name", None)
         mlflow.gluon.log_model(estimator.net,
                                artifact_path="model",
                                registered_model_name=registered_model_name)
Ejemplo n.º 8
0
    def wrapper_fit(original, self, *args, **kwargs):

        should_autolog = False
        if AutologHelpers.should_autolog:
            AutologHelpers.should_autolog = False
            should_autolog = True

        try:
            if should_autolog:
                # This may generate warnings due to collisions in already-logged param names
                log_fn_args_as_params(original, args, kwargs)

            # training model
            model = original(self, *args, **kwargs)

            if should_autolog:
                # Log the model
                if get_autologging_config(FLAVOR_NAME, "log_models", True):
                    global _save_model_called_from_autolog
                    _save_model_called_from_autolog = True
                    registered_model_name = get_autologging_config(
                        FLAVOR_NAME, "registered_model_name", None
                    )
                    try:
                        log_model(
                            model,
                            artifact_path="model",
                            registered_model_name=registered_model_name,
                        )
                    finally:
                        _save_model_called_from_autolog = False

                # Log the most common metrics
                if isinstance(model, statsmodels.base.wrapper.ResultsWrapper):
                    metrics_dict = _get_autolog_metrics(model)
                    mlflow.log_metrics(metrics_dict)

                    model_summary = model.summary().as_text()
                    mlflow.log_text(model_summary, "model_summary.txt")

            return model

        finally:
            # Clean the shared flag for future calls in case it had been set here ...
            if should_autolog:
                AutologHelpers.should_autolog = True
Ejemplo n.º 9
0
 def on_train_end(self, logs=None):
     if log_models:
         registered_model_name = get_autologging_config(
             FLAVOR_NAME, "registered_model_name", None)
         log_model(
             self.model,
             artifact_path="model",
             registered_model_name=registered_model_name,
         )
Ejemplo n.º 10
0
def patched_fit(original, self, *args, **kwargs):
    run_id = mlflow.active_run().info.run_id
    tracking_uri = mlflow.get_tracking_uri()
    client = MlflowAutologgingQueueingClient(tracking_uri)
    metrics_logger = BatchMetricsLogger(run_id, tracking_uri)

    log_models = get_autologging_config(mlflow.paddle.FLAVOR_NAME,
                                        "log_models", True)
    log_every_n_epoch = get_autologging_config(mlflow.paddle.FLAVOR_NAME,
                                               "log_every_n_epoch", 1)

    early_stop_callback = None
    mlflow_callback = __MLflowPaddleCallback(client, metrics_logger, run_id,
                                             log_models, log_every_n_epoch)
    if "callbacks" in kwargs:
        callbacks = kwargs["callbacks"]
        for callback in callbacks:
            if isinstance(callback, paddle.callbacks.EarlyStopping):
                early_stop_callback = callback
                _log_early_stop_params(early_stop_callback, client, run_id)
                break
        kwargs["callbacks"].append(mlflow_callback)
    else:
        kwargs["callbacks"] = [mlflow_callback]
    client.flush(synchronous=False)

    result = original(self, *args, **kwargs)

    if early_stop_callback is not None:
        _log_early_stop_metrics(early_stop_callback, client, run_id)

    mlflow.log_text(str(self.summary()), "model_summary.txt")

    if log_models:
        mlflow.paddle.log_model(pd_model=self, artifact_path="model")

    client.flush(synchronous=True)

    return result
Ejemplo n.º 11
0
def test_universal_autolog_calls_specific_autologs_correctly(library, mlflow_module):
    integrations_with_additional_config = [xgboost, lightgbm, sklearn]
    args_to_test = {
        "log_models": False,
        "disable": True,
        "exclusive": True,
    }
    if library in integrations_with_additional_config:
        args_to_test.update({"log_input_examples": True, "log_model_signatures": True})

    mlflow.autolog(**args_to_test)
    mlflow.utils.import_hooks.notify_module_loaded(library)

    for arg_key, arg_value in args_to_test.items():
        assert get_autologging_config(mlflow_module.FLAVOR_NAME, arg_key, None) == arg_value
Ejemplo n.º 12
0
    def after_fit(self):
        from fastai.callback.all import SaveModelCallback

        # Do not log model in case of predicting
        if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"):
            return

        # Workaround to log model from SaveModelCallback
        # Use this till able to set order between SaveModelCallback and EarlyStoppingCallback
        for cb in self.cbs:
            if isinstance(cb, SaveModelCallback):
                cb("after_fit")

        if self.log_models:
            registered_model_name = get_autologging_config(
                mlflow.fastai.FLAVOR_NAME, "registered_model_name", None
            )
            log_model(
                self.learn, artifact_path="model", registered_model_name=registered_model_name
            )
Ejemplo n.º 13
0
    def safe_patch_function(*args, **kwargs):
        """
        A safe wrapper around the specified `patch_function` implementation designed to
        handle exceptions thrown during the execution of `patch_function`. This wrapper
        distinguishes exceptions thrown from the underlying / original function
        (`<destination>.<function_name>`) from exceptions thrown from other parts of
        `patch_function`. This distinction is made by passing an augmented version of the
        underlying / original function to `patch_function` that uses nonlocal state to track
        whether or not it has been executed and whether or not it threw an exception.
        Exceptions thrown from the underlying / original function are propagated to the caller,
        while exceptions thrown from other parts of `patch_function` are caught and logged as
        warnings.
        """
        # Reroute warnings encountered during the patch function implementation to an MLflow event
        # logger, and enforce silent mode if applicable (i.e. if the corresponding autologging
        # integration was called with `silent=True`), hiding MLflow event logging statements and
        # hiding all warnings in the autologging preamble and postamble (i.e. the code surrounding
        # the user's original / underlying ML function). Non-MLflow warnings are enabled during the
        # execution of the original / underlying ML function
        #
        # Note that we've opted *not* to apply this context manager as a decorator on
        # `safe_patch_function` because the context-manager-as-decorator pattern uses
        # `contextlib.ContextDecorator`, which creates generator expressions that cannot be pickled
        # during model serialization by ML frameworks such as scikit-learn
        is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
        with set_mlflow_events_and_warnings_behavior_globally(
            # MLflow warnings emitted during autologging training sessions are likely not
            # actionable and result from the autologging implementation invoking another MLflow
            # API. Accordingly, we reroute these warnings to the MLflow event logger with level
            # WARNING For reference, see recommended warning and event logging behaviors from
            # https://docs.python.org/3/howto/logging.html#when-to-use-logging
            reroute_warnings=True,
            disable_event_logs=is_silent_mode,
            disable_warnings=is_silent_mode,
        ), set_non_mlflow_warnings_behavior_for_current_thread(
            # non-MLflow Warnings emitted during the autologging preamble (before the original /
            # underlying ML function is called) and postamble (after the original / underlying ML
            # function is called) are likely not actionable and result from the autologging
            # implementation invoking an API from a dependent library. Accordingly, we reroute
            # these warnings to the MLflow event logger with level WARNING. For reference, see
            # recommended warning and event logging behaviors from
            # https://docs.python.org/3/howto/logging.html#when-to-use-logging
            reroute_warnings=True,
            disable_warnings=is_silent_mode,
        ):

            if is_testing():
                preexisting_run_for_testing = mlflow.active_run()

            # Whether or not to exclude autologged content from user-created fluent runs
            # (i.e. runs created manually via `mlflow.start_run()`)
            exclusive = get_autologging_config(autologging_integration, "exclusive", False)
            user_created_fluent_run_is_active = (
                mlflow.active_run() and not _AutologgingSessionManager.active_session()
            )
            active_session_failed = (
                _AutologgingSessionManager.active_session() is not None
                and _AutologgingSessionManager.active_session().state == "failed"
            )

            if (
                active_session_failed
                or autologging_is_disabled(autologging_integration)
                or (user_created_fluent_run_is_active and exclusive)
                or mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
            ):
                # If the autologging integration associated with this patch is disabled,
                # or if the current autologging integration is in exclusive mode and a user-created
                # fluent run is active, call the original function and return. Restore the original
                # warning behavior during original function execution, since autologging is being
                # skipped
                with set_non_mlflow_warnings_behavior_for_current_thread(
                    disable_warnings=False, reroute_warnings=False,
                ):
                    return original(*args, **kwargs)

            # Whether or not the original / underlying function has been called during the
            # execution of patched code
            original_has_been_called = False
            # The value returned by the call to the original / underlying function during
            # the execution of patched code
            original_result = None
            # Whether or not an exception was raised from within the original / underlying function
            # during the execution of patched code
            failed_during_original = False
            # The active MLflow run (if any) associated with patch code execution
            patch_function_run_for_testing = None

            def try_log_autologging_event(log_fn, *args):
                try:
                    log_fn(*args)
                except Exception as e:
                    _logger.debug(
                        "Failed to log autologging event via '%s'. Exception: %s", log_fn, e,
                    )

            with _AutologgingSessionManager.start_session(autologging_integration) as session:
                try:

                    def call_original(*og_args, **og_kwargs):
                        try:
                            try_log_autologging_event(
                                AutologgingEventLogger.get_logger().log_original_function_start,
                                session,
                                destination,
                                function_name,
                                og_args,
                                og_kwargs,
                            )

                            if is_testing():
                                _validate_args(args, kwargs, og_args, og_kwargs)
                                # By the time `original` is called by the patch implementation, we
                                # assume that either: 1. the patch implementation has already
                                # created an MLflow run or 2. the patch code will not create an
                                # MLflow run during the current execution. Here, we capture a
                                # reference to the active run, which we will use later on to
                                # determine whether or not the patch implementation created
                                # a run and perform validation if necessary
                                nonlocal patch_function_run_for_testing
                                patch_function_run_for_testing = mlflow.active_run()

                            nonlocal original_has_been_called
                            original_has_been_called = True

                            nonlocal original_result
                            # Show all non-MLflow warnings as normal (i.e. not as event logs)
                            # during original function execution, even if silent mode is enabled
                            # (`silent=True`), since these warnings originate from the ML framework
                            # or one of its dependencies and are likely relevant to the caller
                            with set_non_mlflow_warnings_behavior_for_current_thread(
                                disable_warnings=False, reroute_warnings=False,
                            ):
                                original_result = original(*og_args, **og_kwargs)

                            try_log_autologging_event(
                                AutologgingEventLogger.get_logger().log_original_function_success,
                                session,
                                destination,
                                function_name,
                                og_args,
                                og_kwargs,
                            )

                            return original_result
                        except Exception as e:
                            try_log_autologging_event(
                                AutologgingEventLogger.get_logger().log_original_function_error,
                                session,
                                destination,
                                function_name,
                                og_args,
                                og_kwargs,
                                e,
                            )

                            nonlocal failed_during_original
                            failed_during_original = True
                            raise

                    # Apply the name, docstring, and signature of `original` to `call_original`.
                    # This is important because several autologging patch implementations inspect
                    # the signature of the `original` argument during execution
                    call_original = update_wrapper_extended(call_original, original)

                    try_log_autologging_event(
                        AutologgingEventLogger.get_logger().log_patch_function_start,
                        session,
                        destination,
                        function_name,
                        args,
                        kwargs,
                    )

                    if patch_is_class:
                        patch_function.call(call_original, *args, **kwargs)
                    else:
                        patch_function(call_original, *args, **kwargs)

                    session.state = "succeeded"

                    try_log_autologging_event(
                        AutologgingEventLogger.get_logger().log_patch_function_success,
                        session,
                        destination,
                        function_name,
                        args,
                        kwargs,
                    )
                except Exception as e:
                    session.state = "failed"

                    # Exceptions thrown during execution of the original function should be
                    # propagated to the caller. Additionally, exceptions encountered during test
                    # mode should be reraised to detect bugs in autologging implementations
                    if failed_during_original or is_testing():
                        raise

                    try_log_autologging_event(
                        AutologgingEventLogger.get_logger().log_patch_function_error,
                        session,
                        destination,
                        function_name,
                        args,
                        kwargs,
                        e,
                    )

                    _logger.warning(
                        "Encountered unexpected error during %s autologging: %s",
                        autologging_integration,
                        e,
                    )

                if is_testing() and not preexisting_run_for_testing:
                    # If an MLflow run was created during the execution of patch code, verify that
                    # it is no longer active and that it contains expected autologging tags
                    assert not mlflow.active_run(), (
                        "Autologging integration %s leaked an active run" % autologging_integration
                    )
                    if patch_function_run_for_testing:
                        _validate_autologging_run(
                            autologging_integration, patch_function_run_for_testing.info.run_id
                        )

                if original_has_been_called:
                    return original_result
                else:
                    return original(*args, **kwargs)
Ejemplo n.º 14
0
def patched_fit(original, self, *args, **kwargs):
    """
    A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the
    following parameters, metrics and artifacts:

    - Training epochs
    - Optimizer parameters
    - `EarlyStoppingCallback`_ parameters
    - Metrics stored in `trainer.callback_metrics`
    - Model checkpoints
    - Trained model

    .. _EarlyStoppingCallback:
        https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
    """
    run_id = mlflow.active_run().info.run_id
    tracking_uri = mlflow.get_tracking_uri()
    client = MlflowAutologgingQueueingClient(tracking_uri)
    metrics_logger = BatchMetricsLogger(run_id, tracking_uri)

    log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME,
                                        "log_models", True)
    log_every_n_epoch = get_autologging_config(mlflow.pytorch.FLAVOR_NAME,
                                               "log_every_n_epoch", 1)
    log_every_n_step = get_autologging_config(mlflow.pytorch.FLAVOR_NAME,
                                              "log_every_n_step", None)

    early_stop_callback = None
    for callback in self.callbacks:
        if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping):
            early_stop_callback = callback
            _log_early_stop_params(early_stop_callback, client, run_id)

    if not any(
            isinstance(callbacks, __MLflowPLCallback)
            for callbacks in self.callbacks):
        self.callbacks += [
            __MLflowPLCallback(client, metrics_logger, run_id, log_models,
                               log_every_n_epoch, log_every_n_step)
        ]

    client.flush(synchronous=False)

    result = original(self, *args, **kwargs)

    if early_stop_callback is not None:
        _log_early_stop_metrics(early_stop_callback, client, run_id)

    if Version(pl.__version__) < Version("1.4.0"):
        summary = str(ModelSummary(self.model, mode="full"))
    else:
        summary = str(ModelSummary(self.model, max_depth=-1))

    tempdir = tempfile.mkdtemp()
    try:
        summary_file = os.path.join(tempdir, "model_summary.txt")
        with open(summary_file, "w") as f:
            f.write(summary)

        mlflow.log_artifact(local_path=summary_file)
    finally:
        shutil.rmtree(tempdir)

    if log_models:
        registered_model_name = get_autologging_config(
            mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None)
        mlflow.pytorch.log_model(
            pytorch_model=self.model,
            artifact_path="model",
            registered_model_name=registered_model_name,
        )

        if early_stop_callback is not None and self.checkpoint_callback.best_model_path:
            mlflow.log_artifact(
                local_path=self.checkpoint_callback.best_model_path,
                artifact_path="restored_model_checkpoint",
            )

    client.flush(synchronous=True)

    return result
Ejemplo n.º 15
0
    def train(_log_models, original, *args, **kwargs):
        def record_eval_results(eval_results, metrics_logger):
            """
            Create a callback function that records evaluation results.
            """
            # TODO: Remove `replace("SNAPSHOT", "dev")` once the following issue is addressed:
            #       https://github.com/dmlc/xgboost/issues/6984
            from mlflow.xgboost._autolog import IS_TRAINING_CALLBACK_SUPPORTED

            if IS_TRAINING_CALLBACK_SUPPORTED:
                from mlflow.xgboost._autolog import AutologCallback

                # In xgboost >= 1.3.0, user-defined callbacks should inherit
                # `xgboost.callback.TrainingCallback`:
                # https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback
                return AutologCallback(metrics_logger, eval_results)
            else:
                from mlflow.xgboost._autolog import autolog_callback

                return picklable_exception_safe_function(
                    functools.partial(autolog_callback,
                                      metrics_logger=metrics_logger,
                                      eval_results=eval_results))

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt
            from cycler import cycler

            features = np.array(features)

            # Structure the supplied `importance` values as a `num_features`-by-`num_classes` matrix
            importances_per_class_by_feature = np.array(importance)
            if importances_per_class_by_feature.ndim <= 1:
                # In this case, the supplied `importance` values are not given per class. Rather,
                # one importance value is given per feature. For consistency with the assumed
                # `num_features`-by-`num_classes` matrix structure, we coerce the importance
                # values to a `num_features`-by-1 matrix
                indices = np.argsort(importance)
                # Sort features and importance values by magnitude during transformation to a
                # `num_features`-by-`num_classes` matrix
                features = features[indices]
                importances_per_class_by_feature = np.array([[
                    importance
                ] for importance in importances_per_class_by_feature[indices]])
                # In this case, do not include class labels on the feature importance plot because
                # only one importance value has been provided per feature, rather than an
                # one importance value for each class per feature
                label_classes_on_plot = False
            else:
                importance_value_magnitudes = np.abs(
                    importances_per_class_by_feature).sum(axis=1)
                indices = np.argsort(importance_value_magnitudes)
                features = features[indices]
                importances_per_class_by_feature = importances_per_class_by_feature[
                    indices]
                label_classes_on_plot = True

            num_classes = importances_per_class_by_feature.shape[1]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            h = h + 0.1 * num_classes if num_classes > 1 else h
            fig, ax = plt.subplots(figsize=(w, h))
            # When importance values are provided for each class per feature, we want to ensure
            # that the same color is used for all bars in the bar chart that have the same class
            colors_to_cycle = plt.rcParams["axes.prop_cycle"].by_key(
            )["color"][:num_classes]
            color_cycler = cycler(color=colors_to_cycle)
            ax.set_prop_cycle(color_cycler)

            # The following logic operates on one feature at a time, adding a bar to the bar chart
            # for each class that reflects the importance of the feature to predictions of that
            # class
            feature_ylocs = np.arange(num_features)
            # Define offsets on the y-axis that are used to evenly space the bars for each class
            # around the y-axis position of each feature
            offsets_per_yloc = np.linspace(
                -0.5, 0.5, num_classes) / 2 if num_classes > 1 else [0]
            for feature_idx, (feature_yloc,
                              importances_per_class) in enumerate(
                                  zip(feature_ylocs,
                                      importances_per_class_by_feature)):
                for class_idx, (offset, class_importance) in enumerate(
                        zip(offsets_per_yloc, importances_per_class)):
                    (bar, ) = ax.barh(
                        feature_yloc + offset,
                        class_importance,
                        align="center",
                        # Set the bar height such that importance value bars for a particular
                        # feature are spaced properly relative to each other (no overlap or gaps)
                        # and relative to importance value bars for other features
                        height=(0.5 / max(num_classes - 1, 1)),
                    )
                    if label_classes_on_plot and feature_idx == 0:
                        # Only set a label the first time a bar for a particular class is plotted to
                        # avoid duplicate legend entries. If we were to set a label for every bar,
                        # the legend would contain `num_features` labels for each class.
                        bar.set_label("Class {}".format(class_idx))

            ax.set_yticks(feature_ylocs)
            ax.set_yticklabels(features)
            ax.set_xlabel("Importance")
            ax.set_title("Feature Importance ({})".format(importance_type))
            if label_classes_on_plot:
                ax.legend()
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(
                    tmpdir, "feature_importance_{}.png".format(imp_type))
                fig.savefig(filepath)
                mlflow.log_artifact(filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        autologging_client = MlflowAutologgingQueueingClient()
        # logging booster params separately to extract key/value pairs and make it easier to
        # compare them across runs.
        booster_params = args[0] if len(args) > 0 else kwargs["params"]
        autologging_client.log_params(run_id=mlflow.active_run().info.run_id,
                                      params=booster_params)

        unlogged_params = [
            "params",
            "dtrain",
            "evals",
            "obj",
            "feval",
            "evals_result",
            "xgb_model",
            "callbacks",
            "learning_rates",
        ]
        params_to_log_for_fn = get_mlflow_run_params_for_fn_args(
            original, args, kwargs, unlogged_params)
        autologging_client.log_params(run_id=mlflow.active_run().info.run_id,
                                      params=params_to_log_for_fn)

        param_logging_operations = autologging_client.flush(synchronous=False)

        all_arg_names = _get_arg_names(original)
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index("callbacks")

        run_id = mlflow.active_run().info.run_id
        with batch_metrics_logger(run_id) as metrics_logger:
            callback = record_eval_results(eval_results, metrics_logger)
            if num_pos_args >= callbacks_index + 1:
                tmp_list = list(args)
                tmp_list[callbacks_index] += [callback]
                args = tuple(tmp_list)
            elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
                kwargs["callbacks"] += [callback]
            else:
                kwargs["callbacks"] = [callback]

            # training model
            model = original(*args, **kwargs)

            # If early_stopping_rounds is present, logging metrics at the best iteration
            # as extra metrics with the max step + 1.
            early_stopping_index = all_arg_names.index("early_stopping_rounds")
            early_stopping = num_pos_args >= early_stopping_index + 1 or kwargs.get(
                "early_stopping_rounds")
            if early_stopping:
                extra_step = len(eval_results)
                autologging_client.log_metrics(
                    run_id=mlflow.active_run().info.run_id,
                    metrics={
                        "stopped_iteration": extra_step - 1,
                        "best_iteration": model.best_iteration,
                    },
                )
                autologging_client.log_metrics(
                    run_id=mlflow.active_run().info.run_id,
                    metrics=eval_results[model.best_iteration],
                    step=extra_step,
                )
                early_stopping_logging_operations = autologging_client.flush(
                    synchronous=False)

        # logging feature importance as artifacts.
        for imp_type in importance_types:
            imp = None
            try:
                imp = model.get_score(importance_type=imp_type)
                features, importance = zip(*imp.items())
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:
                _logger.exception(
                    "Failed to log feature importance plot. XGBoost autologging "
                    "will ignore the failure and continue. Exception: ")

            if imp is not None:
                tmpdir = tempfile.mkdtemp()
                try:
                    filepath = os.path.join(
                        tmpdir, "feature_importance_{}.json".format(imp_type))
                    with open(filepath, "w") as f:
                        json.dump(imp, f)
                    mlflow.log_artifact(filepath)
                finally:
                    shutil.rmtree(tmpdir)

        # dtrain must exist as the original train function already ran successfully
        dtrain = args[1] if len(args) > 1 else kwargs.get("dtrain")

        # it is possible that the dataset was constructed before the patched
        #   constructor was applied, so we cannot assume the input_example_info exists
        input_example_info = getattr(dtrain, "input_example_info", None)

        def get_input_example():
            if input_example_info is None:
                raise Exception(ENSURE_AUTOLOGGING_ENABLED_TEXT)
            if input_example_info.error_msg is not None:
                raise Exception(input_example_info.error_msg)
            return input_example_info.input_example

        def infer_model_signature(input_example):
            model_output = model.predict(xgboost.DMatrix(input_example))
            model_signature = infer_signature(input_example, model_output)
            return model_signature

        # Only log the model if the autolog() param log_models is set to True.
        if _log_models:
            # Will only resolve `input_example` and `signature` if `log_models` is `True`.
            input_example, signature = resolve_input_example_and_signature(
                get_input_example,
                infer_model_signature,
                log_input_examples,
                log_model_signatures,
                _logger,
            )

            registered_model_name = get_autologging_config(
                FLAVOR_NAME, "registered_model_name", None)
            log_model(
                model,
                artifact_path="model",
                signature=signature,
                input_example=input_example,
                registered_model_name=registered_model_name,
            )

        param_logging_operations.await_completion()
        if early_stopping:
            early_stopping_logging_operations.await_completion()

        return model