def test_use_repl_context_if_available(tmpdir): # Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable. with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"): from dbruntime.databricks_repl_context import get_context # pylint: disable=unused-import command_context_mock = mock.MagicMock() command_context_mock.jobId().get.return_value = "job_id" with mock.patch( "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock) as mock_get_command_context: assert databricks_utils.get_job_id() == "job_id" mock_get_command_context.assert_called_once() # Create a fake databricks_repl_context module tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write(""" def get_context(): pass """) sys.path.append(tmpdir.strpath) # Simulate a case where the REPL context object is not initialized. with mock.patch( "dbruntime.databricks_repl_context.get_context", return_value=None, ) as mock_get_context, mock.patch( "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock) as mock_get_command_context: assert databricks_utils.get_job_id() == "job_id" mock_get_command_context.assert_called_once() with mock.patch( "dbruntime.databricks_repl_context.get_context", return_value=mock.MagicMock(jobId="job_id"), ) as mock_get_context, mock.patch( "mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils: assert databricks_utils.get_job_id() == "job_id" mock_get_context.assert_called_once() mock_dbutils.assert_not_called() with mock.patch( "dbruntime.databricks_repl_context.get_context", return_value=mock.MagicMock(notebookId="notebook_id"), ) as mock_get_context, mock.patch( "mlflow.utils.databricks_utils._get_property_from_spark_context" ) as mock_spark_context: assert databricks_utils.get_notebook_id() == "notebook_id" mock_get_context.assert_called_once() mock_spark_context.assert_not_called() with mock.patch( "dbruntime.databricks_repl_context.get_context", return_value=mock.MagicMock(isInCluster=True), ) as mock_get_context, mock.patch( "mlflow.utils._spark_utils._get_active_spark_session" ) as mock_spark_session: assert databricks_utils.is_in_cluster() mock_get_context.assert_called_once() mock_spark_session.assert_not_called()
def get_experiment_id(self): global _resolved_job_experiment_id if _resolved_job_experiment_id: return _resolved_job_experiment_id job_id = databricks_utils.get_job_id() tags = {} tags[ MLFLOW_DATABRICKS_JOB_TYPE_INFO] = databricks_utils.get_job_type_info( ) tags[MLFLOW_EXPERIMENT_SOURCE_TYPE] = "JOB" tags[MLFLOW_EXPERIMENT_SOURCE_ID] = job_id # The create_experiment is a get or create experiment function where it will # return the corresponding experiment if one exists for the job. # If no corresponding experiment exist, it will create a new one and return # the newly created experiment experiment_id = MlflowClient().create_experiment( databricks_utils.get_experiment_name_from_job_id(job_id), None, tags) _logger.debug( "Job experiment with experiment ID '%s' fetched or created", experiment_id, ) _resolved_job_experiment_id = experiment_id return experiment_id
def tags(self): job_id = databricks_utils.get_job_id() job_run_id = databricks_utils.get_job_run_id() job_type = databricks_utils.get_job_type() webapp_url = databricks_utils.get_webapp_url() workspace_url, workspace_id = databricks_utils.get_workspace_info_from_dbutils( ) tags = { MLFLOW_SOURCE_NAME: ("jobs/{job_id}/run/{job_run_id}".format(job_id=job_id, job_run_id=job_run_id) if job_id is not None and job_run_id is not None else None), MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB), } if job_id is not None: tags[MLFLOW_DATABRICKS_JOB_ID] = job_id if job_run_id is not None: tags[MLFLOW_DATABRICKS_JOB_RUN_ID] = job_run_id if job_type is not None: tags[MLFLOW_DATABRICKS_JOB_TYPE] = job_type if webapp_url is not None: tags[MLFLOW_DATABRICKS_WEBAPP_URL] = webapp_url if workspace_url is not None: tags[MLFLOW_DATABRICKS_WORKSPACE_URL] = workspace_url if workspace_id is not None: tags[MLFLOW_DATABRICKS_WORKSPACE_ID] = workspace_id return tags
def request_headers(self): request_headers = {} if databricks_utils.is_in_databricks_notebook(): request_headers["notebook_id"] = databricks_utils.get_notebook_id() if databricks_utils.is_in_databricks_job(): request_headers["job_id"] = databricks_utils.get_job_id() request_headers["job_run_id"] = databricks_utils.get_job_run_id() request_headers["job_type"] = databricks_utils.get_job_type() if databricks_utils.is_in_cluster(): request_headers["cluster_id"] = databricks_utils.get_cluster_id() return request_headers