Ejemplo n.º 1
0
def test_with_managed_run_sets_specified_run_tags():
    client = MlflowClient()
    tags_to_set = {
        "foo": "bar",
        "num_layers": "7",
    }

    patch_function_1 = with_managed_run(
        lambda original, *args, **kwargs: mlflow.active_run(),
        tags=tags_to_set)
    run1 = patch_function_1(lambda: "foo")
    assert tags_to_set.items() <= client.get_run(
        run1.info.run_id).data.tags.items()

    class PatchFunction2(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            return mlflow.active_run()

        def _on_exception(self, exception):
            pass

    patch_function_2 = with_managed_run(PatchFunction2, tags=tags_to_set)
    run2 = patch_function_2.call(lambda: "foo")
    assert tags_to_set.items() <= client.get_run(
        run2.info.run_id).data.tags.items()
Ejemplo n.º 2
0
def test_with_managed_runs_yields_functions_and_classes_as_expected():
    def patch_function(original, *args, **kwargs):
        pass

    class TestPatch(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            pass

        def _on_exception(self, exception):
            pass

    assert callable(with_managed_run(patch_function))
    assert inspect.isclass(with_managed_run(TestPatch))
def test_with_managed_run_with_throwing_function_exhibits_expected_behavior():
    client = MlflowClient()
    patch_function_active_run = None

    def patch_function(original, *args, **kwargs):
        nonlocal patch_function_active_run
        patch_function_active_run = mlflow.active_run()
        raise Exception("bad implementation")

    patch_function = with_managed_run("test_integration", patch_function)

    with pytest.raises(Exception):
        patch_function(lambda: "foo")

    assert patch_function_active_run is not None
    status1 = client.get_run(patch_function_active_run.info.run_id).info.status
    assert RunStatus.from_string(status1) == RunStatus.FAILED

    with mlflow.start_run() as active_run, pytest.raises(Exception):
        patch_function(lambda: "foo")
        assert patch_function_active_run == active_run
        # `with_managed_run` should not terminate a preexisting MLflow run,
        # even if the patch function throws
        status2 = client.get_run(active_run.info.run_id).info.status
        assert RunStatus.from_string(status2) == RunStatus.FINISHED
def test_with_managed_run_ends_run_on_keyboard_interrupt():
    client = MlflowClient()
    run = None

    def original():
        nonlocal run
        run = mlflow.active_run()
        raise KeyboardInterrupt

    patch_function_1 = with_managed_run(
        "test_integration", lambda original, *args, **kwargs: original(*args, **kwargs)
    )

    with pytest.raises(KeyboardInterrupt):
        patch_function_1(original)

    assert not mlflow.active_run()
    run_status_1 = client.get_run(run.info.run_id).info.status
    assert RunStatus.from_string(run_status_1) == RunStatus.FAILED

    class PatchFunction2(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            return original(*args, **kwargs)

        def _on_exception(self, exception):
            pass

    patch_function_2 = with_managed_run("test_integration", PatchFunction2)

    with pytest.raises(KeyboardInterrupt):

        patch_function_2.call(original)

    assert not mlflow.active_run()
    run_status_2 = client.get_run(run.info.run_id).info.status
    assert RunStatus.from_string(run_status_2) == RunStatus.FAILED
def test_with_managed_run_with_non_throwing_function_exhibits_expected_behavior():
    client = MlflowClient()

    def patch_function(original, *args, **kwargs):
        return mlflow.active_run()

    patch_function = with_managed_run("test_integration", patch_function)

    run1 = patch_function(lambda: "foo")
    run1_status = client.get_run(run1.info.run_id).info.status
    assert RunStatus.from_string(run1_status) == RunStatus.FINISHED

    with mlflow.start_run() as active_run:
        run2 = patch_function(lambda: "foo")

    assert run2 == active_run
    run2_status = client.get_run(run2.info.run_id).info.status
    assert RunStatus.from_string(run2_status) == RunStatus.FINISHED
def test_with_managed_run_with_non_throwing_class_exhibits_expected_behavior():
    client = MlflowClient()

    class TestPatch(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            return mlflow.active_run()

        def _on_exception(self, exception):
            pass

    TestPatch = with_managed_run("test_integration", TestPatch)

    run1 = TestPatch.call(lambda: "foo")
    run1_status = client.get_run(run1.info.run_id).info.status
    assert RunStatus.from_string(run1_status) == RunStatus.FINISHED

    with mlflow.start_run() as active_run:
        run2 = TestPatch.call(lambda: "foo")

    assert run2 == active_run
    run2_status = client.get_run(run2.info.run_id).info.status
    assert RunStatus.from_string(run2_status) == RunStatus.FINISHED