def test_autolog_obeys_disabled():
    from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS

    mlflow.autolog(disable=True)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert get_autologging_config("sklearn", "disable")

    mlflow.autolog()
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    mlflow.autolog(disable=True)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert get_autologging_config("sklearn", "disable")

    mlflow.autolog(disable=False)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert not get_autologging_config("sklearn", "disable")
    mlflow.sklearn.autolog(disable=True)
    assert get_autologging_config("sklearn", "disable")

    AUTOLOGGING_INTEGRATIONS.clear()
    mlflow.autolog(disable_for_unsupported_versions=False)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert not get_autologging_config("sklearn", "disable_for_unsupported_versions")
    mlflow.autolog(disable_for_unsupported_versions=True)
    mlflow.utils.import_hooks.notify_module_loaded(sklearn)
    assert get_autologging_config("sklearn", "disable_for_unsupported_versions")

    mlflow.sklearn.autolog(disable_for_unsupported_versions=False)
    assert not get_autologging_config("sklearn", "disable_for_unsupported_versions")
    mlflow.sklearn.autolog(disable_for_unsupported_versions=True)
    assert get_autologging_config("sklearn", "disable_for_unsupported_versions")
def reset_global_states():
    from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS

    for key in AUTOLOGGING_INTEGRATIONS.keys():
        AUTOLOGGING_INTEGRATIONS[key].clear()

    for integration_name in library_to_mlflow_module.keys():
        try:
            del mlflow.utils.import_hooks._post_import_hooks[integration_name.__name__]
        except Exception:
            pass

    assert all(v == {} for v in AUTOLOGGING_INTEGRATIONS.values())
    assert mlflow.utils.import_hooks._post_import_hooks == {}

    yield

    for key in AUTOLOGGING_INTEGRATIONS.keys():
        AUTOLOGGING_INTEGRATIONS[key].clear()

    for integration_name in library_to_mlflow_module.keys():
        try:
            del mlflow.utils.import_hooks._post_import_hooks[integration_name.__name__]
        except Exception:
            pass

    assert all(v == {} for v in AUTOLOGGING_INTEGRATIONS.values())
    assert mlflow.utils.import_hooks._post_import_hooks == {}
예제 #3
0
def test_autolog_globally_configured_flag_set_correctly():
    from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS

    AUTOLOGGING_INTEGRATIONS.clear()
    import sklearn  # pylint: disable=unused-import,unused-variable
    import pyspark  # pylint: disable=unused-import,unused-variable
    import pyspark.ml  # pylint: disable=unused-import,unused-variable

    integrations_to_test = ["sklearn", "spark", "pyspark.ml"]
    mlflow.autolog()
    for integration_name in integrations_to_test:
        assert AUTOLOGGING_INTEGRATIONS[integration_name]["globally_configured"]

    mlflow.sklearn.autolog()
    mlflow.spark.autolog()
    mlflow.pyspark.ml.autolog()

    for integration_name in integrations_to_test:
        assert "globally_configured" not in AUTOLOGGING_INTEGRATIONS[integration_name]
예제 #4
0
def test_disable_for_unsupported_versions_warning_sklearn_integration():
    log_warn_fn_name = "mlflow.utils.autologging_utils._logger.warning"
    log_info_fn_name = "mlflow.tracking.fluent._logger.info"

    def is_sklearn_warning_fired(log_warn_fn_args):
        return (
            "You are using an unsupported version of" in log_warn_fn_args[0][0]
            and log_warn_fn_args[0][1] == "sklearn"
        )

    def is_sklearn_autolog_enabled_info_fired(log_info_fn_args):
        return (
            "Autologging successfully enabled for " in log_info_fn_args[0][0]
            and log_info_fn_args[0][1] == "sklearn"
        )

    with mock.patch("sklearn.__version__", "0.20.3"):
        AUTOLOGGING_INTEGRATIONS.clear()
        with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch(
            log_info_fn_name
        ) as log_info_fn:
            mlflow.autolog(disable_for_unsupported_versions=True)
            assert all(not is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list)
            assert any(
                is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list
            )
        with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch(
            log_info_fn_name
        ) as log_info_fn:
            mlflow.autolog(disable_for_unsupported_versions=False)
            assert all(not is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list)
            assert any(
                is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list
            )

        with mock.patch(log_warn_fn_name) as log_warn_fn:
            mlflow.sklearn.autolog(disable_for_unsupported_versions=True)
            log_warn_fn.assert_not_called()
        with mock.patch(log_warn_fn_name) as log_warn_fn:
            mlflow.sklearn.autolog(disable_for_unsupported_versions=False)
            log_warn_fn.assert_not_called()

    with mock.patch("sklearn.__version__", "0.20.2"):
        AUTOLOGGING_INTEGRATIONS.clear()
        with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch(
            log_info_fn_name
        ) as log_info_fn:
            mlflow.autolog(disable_for_unsupported_versions=True)
            assert all(not is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list)
            assert all(
                not is_sklearn_autolog_enabled_info_fired(args)
                for args in log_info_fn.call_args_list
            )
        with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch(
            log_info_fn_name
        ) as log_info_fn:
            mlflow.autolog(disable_for_unsupported_versions=False)
            assert any(is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list)
            assert any(
                is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list
            )
        with mock.patch(log_warn_fn_name) as log_warn_fn:
            mlflow.sklearn.autolog(disable_for_unsupported_versions=True)
            log_warn_fn.assert_not_called()
        with mock.patch(log_warn_fn_name) as log_warn_fn:
            mlflow.sklearn.autolog(disable_for_unsupported_versions=False)
            assert log_warn_fn.call_count == 1 and is_sklearn_warning_fired(log_warn_fn.call_args)