示例#1
0
    def _run_and_log_function(self, original, args, kwargs, unlogged_params, callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, [self] + list(args), kwargs, unlogged_params)

        callbacks = [cb(self) for cb in self.callback_fns] + (self.callbacks or [])

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            callbacks += list(args[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowFastaiCallback(self)]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            callbacks += list(kwargs['callbacks'])
            kwargs['callbacks'] += [__MLflowFastaiCallback(self)]
        else:
            kwargs['callbacks'] = [__MLflowFastaiCallback(self)]

        early_stop_callback = _find_callback_of_type(EarlyStoppingCallback, callbacks)
        one_cycle_callback = _find_callback_of_type(OneCycleScheduler, callbacks)

        _log_early_stop_callback_params(early_stop_callback)
        _log_one_cycle_callback_params(one_cycle_callback)

        result = original(self, *args, **kwargs)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return result
示例#2
0
    def fit_generator(self, *args, **kwargs):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(keras.Model, 'fit_generator')

        unlogged_params = [
            'self', 'generator', 'callbacks', 'validation_data', 'verbose'
        ]

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        # Checking if the 'callback' argument of fit() is set
        if len(args) >= 5:
            tmp_list = list(args)
            tmp_list[4] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        result = original(self, *args, **kwargs)
        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return result
示例#3
0
        def _patch_implementation(
            self, original, inst, *args, **kwargs
        ):  # pylint: disable=arguments-differ
            unlogged_params = ["self", "generator", "callbacks", "validation_data", "verbose"]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            run_id = mlflow.active_run().info.run_id

            with batch_metrics_logger(run_id) as metrics_logger:
                # Checking if the 'callback' argument of fit() is set
                if len(args) >= 5:
                    tmp_list = list(args)
                    tmp_list[4], self.log_dir = _setup_callbacks(
                        tmp_list[4], log_models, metrics_logger
                    )
                    args = tuple(tmp_list)
                elif kwargs.get("callbacks"):
                    kwargs["callbacks"], self.log_dir = _setup_callbacks(
                        kwargs["callbacks"], log_models, metrics_logger
                    )
                else:
                    kwargs["callbacks"], self.log_dir = _setup_callbacks(
                        [], log_models, metrics_logger
                    )
                result = original(inst, *args, **kwargs)

            _flush_queue()
            _log_artifacts_with_warning(
                local_dir=self.log_dir.location, artifact_path="tensorboard_logs"
            )
            if self.log_dir.is_temp:
                shutil.rmtree(self.log_dir.location)

            return result
示例#4
0
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model, "fit")

            unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)
            early_stop_callback = None

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                early_stop_callback = _early_stop_check(tmp_list[5])
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5], log_models)
                args = tuple(tmp_list)
            elif kwargs.get("callbacks"):
                early_stop_callback = _early_stop_check(kwargs["callbacks"])
                kwargs["callbacks"], log_dir = _setup_callbacks(kwargs["callbacks"], log_models)
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([], log_models)

            _log_early_stop_callback_params(early_stop_callback)

            history = original(self, *args, **kwargs)

            _log_early_stop_callback_metrics(early_stop_callback, history)

            _flush_queue()
            _log_artifacts_with_warning(
                local_dir=log_dir.location, artifact_path="tensorboard_logs"
            )
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return history
示例#5
0
    def _run_and_log_function(self, original, args, kwargs, unlogged_params, callback_arg_index):
        log_fn_args_as_params(original, list(args), kwargs, unlogged_params)

        callbacks = [cb(self) for cb in self.callback_fns] + (self.callbacks or [])

        run_id = mlflow.active_run().info.run_id
        with batch_metrics_logger(run_id) as metrics_logger:
            mlflowFastaiCallback = getFastaiCallback(metrics_logger, self)

            # Checking if the 'callback' argument of the function is set
            if len(args) > callback_arg_index:
                tmp_list = list(args)
                callbacks += list(args[callback_arg_index])
                tmp_list[callback_arg_index] += [mlflowFastaiCallback]
                args = tuple(tmp_list)
            elif kwargs.get("callbacks"):
                callbacks += list(kwargs["callbacks"])
                kwargs["callbacks"] += [mlflowFastaiCallback]
            else:
                kwargs["callbacks"] = [mlflowFastaiCallback]

            early_stop_callback = _find_callback_of_type(EarlyStoppingCallback, callbacks)
            one_cycle_callback = _find_callback_of_type(OneCycleScheduler, callbacks)

            _log_early_stop_callback_params(early_stop_callback)
            _log_one_cycle_callback_params(one_cycle_callback)

            result = original(self, *args, **kwargs)

        return result
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      "fit_generator")

            unlogged_params = [
                "self", "generator", "callbacks", "validation_data", "verbose"
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif "callbacks" in kwargs:
                kwargs["callbacks"], log_dir = _setup_callbacks(
                    kwargs["callbacks"])
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path="tensorboard_logs")
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return result
示例#7
0
文件: keras.py 项目: radzak/mlflow
    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        log_fn_args_as_params(original, args, kwargs, unlogged_params)
        early_stop_callback = None

        # Checking if the 'callback' argument of the function is set
        run_id = mlflow.active_run().info.run_id
        with batch_metrics_logger(run_id) as metrics_logger:
            mlflowKerasCallback = getKerasCallback(metrics_logger)
            if len(args) > callback_arg_index:
                tmp_list = list(args)
                early_stop_callback = _early_stop_check(
                    tmp_list[callback_arg_index])
                tmp_list[callback_arg_index] += [mlflowKerasCallback]
                args = tuple(tmp_list)
            elif kwargs.get("callbacks"):
                early_stop_callback = _early_stop_check(kwargs["callbacks"])
                kwargs["callbacks"] += [mlflowKerasCallback]
            else:
                kwargs["callbacks"] = [mlflowKerasCallback]

            try_mlflow_log(_log_early_stop_callback_params,
                           early_stop_callback)

            history = original(self, *args, **kwargs)

            try_mlflow_log(_log_early_stop_callback_metrics,
                           early_stop_callback, history, metrics_logger)

        return history
示例#8
0
    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, args, kwargs, unlogged_params)
        early_stop_callback = None

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            early_stop_callback = _early_stop_check(
                tmp_list[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            early_stop_callback = _early_stop_check(kwargs['callbacks'])
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        _log_early_stop_callback_params(early_stop_callback)

        history = original(self, *args, **kwargs)

        _log_early_stop_callback_metrics(early_stop_callback, history)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return history
示例#9
0
    def wrapper_fit(original, self, *args, **kwargs):

        should_autolog = False
        if AutologHelpers.should_autolog:
            AutologHelpers.should_autolog = False
            should_autolog = True

        try:
            if should_autolog:
                # This may generate warnings due to collisions in already-logged param names
                log_fn_args_as_params(original, args, kwargs)

            # training model
            model = original(self, *args, **kwargs)

            if should_autolog:
                # Log the model
                if get_autologging_config(FLAVOR_NAME, "log_models", True):
                    try_mlflow_log(log_model, model, artifact_path="model")

                # Log the most common metrics
                if isinstance(model, statsmodels.base.wrapper.ResultsWrapper):
                    metrics_dict = results_to_dict(model)
                    try_mlflow_log(mlflow.log_metrics, metrics_dict)

            return model

        finally:
            # Clean the shared flag for future calls in case it had been set here ...
            if should_autolog:
                AutologHelpers.should_autolog = True
示例#10
0
def test_log_fn_args_as_params(args, kwargs, expected, start_run):  # pylint: disable=W0613
    log_fn_args_as_params(dummy_fn, args, kwargs)
    client = mlflow.tracking.MlflowClient()
    params = client.get_run(mlflow.active_run().info.run_id).data.params
    for arg, value in zip(["arg1", "arg2", "arg3"], expected):
        assert arg in params
        assert params[arg] == value
示例#11
0
    def _run_and_log_function(self, original, args, kwargs, unlogged_params, is_fine_tune=False):
        # Check if is trying to fit while fine tuning or not
        mlflow_cbs = [cb for cb in self.cbs if cb.name == "___mlflow_fastai"]
        fit_in_fine_tune = (
            original.__name__ == "fit" and len(mlflow_cbs) > 0 and mlflow_cbs[0].is_fine_tune
        )

        if not fit_in_fine_tune:
            log_fn_args_as_params(original, list(args), kwargs, unlogged_params)

        run_id = mlflow.active_run().info.run_id
        with batch_metrics_logger(run_id) as metrics_logger:

            if not fit_in_fine_tune:
                early_stop_callback = _find_callback_of_type(EarlyStoppingCallback, self.cbs)
                _log_early_stop_callback_params(early_stop_callback)

                # First try to remove if any already registered callback
                self.remove_cbs(mlflow_cbs)

                # Log information regarding model and data without bar and print-out
                with self.no_bar(), self.no_logging():
                    _log_model_info(learner=self)

                mlflowFastaiCallback = getFastaiCallback(
                    metrics_logger=metrics_logger, is_fine_tune=is_fine_tune
                )

                # Add the new callback
                self.add_cb(mlflowFastaiCallback)

            result = original(self, *args, **kwargs)

        return result
示例#12
0
    def fit_generator(self, *args, **kwargs):
        """
        NOTE: `fit_generator()` is deprecated in TF >= 2.1.0 and simply wraps `fit()`.
        To avoid unintentional creation of nested MLflow runs caused by a patched
        `fit_generator()` method calling a patched `fit()` method, we only patch
        `fit_generator()` in TF < 2.1.0.
        """
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model, "fit_generator")

            unlogged_params = ["self", "generator", "callbacks", "validation_data", "verbose"]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4], log_models)
                args = tuple(tmp_list)
            elif kwargs.get("callbacks"):
                kwargs["callbacks"], log_dir = _setup_callbacks(kwargs["callbacks"], log_models)
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([], log_models)
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(
                local_dir=log_dir.location, artifact_path="tensorboard_logs"
            )
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return result
示例#13
0
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit_generator')

            unlogged_params = [
                'self', 'generator', 'callbacks', 'validation_data', 'verbose'
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir,
                                        artifact_path='tensorboard_logs')
            shutil.rmtree(log_dir)

            return result
示例#14
0
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model, 'fit')

            unlogged_params = ['self', 'x', 'y', 'callbacks', 'validation_data', 'verbose']

            log_fn_args_as_params(original, args, kwargs, unlogged_params)
            early_stop_callback = None

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                early_stop_callback = _early_stop_check(tmp_list[5])
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                early_stop_callback = _early_stop_check(kwargs['callbacks'])
                kwargs['callbacks'], log_dir = _setup_callbacks(kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])

            _log_early_stop_callback_params(early_stop_callback)

            history = original(self, *args, **kwargs)

            _log_early_stop_callback_metrics(early_stop_callback, history)

            _flush_queue()
            _log_artifacts_with_warning(
                local_dir=log_dir.location, artifact_path='tensorboard_logs')
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return history
示例#15
0
def test_log_fn_args_as_params_ignores_unwanted_parameters(start_run):  # pylint: disable=W0613
    args, kwargs, unlogged = ("arg1", {
        "arg2": "value"
    }, ["arg1", "arg2", "arg3"])
    log_fn_args_as_params(dummy_fn, args, kwargs, unlogged)
    client = mlflow.tracking.MlflowClient()
    params = client.get_run(mlflow.active_run().info.run_id).data.params
    assert len(params.keys()) == 0
示例#16
0
        def _patch_implementation(
            self, original, inst, *args, **kwargs
        ):  # pylint: disable=arguments-differ
            unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)
            early_stop_callback = None

            run_id = mlflow.active_run().info.run_id
            with batch_metrics_logger(run_id) as metrics_logger:
                # Check if the 'callback' argument of fit() is set positionally
                if len(args) >= 6:
                    # Convert the positional training function arguments to a list in order to
                    # mutate the contents
                    args = list(args)
                    # Make a shallow copy of the preexisting callbacks to avoid permanently
                    # modifying their contents for future training invocations. Introduce
                    # TensorBoard & tf.keras callbacks if necessary
                    callbacks = list(args[5])
                    callbacks, self.log_dir = _setup_callbacks(
                        callbacks, log_models, metrics_logger
                    )
                    # Replace the callbacks positional entry in the copied arguments and convert
                    # the arguments back to tuple form for usage in the training function
                    args[5] = callbacks
                    args = tuple(args)
                else:
                    # Make a shallow copy of the preexisting callbacks and introduce TensorBoard
                    # & tf.keras callbacks if necessary
                    callbacks = list(kwargs.get("callbacks") or [])
                    kwargs["callbacks"], self.log_dir = _setup_callbacks(
                        callbacks, log_models, metrics_logger
                    )

                early_stop_callback = _get_early_stop_callback(callbacks)
                _log_early_stop_callback_params(early_stop_callback)

                history = original(inst, *args, **kwargs)

                _log_early_stop_callback_metrics(
                    callback=early_stop_callback,
                    history=history,
                    metrics_logger=metrics_logger,
                )

            _flush_queue()
            mlflow.log_artifacts(
                local_dir=self.log_dir.location,
                artifact_path="tensorboard_logs",
            )
            if self.log_dir.is_temp:
                shutil.rmtree(self.log_dir.location)

            return history
示例#17
0
    def wrapper_fit(original, self, *args, **kwargs):

        should_autolog = False
        if AutologHelpers.should_autolog:
            AutologHelpers.should_autolog = False
            should_autolog = True

        try:
            if should_autolog:
                # This may generate warnings due to collisions in already-logged param names
                log_fn_args_as_params(original, args, kwargs)

            # training model
            model = original(self, *args, **kwargs)

            if should_autolog:
                # Log the model
                if get_autologging_config(FLAVOR_NAME, "log_models", True):
                    global _save_model_called_from_autolog
                    _save_model_called_from_autolog = True
                    registered_model_name = get_autologging_config(
                        FLAVOR_NAME, "registered_model_name", None
                    )
                    try:
                        log_model(
                            model,
                            artifact_path="model",
                            registered_model_name=registered_model_name,
                        )
                    finally:
                        _save_model_called_from_autolog = False

                # Log the most common metrics
                if isinstance(model, statsmodels.base.wrapper.ResultsWrapper):
                    metrics_dict = _get_autolog_metrics(model)
                    mlflow.log_metrics(metrics_dict)

                    model_summary = model.summary().as_text()
                    mlflow.log_text(model_summary, "model_summary.txt")

            return model

        finally:
            # Clean the shared flag for future calls in case it had been set here ...
            if should_autolog:
                AutologHelpers.should_autolog = True
示例#18
0
        def fit(self, *args, **kwargs):

            should_autolog = False
            if AutologHelpers.should_autolog:
                AutologHelpers.should_autolog = False
                should_autolog = True

            if not mlflow.active_run():
                try_mlflow_log(mlflow.start_run)
                auto_end_run = True
            else:
                auto_end_run = False

            try:
                # training model
                model = original_method(self, *args, **kwargs)

                if should_autolog:
                    # Log the model
                    try_mlflow_log(log_model, model, artifact_path="model")

                    # Log the most common metrics
                    if isinstance(model, statsmodels.base.wrapper.ResultsWrapper):
                        metrics_dict = results_to_dict(model)
                        try_mlflow_log(mlflow.log_metrics, metrics_dict)
                        # This may generate warnings due to collisions in already-logged param names
                        log_fn_args_as_params(original_method, args, kwargs)

                return model

            finally:
                # Clean the shared flag for future calls in case it had been set here ...
                if should_autolog:
                    AutologHelpers.should_autolog = True

                # End current run if it had been created here ...
                if auto_end_run:
                    try_mlflow_log(mlflow.end_run)
示例#19
0
    def train(*args, **kwargs):
        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                eval_results.append(dict(env.evaluation_result_list))

            return callback

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            features = np.array(features)
            importance = np.array(importance)
            indices = np.argsort(importance)
            features = features[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align="center", height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel("Importance")
            ax.set_title("Feature Importance ({})".format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(
                    tmpdir, "feature_importance_{}.png".format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        original = gorilla.get_original_attribute(xgboost, "train")

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs["params"]
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = [
            "params",
            "dtrain",
            "evals",
            "obj",
            "feval",
            "evals_result",
            "xgb_model",
            "callbacks",
            "learning_rates",
        ]
        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index("callbacks")
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
            kwargs["callbacks"] += [callback]
        else:
            kwargs["callbacks"] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index("early_stopping_rounds")
        early_stopping = (num_pos_args >= early_stopping_index + 1
                          or "early_stopping_rounds" in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, "stopped_iteration",
                           len(eval_results) - 1)
            try_mlflow_log(mlflow.log_metric, "best_iteration",
                           model.best_iteration)
            try_mlflow_log(mlflow.log_metrics,
                           eval_results[model.best_iteration],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in importance_types:
            imp = None
            try:
                imp = model.get_score(importance_type=imp_type)
                features, importance = zip(*imp.items())
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:  # pylint: disable=broad-except
                _logger.exception(
                    "Failed to log feature importance plot. XGBoost autologging "
                    "will ignore the failure and continue. Exception: ")

            if imp is not None:
                tmpdir = tempfile.mkdtemp()
                try:
                    filepath = os.path.join(
                        tmpdir, "feature_importance_{}.json".format(imp_type))
                    with open(filepath, "w") as f:
                        json.dump(imp, f)
                    try_mlflow_log(mlflow.log_artifact, filepath)
                finally:
                    shutil.rmtree(tmpdir)

        # dtrain must exist as the original train function already ran successfully
        dtrain = args[1] if len(args) > 1 else kwargs.get("dtrain")

        input_example = None
        signature = None
        try:
            # it is possible that the dataset was constructed before the patched
            #   constructor was applied, so we cannot assume the input_example_info exists
            input_example_info = getattr(dtrain, "input_example_info", None)

            if input_example_info is None:
                raise Exception("please ensure that autologging is " +
                                "enabled before constructing the dataset.")

            input_example = input_example_info.input_example
            if input_example is None:
                # input example collection failed
                raise Exception(input_example_info.error_msg)

            model_output = model.predict(xgboost.DMatrix(input_example))
            signature = infer_signature(input_example, model_output)
        except Exception as e:  # pylint: disable=broad-except
            input_example = None
            msg = "Failed to gather example input and model signature: " + str(
                e)
            _logger.warning(msg)

        try_mlflow_log(
            log_model,
            model,
            artifact_path="model",
            signature=signature,
            input_example=input_example,
        )

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model
示例#20
0
文件: lightgbm.py 项目: vsoch/mlflow
    def train(*args, **kwargs):
        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """

            def callback(env):
                res = {}
                for data_name, eval_name, value, _ in env.evaluation_result_list:
                    key = data_name + "-" + eval_name
                    res[key] = value

                eval_results.append(res)

            return callback

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            indices = np.argsort(importance)
            features = np.array(features)[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align="center", height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel("Importance")
            ax.set_title("Feature Importance ({})".format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(tmpdir, "feature_importance_{}.png".format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(lightgbm, "train")

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs["params"]
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = [
            "params",
            "train_set",
            "valid_sets",
            "valid_names",
            "fobj",
            "feval",
            "init_model",
            "evals_result",
            "learning_rates",
            "callbacks",
        ]

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index("callbacks")
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
            kwargs["callbacks"] += [callback]
        else:
            kwargs["callbacks"] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index("early_stopping_rounds")
        early_stopping = (
            num_pos_args >= early_stopping_index + 1 or "early_stopping_rounds" in kwargs
        )
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, "stopped_iteration", len(eval_results))
            # best_iteration is set even if training does not stop early.
            try_mlflow_log(mlflow.log_metric, "best_iteration", model.best_iteration)
            # iteration starts from 1 in LightGBM.
            try_mlflow_log(
                mlflow.log_metrics, eval_results[model.best_iteration - 1], step=extra_step
            )

        # logging feature importance as artifacts.
        for imp_type in ["split", "gain"]:
            features = model.feature_name()
            importance = model.feature_importance(importance_type=imp_type)
            try:
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:  # pylint: disable=broad-except
                _logger.exception(
                    "Failed to log feature importance plot. LightGBM autologging "
                    "will ignore the failure and continue. Exception: "
                )

            imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(tmpdir, "feature_importance_{}.json".format(imp_type))
                with open(filepath, "w") as f:
                    json.dump(imp, f, indent=2)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        # train_set must exist as the original train function already ran successfully
        train_set = args[1] if len(args) > 1 else kwargs.get("train_set")

        # it is possible that the dataset was constructed before the patched
        #   constructor was applied, so we cannot assume the input_example_info exists
        input_example_info = getattr(train_set, "input_example_info", None)

        def get_input_example():
            if input_example_info is None:
                raise Exception(ENSURE_AUTOLOGGING_ENABLED_TEXT)
            if input_example_info.error_msg is not None:
                raise Exception(input_example_info.error_msg)
            return input_example_info.input_example

        def infer_model_signature(input_example):
            model_output = model.predict(input_example)
            model_signature = infer_signature(input_example, model_output)
            return model_signature

        input_example, signature = resolve_input_example_and_signature(
            get_input_example,
            infer_model_signature,
            log_input_example,
            log_model_signature,
            _logger,
        )

        try_mlflow_log(
            log_model,
            model,
            artifact_path="model",
            signature=signature,
            input_example=input_example,
        )

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model
示例#21
0
    def train(*args, **kwargs):

        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                res = {}
                for data_name, eval_name, value, _ in env.evaluation_result_list:
                    key = data_name + '-' + eval_name
                    res[key] = value

                eval_results.append(res)
            return callback

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(lightgbm, 'train')

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs['params']
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval',
                           'init_model', 'evals_result', 'learning_rates', 'callbacks']

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index('callbacks')
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs and kwargs['callbacks'] is not None:
            kwargs['callbacks'] += [callback]
        else:
            kwargs['callbacks'] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index('early_stopping_rounds')
        early_stopping = (num_pos_args >= early_stopping_index + 1 or
                          'early_stopping_rounds' in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results))
            # best_iteration is set even if training does not stop early.
            try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration)
            # iteration starts from 1 in LightGBM.
            try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in ['split', 'gain']:
            features = model.feature_name()
            importance = model.feature_importance(importance_type=imp_type)
            imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type))
                with open(filepath, 'w') as f:
                    json.dump(imp, f)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        try_mlflow_log(log_model, model, artifact_path='model')

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model
示例#22
0
    def train(original, *args, **kwargs):
        def record_eval_results(eval_results, metrics_logger):
            """
            Create a callback function that records evaluation results.
            """
            @exception_safe_function
            def callback(env):
                metrics_logger.record_metrics(dict(env.evaluation_result_list),
                                              env.iteration)
                eval_results.append(dict(env.evaluation_result_list))

            return callback

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            features = np.array(features)
            importance = np.array(importance)
            indices = np.argsort(importance)
            features = features[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align="center", height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel("Importance")
            ax.set_title("Feature Importance ({})".format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(
                    tmpdir, "feature_importance_{}.png".format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs["params"]
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = [
            "params",
            "dtrain",
            "evals",
            "obj",
            "feval",
            "evals_result",
            "xgb_model",
            "callbacks",
            "learning_rates",
        ]
        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index("callbacks")

        run_id = mlflow.active_run().info.run_id
        with batch_metrics_logger(run_id) as metrics_logger:
            callback = record_eval_results(eval_results, metrics_logger)
            if num_pos_args >= callbacks_index + 1:
                tmp_list = list(args)
                tmp_list[callbacks_index] += [callback]
                args = tuple(tmp_list)
            elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
                kwargs["callbacks"] += [callback]
            else:
                kwargs["callbacks"] = [callback]

            # training model
            model = original(*args, **kwargs)

            # If early_stopping_rounds is present, logging metrics at the best iteration
            # as extra metrics with the max step + 1.
            early_stopping_index = all_arg_names.index("early_stopping_rounds")
            early_stopping = (num_pos_args >= early_stopping_index + 1
                              or "early_stopping_rounds" in kwargs)
            if early_stopping:
                extra_step = len(eval_results)
                metrics_logger.record_metrics(
                    {"stopped_iteration": extra_step - 1})
                metrics_logger.record_metrics(
                    {"best_iteration": model.best_iteration})
                metrics_logger.record_metrics(
                    eval_results[model.best_iteration], extra_step)

        # logging feature importance as artifacts.
        for imp_type in importance_types:
            imp = None
            try:
                imp = model.get_score(importance_type=imp_type)
                features, importance = zip(*imp.items())
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:
                _logger.exception(
                    "Failed to log feature importance plot. XGBoost autologging "
                    "will ignore the failure and continue. Exception: ")

            if imp is not None:
                tmpdir = tempfile.mkdtemp()
                try:
                    filepath = os.path.join(
                        tmpdir, "feature_importance_{}.json".format(imp_type))
                    with open(filepath, "w") as f:
                        json.dump(imp, f)
                    try_mlflow_log(mlflow.log_artifact, filepath)
                finally:
                    shutil.rmtree(tmpdir)

        # dtrain must exist as the original train function already ran successfully
        dtrain = args[1] if len(args) > 1 else kwargs.get("dtrain")

        # it is possible that the dataset was constructed before the patched
        #   constructor was applied, so we cannot assume the input_example_info exists
        input_example_info = getattr(dtrain, "input_example_info", None)

        def get_input_example():
            if input_example_info is None:
                raise Exception(ENSURE_AUTOLOGGING_ENABLED_TEXT)
            if input_example_info.error_msg is not None:
                raise Exception(input_example_info.error_msg)
            return input_example_info.input_example

        def infer_model_signature(input_example):
            model_output = model.predict(xgboost.DMatrix(input_example))
            model_signature = infer_signature(input_example, model_output)
            return model_signature

        # Only log the model if the autolog() param log_models is set to True.
        if log_models:
            # Will only resolve `input_example` and `signature` if `log_models` is `True`.
            input_example, signature = resolve_input_example_and_signature(
                get_input_example,
                infer_model_signature,
                log_input_examples,
                log_model_signatures,
                _logger,
            )

            try_mlflow_log(
                log_model,
                model,
                artifact_path="model",
                signature=signature,
                input_example=input_example,
            )

        return model
示例#23
0
    def train(*args, **kwargs):

        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                res = {}
                for data_name, eval_name, value, _ in env.evaluation_result_list:
                    key = data_name + '-' + eval_name
                    res[key] = value

                eval_results.append(res)
            return callback

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            indices = np.argsort(importance)
            features = np.array(features)[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align='center', height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel('Importance')
            ax.set_title('Feature Importance ({})'.format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(tmpdir, 'feature_importance_{}.png'.format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(lightgbm, 'train')

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs['params']
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval',
                           'init_model', 'evals_result', 'learning_rates', 'callbacks']

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index('callbacks')
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs and kwargs['callbacks'] is not None:
            kwargs['callbacks'] += [callback]
        else:
            kwargs['callbacks'] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index('early_stopping_rounds')
        early_stopping = (num_pos_args >= early_stopping_index + 1 or
                          'early_stopping_rounds' in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results))
            # best_iteration is set even if training does not stop early.
            try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration)
            # iteration starts from 1 in LightGBM.
            try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in ['split', 'gain']:
            features = model.feature_name()
            importance = model.feature_importance(importance_type=imp_type)
            try:
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:  # pylint: disable=broad-except
                _logger.exception('Failed to log feature importance plot. LightGBM autologging '
                                  'will ignore the failure and continue. Exception: ')

            imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type))
                with open(filepath, 'w') as f:
                    json.dump(imp, f, indent=2)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        try_mlflow_log(log_model, model, artifact_path='model')

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model
示例#24
0
    def train(original, *args, **kwargs):
        def record_eval_results(eval_results, metrics_logger):
            """
            Create a callback function that records evaluation results.
            """
            # TODO: Remove `replace("SNAPSHOT", "dev")` once the following issue is addressed:
            #       https://github.com/dmlc/xgboost/issues/6984
            if Version(xgboost.__version__.replace("SNAPSHOT", "dev")) >= Version("1.3.0"):
                # In xgboost >= 1.3.0, user-defined callbacks should inherit
                # `xgboost.callback.TrainingCallback`:
                # https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback  # noqa

                class Callback(
                    xgboost.callback.TrainingCallback, metaclass=ExceptionSafeAbstractClass,
                ):
                    def after_iteration(self, model, epoch, evals_log):
                        """
                        Run after each iteration. Return True when training should stop.
                        """
                        # `evals_log` is a nested dict (type: Dict[str, Dict[str, List[float]]])
                        # that looks like this:
                        # {
                        #   "train": {
                        #     "auc": [0.5, 0.6, 0.7, ...],
                        #     ...
                        #   },
                        #   ...
                        # }
                        evaluation_result_dict = {}
                        for data_name, metric_dict in evals_log.items():
                            for metric_name, metric_values_on_each_iter in metric_dict.items():
                                key = "{}-{}".format(data_name, metric_name)
                                # The last element in `metric_values_on_each_iter` corresponds to
                                # the meric on the current iteration
                                evaluation_result_dict[key] = metric_values_on_each_iter[-1]

                        metrics_logger.record_metrics(evaluation_result_dict, epoch)
                        eval_results.append(evaluation_result_dict)

                        # Return `False` to indicate training should not stop
                        return False

                return Callback()

            else:

                @exception_safe_function
                def callback(env):
                    metrics_logger.record_metrics(dict(env.evaluation_result_list), env.iteration)
                    eval_results.append(dict(env.evaluation_result_list))

                return callback

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            features = np.array(features)
            importance = np.array(importance)
            indices = np.argsort(importance)
            features = features[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align="center", height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel("Importance")
            ax.set_title("Feature Importance ({})".format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(tmpdir, "feature_importance_{}.png".format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs["params"]
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = [
            "params",
            "dtrain",
            "evals",
            "obj",
            "feval",
            "evals_result",
            "xgb_model",
            "callbacks",
            "learning_rates",
        ]
        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index("callbacks")

        run_id = mlflow.active_run().info.run_id
        with batch_metrics_logger(run_id) as metrics_logger:
            callback = record_eval_results(eval_results, metrics_logger)
            if num_pos_args >= callbacks_index + 1:
                tmp_list = list(args)
                tmp_list[callbacks_index] += [callback]
                args = tuple(tmp_list)
            elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
                kwargs["callbacks"] += [callback]
            else:
                kwargs["callbacks"] = [callback]

            # training model
            model = original(*args, **kwargs)

            # If early_stopping_rounds is present, logging metrics at the best iteration
            # as extra metrics with the max step + 1.
            early_stopping_index = all_arg_names.index("early_stopping_rounds")
            early_stopping = (
                num_pos_args >= early_stopping_index + 1 or "early_stopping_rounds" in kwargs
            )
            if early_stopping:
                extra_step = len(eval_results)
                metrics_logger.record_metrics({"stopped_iteration": extra_step - 1})
                metrics_logger.record_metrics({"best_iteration": model.best_iteration})
                metrics_logger.record_metrics(eval_results[model.best_iteration], extra_step)

        # logging feature importance as artifacts.
        for imp_type in importance_types:
            imp = None
            try:
                imp = model.get_score(importance_type=imp_type)
                features, importance = zip(*imp.items())
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:
                _logger.exception(
                    "Failed to log feature importance plot. XGBoost autologging "
                    "will ignore the failure and continue. Exception: "
                )

            if imp is not None:
                tmpdir = tempfile.mkdtemp()
                try:
                    filepath = os.path.join(tmpdir, "feature_importance_{}.json".format(imp_type))
                    with open(filepath, "w") as f:
                        json.dump(imp, f)
                    try_mlflow_log(mlflow.log_artifact, filepath)
                finally:
                    shutil.rmtree(tmpdir)

        # dtrain must exist as the original train function already ran successfully
        dtrain = args[1] if len(args) > 1 else kwargs.get("dtrain")

        # it is possible that the dataset was constructed before the patched
        #   constructor was applied, so we cannot assume the input_example_info exists
        input_example_info = getattr(dtrain, "input_example_info", None)

        def get_input_example():
            if input_example_info is None:
                raise Exception(ENSURE_AUTOLOGGING_ENABLED_TEXT)
            if input_example_info.error_msg is not None:
                raise Exception(input_example_info.error_msg)
            return input_example_info.input_example

        def infer_model_signature(input_example):
            model_output = model.predict(xgboost.DMatrix(input_example))
            model_signature = infer_signature(input_example, model_output)
            return model_signature

        # Only log the model if the autolog() param log_models is set to True.
        if log_models:
            # Will only resolve `input_example` and `signature` if `log_models` is `True`.
            input_example, signature = resolve_input_example_and_signature(
                get_input_example,
                infer_model_signature,
                log_input_examples,
                log_model_signatures,
                _logger,
            )

            try_mlflow_log(
                log_model,
                model,
                artifact_path="model",
                signature=signature,
                input_example=input_example,
            )

        return model