def test_start_and_end_run(): # Use the start_run() and end_run() APIs without a `with` block, verify they work. with start_run() as active_run: mlflow.log_metric("name_1", 25) finished_run = tracking.MlflowClient().get_run(active_run.info.run_id) # Validate metrics assert len(finished_run.data.metrics) == 1 assert finished_run.data.metrics["name_1"] == 25
def test_start_run_context_manager(): with start_run() as first_run: first_uuid = first_run.info.run_id # Check that start_run() causes the run information to be persisted in the store persisted_run = tracking.MlflowClient().get_run(first_uuid) assert persisted_run is not None assert persisted_run.info == first_run.info finished_run = tracking.MlflowClient().get_run(first_uuid) assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED) # Launch a separate run that fails, verify the run status is FAILED and the run UUID is # different with pytest.raises(Exception): with start_run() as second_run: second_run_id = second_run.info.run_id raise Exception("Failing run!") assert second_run_id != first_uuid finished_run2 = tracking.MlflowClient().get_run(second_run_id) assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED)
def test_log_param(tracking_uri_mock): with start_run() as active_run: run_id = active_run.info.run_id mlflow.log_param("name_1", "a") mlflow.log_param("name_2", "b") mlflow.log_param("nested/nested/name", 5) finished_run = tracking.MlflowClient().get_run(run_id) # Validate params assert finished_run.data.params == {"name_1": "a", "name_2": "b", "nested/nested/name": "5"}
def test_log_artifact_with_dirs(tmpdir): # Test log artifact with a directory art_dir = tmpdir.mkdir("parent") file0 = art_dir.join("file0") file0.write("something") file1 = art_dir.join("file1") file1.write("something") sub_dir = art_dir.mkdir("child") with start_run(): artifact_uri = mlflow.get_artifact_uri() run_artifact_dir = local_file_uri_to_path(artifact_uri) mlflow.log_artifact(str(art_dir)) base = os.path.basename(str(art_dir)) assert os.listdir(run_artifact_dir) == [base] assert set(os.listdir(os.path.join( run_artifact_dir, base))) == {"child", "file0", "file1"} with open(os.path.join(run_artifact_dir, base, "file0")) as f: assert f.read() == "something" # Test log artifact with directory and specified parent folder art_dir = tmpdir.mkdir("dir") with start_run(): artifact_uri = mlflow.get_artifact_uri() run_artifact_dir = local_file_uri_to_path(artifact_uri) mlflow.log_artifact(str(art_dir), "some_parent") assert os.listdir(run_artifact_dir) == [ os.path.basename("some_parent") ] assert os.listdir( os.path.join(run_artifact_dir, "some_parent")) == [os.path.basename(str(art_dir))] sub_dir = art_dir.mkdir("another_dir") with start_run(): artifact_uri = mlflow.get_artifact_uri() run_artifact_dir = local_file_uri_to_path(artifact_uri) mlflow.log_artifact(str(art_dir), "parent/and_child") assert os.listdir( os.path.join(run_artifact_dir, "parent", "and_child")) == [os.path.basename(str(art_dir))] assert set( os.listdir( os.path.join(run_artifact_dir, "parent", "and_child", os.path.basename(str(art_dir))))) == { os.path.basename(str(sub_dir)) }
def test_set_experiment_with_deleted_experiment_name(tracking_uri_mock): name = "dead_exp" mlflow.set_experiment(name) with start_run() as run: exp_id = run.info.experiment_id tracking.MlflowClient().delete_experiment(exp_id) with pytest.raises(MlflowException): mlflow.set_experiment(name)
def test_log_batch_duplicate_entries_raises(): with start_run() as active_run: run_id = active_run.info.run_id with pytest.raises( MlflowException, match=r"Duplicate parameter keys have been submitted." ) as e: tracking.MlflowClient().log_batch( run_id=run_id, params=[Param("a", "1"), Param("a", "2")] ) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
def test_log_batch(tracking_uri_mock, tmpdir): expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0} expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"} exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"} approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE]) t = int(time.time()) sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[0]) metrics = [ Metric(key=key, value=value, timestamp=t, step=i) for i, (key, value) in enumerate(sorted_expected_metrics) ] params = [ Param(key=key, value=value) for key, value in expected_params.items() ] tags = [ RunTag(key=key, value=value) for key, value in exact_expected_tags.items() ] with start_run() as active_run: run_uuid = active_run.info.run_uuid mlflow.tracking.MlflowClient().log_batch(run_id=run_uuid, metrics=metrics, params=params, tags=tags) finished_run = tracking.MlflowClient().get_run(run_uuid) # Validate metrics assert len(finished_run.data.metrics) == 2 for key, value in finished_run.data.metrics.items(): assert expected_metrics[key] == value # TODO: use client get_metric_history API here instead once it exists fs = FileStore(os.path.join(tmpdir.strpath, "mlruns")) metric_history0 = fs.get_metric_history(run_uuid, "metric-key0") assert set([(m.value, m.timestamp, m.step) for m in metric_history0]) == set([ (1.0, t, 0), ]) metric_history1 = fs.get_metric_history(run_uuid, "metric-key1") assert set([(m.value, m.timestamp, m.step) for m in metric_history1]) == set([ (4.0, t, 1), ]) # Validate tags (for automatically-set tags) assert len(finished_run.data.tags ) == len(exact_expected_tags) + len(approx_expected_tags) for tag_key, tag_value in finished_run.data.tags.items(): if tag_key in approx_expected_tags: pass else: assert exact_expected_tags[tag_key] == tag_value # Validate params assert finished_run.data.params == expected_params
def test_start_run_overrides_databricks_notebook(empty_active_run_stack): databricks_notebook_patch = mock.patch( "mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True) mock_notebook_id = mock.Mock() notebook_id_patch = mock.patch( "mlflow.utils.databricks_utils.get_notebook_id", return_value=mock_notebook_id) mock_notebook_path = mock.Mock() notebook_path_patch = mock.patch( "mlflow.utils.databricks_utils.get_notebook_path", return_value=mock_notebook_path) mock_webapp_url = mock.Mock() webapp_url_patch = mock.patch( "mlflow.utils.databricks_utils.get_webapp_url", return_value=mock_webapp_url) create_run_patch = mock.patch.object(MlflowClient, "create_run") mock_experiment_id = mock.Mock() mock_source_name = mock.Mock() mock_source_type = mock.Mock() mock_source_version = mock.Mock() mock_entry_point_name = mock.Mock() mock_run_name = mock.Mock() expected_tags = { mlflow_tags.MLFLOW_SOURCE_NAME: mock_source_name, mlflow_tags.MLFLOW_SOURCE_TYPE: mock_source_type, mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version, mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT: mock_entry_point_name, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_ID: mock_notebook_id, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_PATH: mock_notebook_path, mlflow_tags.MLFLOW_DATABRICKS_WEBAPP_URL: mock_webapp_url } with databricks_notebook_patch, create_run_patch, notebook_id_patch, notebook_path_patch, \ webapp_url_patch: active_run = start_run(experiment_id=mock_experiment_id, source_name=mock_source_name, source_version=mock_source_version, entry_point_name=mock_entry_point_name, source_type=mock_source_type, run_name=mock_run_name) MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, run_name=mock_run_name, source_name=mock_source_name, source_version=mock_source_version, entry_point_name=mock_entry_point_name, source_type=mock_source_type, tags=expected_tags, parent_run_id=None) assert is_from_run(active_run, MlflowClient.create_run.return_value)
def test_log_artifact(): artifact_src_dir = tempfile.mkdtemp() # Create artifacts _, path0 = tempfile.mkstemp(dir=artifact_src_dir) _, path1 = tempfile.mkstemp(dir=artifact_src_dir) for i, path in enumerate([path0, path1]): with open(path, "w") as handle: handle.write("%s" % str(i)) # Log an artifact, verify it exists in the directory returned by get_artifact_uri # after the run finishes artifact_parent_dirs = ["some_parent_dir", None] for parent_dir in artifact_parent_dirs: with start_run(): artifact_uri = mlflow.get_artifact_uri() run_artifact_dir = local_file_uri_to_path(artifact_uri) mlflow.log_artifact(path0, parent_dir) expected_dir = ( os.path.join(run_artifact_dir, parent_dir) if parent_dir is not None else run_artifact_dir ) assert os.listdir(expected_dir) == [os.path.basename(path0)] logged_artifact_path = os.path.join(expected_dir, path0) assert filecmp.cmp(logged_artifact_path, path0, shallow=False) # Log multiple artifacts, verify they exist in the directory returned by get_artifact_uri for parent_dir in artifact_parent_dirs: with start_run(): artifact_uri = mlflow.get_artifact_uri() run_artifact_dir = local_file_uri_to_path(artifact_uri) mlflow.log_artifacts(artifact_src_dir, parent_dir) # Check that the logged artifacts match expected_artifact_output_dir = ( os.path.join(run_artifact_dir, parent_dir) if parent_dir is not None else run_artifact_dir ) dir_comparison = filecmp.dircmp(artifact_src_dir, expected_artifact_output_dir) assert len(dir_comparison.left_only) == 0 assert len(dir_comparison.right_only) == 0 assert len(dir_comparison.diff_files) == 0 assert len(dir_comparison.funny_files) == 0
def test_start_run_creates_new_run_with_user_specified_tags(): mock_experiment_id = mock.Mock() experiment_id_patch = mock.patch( "mlflow.tracking.fluent._get_experiment_id", return_value=mock_experiment_id ) databricks_notebook_patch = mock.patch( "mlflow.tracking.fluent.is_in_databricks_notebook", return_value=False ) mock_user = mock.Mock() user_patch = mock.patch( "mlflow.tracking.context.default_context._get_user", return_value=mock_user ) mock_source_name = mock.Mock() source_name_patch = mock.patch( "mlflow.tracking.context.default_context._get_source_name", return_value=mock_source_name ) source_type_patch = mock.patch( "mlflow.tracking.context.default_context._get_source_type", return_value=SourceType.NOTEBOOK ) mock_source_version = mock.Mock() source_version_patch = mock.patch( "mlflow.tracking.context.git_context._get_source_version", return_value=mock_source_version ) user_specified_tags = { "ml_task": "regression", "num_layers": 7, mlflow_tags.MLFLOW_USER: "******", } expected_tags = { mlflow_tags.MLFLOW_SOURCE_NAME: mock_source_name, mlflow_tags.MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK), mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version, mlflow_tags.MLFLOW_USER: "******", "ml_task": "regression", "num_layers": 7, } create_run_patch = mock.patch.object(MlflowClient, "create_run") with multi_context( experiment_id_patch, databricks_notebook_patch, user_patch, source_name_patch, source_type_patch, source_version_patch, create_run_patch, ): active_run = start_run(tags=user_specified_tags) MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, tags=expected_tags ) assert is_from_run(active_run, MlflowClient.create_run.return_value)
def test_start_and_end_run(tracking_uri_mock): # Use the start_run() and end_run() APIs without a `with` block, verify they work. active_run = start_run() mlflow.log_metric("name_1", 25) end_run() finished_run = tracking.MlflowClient().get_run(active_run.info.run_uuid) # Validate metrics assert len(finished_run.data.metrics) == 1 expected_pairs = {"name_1": 25} for metric in finished_run.data.metrics: assert expected_pairs[metric.key] == metric.value
def test_start_run_existing_run_from_environment_with_set_environment(empty_active_run_stack): mock_run = mock.Mock() mock_run.info.lifecycle_stage = LifecycleStage.ACTIVE run_id = uuid.uuid4().hex env_patch = mock.patch.dict("os.environ", {_RUN_ID_ENV_VAR: run_id}) with env_patch, mock.patch.object(MlflowClient, "get_run", return_value=mock_run): with pytest.raises(MlflowException): set_experiment("test-run") active_run = start_run()
def test_log_metric_validation(): try: tracking.set_tracking_uri(tempfile.mkdtemp()) active_run = start_run() run_uuid = active_run.info.run_uuid with active_run: mlflow.log_metric("name_1", "apple") finished_run = tracking.MlflowClient().get_run(run_uuid) assert len(finished_run.data.metrics) == 0 finally: tracking.set_tracking_uri(None)
def test_log_params(tracking_uri_mock): expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"} active_run = start_run() run_uuid = active_run.info.run_uuid with active_run: mlflow.log_params(expected_params) finished_run = tracking.MlflowClient().get_run(run_uuid) # Validate params assert len(finished_run.data.params) == 3 for param in finished_run.data.params: assert expected_params[param.key] == param.value
def test_start_run_existing_run(empty_active_run_stack): mock_run = mock.Mock() mock_run.info.lifecycle_stage = LifecycleStage.ACTIVE run_id = uuid.uuid4().hex with mock.patch.object(MlflowClient, "get_run", return_value=mock_run): active_run = start_run(run_id) assert is_from_run(active_run, mock_run) MlflowClient.get_run.assert_called_once_with(run_id)
def test_start_run_existing_run_from_environment(empty_active_run_stack): mock_run = mock.Mock() mock_run.info.lifecycle_stage = LifecycleStage.ACTIVE run_id = uuid.uuid4().hex env_patch = mock.patch.dict("os.environ", {_RUN_ID_ENV_VAR: run_id}) with env_patch, mock.patch.object(MlflowClient, "get_run", return_value=mock_run): active_run = start_run() assert is_from_run(active_run, mock_run) MlflowClient.get_run.assert_called_once_with(run_id)
def test_log_params(tracking_uri_mock): expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"} with start_run() as active_run: run_uuid = active_run.info.run_uuid mlflow.log_params(expected_params) finished_run = tracking.MlflowClient().get_run(run_uuid) # Validate params assert finished_run.data.params == { "name_1": "c", "name_2": "b", "nested/nested/name": "5" }
def test_set_experiment(tracking_uri_mock, reset_active_experiment): with pytest.raises(TypeError): mlflow.set_experiment() with pytest.raises(Exception): mlflow.set_experiment(None) with pytest.raises(Exception): mlflow.set_experiment("") name = "random_exp" exp_id = mlflow.create_experiment(name) mlflow.set_experiment(name) with start_run() as run: assert run.info.experiment_id == exp_id another_name = "another_experiment" mlflow.set_experiment(another_name) exp_id2 = mlflow.tracking.MlflowClient().get_experiment_by_name(another_name) with start_run() as another_run: assert another_run.info.experiment_id == exp_id2.experiment_id
def test_set_experiment(): with pytest.raises(TypeError): mlflow.set_experiment() # pylint: disable=no-value-for-parameter with pytest.raises(Exception): mlflow.set_experiment(None) with pytest.raises(Exception): mlflow.set_experiment("") name = "random_exp" exp_id = mlflow.create_experiment(name) mlflow.set_experiment(name) with start_run() as run: assert run.info.experiment_id == exp_id another_name = "another_experiment" mlflow.set_experiment(another_name) exp_id2 = mlflow.tracking.MlflowClient().get_experiment_by_name(another_name) with start_run() as another_run: assert another_run.info.experiment_id == exp_id2.experiment_id
def test_start_run_existing_run(empty_active_run_stack): # pylint: disable=unused-argument mock_run = mock.Mock() mock_run.info.lifecycle_stage = LifecycleStage.ACTIVE run_id = uuid.uuid4().hex mock_get_store = mock.patch("mlflow.tracking.fluent._get_store") with mock_get_store, mock.patch.object(MlflowClient, "get_run", return_value=mock_run): active_run = start_run(run_id) assert is_from_run(active_run, mock_run) MlflowClient.get_run.assert_called_with(run_id)
def test_log_params_duplicate_keys_raises(): params = {"a": "1", "b": "2"} with start_run() as active_run: run_id = active_run.info.run_id mlflow.log_params(params) with pytest.raises( expected_exception=MlflowException, match=r"Changing param values is not allowed. Param with key=", ) as e: mlflow.log_param("a", "3") assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) finished_run = tracking.MlflowClient().get_run(run_id) assert finished_run.data.params == params
def test_log_metrics(tracking_uri_mock): active_run = start_run() run_uuid = active_run.info.run_uuid expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40} with active_run: mlflow.log_metrics(expected_metrics) finished_run = tracking.MlflowClient().get_run(run_uuid) # Validate metric key/values match what we expect, and that all metrics have the same timestamp common_timestamp = finished_run.data.metrics[0].timestamp assert len(finished_run.data.metrics) == len(expected_metrics) for metric in finished_run.data.metrics: assert expected_metrics[metric.key] == metric.value assert metric.timestamp == common_timestamp
def test_start_run_defaults_databricks_notebook(empty_active_run_stack): mock_experiment_id = mock.Mock() experiment_id_patch = mock.patch( "mlflow.tracking.fluent._get_experiment_id", return_value=mock_experiment_id ) databricks_notebook_patch = mock.patch( "mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True ) mock_source_version = mock.Mock() source_version_patch = mock.patch( "mlflow.tracking.context._get_source_version", return_value=mock_source_version ) mock_notebook_id = mock.Mock() notebook_id_patch = mock.patch( "mlflow.utils.databricks_utils.get_notebook_id", return_value=mock_notebook_id ) mock_notebook_path = mock.Mock() notebook_path_patch = mock.patch( "mlflow.utils.databricks_utils.get_notebook_path", return_value=mock_notebook_path ) mock_webapp_url = mock.Mock() webapp_url_patch = mock.patch( "mlflow.utils.databricks_utils.get_webapp_url", return_value=mock_webapp_url ) expected_tags = { mlflow_tags.MLFLOW_SOURCE_NAME: mock_notebook_path, mlflow_tags.MLFLOW_SOURCE_TYPE: "NOTEBOOK", mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_ID: mock_notebook_id, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_PATH: mock_notebook_path, mlflow_tags.MLFLOW_DATABRICKS_WEBAPP_URL: mock_webapp_url } create_run_patch = mock.patch.object(MlflowClient, "create_run") with experiment_id_patch, databricks_notebook_patch, source_version_patch, \ notebook_id_patch, notebook_path_patch, webapp_url_patch, create_run_patch: active_run = start_run() MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, run_name=None, source_name=mock_notebook_path, source_version=mock_source_version, entry_point_name=None, source_type=SourceType.NOTEBOOK, tags=expected_tags, parent_run_id=None ) assert is_from_run(active_run, MlflowClient.create_run.return_value)
def test_log_batch(tracking_uri_mock, tmpdir): expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0} expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"} exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"} approx_expected_tags = set([MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE]) t = int(time.time()) sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[0]) metrics = [Metric(key=key, value=value, timestamp=t, step=i) for i, (key, value) in enumerate(sorted_expected_metrics)] params = [Param(key=key, value=value) for key, value in expected_params.items()] tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()] with start_run() as active_run: run_id = active_run.info.run_id mlflow.tracking.MlflowClient().log_batch(run_id=run_id, metrics=metrics, params=params, tags=tags) client = tracking.MlflowClient() finished_run = client.get_run(run_id) # Validate metrics assert len(finished_run.data.metrics) == 2 for key, value in finished_run.data.metrics.items(): assert expected_metrics[key] == value metric_history0 = client.get_metric_history(run_id, "metric-key0") assert set([(m.value, m.timestamp, m.step) for m in metric_history0]) == set([ (1.0, t, 0), ]) metric_history1 = client.get_metric_history(run_id, "metric-key1") assert set([(m.value, m.timestamp, m.step) for m in metric_history1]) == set([ (4.0, t, 1), ]) # Validate tags (for automatically-set tags) assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags) for tag_key, tag_value in finished_run.data.tags.items(): if tag_key in approx_expected_tags: pass else: assert exact_expected_tags[tag_key] == tag_value # Validate params assert finished_run.data.params == expected_params # test that log_batch works with fewer params new_tags = {"1": "2", "3": "4", "5": "6"} tags = [RunTag(key=key, value=value) for key, value in new_tags.items()] client.log_batch(run_id=run_id, tags=tags) finished_run_2 = client.get_run(run_id) # Validate tags (for automatically-set tags) assert len(finished_run_2.data.tags) == len(finished_run.data.tags) + 3 for tag_key, tag_value in finished_run_2.data.tags.items(): if tag_key in new_tags: assert new_tags[tag_key] == tag_value
def test_log_batch_validates_entity_names_and_values(): with start_run() as active_run: run_id = active_run.info.run_id metrics = [ Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0) ] with pytest.raises(MlflowException, match="Invalid metric name") as e: tracking.MlflowClient().log_batch(run_id, metrics=metrics) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) metrics = [ Metric(key="ok-name", value="non-numerical-value", timestamp=3, step=0) ] with pytest.raises(MlflowException, match="Got invalid value") as e: tracking.MlflowClient().log_batch(run_id, metrics=metrics) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) metrics = [ Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp", step=0) ] with pytest.raises(MlflowException, match="Got invalid timestamp") as e: tracking.MlflowClient().log_batch(run_id, metrics=metrics) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) params = [Param(key="../bad/param/name", value="my-val")] with pytest.raises(MlflowException, match="Invalid parameter name") as e: tracking.MlflowClient().log_batch(run_id, params=params) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) tags = [Param(key="../bad/tag/name", value="my-val")] with pytest.raises(MlflowException, match="Invalid tag name") as e: tracking.MlflowClient().log_batch(run_id, tags=tags) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) metrics = [Metric(key=None, value=42.0, timestamp=4, step=1)] with pytest.raises( MlflowException, match="Metric name cannot be None. A key name must be provided." ) as e: tracking.MlflowClient().log_batch(run_id, metrics=metrics) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
def test_start_run_context_manager(): try: tracking.set_tracking_uri(tempfile.mkdtemp()) first_run = start_run() first_uuid = first_run.info.run_uuid with first_run: # Check that start_run() causes the run information to be persisted in the store persisted_run = tracking.MlflowClient().get_run(first_uuid) assert persisted_run is not None assert persisted_run.info == first_run.info finished_run = tracking.MlflowClient().get_run(first_uuid) assert finished_run.info.status == RunStatus.FINISHED # Launch a separate run that fails, verify the run status is FAILED and the run UUID is # different second_run = start_run() assert second_run.info.run_uuid != first_uuid with pytest.raises(Exception): with second_run: raise Exception("Failing run!") finished_run2 = tracking.MlflowClient().get_run(second_run.info.run_uuid) assert finished_run2.info.status == RunStatus.FAILED finally: tracking.set_tracking_uri(None)
def test_start_run_defaults_databricks_notebook( empty_active_run_stack, ): # pylint: disable=unused-argument mock_experiment_id = mock.Mock() experiment_id_patch = mock.patch( "mlflow.tracking.fluent._get_experiment_id", return_value=mock_experiment_id ) databricks_notebook_patch = mock.patch( "mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True ) mock_user = mock.Mock() user_patch = mock.patch( "mlflow.tracking.context.default_context._get_user", return_value=mock_user ) mock_source_version = mock.Mock() source_version_patch = mock.patch( "mlflow.tracking.context.git_context._get_source_version", return_value=mock_source_version ) mock_notebook_id = mock.Mock() notebook_id_patch = mock.patch( "mlflow.utils.databricks_utils.get_notebook_id", return_value=mock_notebook_id ) mock_notebook_path = mock.Mock() notebook_path_patch = mock.patch( "mlflow.utils.databricks_utils.get_notebook_path", return_value=mock_notebook_path ) mock_webapp_url = mock.Mock() webapp_url_patch = mock.patch( "mlflow.utils.databricks_utils.get_webapp_url", return_value=mock_webapp_url ) expected_tags = { mlflow_tags.MLFLOW_USER: mock_user, mlflow_tags.MLFLOW_SOURCE_NAME: mock_notebook_path, mlflow_tags.MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK), mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_ID: mock_notebook_id, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_PATH: mock_notebook_path, mlflow_tags.MLFLOW_DATABRICKS_WEBAPP_URL: mock_webapp_url, } create_run_patch = mock.patch.object(MlflowClient, "create_run") with experiment_id_patch, databricks_notebook_patch, user_patch, source_version_patch, notebook_id_patch, notebook_path_patch, webapp_url_patch, create_run_patch: # noqa active_run = start_run() MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, tags=expected_tags ) assert is_from_run(active_run, MlflowClient.create_run.return_value)
def test_log_metric(tracking_uri_mock): active_run = start_run() run_uuid = active_run.info.run_uuid with active_run: mlflow.log_metric("name_1", 25) mlflow.log_metric("name_2", -3) mlflow.log_metric("name_1", 30) mlflow.log_metric("nested/nested/name", 40) finished_run = tracking.MlflowClient().get_run(run_uuid) # Validate metrics assert len(finished_run.data.metrics) == 3 expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40} for metric in finished_run.data.metrics: assert expected_pairs[metric.key] == metric.value
def test_set_tags(): exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": 5} approx_expected_tags = set([MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE]) with start_run() as active_run: run_id = active_run.info.run_id mlflow.set_tags(exact_expected_tags) finished_run = tracking.MlflowClient().get_run(run_id) # Validate tags assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags) for tag_key, tag_val in finished_run.data.tags.items(): if tag_key in approx_expected_tags: pass else: assert str(exact_expected_tags[tag_key]) == tag_val
def test_start_run_overrides_databricks_notebook(empty_active_run_stack): databricks_notebook_patch = mock.patch( "mlflow.tracking.fluent.is_in_databricks_notebook", return_value=True ) mock_notebook_id = mock.Mock() notebook_id_patch = mock.patch( "mlflow.tracking.fluent.get_notebook_id", return_value=mock_notebook_id ) mock_notebook_path = mock.Mock() notebook_path_patch = mock.patch( "mlflow.tracking.fluent.get_notebook_path", return_value=mock_notebook_path ) mock_webapp_url = mock.Mock() webapp_url_patch = mock.patch( "mlflow.tracking.fluent.get_webapp_url", return_value=mock_webapp_url ) expected_tags = { mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_ID: mock_notebook_id, mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_PATH: mock_notebook_path, mlflow_tags.MLFLOW_DATABRICKS_WEBAPP_URL: mock_webapp_url } create_run_patch = mock.patch.object(MlflowClient, "create_run") mock_experiment_id = mock.Mock() mock_source_version = mock.Mock() mock_entry_point_name = mock.Mock() mock_run_name = mock.Mock() with databricks_notebook_patch, create_run_patch, notebook_id_patch, notebook_path_patch, \ webapp_url_patch: active_run = start_run( experiment_id=mock_experiment_id, source_name="ignored", source_version=mock_source_version, entry_point_name=mock_entry_point_name, source_type="ignored", run_name=mock_run_name ) MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, run_name=mock_run_name, source_name=mock_notebook_path, source_version=mock_source_version, entry_point_name=mock_entry_point_name, source_type=SourceType.NOTEBOOK, tags=expected_tags, parent_run_id=None ) assert is_from_run(active_run, MlflowClient.create_run.return_value)