def test_use_repl_context_if_available(tmpdir):
    # Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable.
    with pytest.raises(ModuleNotFoundError,
                       match="No module named 'dbruntime'"):
        from dbruntime.databricks_repl_context import get_context  # pylint: disable=unused-import

    command_context_mock = mock.MagicMock()
    command_context_mock.jobId().get.return_value = "job_id"
    with mock.patch(
            "mlflow.utils.databricks_utils._get_command_context",
            return_value=command_context_mock) as mock_get_command_context:
        assert databricks_utils.get_job_id() == "job_id"
        mock_get_command_context.assert_called_once()

    # Create a fake databricks_repl_context module
    tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write("""
def get_context():
    pass
""")
    sys.path.append(tmpdir.strpath)

    # Simulate a case where the REPL context object is not initialized.
    with mock.patch(
            "dbruntime.databricks_repl_context.get_context",
            return_value=None,
    ) as mock_get_context, mock.patch(
            "mlflow.utils.databricks_utils._get_command_context",
            return_value=command_context_mock) as mock_get_command_context:
        assert databricks_utils.get_job_id() == "job_id"
        mock_get_command_context.assert_called_once()

    with mock.patch(
            "dbruntime.databricks_repl_context.get_context",
            return_value=mock.MagicMock(jobId="job_id"),
    ) as mock_get_context, mock.patch(
            "mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils:
        assert databricks_utils.get_job_id() == "job_id"
        mock_get_context.assert_called_once()
        mock_dbutils.assert_not_called()

    with mock.patch(
            "dbruntime.databricks_repl_context.get_context",
            return_value=mock.MagicMock(notebookId="notebook_id"),
    ) as mock_get_context, mock.patch(
            "mlflow.utils.databricks_utils._get_property_from_spark_context"
    ) as mock_spark_context:
        assert databricks_utils.get_notebook_id() == "notebook_id"
        mock_get_context.assert_called_once()
        mock_spark_context.assert_not_called()

    with mock.patch(
            "dbruntime.databricks_repl_context.get_context",
            return_value=mock.MagicMock(isInCluster=True),
    ) as mock_get_context, mock.patch(
            "mlflow.utils._spark_utils._get_active_spark_session"
    ) as mock_spark_session:
        assert databricks_utils.is_in_cluster()
        mock_get_context.assert_called_once()
        mock_spark_session.assert_not_called()
    def get_experiment_id(self):
        global _resolved_job_experiment_id
        if _resolved_job_experiment_id:
            return _resolved_job_experiment_id

        job_id = databricks_utils.get_job_id()
        tags = {}
        tags[
            MLFLOW_DATABRICKS_JOB_TYPE_INFO] = databricks_utils.get_job_type_info(
            )
        tags[MLFLOW_EXPERIMENT_SOURCE_TYPE] = "JOB"
        tags[MLFLOW_EXPERIMENT_SOURCE_ID] = job_id

        # The create_experiment is a get or create experiment function where it will
        # return the corresponding experiment if one exists for the job.
        # If no corresponding experiment exist, it will create a new one and return
        # the newly created experiment
        experiment_id = MlflowClient().create_experiment(
            databricks_utils.get_experiment_name_from_job_id(job_id), None,
            tags)
        _logger.debug(
            "Job experiment with experiment ID '%s' fetched or created",
            experiment_id,
        )

        _resolved_job_experiment_id = experiment_id

        return experiment_id
 def tags(self):
     job_id = databricks_utils.get_job_id()
     job_run_id = databricks_utils.get_job_run_id()
     job_type = databricks_utils.get_job_type()
     webapp_url = databricks_utils.get_webapp_url()
     workspace_url, workspace_id = databricks_utils.get_workspace_info_from_dbutils(
     )
     tags = {
         MLFLOW_SOURCE_NAME:
         ("jobs/{job_id}/run/{job_run_id}".format(job_id=job_id,
                                                  job_run_id=job_run_id)
          if job_id is not None and job_run_id is not None else None),
         MLFLOW_SOURCE_TYPE:
         SourceType.to_string(SourceType.JOB),
     }
     if job_id is not None:
         tags[MLFLOW_DATABRICKS_JOB_ID] = job_id
     if job_run_id is not None:
         tags[MLFLOW_DATABRICKS_JOB_RUN_ID] = job_run_id
     if job_type is not None:
         tags[MLFLOW_DATABRICKS_JOB_TYPE] = job_type
     if webapp_url is not None:
         tags[MLFLOW_DATABRICKS_WEBAPP_URL] = webapp_url
     if workspace_url is not None:
         tags[MLFLOW_DATABRICKS_WORKSPACE_URL] = workspace_url
     if workspace_id is not None:
         tags[MLFLOW_DATABRICKS_WORKSPACE_ID] = workspace_id
     return tags
예제 #4
0
    def request_headers(self):
        request_headers = {}
        if databricks_utils.is_in_databricks_notebook():
            request_headers["notebook_id"] = databricks_utils.get_notebook_id()
        if databricks_utils.is_in_databricks_job():
            request_headers["job_id"] = databricks_utils.get_job_id()
            request_headers["job_run_id"] = databricks_utils.get_job_run_id()
            request_headers["job_type"] = databricks_utils.get_job_type()
        if databricks_utils.is_in_cluster():
            request_headers["cluster_id"] = databricks_utils.get_cluster_id()

        return request_headers