Beispiel #1
0
    def _log_pretraining_metadata(estimator, *args, **kwargs):  # pylint: disable=unused-argument
        """
        Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training.
        This is intended to be invoked within a patched scikit-learn training routine
        (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
        MLflow run that can be referenced via the fluent Tracking API.

        :param estimator: The scikit-learn estimator for which to log metadata.
        :param args: The arguments passed to the scikit-learn training routine (e.g.,
                     `fit()`, `fit_transform()`, ...).
        :param kwargs: The keyword arguments passed to the scikit-learn training routine.
        """
        # Deep parameter logging includes parameters from children of a given
        # estimator. For some meta estimators (e.g., pipelines), recording
        # these parameters is desirable. For parameter search estimators,
        # however, child estimators act as seeds for the parameter search
        # process; accordingly, we avoid logging initial, untuned parameters
        # for these seed estimators.
        should_log_params_deeply = not _is_parameter_search_estimator(
            estimator)
        # Chunk model parameters to avoid hitting the log_batch API limit
        for chunk in _chunk_dict(
                estimator.get_params(deep=should_log_params_deeply),
                chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
        ):
            truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH,
                                       MAX_PARAM_VAL_LENGTH)
            try_mlflow_log(mlflow.log_params, truncated)

        try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator))
Beispiel #2
0
    def fit_mlflow(self, func_name, *args, **kwargs):
        should_start_run = mlflow.active_run() is None
        if should_start_run:
            try_mlflow_log(mlflow.start_run)

        # TODO: We should not log nested estimator parameters for
        # parameter search estimators (GridSearchCV, RandomizedSearchCV)

        # Chunk and truncate model parameters to avoid hitting the log_batch API limit
        for chunk in _chunk_dict(self.get_params(deep=True),
                                 chunk_size=MAX_PARAMS_TAGS_PER_BATCH):
            truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH,
                                       MAX_PARAM_VAL_LENGTH)
            try_mlflow_log(mlflow.log_params, truncated)

        try_mlflow_log(
            mlflow.set_tags,
            {
                "estimator_name":
                self.__class__.__name__,
                "estimator_class":
                self.__class__.__module__ + "." + self.__class__.__name__,
            },
        )

        original_fit = gorilla.get_original_attribute(self, func_name)
        try:
            fit_output = original_fit(*args, **kwargs)
        except Exception as e:
            if should_start_run:
                try_mlflow_log(mlflow.end_run,
                               RunStatus.to_string(RunStatus.FAILED))

            raise e

        if hasattr(self, "score"):
            try:
                score_args = _get_args_for_score(self.score, self.fit, args,
                                                 kwargs)
                training_score = self.score(*score_args)
            except Exception as e:  # pylint: disable=broad-except
                msg = (
                    self.score.__qualname__ +
                    " failed. The 'training_score' metric will not be recorded. Scoring error: "
                    + str(e))
                _logger.warning(msg)
            else:
                try_mlflow_log(mlflow.log_metric, "training_score",
                               training_score)

        try_mlflow_log(log_model, self, artifact_path="model")

        if should_start_run:
            try_mlflow_log(mlflow.end_run)

        return fit_output
Beispiel #3
0
def truncate_dict(d):
    return _truncate_dict(d, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH)