def test_autolog_obeys_disabled(): from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS mlflow.autolog(disable=True) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert get_autologging_config("sklearn", "disable") mlflow.autolog() mlflow.utils.import_hooks.notify_module_loaded(sklearn) mlflow.autolog(disable=True) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert get_autologging_config("sklearn", "disable") mlflow.autolog(disable=False) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert not get_autologging_config("sklearn", "disable") mlflow.sklearn.autolog(disable=True) assert get_autologging_config("sklearn", "disable") AUTOLOGGING_INTEGRATIONS.clear() mlflow.autolog(disable_for_unsupported_versions=False) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert not get_autologging_config("sklearn", "disable_for_unsupported_versions") mlflow.autolog(disable_for_unsupported_versions=True) mlflow.utils.import_hooks.notify_module_loaded(sklearn) assert get_autologging_config("sklearn", "disable_for_unsupported_versions") mlflow.sklearn.autolog(disable_for_unsupported_versions=False) assert not get_autologging_config("sklearn", "disable_for_unsupported_versions") mlflow.sklearn.autolog(disable_for_unsupported_versions=True) assert get_autologging_config("sklearn", "disable_for_unsupported_versions")
def reset_global_states(): from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS for key in AUTOLOGGING_INTEGRATIONS.keys(): AUTOLOGGING_INTEGRATIONS[key].clear() for integration_name in library_to_mlflow_module.keys(): try: del mlflow.utils.import_hooks._post_import_hooks[integration_name.__name__] except Exception: pass assert all(v == {} for v in AUTOLOGGING_INTEGRATIONS.values()) assert mlflow.utils.import_hooks._post_import_hooks == {} yield for key in AUTOLOGGING_INTEGRATIONS.keys(): AUTOLOGGING_INTEGRATIONS[key].clear() for integration_name in library_to_mlflow_module.keys(): try: del mlflow.utils.import_hooks._post_import_hooks[integration_name.__name__] except Exception: pass assert all(v == {} for v in AUTOLOGGING_INTEGRATIONS.values()) assert mlflow.utils.import_hooks._post_import_hooks == {}
def test_autolog_globally_configured_flag_set_correctly(): from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS AUTOLOGGING_INTEGRATIONS.clear() import sklearn # pylint: disable=unused-import,unused-variable import pyspark # pylint: disable=unused-import,unused-variable import pyspark.ml # pylint: disable=unused-import,unused-variable integrations_to_test = ["sklearn", "spark", "pyspark.ml"] mlflow.autolog() for integration_name in integrations_to_test: assert AUTOLOGGING_INTEGRATIONS[integration_name]["globally_configured"] mlflow.sklearn.autolog() mlflow.spark.autolog() mlflow.pyspark.ml.autolog() for integration_name in integrations_to_test: assert "globally_configured" not in AUTOLOGGING_INTEGRATIONS[integration_name]
def test_disable_for_unsupported_versions_warning_sklearn_integration(): log_warn_fn_name = "mlflow.utils.autologging_utils._logger.warning" log_info_fn_name = "mlflow.tracking.fluent._logger.info" def is_sklearn_warning_fired(log_warn_fn_args): return ( "You are using an unsupported version of" in log_warn_fn_args[0][0] and log_warn_fn_args[0][1] == "sklearn" ) def is_sklearn_autolog_enabled_info_fired(log_info_fn_args): return ( "Autologging successfully enabled for " in log_info_fn_args[0][0] and log_info_fn_args[0][1] == "sklearn" ) with mock.patch("sklearn.__version__", "0.20.3"): AUTOLOGGING_INTEGRATIONS.clear() with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch( log_info_fn_name ) as log_info_fn: mlflow.autolog(disable_for_unsupported_versions=True) assert all(not is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list) assert any( is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list ) with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch( log_info_fn_name ) as log_info_fn: mlflow.autolog(disable_for_unsupported_versions=False) assert all(not is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list) assert any( is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list ) with mock.patch(log_warn_fn_name) as log_warn_fn: mlflow.sklearn.autolog(disable_for_unsupported_versions=True) log_warn_fn.assert_not_called() with mock.patch(log_warn_fn_name) as log_warn_fn: mlflow.sklearn.autolog(disable_for_unsupported_versions=False) log_warn_fn.assert_not_called() with mock.patch("sklearn.__version__", "0.20.2"): AUTOLOGGING_INTEGRATIONS.clear() with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch( log_info_fn_name ) as log_info_fn: mlflow.autolog(disable_for_unsupported_versions=True) assert all(not is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list) assert all( not is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list ) with mock.patch(log_warn_fn_name) as log_warn_fn, mock.patch( log_info_fn_name ) as log_info_fn: mlflow.autolog(disable_for_unsupported_versions=False) assert any(is_sklearn_warning_fired(args) for args in log_warn_fn.call_args_list) assert any( is_sklearn_autolog_enabled_info_fired(args) for args in log_info_fn.call_args_list ) with mock.patch(log_warn_fn_name) as log_warn_fn: mlflow.sklearn.autolog(disable_for_unsupported_versions=True) log_warn_fn.assert_not_called() with mock.patch(log_warn_fn_name) as log_warn_fn: mlflow.sklearn.autolog(disable_for_unsupported_versions=False) assert log_warn_fn.call_count == 1 and is_sklearn_warning_fired(log_warn_fn.call_args)