def test_silent_mode_and_warning_rerouting_respect_disabled_flag(
    patch_destination, silent, disable
):
    stream = StringIO()
    sys.stderr = stream

    def original_fn():
        warnings.warn("Test warning", category=UserWarning)

    patch_destination.fn = original_fn

    @autologging_integration("test_integration")
    def test_autolog(disable=False, silent=False):
        safe_patch("test_integration", patch_destination, "fn", lambda original: original())

    test_autolog(disable=disable, silent=silent)

    with warnings.catch_warnings(record=True) as warnings_record:
        patch_destination.fn()

    # Verify that calling the patched instance method still emits the expected warning
    assert len(warnings_record) == 1
    assert warnings_record[0].message.args[0] == "Test warning"
    assert warnings_record[0].category == UserWarning

    # Verify that nothing is printed to the stderr-backed MLflow event logger, which would indicate
    # rerouting of warning content
    assert not stream.getvalue()
 def parallel_fn():
     # Sleep for a random interval to increase the likelihood of overlapping session stages
     # (i.e. simultaneous preamble / postamble / original function execution states across
     # autologging sessions)
     time.sleep(np.random.random())
     patch_destination.fn()
     return True
def test_safe_patch_makes_expected_event_logging_calls_for_successful_patch_invocation(
    patch_destination, test_autologging_integration, mock_event_logger,
):
    patch_session = None
    og_call_kwargs = {}

    def patch_impl(original, *args, **kwargs):
        nonlocal og_call_kwargs
        kwargs.update({"extra_func": exception_safe_function(lambda k: "foo")})
        og_call_kwargs = kwargs

        nonlocal patch_session
        patch_session = _AutologgingSessionManager.active_session()

        original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

    patch_destination.fn("a", 1, b=2)
    expected_order = ["patch_start", "original_start", "original_success", "patch_success"]
    assert [call.method for call in mock_event_logger.calls] == expected_order
    assert all([call.session == patch_session for call in mock_event_logger.calls])
    assert all([call.patch_obj == patch_destination for call in mock_event_logger.calls])
    assert all([call.function_name == "fn" for call in mock_event_logger.calls])
    patch_start, original_start, original_success, patch_success = mock_event_logger.calls
    assert patch_start.call_args == patch_success.call_args == ("a", 1)
    assert patch_start.call_kwargs == patch_success.call_kwargs == {"b": 2}
    assert original_start.call_args == original_success.call_args == ("a", 1)
    assert original_start.call_kwargs == original_success.call_kwargs == og_call_kwargs
    assert patch_start.exception is original_start.exception is None
    assert patch_success.exception is original_success.exception is None
def test_safe_patch_validates_autologging_runs_when_necessary_in_test_mode(
    patch_destination, test_autologging_integration
):
    assert autologging_utils.is_testing()

    def no_tag_run_patch_impl(original, *args, **kwargs):
        with mlflow.start_run(nested=True):
            return original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl)

    with mock.patch(
        "mlflow.utils.autologging_utils.safety._validate_autologging_run",
        wraps=_validate_autologging_run,
    ) as validate_run_mock:

        with pytest.raises(
            AssertionError, match="failed to set autologging tag with expected value"
        ):
            patch_destination.fn()
            assert validate_run_mock.call_count == 1

        validate_run_mock.reset_mock()

        with mlflow.start_run(nested=True):
            # If a user-generated run existed prior to the autologged training session, we expect
            # that safe patch will not attempt to validate it
            patch_destination.fn()
            assert not validate_run_mock.called
def test_safe_patch_manages_run_if_specified_and_sets_expected_run_tags(
    patch_destination, test_autologging_integration
):
    client = MlflowClient()
    active_run = None

    def patch_impl(original, *args, **kwargs):
        nonlocal active_run
        active_run = mlflow.active_run()
        return original(*args, **kwargs)

    with mock.patch(
        "mlflow.utils.autologging_utils.safety.with_managed_run", wraps=with_managed_run
    ) as managed_run_mock:
        safe_patch(
            test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True
        )
        patch_destination.fn()
        assert managed_run_mock.call_count == 1
        assert active_run is not None
        assert active_run.info.run_id is not None
        assert (
            client.get_run(active_run.info.run_id).data.tags[MLFLOW_AUTOLOGGING]
            == "test_integration"
        )
def test_safe_patch_propagates_exceptions_raised_from_original_function(
    patch_destination, test_autologging_integration
):

    exc_to_throw = Exception("Bad original function")

    def original(*args, **kwargs):
        raise exc_to_throw

    patch_destination.fn = original

    patch_impl_called = False

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_impl_called
        patch_impl_called = True
        return original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

    with pytest.raises(Exception) as exc:
        patch_destination.fn()

    assert exc.value == exc_to_throw
    assert patch_impl_called
def test_safe_patch_forwards_expected_arguments_to_class_based_patch(
    patch_destination, test_autologging_integration
):

    foo_val = None
    bar_val = None

    class TestPatch(PatchFunction):
        def _patch_implementation(self, original, foo, bar=10):  # pylint: disable=arguments-differ
            nonlocal foo_val
            nonlocal bar_val
            foo_val = foo
            bar_val = bar

        def _on_exception(self, exception):
            pass

    safe_patch(test_autologging_integration, patch_destination, "fn", TestPatch)
    with mock.patch(
        "mlflow.utils.autologging_utils.PatchFunction.call", wraps=TestPatch.call
    ) as call_mock:
        patch_destination.fn(foo=7, bar=11)
        assert call_mock.call_count == 1
        assert foo_val == 7
        assert bar_val == 11
def test_session_manager_exits_session_after_patch_executes(
    patch_destination, test_autologging_integration
):
    def patch_fn(original):
        assert _AutologgingSessionManager.active_session() is not None

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_fn)
    patch_destination.fn()
    assert _AutologgingSessionManager.active_session() is None
def test_session_manager_creates_session_before_patch_executes(
    patch_destination, test_autologging_integration
):
    is_session_active = None

    def check_session_manager_status(original):
        nonlocal is_session_active
        is_session_active = _AutologgingSessionManager.active_session()

    safe_patch(test_autologging_integration, patch_destination, "fn", check_session_manager_status)
    patch_destination.fn()
    assert is_session_active is not None
def test_session_manager_exits_session_if_error_in_patch(
    patch_destination, test_autologging_integration
):
    def patch_fn(original):
        raise Exception("Exception that should stop autologging session")

    # If use safe_patch to patch, exception would not come from original fn and so would be logged
    patch_destination.fn = patch_fn
    with pytest.raises(Exception):
        patch_destination.fn()

    assert _AutologgingSessionManager.active_session() is None
def test_safe_patch_propagates_exceptions_raised_outside_of_original_function_in_test_mode(
    patch_destination, test_autologging_integration
):

    exc_to_throw = Exception("Bad patch implementation")

    def patch_impl(original, *args, **kwargs):
        raise exc_to_throw

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    with pytest.raises(Exception) as exc:
        patch_destination.fn()

    assert exc.value == exc_to_throw
def test_safe_patch_validates_arguments_to_original_function_in_test_mode(
    patch_destination, test_autologging_integration
):
    def patch_impl(original, *args, **kwargs):
        return original("1", "2", "3")

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

    with pytest.raises(Exception, match="does not match expected input"), mock.patch(
        "mlflow.utils.autologging_utils.safety._validate_args",
        wraps=autologging_utils.safety._validate_args,
    ) as validate_mock:
        patch_destination.fn("a", "b", "c")

    assert validate_mock.call_count == 1
def test_safe_patch_does_not_throw_when_autologging_runs_are_leaked_in_standard_mode(
    patch_destination, test_autologging_integration
):
    assert not autologging_utils.is_testing()

    def leak_run_patch_impl(original, *args, **kwargs):
        mlflow.start_run(nested=True)

    safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl)
    patch_destination.fn()
    assert mlflow.active_run()

    # End the leaked run
    mlflow.end_run()

    assert not mlflow.active_run()
def test_safe_patch_forwards_expected_arguments_to_function_based_patch_implementation(
    patch_destination, test_autologging_integration
):

    foo_val = None
    bar_val = None

    def patch_impl(original, foo, bar=10):
        nonlocal foo_val
        nonlocal bar_val
        foo_val = foo
        bar_val = bar

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    patch_destination.fn(foo=7, bar=11)
    assert foo_val == 7
    assert bar_val == 11
def test_original_fn_runs_if_patch_should_not_be_applied(patch_destination):
    patch_impl_call_count = 0

    @autologging_integration("test_respects_exclusive")
    def autolog(disable=False, exclusive=False, silent=False):
        def patch_impl(original, *args, **kwargs):
            nonlocal patch_impl_call_count
            patch_impl_call_count += 1
            return original(*args, **kwargs)

        safe_patch("test_respects_exclusive", patch_destination, "fn", patch_impl)

    autolog(exclusive=True)
    with mlflow.start_run():
        patch_destination.fn()
    assert patch_impl_call_count == 0
    assert patch_destination.fn_call_count == 1
def test_safe_patch_provides_original_function_with_expected_signature(
    patch_destination, test_autologging_integration
):
    def original(a, b, c=10, *, d=11):
        return 10

    patch_destination.fn = original

    original_signature = False

    def patch_impl(original, *args, **kwargs):
        nonlocal original_signature
        original_signature = inspect.signature(original)
        return original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    patch_destination.fn(1, 2)
    assert original_signature == inspect.signature(original)
Example #17
0
def test_silent_mode_operates_independently_across_integrations(
        patch_destination, logger):
    stream = StringIO()
    sys.stderr = stream

    patch_destination.fn2 = lambda *args, **kwargs: "fn2"

    def patch_impl1(original):
        warnings.warn("patchimpl1")
        original()

    @autologging_integration("integration1")
    def autolog1(disable=False, silent=False):
        logger.info("autolog1")
        safe_patch("integration1", patch_destination, "fn", patch_impl1)

    def patch_impl2(original):
        logger.info("patchimpl2")
        original()

    @autologging_integration("integration2")
    def autolog2(disable=False, silent=False):
        warnings.warn_explicit("warn_autolog2",
                               category=Warning,
                               filename=mlflow.__file__,
                               lineno=5)
        logger.info("event_autolog2")
        safe_patch("integration2", patch_destination, "fn2", patch_impl2)

    with pytest.warns(None) as warnings_record:
        autolog1(silent=True)
        autolog2(silent=False)

        patch_destination.fn()
        patch_destination.fn2()

    warning_messages = [str(w.message) for w in warnings_record]
    assert warning_messages == ["warn_autolog2"]

    assert "autolog1" not in stream.getvalue()
    assert "patchimpl1" not in stream.getvalue()

    assert "event_autolog2" in stream.getvalue()
    assert "patchimpl2" in stream.getvalue()
def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws(
    patch_destination, test_autologging_integration, mock_event_logger,
):
    patch_session = None
    exc_to_raise = Exception("thrown from patch")

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_session
        patch_session = _AutologgingSessionManager.active_session()

        if throw_location == "before":
            raise exc_to_raise

        original(*args, **kwargs)

        if throw_location != "before":
            raise exc_to_raise

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

    throw_location = "before"
    patch_destination.fn()
    expected_order_throw_before = ["patch_start", "patch_error"]
    assert [call.method for call in mock_event_logger.calls] == expected_order_throw_before
    patch_start, patch_error = mock_event_logger.calls
    assert patch_start.exception is None
    assert patch_error.exception == exc_to_raise

    mock_event_logger.reset()

    throw_location = "after"
    patch_destination.fn()
    expected_order_throw_after = [
        "patch_start",
        "original_start",
        "original_success",
        "patch_error",
    ]
    assert [call.method for call in mock_event_logger.calls] == expected_order_throw_after
    patch_start, original_start, original_success, patch_error = mock_event_logger.calls
    assert patch_start.exception is original_start.exception is None
    assert original_success.exception is None
    assert patch_error.exception == exc_to_raise
def test_autologging_event_logging_and_warnings_respect_silent_mode(
    autolog_function, patch_destination, logger
):
    og_showwarning = warnings.showwarning
    stream = StringIO()
    sys.stderr = stream

    with pytest.warns(None) as silent_warnings_record:
        autolog_function(silent=True)
        patch_destination.fn()

    assert len(silent_warnings_record) == 1
    assert "Test warning from OG function" in str(silent_warnings_record[0].message)
    assert not stream.getvalue()

    # Verify that `warnings.showwarning` was restored to its original value after training
    # and that MLflow event logs are enabled
    assert warnings.showwarning == og_showwarning
    logger.info("verify that event logs are enabled")
    assert "verify that event logs are enabled" in stream.getvalue()

    stream.truncate(0)

    with pytest.warns(None) as noisy_warnings_record:
        autolog_function(silent=False)
        patch_destination.fn()

    # Verify that calling the autolog function with `silent=False` and invoking the mock training
    # function with autolog disabled produces event logs and warnings
    for item in ["enablement1", "enablement2", "enablement3", "enablement4"]:
        assert item in stream.getvalue()

    for item in ["patch1", "patch2", "patch3", "patch4"]:
        assert item in stream.getvalue()

    warning_messages = set([str(w.message) for w in noisy_warnings_record])
    assert "enablement warning MLflow" in warning_messages

    # Verify that `warnings.showwarning` was restored to its original value after training
    # and that MLflow event logs are enabled
    assert warnings.showwarning == og_showwarning
    logger.info("verify that event logs are enabled")
    assert "verify that event logs are enabled" in stream.getvalue()
def test_safe_patch_preserves_signature_of_patched_function(
    patch_destination, test_autologging_integration
):
    def original(a, b, c=10, *, d=11):
        return 10

    patch_destination.fn = original

    patch_impl_called = False

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_impl_called
        patch_impl_called = True
        return original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    patch_destination.fn(1, 2)
    assert patch_impl_called
    assert inspect.signature(patch_destination.fn) == inspect.signature(original)
def test_safe_patch_does_not_manage_run_if_unspecified(
    patch_destination, test_autologging_integration
):

    active_run = None

    def patch_impl(original, *args, **kwargs):
        nonlocal active_run
        active_run = mlflow.active_run()
        return original(*args, **kwargs)

    with mock.patch(
        "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run
    ) as managed_run_mock:
        safe_patch(
            test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=False
        )
        patch_destination.fn()
        assert managed_run_mock.call_count == 0
        assert active_run is None
def test_safe_patch_respects_disable_flag(patch_destination):

    patch_impl_call_count = 0

    @autologging_integration("test_respects_disable")
    def autolog(disable=False, silent=False):
        def patch_impl(original, *args, **kwargs):
            nonlocal patch_impl_call_count
            patch_impl_call_count += 1
            return original(*args, **kwargs)

        safe_patch("test_respects_disable", patch_destination, "fn", patch_impl)

    autolog(disable=False)
    patch_destination.fn()
    assert patch_impl_call_count == 1

    autolog(disable=True)
    patch_destination.fn()
    assert patch_impl_call_count == 1
def test_safe_patch_makes_expected_event_logging_calls_when_original_function_throws(
    patch_destination, test_autologging_integration, mock_event_logger,
):
    exc_to_raise = Exception("thrown from patch")

    def original(*args, **kwargs):
        raise exc_to_raise

    patch_destination.fn = original

    def patch_impl(original, *args, **kwargs):
        original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

    with pytest.raises(Exception, match="thrown from patch"):
        patch_destination.fn()
    expected_order = ["patch_start", "original_start", "original_error"]
    assert [call.method for call in mock_event_logger.calls] == expected_order
    patch_start, original_start, original_error = mock_event_logger.calls
    assert patch_start.exception is original_start.exception is None
    assert original_error.exception == exc_to_raise
def test_safe_patch_throws_when_autologging_runs_are_leaked_in_test_mode(
    patch_destination, test_autologging_integration
):
    assert autologging_utils.is_testing()

    def leak_run_patch_impl(original, *args, **kwargs):
        mlflow.start_run(nested=True)

    safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl)
    with pytest.raises(AssertionError, match="leaked an active run"):
        patch_destination.fn()

    # End the leaked run
    mlflow.end_run()

    with mlflow.start_run():
        # If a user-generated run existed prior to the autologged training session, we expect
        # that safe patch will not throw a leaked run exception
        patch_destination.fn()
        # End the leaked nested run
        mlflow.end_run()

    assert not mlflow.active_run()
def test_autologging_warnings_are_redirected_as_expected(
    autolog_function, patch_destination, logger
):
    stream = StringIO()
    sys.stderr = stream

    with pytest.warns(None) as warnings_record:
        autolog_function(silent=False)
        patch_destination.fn()

    # The following types of warnings are rerouted to MLflow's event loggers:
    # 1. All MLflow warnings emitted during patch function execution
    # 2. All warnings emitted during the patch function preamble (before the execution of the
    #    original / underlying function) and postamble (after the execution of the underlying
    #    function)
    # 3. non-MLflow warnings emitted during autologging setup / enablement
    #
    # Accordingly, we expect the following warnings to have been emitted normally: 1. MLflow
    # warnings emitted during autologging enablement, 2. non-MLflow warnings emitted during original
    # / underlying function execution
    warning_messages = set([str(w.message) for w in warnings_record])
    assert warning_messages == set(["enablement warning MLflow", "Test warning from OG function"])

    # Further, We expect MLflow's logging stream to contain content from all warnings emitted during
    # the autologging preamble and postamble and non-MLflow warnings emitted during autologging
    # enablement
    for item in [
        'MLflow autologging encountered a warning: "%s:5: Warning: preamble MLflow warning"',
        'MLflow autologging encountered a warning: "%s:10: Warning: postamble MLflow warning"',
    ]:
        assert (item % mlflow.__file__) in stream.getvalue()
    for item in [
        'MLflow autologging encountered a warning: "%s:7: UserWarning: preamble numpy warning"',
        'MLflow autologging encountered a warning: "%s:14: Warning: postamble numpy warning"',
        'MLflow autologging encountered a warning: "%s:30: Warning: enablement warning numpy"',
    ]:
        assert (item % np.__file__) in stream.getvalue()
def test_safe_patch_calls_original_function_when_patch_preamble_throws(
    patch_destination, test_autologging_integration
):

    patch_impl_called = False

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_impl_called
        patch_impl_called = True
        raise Exception("Bad patch preamble")

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT
    assert patch_destination.fn_call_count == 1
    assert patch_impl_called
def test_safe_patch_returns_original_result_and_ignores_patch_return_value(
    patch_destination, test_autologging_integration
):

    patch_impl_called = False

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_impl_called
        patch_impl_called = True
        return 10

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT
    assert patch_destination.fn_call_count == 1
    assert patch_impl_called
def test_safe_patch_does_not_validate_autologging_runs_in_standard_mode(
    patch_destination, test_autologging_integration
):
    assert not autologging_utils.is_testing()

    def no_tag_run_patch_impl(original, *args, **kwargs):
        with mlflow.start_run(nested=True):
            return original(*args, **kwargs)

    safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl)

    with mock.patch(
        "mlflow.utils.autologging_utils.safety._validate_autologging_run",
        wraps=_validate_autologging_run,
    ) as validate_run_mock:

        patch_destination.fn()

        with mlflow.start_run(nested=True):
            # If a user-generated run existed prior to the autologged training session, we expect
            # that safe patch will not attempt to validate it
            patch_destination.fn()

        assert not validate_run_mock.called
def test_safe_patch_returns_original_result_without_second_call_when_patch_postamble_throws(
    patch_destination, test_autologging_integration
):

    patch_impl_called = False

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_impl_called
        patch_impl_called = True
        original(*args, **kwargs)
        raise Exception("Bad patch postamble")

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT
    assert patch_destination.fn_call_count == 1
    assert patch_impl_called
def test_safe_patch_provides_expected_original_function(
    patch_destination, test_autologging_integration
):
    def original_fn(foo, bar=10):
        return {
            "foo": foo,
            "bar": bar,
        }

    patch_destination.fn = original_fn

    def patch_impl(original, foo, bar):
        return original(foo + 1, bar + 2)

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
    assert patch_destination.fn(1, 2) == {"foo": 2, "bar": 4}