コード例 #1
0
def test_universal_autolog_makes_expected_event_logging_calls():
    class TestLogger(AutologgingEventLogger):

        LoggerCall = namedtuple("LoggerCall", ["integration", "call_args", "call_kwargs"])

        def __init__(self):
            self.calls = []

        def reset(self):
            self.calls = []

        def log_autolog_called(self, integration, call_args, call_kwargs):
            self.calls.append(TestLogger.LoggerCall(integration, call_args, call_kwargs))

    logger = TestLogger()
    AutologgingEventLogger.set_logger(logger)

    mlflow.autolog(exclusive=True, disable=True)

    universal_autolog_event_logging_calls = [
        call for call in logger.calls if call.integration == "mlflow"
    ]
    assert len(universal_autolog_event_logging_calls) == 1
    call = universal_autolog_event_logging_calls[0]
    assert call.integration == "mlflow"
    assert {"disable": True, "exclusive": True}.items() <= call.call_kwargs.items()
コード例 #2
0
def test_autologging_integration_makes_expected_event_logging_calls():
    @autologging_integration("test_success")
    def autolog_success(foo, bar=7, disable=False, silent=False):
        pass

    @autologging_integration("test_failure")
    def autolog_failure(biz, baz="val", disable=False, silent=False):
        raise Exception("autolog failed")

    class TestLogger(AutologgingEventLogger):

        LoggerCall = namedtuple("LoggerCall",
                                ["integration", "call_args", "call_kwargs"])

        def __init__(self):
            self.calls = []

        def reset(self):
            self.calls = []

        def log_autolog_called(self, integration, call_args, call_kwargs):
            self.calls.append(
                TestLogger.LoggerCall(integration, call_args, call_kwargs))

    logger = TestLogger()
    AutologgingEventLogger.set_logger(logger)

    autolog_success("a", bar=9, disable=True)
    assert len(logger.calls) == 1
    call = logger.calls[0]
    assert call.integration == "test_success"
    # NB: In MLflow > 1.13.1, the `call_args` argument to `log_autolog_called` is deprecated.
    # Positional arguments passed to `autolog()` should be forwarded to `log_autolog_called`
    # in keyword format
    assert call.call_args == ()
    assert call.call_kwargs == {
        "foo": "a",
        "bar": 9,
        "disable": True,
        "silent": False
    }

    logger.reset()

    with pytest.raises(Exception, match="autolog failed"):
        autolog_failure(82, disable=False, silent=True)
    assert len(logger.calls) == 1
    call = logger.calls[0]
    assert call.integration == "test_failure"
    # NB: In MLflow > 1.13.1, the `call_args` argument to `log_autolog_called` is deprecated.
    # Positional arguments passed to `autolog()` should be forwarded to `log_autolog_called`
    # in keyword format
    assert call.call_args == ()
    assert call.call_kwargs == {
        "biz": 82,
        "baz": "val",
        "disable": False,
        "silent": True
    }
コード例 #3
0
def mock_event_logger():
    try:
        prev_logger = AutologgingEventLogger.get_logger()
        logger = MockEventLogger()
        AutologgingEventLogger.set_logger(logger)
        yield logger
    finally:
        AutologgingEventLogger.set_logger(prev_logger)
コード例 #4
0
def test_autologging_event_logger_default_impl_warns_for_log_autolog_called_with_deprecated_args():
    AutologgingEventLogger.set_logger(AutologgingEventLogger())

    with pytest.warns(DeprecationWarning, match="Received 1 positional arguments"):
        AutologgingEventLogger.get_logger().log_autolog_called(
            "test_integration",
            # call_args is deprecated in MLflow > 1.13.1; specifying a non-empty
            # value for this parameter should emit a warning
            call_args=("a"),
            call_kwargs={"b": "c"},
        )
コード例 #5
0
def test_autologging_integration_succeeds_when_event_logging_throws_in_standard_mode():
    @autologging_integration("test")
    def autolog(disable=False, silent=False):
        return "result"

    class ThrowingLogger(AutologgingEventLogger):
        def __init__(self):
            self.logged_event = False

        def log_autolog_called(self, integration, call_args, call_kwargs):
            self.logged_event = True
            raise Exception("autolog failed")

    logger = ThrowingLogger()
    AutologgingEventLogger.set_logger(logger)
    assert autolog() == "result"
    assert logger.logged_event
コード例 #6
0
def test_autologging_integration_makes_expected_event_logging_calls():
    @autologging_integration("test_success")
    def autolog_success(foo, bar=7, disable=False):
        pass

    @autologging_integration("test_failure")
    def autolog_failure(biz, baz="val", disable=False):
        raise Exception("autolog failed")

    class TestLogger(AutologgingEventLogger):

        LoggerCall = namedtuple("LoggerCall", ["integration", "call_args", "call_kwargs"])

        def __init__(self):
            self.calls = []

        def reset(self):
            self.calls = []

        def log_autolog_called(self, integration, call_args, call_kwargs):
            self.calls.append(TestLogger.LoggerCall(integration, call_args, call_kwargs))

    logger = TestLogger()
    AutologgingEventLogger.set_logger(logger)

    autolog_success("a", bar=9, disable=True)
    assert len(logger.calls) == 1
    call = logger.calls[0]
    assert call.integration == "test_success"
    assert call.call_args == ("a",)
    assert call.call_kwargs == {"bar": 9, "disable": True}

    logger.reset()

    with pytest.raises(Exception, match="autolog failed"):
        autolog_failure(82, baz="b", disable=False)
    assert len(logger.calls) == 1
    call = logger.calls[0]
    assert call.integration == "test_failure"
    assert call.call_args == (82,)
    assert call.call_kwargs == {"baz": "b", "disable": False}
コード例 #7
0
def test_autologging_event_logger_default_implementation_does_not_throw_for_valid_inputs():
    AutologgingEventLogger.set_logger(AutologgingEventLogger())

    class PatchObj:
        def test_fn(self):
            pass

    # Test successful autologging workflow
    AutologgingEventLogger.get_logger().log_autolog_called(
        "test_integration", ("a"), {"b": 1, "c": "d"}
    )
    AutologgingEventLogger.get_logger().log_patch_function_start(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )
    AutologgingEventLogger.get_logger().log_original_function_start(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )
    AutologgingEventLogger.get_logger().log_original_function_success(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )
    AutologgingEventLogger.get_logger().log_patch_function_success(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )

    # Test patch function failure autologging workflow
    AutologgingEventLogger.get_logger().log_patch_function_start(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )
    AutologgingEventLogger.get_logger().log_patch_function_error(
        AutologgingSession("test_integration", "123"),
        PatchObj(),
        "test_fn",
        (1000),
        {"a": 2},
        Exception("patch error"),
    )

    # Test original function failure autologging workflow
    AutologgingEventLogger.get_logger().log_patch_function_start(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )
    AutologgingEventLogger.get_logger().log_original_function_start(
        AutologgingSession("test_integration", "123"), PatchObj(), "test_fn", (1000), {"a": 2}
    )
    AutologgingEventLogger.get_logger().log_patch_function_error(
        AutologgingSession("test_integration", "123"),
        PatchObj(),
        "test_fn",
        (1000),
        {"a": 2},
        Exception("patch error"),
    )
コード例 #8
0
def test_safe_patch_succeeds_when_event_logging_throws_in_standard_mode(
    patch_destination, test_autologging_integration,
):
    patch_preamble_called = False
    patch_postamble_called = False

    def patch_impl(original, *args, **kwargs):
        nonlocal patch_preamble_called
        patch_preamble_called = True
        original(*args, **kwargs)
        nonlocal patch_postamble_called
        patch_postamble_called = True

    safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

    class ThrowingLogger(MockEventLogger):
        def log_patch_function_start(
            self, session, patch_obj, function_name, call_args, call_kwargs
        ):
            super().log_patch_function_start(
                session, patch_obj, function_name, call_args, call_kwargs
            )
            raise Exception("failed")

        def log_patch_function_success(
            self, session, patch_obj, function_name, call_args, call_kwargs
        ):
            super().log_patch_function_success(
                session, patch_obj, function_name, call_args, call_kwargs
            )
            raise Exception("failed")

        def log_patch_function_error(
            self, session, patch_obj, function_name, call_args, call_kwargs, exception
        ):
            super().log_patch_function_error(
                session, patch_obj, function_name, call_args, call_kwargs, exception
            )
            raise Exception("failed")

        def log_original_function_start(
            self, session, patch_obj, function_name, call_args, call_kwargs
        ):
            super().log_original_function_start(
                session, patch_obj, function_name, call_args, call_kwargs
            )
            raise Exception("failed")

        def log_original_function_success(
            self, session, patch_obj, function_name, call_args, call_kwargs
        ):
            super().log_original_function_success(
                session, patch_obj, function_name, call_args, call_kwargs
            )
            raise Exception("failed")

        def log_original_function_error(
            self, session, patch_obj, function_name, call_args, call_kwargs, exception
        ):
            super().log_original_function_error(
                session, patch_obj, function_name, call_args, call_kwargs, exception
            )
            raise Exception("failed")

    logger = ThrowingLogger()
    AutologgingEventLogger.set_logger(logger)
    assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT
    assert patch_preamble_called
    assert patch_postamble_called
    expected_calls = ["patch_start", "original_start", "original_success", "patch_success"]
    assert [call.method for call in logger.calls] == expected_calls