def test_with_managed_run_sets_specified_run_tags(): client = MlflowClient() tags_to_set = { "foo": "bar", "num_layers": "7", } patch_function_1 = with_managed_run( lambda original, *args, **kwargs: mlflow.active_run(), tags=tags_to_set) run1 = patch_function_1(lambda: "foo") assert tags_to_set.items() <= client.get_run( run1.info.run_id).data.tags.items() class PatchFunction2(PatchFunction): def _patch_implementation(self, original, *args, **kwargs): return mlflow.active_run() def _on_exception(self, exception): pass patch_function_2 = with_managed_run(PatchFunction2, tags=tags_to_set) run2 = patch_function_2.call(lambda: "foo") assert tags_to_set.items() <= client.get_run( run2.info.run_id).data.tags.items()
def test_with_managed_runs_yields_functions_and_classes_as_expected(): def patch_function(original, *args, **kwargs): pass class TestPatch(PatchFunction): def _patch_implementation(self, original, *args, **kwargs): pass def _on_exception(self, exception): pass assert callable(with_managed_run(patch_function)) assert inspect.isclass(with_managed_run(TestPatch))
def test_with_managed_run_with_throwing_function_exhibits_expected_behavior(): client = MlflowClient() patch_function_active_run = None def patch_function(original, *args, **kwargs): nonlocal patch_function_active_run patch_function_active_run = mlflow.active_run() raise Exception("bad implementation") patch_function = with_managed_run("test_integration", patch_function) with pytest.raises(Exception): patch_function(lambda: "foo") assert patch_function_active_run is not None status1 = client.get_run(patch_function_active_run.info.run_id).info.status assert RunStatus.from_string(status1) == RunStatus.FAILED with mlflow.start_run() as active_run, pytest.raises(Exception): patch_function(lambda: "foo") assert patch_function_active_run == active_run # `with_managed_run` should not terminate a preexisting MLflow run, # even if the patch function throws status2 = client.get_run(active_run.info.run_id).info.status assert RunStatus.from_string(status2) == RunStatus.FINISHED
def test_with_managed_run_ends_run_on_keyboard_interrupt(): client = MlflowClient() run = None def original(): nonlocal run run = mlflow.active_run() raise KeyboardInterrupt patch_function_1 = with_managed_run( "test_integration", lambda original, *args, **kwargs: original(*args, **kwargs) ) with pytest.raises(KeyboardInterrupt): patch_function_1(original) assert not mlflow.active_run() run_status_1 = client.get_run(run.info.run_id).info.status assert RunStatus.from_string(run_status_1) == RunStatus.FAILED class PatchFunction2(PatchFunction): def _patch_implementation(self, original, *args, **kwargs): return original(*args, **kwargs) def _on_exception(self, exception): pass patch_function_2 = with_managed_run("test_integration", PatchFunction2) with pytest.raises(KeyboardInterrupt): patch_function_2.call(original) assert not mlflow.active_run() run_status_2 = client.get_run(run.info.run_id).info.status assert RunStatus.from_string(run_status_2) == RunStatus.FAILED
def test_with_managed_run_with_non_throwing_function_exhibits_expected_behavior(): client = MlflowClient() def patch_function(original, *args, **kwargs): return mlflow.active_run() patch_function = with_managed_run("test_integration", patch_function) run1 = patch_function(lambda: "foo") run1_status = client.get_run(run1.info.run_id).info.status assert RunStatus.from_string(run1_status) == RunStatus.FINISHED with mlflow.start_run() as active_run: run2 = patch_function(lambda: "foo") assert run2 == active_run run2_status = client.get_run(run2.info.run_id).info.status assert RunStatus.from_string(run2_status) == RunStatus.FINISHED
def test_with_managed_run_with_non_throwing_class_exhibits_expected_behavior(): client = MlflowClient() class TestPatch(PatchFunction): def _patch_implementation(self, original, *args, **kwargs): return mlflow.active_run() def _on_exception(self, exception): pass TestPatch = with_managed_run("test_integration", TestPatch) run1 = TestPatch.call(lambda: "foo") run1_status = client.get_run(run1.info.run_id).info.status assert RunStatus.from_string(run1_status) == RunStatus.FINISHED with mlflow.start_run() as active_run: run2 = TestPatch.call(lambda: "foo") assert run2 == active_run run2_status = client.get_run(run2.info.run_id).info.status assert RunStatus.from_string(run2_status) == RunStatus.FINISHED