def test_silent_mode_and_warning_rerouting_respect_disabled_flag( patch_destination, silent, disable ): stream = StringIO() sys.stderr = stream def original_fn(): warnings.warn("Test warning", category=UserWarning) patch_destination.fn = original_fn @autologging_integration("test_integration") def test_autolog(disable=False, silent=False): safe_patch("test_integration", patch_destination, "fn", lambda original: original()) test_autolog(disable=disable, silent=silent) with warnings.catch_warnings(record=True) as warnings_record: patch_destination.fn() # Verify that calling the patched instance method still emits the expected warning assert len(warnings_record) == 1 assert warnings_record[0].message.args[0] == "Test warning" assert warnings_record[0].category == UserWarning # Verify that nothing is printed to the stderr-backed MLflow event logger, which would indicate # rerouting of warning content assert not stream.getvalue()
def parallel_fn(): # Sleep for a random interval to increase the likelihood of overlapping session stages # (i.e. simultaneous preamble / postamble / original function execution states across # autologging sessions) time.sleep(np.random.random()) patch_destination.fn() return True
def test_safe_patch_makes_expected_event_logging_calls_for_successful_patch_invocation( patch_destination, test_autologging_integration, mock_event_logger, ): patch_session = None og_call_kwargs = {} def patch_impl(original, *args, **kwargs): nonlocal og_call_kwargs kwargs.update({"extra_func": exception_safe_function(lambda k: "foo")}) og_call_kwargs = kwargs nonlocal patch_session patch_session = _AutologgingSessionManager.active_session() original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn("a", 1, b=2) expected_order = ["patch_start", "original_start", "original_success", "patch_success"] assert [call.method for call in mock_event_logger.calls] == expected_order assert all([call.session == patch_session for call in mock_event_logger.calls]) assert all([call.patch_obj == patch_destination for call in mock_event_logger.calls]) assert all([call.function_name == "fn" for call in mock_event_logger.calls]) patch_start, original_start, original_success, patch_success = mock_event_logger.calls assert patch_start.call_args == patch_success.call_args == ("a", 1) assert patch_start.call_kwargs == patch_success.call_kwargs == {"b": 2} assert original_start.call_args == original_success.call_args == ("a", 1) assert original_start.call_kwargs == original_success.call_kwargs == og_call_kwargs assert patch_start.exception is original_start.exception is None assert patch_success.exception is original_success.exception is None
def test_safe_patch_validates_autologging_runs_when_necessary_in_test_mode( patch_destination, test_autologging_integration ): assert autologging_utils.is_testing() def no_tag_run_patch_impl(original, *args, **kwargs): with mlflow.start_run(nested=True): return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl) with mock.patch( "mlflow.utils.autologging_utils.safety._validate_autologging_run", wraps=_validate_autologging_run, ) as validate_run_mock: with pytest.raises( AssertionError, match="failed to set autologging tag with expected value" ): patch_destination.fn() assert validate_run_mock.call_count == 1 validate_run_mock.reset_mock() with mlflow.start_run(nested=True): # If a user-generated run existed prior to the autologged training session, we expect # that safe patch will not attempt to validate it patch_destination.fn() assert not validate_run_mock.called
def test_safe_patch_manages_run_if_specified_and_sets_expected_run_tags( patch_destination, test_autologging_integration ): client = MlflowClient() active_run = None def patch_impl(original, *args, **kwargs): nonlocal active_run active_run = mlflow.active_run() return original(*args, **kwargs) with mock.patch( "mlflow.utils.autologging_utils.safety.with_managed_run", wraps=with_managed_run ) as managed_run_mock: safe_patch( test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True ) patch_destination.fn() assert managed_run_mock.call_count == 1 assert active_run is not None assert active_run.info.run_id is not None assert ( client.get_run(active_run.info.run_id).data.tags[MLFLOW_AUTOLOGGING] == "test_integration" )
def test_safe_patch_propagates_exceptions_raised_from_original_function( patch_destination, test_autologging_integration ): exc_to_throw = Exception("Bad original function") def original(*args, **kwargs): raise exc_to_throw patch_destination.fn = original patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception) as exc: patch_destination.fn() assert exc.value == exc_to_throw assert patch_impl_called
def test_safe_patch_forwards_expected_arguments_to_class_based_patch( patch_destination, test_autologging_integration ): foo_val = None bar_val = None class TestPatch(PatchFunction): def _patch_implementation(self, original, foo, bar=10): # pylint: disable=arguments-differ nonlocal foo_val nonlocal bar_val foo_val = foo bar_val = bar def _on_exception(self, exception): pass safe_patch(test_autologging_integration, patch_destination, "fn", TestPatch) with mock.patch( "mlflow.utils.autologging_utils.PatchFunction.call", wraps=TestPatch.call ) as call_mock: patch_destination.fn(foo=7, bar=11) assert call_mock.call_count == 1 assert foo_val == 7 assert bar_val == 11
def test_session_manager_exits_session_after_patch_executes( patch_destination, test_autologging_integration ): def patch_fn(original): assert _AutologgingSessionManager.active_session() is not None safe_patch(test_autologging_integration, patch_destination, "fn", patch_fn) patch_destination.fn() assert _AutologgingSessionManager.active_session() is None
def test_session_manager_creates_session_before_patch_executes( patch_destination, test_autologging_integration ): is_session_active = None def check_session_manager_status(original): nonlocal is_session_active is_session_active = _AutologgingSessionManager.active_session() safe_patch(test_autologging_integration, patch_destination, "fn", check_session_manager_status) patch_destination.fn() assert is_session_active is not None
def test_session_manager_exits_session_if_error_in_patch( patch_destination, test_autologging_integration ): def patch_fn(original): raise Exception("Exception that should stop autologging session") # If use safe_patch to patch, exception would not come from original fn and so would be logged patch_destination.fn = patch_fn with pytest.raises(Exception): patch_destination.fn() assert _AutologgingSessionManager.active_session() is None
def test_safe_patch_propagates_exceptions_raised_outside_of_original_function_in_test_mode( patch_destination, test_autologging_integration ): exc_to_throw = Exception("Bad patch implementation") def patch_impl(original, *args, **kwargs): raise exc_to_throw safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception) as exc: patch_destination.fn() assert exc.value == exc_to_throw
def test_safe_patch_validates_arguments_to_original_function_in_test_mode( patch_destination, test_autologging_integration ): def patch_impl(original, *args, **kwargs): return original("1", "2", "3") safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception, match="does not match expected input"), mock.patch( "mlflow.utils.autologging_utils.safety._validate_args", wraps=autologging_utils.safety._validate_args, ) as validate_mock: patch_destination.fn("a", "b", "c") assert validate_mock.call_count == 1
def test_safe_patch_does_not_throw_when_autologging_runs_are_leaked_in_standard_mode( patch_destination, test_autologging_integration ): assert not autologging_utils.is_testing() def leak_run_patch_impl(original, *args, **kwargs): mlflow.start_run(nested=True) safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl) patch_destination.fn() assert mlflow.active_run() # End the leaked run mlflow.end_run() assert not mlflow.active_run()
def test_safe_patch_forwards_expected_arguments_to_function_based_patch_implementation( patch_destination, test_autologging_integration ): foo_val = None bar_val = None def patch_impl(original, foo, bar=10): nonlocal foo_val nonlocal bar_val foo_val = foo bar_val = bar safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn(foo=7, bar=11) assert foo_val == 7 assert bar_val == 11
def test_original_fn_runs_if_patch_should_not_be_applied(patch_destination): patch_impl_call_count = 0 @autologging_integration("test_respects_exclusive") def autolog(disable=False, exclusive=False, silent=False): def patch_impl(original, *args, **kwargs): nonlocal patch_impl_call_count patch_impl_call_count += 1 return original(*args, **kwargs) safe_patch("test_respects_exclusive", patch_destination, "fn", patch_impl) autolog(exclusive=True) with mlflow.start_run(): patch_destination.fn() assert patch_impl_call_count == 0 assert patch_destination.fn_call_count == 1
def test_safe_patch_provides_original_function_with_expected_signature( patch_destination, test_autologging_integration ): def original(a, b, c=10, *, d=11): return 10 patch_destination.fn = original original_signature = False def patch_impl(original, *args, **kwargs): nonlocal original_signature original_signature = inspect.signature(original) return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn(1, 2) assert original_signature == inspect.signature(original)
def test_silent_mode_operates_independently_across_integrations( patch_destination, logger): stream = StringIO() sys.stderr = stream patch_destination.fn2 = lambda *args, **kwargs: "fn2" def patch_impl1(original): warnings.warn("patchimpl1") original() @autologging_integration("integration1") def autolog1(disable=False, silent=False): logger.info("autolog1") safe_patch("integration1", patch_destination, "fn", patch_impl1) def patch_impl2(original): logger.info("patchimpl2") original() @autologging_integration("integration2") def autolog2(disable=False, silent=False): warnings.warn_explicit("warn_autolog2", category=Warning, filename=mlflow.__file__, lineno=5) logger.info("event_autolog2") safe_patch("integration2", patch_destination, "fn2", patch_impl2) with pytest.warns(None) as warnings_record: autolog1(silent=True) autolog2(silent=False) patch_destination.fn() patch_destination.fn2() warning_messages = [str(w.message) for w in warnings_record] assert warning_messages == ["warn_autolog2"] assert "autolog1" not in stream.getvalue() assert "patchimpl1" not in stream.getvalue() assert "event_autolog2" in stream.getvalue() assert "patchimpl2" in stream.getvalue()
def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws( patch_destination, test_autologging_integration, mock_event_logger, ): patch_session = None exc_to_raise = Exception("thrown from patch") def patch_impl(original, *args, **kwargs): nonlocal patch_session patch_session = _AutologgingSessionManager.active_session() if throw_location == "before": raise exc_to_raise original(*args, **kwargs) if throw_location != "before": raise exc_to_raise safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) throw_location = "before" patch_destination.fn() expected_order_throw_before = ["patch_start", "patch_error"] assert [call.method for call in mock_event_logger.calls] == expected_order_throw_before patch_start, patch_error = mock_event_logger.calls assert patch_start.exception is None assert patch_error.exception == exc_to_raise mock_event_logger.reset() throw_location = "after" patch_destination.fn() expected_order_throw_after = [ "patch_start", "original_start", "original_success", "patch_error", ] assert [call.method for call in mock_event_logger.calls] == expected_order_throw_after patch_start, original_start, original_success, patch_error = mock_event_logger.calls assert patch_start.exception is original_start.exception is None assert original_success.exception is None assert patch_error.exception == exc_to_raise
def test_autologging_event_logging_and_warnings_respect_silent_mode( autolog_function, patch_destination, logger ): og_showwarning = warnings.showwarning stream = StringIO() sys.stderr = stream with pytest.warns(None) as silent_warnings_record: autolog_function(silent=True) patch_destination.fn() assert len(silent_warnings_record) == 1 assert "Test warning from OG function" in str(silent_warnings_record[0].message) assert not stream.getvalue() # Verify that `warnings.showwarning` was restored to its original value after training # and that MLflow event logs are enabled assert warnings.showwarning == og_showwarning logger.info("verify that event logs are enabled") assert "verify that event logs are enabled" in stream.getvalue() stream.truncate(0) with pytest.warns(None) as noisy_warnings_record: autolog_function(silent=False) patch_destination.fn() # Verify that calling the autolog function with `silent=False` and invoking the mock training # function with autolog disabled produces event logs and warnings for item in ["enablement1", "enablement2", "enablement3", "enablement4"]: assert item in stream.getvalue() for item in ["patch1", "patch2", "patch3", "patch4"]: assert item in stream.getvalue() warning_messages = set([str(w.message) for w in noisy_warnings_record]) assert "enablement warning MLflow" in warning_messages # Verify that `warnings.showwarning` was restored to its original value after training # and that MLflow event logs are enabled assert warnings.showwarning == og_showwarning logger.info("verify that event logs are enabled") assert "verify that event logs are enabled" in stream.getvalue()
def test_safe_patch_preserves_signature_of_patched_function( patch_destination, test_autologging_integration ): def original(a, b, c=10, *, d=11): return 10 patch_destination.fn = original patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) patch_destination.fn(1, 2) assert patch_impl_called assert inspect.signature(patch_destination.fn) == inspect.signature(original)
def test_safe_patch_does_not_manage_run_if_unspecified( patch_destination, test_autologging_integration ): active_run = None def patch_impl(original, *args, **kwargs): nonlocal active_run active_run = mlflow.active_run() return original(*args, **kwargs) with mock.patch( "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run ) as managed_run_mock: safe_patch( test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=False ) patch_destination.fn() assert managed_run_mock.call_count == 0 assert active_run is None
def test_safe_patch_respects_disable_flag(patch_destination): patch_impl_call_count = 0 @autologging_integration("test_respects_disable") def autolog(disable=False, silent=False): def patch_impl(original, *args, **kwargs): nonlocal patch_impl_call_count patch_impl_call_count += 1 return original(*args, **kwargs) safe_patch("test_respects_disable", patch_destination, "fn", patch_impl) autolog(disable=False) patch_destination.fn() assert patch_impl_call_count == 1 autolog(disable=True) patch_destination.fn() assert patch_impl_call_count == 1
def test_safe_patch_makes_expected_event_logging_calls_when_original_function_throws( patch_destination, test_autologging_integration, mock_event_logger, ): exc_to_raise = Exception("thrown from patch") def original(*args, **kwargs): raise exc_to_raise patch_destination.fn = original def patch_impl(original, *args, **kwargs): original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) with pytest.raises(Exception, match="thrown from patch"): patch_destination.fn() expected_order = ["patch_start", "original_start", "original_error"] assert [call.method for call in mock_event_logger.calls] == expected_order patch_start, original_start, original_error = mock_event_logger.calls assert patch_start.exception is original_start.exception is None assert original_error.exception == exc_to_raise
def test_safe_patch_throws_when_autologging_runs_are_leaked_in_test_mode( patch_destination, test_autologging_integration ): assert autologging_utils.is_testing() def leak_run_patch_impl(original, *args, **kwargs): mlflow.start_run(nested=True) safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl) with pytest.raises(AssertionError, match="leaked an active run"): patch_destination.fn() # End the leaked run mlflow.end_run() with mlflow.start_run(): # If a user-generated run existed prior to the autologged training session, we expect # that safe patch will not throw a leaked run exception patch_destination.fn() # End the leaked nested run mlflow.end_run() assert not mlflow.active_run()
def test_autologging_warnings_are_redirected_as_expected( autolog_function, patch_destination, logger ): stream = StringIO() sys.stderr = stream with pytest.warns(None) as warnings_record: autolog_function(silent=False) patch_destination.fn() # The following types of warnings are rerouted to MLflow's event loggers: # 1. All MLflow warnings emitted during patch function execution # 2. All warnings emitted during the patch function preamble (before the execution of the # original / underlying function) and postamble (after the execution of the underlying # function) # 3. non-MLflow warnings emitted during autologging setup / enablement # # Accordingly, we expect the following warnings to have been emitted normally: 1. MLflow # warnings emitted during autologging enablement, 2. non-MLflow warnings emitted during original # / underlying function execution warning_messages = set([str(w.message) for w in warnings_record]) assert warning_messages == set(["enablement warning MLflow", "Test warning from OG function"]) # Further, We expect MLflow's logging stream to contain content from all warnings emitted during # the autologging preamble and postamble and non-MLflow warnings emitted during autologging # enablement for item in [ 'MLflow autologging encountered a warning: "%s:5: Warning: preamble MLflow warning"', 'MLflow autologging encountered a warning: "%s:10: Warning: postamble MLflow warning"', ]: assert (item % mlflow.__file__) in stream.getvalue() for item in [ 'MLflow autologging encountered a warning: "%s:7: UserWarning: preamble numpy warning"', 'MLflow autologging encountered a warning: "%s:14: Warning: postamble numpy warning"', 'MLflow autologging encountered a warning: "%s:30: Warning: enablement warning numpy"', ]: assert (item % np.__file__) in stream.getvalue()
def test_safe_patch_calls_original_function_when_patch_preamble_throws( patch_destination, test_autologging_integration ): patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True raise Exception("Bad patch preamble") safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert patch_destination.fn_call_count == 1 assert patch_impl_called
def test_safe_patch_returns_original_result_and_ignores_patch_return_value( patch_destination, test_autologging_integration ): patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True return 10 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert patch_destination.fn_call_count == 1 assert patch_impl_called
def test_safe_patch_does_not_validate_autologging_runs_in_standard_mode( patch_destination, test_autologging_integration ): assert not autologging_utils.is_testing() def no_tag_run_patch_impl(original, *args, **kwargs): with mlflow.start_run(nested=True): return original(*args, **kwargs) safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl) with mock.patch( "mlflow.utils.autologging_utils.safety._validate_autologging_run", wraps=_validate_autologging_run, ) as validate_run_mock: patch_destination.fn() with mlflow.start_run(nested=True): # If a user-generated run existed prior to the autologged training session, we expect # that safe patch will not attempt to validate it patch_destination.fn() assert not validate_run_mock.called
def test_safe_patch_returns_original_result_without_second_call_when_patch_postamble_throws( patch_destination, test_autologging_integration ): patch_impl_called = False def patch_impl(original, *args, **kwargs): nonlocal patch_impl_called patch_impl_called = True original(*args, **kwargs) raise Exception("Bad patch postamble") safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT assert patch_destination.fn_call_count == 1 assert patch_impl_called
def test_safe_patch_provides_expected_original_function( patch_destination, test_autologging_integration ): def original_fn(foo, bar=10): return { "foo": foo, "bar": bar, } patch_destination.fn = original_fn def patch_impl(original, foo, bar): return original(foo + 1, bar + 2) safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) assert patch_destination.fn(1, 2) == {"foo": 2, "bar": 4}