Beispiel #1
0
def test_run_databricks_validations(
        tmpdir, cluster_spec_mock,  # pylint: disable=unused-argument
        tracking_uri_mock, dbfs_mocks, set_tag_mock):  # pylint: disable=unused-argument
    """
    Tests that running on Databricks fails before making any API requests if validations fail.
    """
    with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
        mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\
            as db_api_req_mock:
        # Test bad tracking URI
        tracking_uri_mock.return_value = tmpdir.strpath
        with pytest.raises(ExecutionException):
            run_databricks_project(cluster_spec_mock, synchronous=True)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        mlflow_service = mlflow.tracking.MlflowClient()
        assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID))
                == 0)
        tracking_uri_mock.return_value = "http://"
        # Test misspecified parameters
        with pytest.raises(ExecutionException):
            mlflow.projects.run(
                TEST_PROJECT_DIR, backend="databricks", entry_point="greeter",
                backend_config=cluster_spec_mock)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test bad cluster spec
        with pytest.raises(ExecutionException):
            mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True,
                                backend_config=None)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test that validations pass with good tracking URIs
        databricks.before_run_validations("http://", cluster_spec_mock)
        databricks.before_run_validations("databricks", cluster_spec_mock)
Beispiel #2
0
def test_run_databricks_validations(
    tmpdir,
    cluster_spec_mock,
    dbfs_mocks,
    set_tag_mock,
):  # pylint: disable=unused-argument
    """
    Tests that running on Databricks fails before making any API requests if validations fail.
    """
    with mock.patch.dict(os.environ, {
            "DATABRICKS_HOST": "test-host",
            "DATABRICKS_TOKEN": "foo"
    }), mock.patch(
            "mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request"
    ) as db_api_req_mock:
        # Test bad tracking URI
        mlflow.set_tracking_uri(tmpdir.strpath)
        with pytest.raises(ExecutionException,
                           match="MLflow tracking URI must be of"):
            run_databricks_project(cluster_spec_mock, synchronous=True)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        mlflow_service = mlflow.tracking.MlflowClient()
        assert (len(
            mlflow_service.list_run_infos(
                experiment_id=FileStore.DEFAULT_EXPERIMENT_ID)) == 0)
        mlflow.set_tracking_uri("databricks")
        # Test misspecified parameters
        with pytest.raises(
                ExecutionException,
                match="No value given for missing parameters: 'name'"):
            mlflow.projects.run(
                TEST_PROJECT_DIR,
                backend="databricks",
                entry_point="greeter",
                backend_config=cluster_spec_mock,
            )
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test bad cluster spec
        with pytest.raises(ExecutionException,
                           match="Backend spec must be provided"):
            mlflow.projects.run(TEST_PROJECT_DIR,
                                backend="databricks",
                                synchronous=True,
                                backend_config=None)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test that validations pass with good tracking URIs
        databricks.before_run_validations("http://", cluster_spec_mock)
        databricks.before_run_validations("databricks", cluster_spec_mock)