def test_run_databricks_validations( tmpdir, cluster_spec_mock, # pylint: disable=unused-argument tracking_uri_mock, dbfs_mocks, set_tag_mock): # pylint: disable=unused-argument """ Tests that running on Databricks fails before making any API requests if validations fail. """ with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\ mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\ as db_api_req_mock: # Test bad tracking URI tracking_uri_mock.return_value = tmpdir.strpath with pytest.raises(ExecutionException): run_databricks_project(cluster_spec_mock, synchronous=True) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() mlflow_service = mlflow.tracking.MlflowClient() assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID)) == 0) tracking_uri_mock.return_value = "http://" # Test misspecified parameters with pytest.raises(ExecutionException): mlflow.projects.run( TEST_PROJECT_DIR, backend="databricks", entry_point="greeter", backend_config=cluster_spec_mock) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() # Test bad cluster spec with pytest.raises(ExecutionException): mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True, backend_config=None) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() # Test that validations pass with good tracking URIs databricks.before_run_validations("http://", cluster_spec_mock) databricks.before_run_validations("databricks", cluster_spec_mock)
def test_run_databricks_validations( tmpdir, cluster_spec_mock, dbfs_mocks, set_tag_mock, ): # pylint: disable=unused-argument """ Tests that running on Databricks fails before making any API requests if validations fail. """ with mock.patch.dict(os.environ, { "DATABRICKS_HOST": "test-host", "DATABRICKS_TOKEN": "foo" }), mock.patch( "mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request" ) as db_api_req_mock: # Test bad tracking URI mlflow.set_tracking_uri(tmpdir.strpath) with pytest.raises(ExecutionException, match="MLflow tracking URI must be of"): run_databricks_project(cluster_spec_mock, synchronous=True) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() mlflow_service = mlflow.tracking.MlflowClient() assert (len( mlflow_service.list_run_infos( experiment_id=FileStore.DEFAULT_EXPERIMENT_ID)) == 0) mlflow.set_tracking_uri("databricks") # Test misspecified parameters with pytest.raises( ExecutionException, match="No value given for missing parameters: 'name'"): mlflow.projects.run( TEST_PROJECT_DIR, backend="databricks", entry_point="greeter", backend_config=cluster_spec_mock, ) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() # Test bad cluster spec with pytest.raises(ExecutionException, match="Backend spec must be provided"): mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True, backend_config=None) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() # Test that validations pass with good tracking URIs databricks.before_run_validations("http://", cluster_spec_mock) databricks.before_run_validations("databricks", cluster_spec_mock)