예제 #1
0
def test_start_run_with_parent():
    parent_run = mock.Mock()
    mock_experiment_id = "123456"
    mock_source_name = mock.Mock()

    active_run_stack_patch = mock.patch(
        "mlflow.tracking.fluent._active_run_stack", [parent_run])

    mock_user = mock.Mock()
    user_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_user",
        return_value=mock_user)
    source_name_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_source_name",
        return_value=mock_source_name)

    expected_tags = {
        mlflow_tags.MLFLOW_USER: mock_user,
        mlflow_tags.MLFLOW_SOURCE_NAME: mock_source_name,
        mlflow_tags.MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.LOCAL),
        mlflow_tags.MLFLOW_PARENT_RUN_ID: parent_run.info.run_id,
    }

    create_run_patch = mock.patch.object(MlflowClient, "create_run")

    with multi_context(
            active_run_stack_patch,
            create_run_patch,
            user_patch,
            source_name_patch,
    ):
        active_run = start_run(experiment_id=mock_experiment_id, nested=True)
        MlflowClient.create_run.assert_called_once_with(
            experiment_id=mock_experiment_id, tags=expected_tags)
        assert is_from_run(active_run, MlflowClient.create_run.return_value)
예제 #2
0
def test_databricks_repo_run_context_tags_nones():
    patch_git_repo_url = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_url", return_value=None)
    patch_git_repo_provider = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_provider",
        return_value=None)
    patch_git_repo_commit = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_commit", return_value=None)
    patch_git_repo_relative_path = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_relative_path",
        return_value=None)
    patch_git_repo_reference = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_reference",
        return_value=None)
    patch_git_repo_reference_type = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_reference_type",
        return_value=None)
    patch_git_repo_status = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_status", return_value=None)
    with multi_context(
            patch_git_repo_url,
            patch_git_repo_provider,
            patch_git_repo_commit,
            patch_git_repo_relative_path,
            patch_git_repo_reference,
            patch_git_repo_reference_type,
            patch_git_repo_status,
    ):
        assert DatabricksRepoRunContext().tags() == {}
def test_databricks_job_default_experiment_id():
    job_id = "job_id"
    exp_name = "jobs:/" + str(job_id)
    patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id",
                              return_value=job_id)
    patch_job_type = mock.patch(
        "mlflow.utils.databricks_utils.get_job_type_info",
        return_value="NORMAL")
    patch_experiment_name_from_job_id = mock.patch(
        "mlflow.utils.databricks_utils.get_experiment_name_from_job_id",
        return_value=exp_name)
    experiment_id = "experiment_id"

    create_experiment = mock.patch.object(MlflowClient,
                                          "create_experiment",
                                          return_value=experiment_id)

    with multi_context(patch_job_id, patch_job_type,
                       patch_experiment_name_from_job_id,
                       create_experiment) as (
                           job_id_mock,
                           job_type_info_mock,
                           experiment_name_from_job_id_mock,
                           create_experiment_mock,
                       ):
        tags = {}
        tags[MLFLOW_DATABRICKS_JOB_TYPE_INFO] = job_type_info_mock.return_value
        tags[MLFLOW_EXPERIMENT_SOURCE_TYPE] = SourceType.to_string(
            SourceType.JOB)
        tags[MLFLOW_EXPERIMENT_SOURCE_ID] = job_id_mock.return_value

        assert DatabricksJobExperimentProvider().get_experiment_id(
        ) == experiment_id
        create_experiment_mock.assert_called_once_with(
            experiment_name_from_job_id_mock.return_value, None, tags)
def test_databricks_notebook_run_context_tags():
    patch_notebook_id = mock.patch(
        "mlflow.utils.databricks_utils.get_notebook_id")
    patch_notebook_path = mock.patch(
        "mlflow.utils.databricks_utils.get_notebook_path")
    patch_webapp_url = mock.patch(
        "mlflow.utils.databricks_utils.get_webapp_url")
    patch_workspace_info = mock.patch(
        "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils",
        return_value=("https://databricks.com", "123456"),
    )

    with multi_context(patch_notebook_id, patch_notebook_path,
                       patch_webapp_url, patch_workspace_info) as (
                           notebook_id_mock,
                           notebook_path_mock,
                           webapp_url_mock,
                           workspace_info_mock,
                       ):
        assert DatabricksNotebookRunContext().tags() == {
            MLFLOW_SOURCE_NAME: notebook_path_mock.return_value,
            MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK),
            MLFLOW_DATABRICKS_NOTEBOOK_ID: notebook_id_mock.return_value,
            MLFLOW_DATABRICKS_NOTEBOOK_PATH: notebook_path_mock.return_value,
            MLFLOW_DATABRICKS_WEBAPP_URL: webapp_url_mock.return_value,
            MLFLOW_DATABRICKS_WORKSPACE_URL:
            workspace_info_mock.return_value[0],
            MLFLOW_DATABRICKS_WORKSPACE_ID:
            workspace_info_mock.return_value[1],
        }
예제 #5
0
def test_databricks_job_run_context_tags():
    patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id")
    patch_job_run_id = mock.patch(
        "mlflow.utils.databricks_utils.get_job_run_id")
    patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type")
    patch_webapp_url = mock.patch(
        "mlflow.utils.databricks_utils.get_webapp_url")

    with multi_context(patch_job_id, patch_job_run_id, patch_job_type,
                       patch_webapp_url) as (
                           job_id_mock,
                           job_run_id_mock,
                           job_type_mock,
                           webapp_url_mock,
                       ):
        assert DatabricksJobRunContext().tags() == {
            MLFLOW_SOURCE_NAME:
            "jobs/{job_id}/run/{job_run_id}".format(
                job_id=job_id_mock.return_value,
                job_run_id=job_run_id_mock.return_value),
            MLFLOW_SOURCE_TYPE:
            SourceType.to_string(SourceType.JOB),
            MLFLOW_DATABRICKS_JOB_ID:
            job_id_mock.return_value,
            MLFLOW_DATABRICKS_JOB_RUN_ID:
            job_run_id_mock.return_value,
            MLFLOW_DATABRICKS_JOB_TYPE:
            job_type_mock.return_value,
            MLFLOW_DATABRICKS_WEBAPP_URL:
            webapp_url_mock.return_value,
        }
예제 #6
0
def test_start_run_defaults_databricks_notebook(empty_active_run_stack, ):  # pylint: disable=unused-argument

    mock_experiment_id = mock.Mock()
    experiment_id_patch = mock.patch(
        "mlflow.tracking.fluent._get_experiment_id",
        return_value=mock_experiment_id)
    databricks_notebook_patch = mock.patch(
        "mlflow.utils.databricks_utils.is_in_databricks_notebook",
        return_value=True)
    mock_user = mock.Mock()
    user_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_user",
        return_value=mock_user)
    mock_source_version = mock.Mock()
    source_version_patch = mock.patch(
        "mlflow.tracking.context.git_context._get_source_version",
        return_value=mock_source_version)
    mock_notebook_id = mock.Mock()
    notebook_id_patch = mock.patch(
        "mlflow.utils.databricks_utils.get_notebook_id",
        return_value=mock_notebook_id)
    mock_notebook_path = mock.Mock()
    notebook_path_patch = mock.patch(
        "mlflow.utils.databricks_utils.get_notebook_path",
        return_value=mock_notebook_path)
    mock_webapp_url = mock.Mock()
    webapp_url_patch = mock.patch(
        "mlflow.utils.databricks_utils.get_webapp_url",
        return_value=mock_webapp_url)

    expected_tags = {
        mlflow_tags.MLFLOW_USER: mock_user,
        mlflow_tags.MLFLOW_SOURCE_NAME: mock_notebook_path,
        mlflow_tags.MLFLOW_SOURCE_TYPE:
        SourceType.to_string(SourceType.NOTEBOOK),
        mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version,
        mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_ID: mock_notebook_id,
        mlflow_tags.MLFLOW_DATABRICKS_NOTEBOOK_PATH: mock_notebook_path,
        mlflow_tags.MLFLOW_DATABRICKS_WEBAPP_URL: mock_webapp_url,
    }

    create_run_patch = mock.patch.object(MlflowClient, "create_run")

    with multi_context(
            experiment_id_patch,
            databricks_notebook_patch,
            user_patch,
            source_version_patch,
            notebook_id_patch,
            notebook_path_patch,
            webapp_url_patch,
            create_run_patch,
    ):
        active_run = start_run()
        MlflowClient.create_run.assert_called_once_with(
            experiment_id=mock_experiment_id, tags=expected_tags)
        assert is_from_run(active_run, MlflowClient.create_run.return_value)
예제 #7
0
def test_start_run_creates_new_run_with_user_specified_tags():

    mock_experiment_id = mock.Mock()
    experiment_id_patch = mock.patch(
        "mlflow.tracking.fluent._get_experiment_id", return_value=mock_experiment_id
    )
    databricks_notebook_patch = mock.patch(
        "mlflow.tracking.fluent.is_in_databricks_notebook", return_value=False
    )
    mock_user = mock.Mock()
    user_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_user", return_value=mock_user
    )
    mock_source_name = mock.Mock()
    source_name_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_source_name", return_value=mock_source_name
    )
    source_type_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_source_type", return_value=SourceType.NOTEBOOK
    )
    mock_source_version = mock.Mock()
    source_version_patch = mock.patch(
        "mlflow.tracking.context.git_context._get_source_version", return_value=mock_source_version
    )
    user_specified_tags = {
        "ml_task": "regression",
        "num_layers": 7,
        mlflow_tags.MLFLOW_USER: "******",
    }
    expected_tags = {
        mlflow_tags.MLFLOW_SOURCE_NAME: mock_source_name,
        mlflow_tags.MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK),
        mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version,
        mlflow_tags.MLFLOW_USER: "******",
        "ml_task": "regression",
        "num_layers": 7,
    }

    create_run_patch = mock.patch.object(MlflowClient, "create_run")

    with multi_context(
        experiment_id_patch,
        databricks_notebook_patch,
        user_patch,
        source_name_patch,
        source_type_patch,
        source_version_patch,
        create_run_patch,
    ):
        active_run = start_run(tags=user_specified_tags)
        MlflowClient.create_run.assert_called_once_with(
            experiment_id=mock_experiment_id, tags=expected_tags
        )
        assert is_from_run(active_run, MlflowClient.create_run.return_value)
예제 #8
0
def test_databricks_repo_run_context_tags():
    patch_git_repo_url = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_url")
    patch_git_repo_provider = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_provider")
    patch_git_repo_commit = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_commit")
    patch_git_repo_relative_path = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_relative_path")
    patch_git_repo_reference = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_reference")
    patch_git_repo_reference_type = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_reference_type")
    patch_git_repo_status = mock.patch(
        "mlflow.utils.databricks_utils.get_git_repo_status")

    with multi_context(
            patch_git_repo_url,
            patch_git_repo_provider,
            patch_git_repo_commit,
            patch_git_repo_relative_path,
            patch_git_repo_reference,
            patch_git_repo_reference_type,
            patch_git_repo_status,
    ) as (
            git_repo_url_mock,
            git_repo_provider_mock,
            git_repo_commit_mock,
            git_repo_relative_path_mock,
            git_repo_reference_mock,
            git_repo_reference_type_mock,
            git_repo_status_mock,
    ):
        assert DatabricksRepoRunContext().tags() == {
            MLFLOW_DATABRICKS_GIT_REPO_URL:
            git_repo_url_mock.return_value,
            MLFLOW_DATABRICKS_GIT_REPO_PROVIDER:
            git_repo_provider_mock.return_value,
            MLFLOW_DATABRICKS_GIT_REPO_COMMIT:
            git_repo_commit_mock.return_value,
            MLFLOW_DATABRICKS_GIT_REPO_RELATIVE_PATH:
            git_repo_relative_path_mock.return_value,
            MLFLOW_DATABRICKS_GIT_REPO_REFERENCE:
            git_repo_reference_mock.return_value,
            MLFLOW_DATABRICKS_GIT_REPO_REFERENCE_TYPE:
            git_repo_reference_type_mock.return_value,
            MLFLOW_DATABRICKS_GIT_REPO_STATUS:
            git_repo_status_mock.return_value,
        }
예제 #9
0
def test_start_run_defaults(empty_active_run_stack):  # pylint: disable=unused-argument

    mock_experiment_id = mock.Mock()
    experiment_id_patch = mock.patch(
        "mlflow.tracking.fluent._get_experiment_id", return_value=mock_experiment_id
    )
    databricks_notebook_patch = mock.patch(
        "mlflow.tracking.fluent.is_in_databricks_notebook", return_value=False
    )
    mock_user = mock.Mock()
    user_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_user", return_value=mock_user
    )
    mock_source_name = mock.Mock()
    source_name_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_source_name", return_value=mock_source_name
    )
    source_type_patch = mock.patch(
        "mlflow.tracking.context.default_context._get_source_type", return_value=SourceType.NOTEBOOK
    )
    mock_source_version = mock.Mock()
    source_version_patch = mock.patch(
        "mlflow.tracking.context.git_context._get_source_version", return_value=mock_source_version
    )

    expected_tags = {
        mlflow_tags.MLFLOW_USER: mock_user,
        mlflow_tags.MLFLOW_SOURCE_NAME: mock_source_name,
        mlflow_tags.MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK),
        mlflow_tags.MLFLOW_GIT_COMMIT: mock_source_version,
    }

    create_run_patch = mock.patch.object(MlflowClient, "create_run")

    with multi_context(
        experiment_id_patch,
        databricks_notebook_patch,
        user_patch,
        source_name_patch,
        source_type_patch,
        source_version_patch,
        create_run_patch,
    ):
        active_run = start_run()
        MlflowClient.create_run.assert_called_once_with(
            experiment_id=mock_experiment_id, tags=expected_tags
        )
        assert is_from_run(active_run, MlflowClient.create_run.return_value)
예제 #10
0
def test_databricks_notebook_run_context_tags():
    patch_notebook_id = mock.patch("mlflow.utils.databricks_utils.get_notebook_id")
    patch_notebook_path = mock.patch("mlflow.utils.databricks_utils.get_notebook_path")
    patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url")

    with multi_context(patch_notebook_id, patch_notebook_path, patch_webapp_url) as (
        notebook_id_mock,
        notebook_path_mock,
        webapp_url_mock,
    ):
        assert DatabricksNotebookRunContext().tags() == {
            MLFLOW_SOURCE_NAME: notebook_path_mock.return_value,
            MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK),
            MLFLOW_DATABRICKS_NOTEBOOK_ID: notebook_id_mock.return_value,
            MLFLOW_DATABRICKS_NOTEBOOK_PATH: notebook_path_mock.return_value,
            MLFLOW_DATABRICKS_WEBAPP_URL: webapp_url_mock.return_value,
        }
def test_databricks_job_run_context_tags():
    patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id")
    patch_job_run_id = mock.patch(
        "mlflow.utils.databricks_utils.get_job_run_id")
    patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type")
    patch_webapp_url = mock.patch(
        "mlflow.utils.databricks_utils.get_webapp_url")
    patch_workspace_info = mock.patch(
        "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils",
        return_value=("https://databricks.com", "123456"),
    )

    with multi_context(patch_job_id, patch_job_run_id, patch_job_type,
                       patch_webapp_url, patch_workspace_info) as (
                           job_id_mock,
                           job_run_id_mock,
                           job_type_mock,
                           webapp_url_mock,
                           workspace_info_mock,
                       ):
        assert DatabricksJobRunContext().tags() == {
            MLFLOW_SOURCE_NAME:
            "jobs/{job_id}/run/{job_run_id}".format(
                job_id=job_id_mock.return_value,
                job_run_id=job_run_id_mock.return_value),
            MLFLOW_SOURCE_TYPE:
            SourceType.to_string(SourceType.JOB),
            MLFLOW_DATABRICKS_JOB_ID:
            job_id_mock.return_value,
            MLFLOW_DATABRICKS_JOB_RUN_ID:
            job_run_id_mock.return_value,
            MLFLOW_DATABRICKS_JOB_TYPE:
            job_type_mock.return_value,
            MLFLOW_DATABRICKS_WEBAPP_URL:
            webapp_url_mock.return_value,
            MLFLOW_DATABRICKS_WORKSPACE_URL:
            workspace_info_mock.return_value[0],
            MLFLOW_DATABRICKS_WORKSPACE_ID:
            workspace_info_mock.return_value[1],
        }