def patch_class_tree(klass): """ Patches all subclasses that override any auto-loggable method via monkey patching using the gorilla package, taking the argument as the tree root in the class hierarchy. Every auto-loggable method found in any of the subclasses is replaced by the patched version. :param klass: root in the class hierarchy to be analyzed and patched recursively """ # TODO: add more autologgable methods here (e.g. fit_regularized, from_formula, etc) # See https://www.statsmodels.org/dev/api.html autolog_supported_func = {"fit": wrapper_fit} glob_subclasses = set(find_subclasses(klass)) # Create a patch for every method that needs to be patched, i.e. those # which actually override an autologgable method patches_list = [ # Link the patched function with the original via a local variable in the closure # to allow invoking superclass methods in the context of the subclass, and not # losing the trace of the true original method (clazz, method_name, wrapper_func) for clazz in glob_subclasses for (method_name, wrapper_func) in autolog_supported_func.items() if overrides(clazz, method_name) ] for clazz, method_name, patch_impl in patches_list: safe_patch(FLAVOR_NAME, clazz, method_name, patch_impl, manage_run=True)
def autolog2(disable=False, silent=False): warnings.warn_explicit("warn_autolog2", category=Warning, filename=mlflow.__file__, lineno=5) logger.info("event_autolog2") safe_patch("integration2", patch_destination, "fn2", patch_impl2)
def test_safe_patch_manages_run_if_specified_and_sets_expected_run_tags( patch_destination, test_autologging_integration ): client = MlflowClient() active_run = None def patch_impl(original, *args, **kwargs): nonlocal active_run active_run = mlflow.active_run() return original(*args, **kwargs) with mock.patch( "mlflow.utils.autologging_utils.safety.with_managed_run", wraps=with_managed_run ) as managed_run_mock: safe_patch( test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True ) patch_destination.fn() assert managed_run_mock.call_count == 1 assert active_run is not None assert active_run.info.run_id is not None assert ( client.get_run(active_run.info.run_id).data.tags[MLFLOW_AUTOLOGGING] == "test_integration" )
def test_safe_patch_makes_expected_event_logging_calls_for_successful_patch_invocation( patch_destination, test_autologging_integration, mock_event_logger, ): patch_session = None og_call_kwargs = {} def patch_impl(original, *args, **kwargs): nonlocal og_call_kwargs kwargs.update({"extra_func": exception_safe_function(lambda k: "foo")}) og_call_kwargs = kwargs nonlocal patch_session patch_session = _AutologgingSessionManager.active_session() original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn("a", 1, b=2) expected_order = ["patch_start", "original_start", "original_success", "patch_success"] assert [call.method for call in mock_event_logger.calls] == expected_order assert all([call.session == patch_session for call in mock_event_logger.calls]) assert all([call.patch_obj == patch_destination for call in mock_event_logger.calls]) assert all([call.function_name == "fn" for call in mock_event_logger.calls]) patch_start, original_start, original_success, patch_success = mock_event_logger.calls assert patch_start.call_args == patch_success.call_args == ("a", 1) assert patch_start.call_kwargs == patch_success.call_kwargs == {"b": 2} assert original_start.call_args == original_success.call_args == ("a", 1) assert original_start.call_kwargs == original_success.call_kwargs == og_call_kwargs assert patch_start.exception is original_start.exception is None assert patch_success.exception is original_success.exception is None
def autolog(disable=False): def patch_impl(original, *args, **kwargs): nonlocal patch_impl_call_count patch_impl_call_count += 1 return original(*args, **kwargs) safe_patch("test_respects_disable", patch_destination, "fn", patch_impl)
def test_safe_patch_validates_autologging_runs_when_necessary_in_test_mode( patch_destination, test_autologging_integration ): assert autologging_utils.is_testing() def no_tag_run_patch_impl(original, *args, **kwargs): with mlflow.start_run(nested=True): return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl) with mock.patch( "mlflow.utils.autologging_utils.safety._validate_autologging_run", wraps=_validate_autologging_run, ) as validate_run_mock: with pytest.raises( AssertionError, match="failed to set autologging tag with expected value" ): patch_destination.fn() assert validate_run_mock.call_count == 1 validate_run_mock.reset_mock() with mlflow.start_run(nested=True): # If a user-generated run existed prior to the autologged training session, we expect # that safe patch will not attempt to validate it patch_destination.fn() assert not validate_run_mock.called
def test_safe_patch_forwards_expected_arguments_to_class_based_patch( patch_destination, test_autologging_integration ): foo_val = None bar_val = None class TestPatch(PatchFunction): def _patch_implementation(self, original, foo, bar=10): # pylint: disable=arguments-differ nonlocal foo_val nonlocal bar_val foo_val = foo bar_val = bar def _on_exception(self, exception): pass safe_patch(test_autologging_integration, patch_destination, "fn", TestPatch) with mock.patch( "mlflow.utils.autologging_utils.PatchFunction.call", wraps=TestPatch.call ) as call_mock: patch_destination.fn(foo=7, bar=11) assert call_mock.call_count == 1 assert foo_val == 7 assert bar_val == 11
def test_safe_patch_propagates_exceptions_raised_from_original_function( patch_destination, test_autologging_integration ): exc_to_throw = Exception("Bad original function") def original(*args, **kwargs): raise exc_to_throw patch_destination.fn = original patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception) as exc: patch_destination.fn() assert exc.value == exc_to_throw assert patch_impl_called
def test_session_manager_exits_session_after_patch_executes( patch_destination, test_autologging_integration ): def patch_fn(original): assert _AutologgingSessionManager.active_session() is not None safe_patch(test_autologging_integration, patch_destination, "fn", patch_fn) patch_destination.fn() assert _AutologgingSessionManager.active_session() is None
def autolog(disable=False, exclusive=False, silent=False): def patch_impl(original, *args, **kwargs): nonlocal patch_impl_call_count patch_impl_call_count += 1 def new_fn_patch(original, *args, **kwargs): pass safe_patch("test_respects_exclusive", patch_obj, "fn", patch_impl) safe_patch("test_respects_exclusive", patch_obj, "new_fn", new_fn_patch)
def autolog( log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None, ): # pylint: disable=unused-argument """ Enables (or disables) and configures autologging from Gluon to MLflow. Logs loss and any other metrics specified in the fit function, and optimizer data as parameters. Model checkpoints are logged as artifacts to a 'models' directory. :param log_models: If ``True``, trained models are logged as MLflow model artifacts. If ``False``, trained models are not logged. :param disable: If ``True``, disables the MXNet Gluon autologging integration. If ``False``, enables the MXNet Gluon autologging integration. :param exclusive: If ``True``, autologged content is not logged to user-created fluent runs. If ``False``, autologged content is logged to the active fluent run, which may be user-created. :param disable_for_unsupported_versions: If ``True``, disable autologging for versions of gluon that have not been tested against this version of the MLflow client or are incompatible. :param silent: If ``True``, suppress all event logs and warnings from MLflow during MXNet Gluon autologging. If ``False``, show all events and warnings during MXNet Gluon autologging. :param registered_model_name: If given, each time a model is trained, it is registered as a new model version of the registered model with this name. The registered model is created if it does not already exist. """ from mxnet.gluon.contrib.estimator import Estimator from mlflow.gluon._autolog import __MLflowGluonCallback def getGluonCallback(metrics_logger): return __MLflowGluonCallback(log_models, metrics_logger) def fit(original, self, *args, **kwargs): # Wrap `fit` execution within a batch metrics logger context. run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: mlflowGluonCallback = getGluonCallback(metrics_logger) if len(args) >= 4: args = (*args[:3], args[3] + [mlflowGluonCallback], *args[4:]) elif "event_handlers" in kwargs: kwargs["event_handlers"] += [mlflowGluonCallback] else: kwargs["event_handlers"] = [mlflowGluonCallback] result = original(self, *args, **kwargs) return result safe_patch(FLAVOR_NAME, Estimator, "fit", fit, manage_run=True)
def test_autolog(disable=False, silent=False): eprint("enablement1") logger.info("enablement2") logger.warning("enablement3") logger.critical("enablement4") warnings.warn_explicit( "enablement warning MLflow", category=Warning, filename=mlflow.__file__, lineno=15 ) warnings.warn_explicit( "enablement warning numpy", category=Warning, filename=np.__file__, lineno=30 ) safe_patch("test_integration", patch_destination, "fn", patch_impl)
def test_session_manager_creates_session_before_patch_executes( patch_destination, test_autologging_integration ): is_session_active = None def check_session_manager_status(original): nonlocal is_session_active is_session_active = _AutologgingSessionManager.active_session() safe_patch(test_autologging_integration, patch_destination, "fn", check_session_manager_status) patch_destination.fn() assert is_session_active is not None
def test_safe_patch_validates_arguments_to_original_function_in_test_mode( patch_destination, test_autologging_integration ): def patch_impl(original, *args, **kwargs): return original("1", "2", "3") safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception, match="does not match expected input"), mock.patch( "mlflow.utils.autologging_utils._validate_args", wraps=autologging_utils._validate_args ) as validate_mock: patch_destination.fn("a", "b", "c") assert validate_mock.call_count == 1
def test_safe_patch_propagates_exceptions_raised_outside_of_original_function_in_test_mode( patch_destination, test_autologging_integration ): exc_to_throw = Exception("Bad patch implementation") def patch_impl(original, *args, **kwargs): raise exc_to_throw safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception) as exc: patch_destination.fn() assert exc.value == exc_to_throw
def test_safe_patch_calls_original_function_when_patch_preamble_throws( patch_destination, test_autologging_integration ): patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True raise Exception("Bad patch preamble") safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert patch_destination.fn_call_count == 1 assert patch_impl_called
def test_safe_patch_returns_original_result_and_ignores_patch_return_value( patch_destination, test_autologging_integration ): patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True return 10 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert patch_destination.fn_call_count == 1 assert patch_impl_called
def test_safe_patch_does_not_throw_when_autologging_runs_are_leaked_in_standard_mode( patch_destination, test_autologging_integration ): assert not autologging_utils.is_testing() def leak_run_patch_impl(original, *args, **kwargs): mlflow.start_run(nested=True) safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl) patch_destination.fn() assert mlflow.active_run() # End the leaked run mlflow.end_run() assert not mlflow.active_run()
def test_safe_patch_provides_expected_original_function( patch_destination, test_autologging_integration ): def original_fn(foo, bar=10): return { "foo": foo, "bar": bar, } patch_destination.fn = original_fn def patch_impl(original, foo, bar): return original(foo + 1, bar + 2) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn(1, 2) == {"foo": 2, "bar": 4}
def test_safe_patch_returns_original_result_without_second_call_when_patch_postamble_throws( patch_destination, test_autologging_integration ): patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True original(*args, **kwargs) raise Exception("Bad patch postamble") safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert patch_destination.fn_call_count == 1 assert patch_impl_called
def test_safe_patch_forwards_expected_arguments_to_function_based_patch_implementation( patch_destination, test_autologging_integration ): foo_val = None bar_val = None def patch_impl(original, foo, bar=10): nonlocal foo_val nonlocal bar_val foo_val = foo bar_val = bar safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn(foo=7, bar=11) assert foo_val == 7 assert bar_val == 11
def test_safe_patch_logs_exceptions_raised_outside_of_original_function_as_warnings( patch_destination, test_autologging_integration ): exc_to_throw = Exception("Bad patch implementation") def patch_impl(original, *args, **kwargs): raise exc_to_throw safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert logger_mock.call_count == 1 message, formatting_arg1, formatting_arg2 = logger_mock.call_args[0] assert "Encountered unexpected error" in message assert formatting_arg1 == test_autologging_integration assert formatting_arg2 == exc_to_throw
def test_safe_patch_provides_original_function_with_expected_signature( patch_destination, test_autologging_integration ): def original(a, b, c=10, *, d=11): return 10 patch_destination.fn = original original_signature = False def patch_impl(original, *args, **kwargs): nonlocal original_signature original_signature = inspect.signature(original) return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn(1, 2) assert original_signature == inspect.signature(original)
def test_safe_patch_preserves_signature_of_patched_function( patch_destination, test_autologging_integration ): def original(a, b, c=10, *, d=11): return 10 patch_destination.fn = original patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn(1, 2) assert patch_impl_called assert inspect.signature(patch_destination.fn) == inspect.signature(original)
def test_safe_patch_manages_run_if_specified(patch_destination, test_autologging_integration): active_run = None def patch_impl(original, *args, **kwargs): nonlocal active_run active_run = mlflow.active_run() return original(*args, **kwargs) with mock.patch( "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run ) as managed_run_mock: safe_patch( test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True ) patch_destination.fn() assert managed_run_mock.call_count == 1 assert active_run is not None assert active_run.info.run_id is not None
def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws( patch_destination, test_autologging_integration, mock_event_logger, ): patch_session = None exc_to_raise = Exception("thrown from patch") def patch_impl(original, *args, **kwargs): nonlocal patch_session patch_session = _AutologgingSessionManager.active_session() if throw_location == "before": raise exc_to_raise original(*args, **kwargs) if throw_location != "before": raise exc_to_raise safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) throw_location = "before" patch_destination.fn() expected_order_throw_before = ["patch_start", "patch_error"] assert [call.method for call in mock_event_logger.calls] == expected_order_throw_before patch_start, patch_error = mock_event_logger.calls assert patch_start.exception is None assert patch_error.exception == exc_to_raise mock_event_logger.reset() throw_location = "after" patch_destination.fn() expected_order_throw_after = [ "patch_start", "original_start", "original_success", "patch_error", ] assert [call.method for call in mock_event_logger.calls] == expected_order_throw_after patch_start, original_start, original_success, patch_error = mock_event_logger.calls assert patch_start.exception is original_start.exception is None assert original_success.exception is None assert patch_error.exception == exc_to_raise
def test_safe_patch_makes_expected_event_logging_calls_when_original_function_throws( patch_destination, test_autologging_integration, mock_event_logger, ): exc_to_raise = Exception("thrown from patch") def original(*args, **kwargs): raise exc_to_raise patch_destination.fn = original def patch_impl(original, *args, **kwargs): original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception, match="thrown from patch"): patch_destination.fn() expected_order = ["patch_start", "original_start", "original_error"] assert [call.method for call in mock_event_logger.calls] == expected_order patch_start, original_start, original_error = mock_event_logger.calls assert patch_start.exception is original_start.exception is None assert original_error.exception == exc_to_raise
def test_safe_patch_does_not_validate_autologging_runs_in_standard_mode( patch_destination, test_autologging_integration): assert not autologging_utils._is_testing() def no_tag_run_patch_impl(original, *args, **kwargs): with mlflow.start_run(nested=True): return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl) with mock.patch("mlflow.utils.autologging_utils._validate_autologging_run", wraps=_validate_autologging_run) as validate_run_mock: patch_destination.fn() with mlflow.start_run(nested=True): # If a user-generated run existed prior to the autologged training session, we expect # that safe patch will not attempt to validate it patch_destination.fn() assert not validate_run_mock.called
def test_safe_patch_throws_when_autologging_runs_are_leaked_in_test_mode( patch_destination, test_autologging_integration ): assert autologging_utils.is_testing() def leak_run_patch_impl(original, *args, **kwargs): mlflow.start_run(nested=True) safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl) with pytest.raises(AssertionError, match="leaked an active run"): patch_destination.fn() # End the leaked run mlflow.end_run() with mlflow.start_run(): # If a user-generated run existed prior to the autologged training session, we expect # that safe patch will not throw a leaked run exception patch_destination.fn() # End the leaked nested run mlflow.end_run() assert not mlflow.active_run()
def test_safe_patch_provides_expected_original_function_to_class_based_patch( patch_destination, test_autologging_integration): def original_fn(foo, bar=10): return { "foo": foo, "bar": bar, } patch_destination.fn = original_fn class TestPatch(PatchFunction): def _patch_implementation(self, original, foo, bar=10): # pylint: disable=arguments-differ return original(foo + 1, bar + 2) def _on_exception(self, exception): pass safe_patch(test_autologging_integration, patch_destination, "fn", TestPatch) with mock.patch("mlflow.utils.autologging_utils.PatchFunction.call", wraps=TestPatch.call) as call_mock: assert patch_destination.fn(1, 2) == {"foo": 2, "bar": 4} assert call_mock.call_count == 1