Esempio n. 1
0
def test_databricks_http_request_integration(get_config, request):
    """Confirms that the databricks http request params can in fact be used as an HTTP request"""
    def confirm_request_params(**kwargs):
        headers = dict(_DEFAULT_HEADERS)
        headers['Authorization'] = 'Basic dXNlcjpwYXNz'
        assert kwargs == {
            'method': 'PUT',
            'url': 'host/clusters/list',
            'headers': headers,
            'verify': True,
            'json': {
                'a': 'b'
            }
        }
        http_response = mock.MagicMock()
        http_response.status_code = 200
        http_response.text = '{"OK": "woo"}'
        return http_response

    request.side_effect = confirm_request_params
    get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=False)

    response = DatabricksJobRunner(
        databricks_profile=None)._databricks_api_request('/clusters/list',
                                                         'PUT',
                                                         json={'a': 'b'})
    assert json.loads(response.text) == {'OK': 'woo'}
    get_config.reset_mock()
    response = DatabricksJobRunner(
        databricks_profile="my-profile")._databricks_api_request(
            '/clusters/list', 'PUT', json={'a': 'b'})
    assert json.loads(response.text) == {'OK': 'woo'}
    assert get_config.call_count == 0
Esempio n. 2
0
def test_run_databricks_failed(_):
    with mock.patch('mlflow.utils.rest_utils.http_request') as m:
        text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
        m.return_value = mock.Mock(text=text, status_code=400)
        runner = DatabricksJobRunner('profile')
        with pytest.raises(MlflowException):
            runner._run_shell_command_job('/project', 'command', {}, {})
Esempio n. 3
0
def test_databricks_http_request_integration(get_config, request):
    """Confirms that the databricks http request params can in fact be used as an HTTP request"""

    def confirm_request_params(*args, **kwargs):
        headers = dict(_DEFAULT_HEADERS)
        headers["Authorization"] = "Basic dXNlcjpwYXNz"
        assert args == ("PUT", "host/clusters/list")
        assert kwargs == {
            "headers": headers,
            "verify": True,
            "json": {"a": "b"},
            "timeout": 120,
        }
        http_response = mock.MagicMock()
        http_response.status_code = 200
        http_response.text = '{"OK": "woo"}'
        return http_response

    request.side_effect = confirm_request_params
    get_config.return_value = DatabricksConfig.from_password("host", "user", "pass", insecure=False)

    response = DatabricksJobRunner(databricks_profile_uri=None)._databricks_api_request(
        "/clusters/list", "PUT", json={"a": "b"}
    )
    assert json.loads(response.text) == {"OK": "woo"}
    get_config.reset_mock()
    response = DatabricksJobRunner(
        databricks_profile_uri=construct_db_uri_from_profile("my-profile")
    )._databricks_api_request("/clusters/list", "PUT", json={"a": "b"})
    assert json.loads(response.text) == {"OK": "woo"}
    assert get_config.call_count == 0
Esempio n. 4
0
def test_run_databricks_failed(_):
    with mock.patch("mlflow.utils.rest_utils.http_request") as m:
        text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
        m.return_value = mock.Mock(text=text, status_code=400)
        runner = DatabricksJobRunner(construct_db_uri_from_profile("profile"))
        with pytest.raises(MlflowException):
            runner._run_shell_command_job("/project", "command", {}, {})
Esempio n. 5
0
def test_upload_existing_project_to_dbfs(dbfs_path_exists_mock):  # pylint: disable=unused-argument
    # Check that we don't upload the project if it already exists on DBFS
    with mock.patch("mlflow.projects.databricks.DatabricksJobRunner._upload_to_dbfs")\
            as upload_to_dbfs_mock:
        dbfs_path_exists_mock.return_value = True
        runner = DatabricksJobRunner(databricks_profile="DEFAULT")
        runner._upload_project_to_dbfs(project_dir=TEST_PROJECT_DIR,
                                       experiment_id=0)
        assert upload_to_dbfs_mock.call_count == 0
Esempio n. 6
0
def test_run_databricks_failed():
    with mock.patch(
            'mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request'
    ) as m:
        m.return_value = mock.Mock(
            text="{'message': 'Node type not supported'}", status_code=400)
        runner = DatabricksJobRunner('profile')
        with pytest.raises(MlflowException):
            runner._run_shell_command_job('/project', 'command', {}, {})
Esempio n. 7
0
def test_upload_project_to_dbfs(dbfs_root_mock, tmpdir, dbfs_path_exists_mock,
                                upload_to_dbfs_mock):  # pylint: disable=unused-argument
    # Upload project to a mock directory
    dbfs_path_exists_mock.return_value = False
    runner = DatabricksJobRunner(databricks_profile="DEFAULT")
    dbfs_uri = runner._upload_project_to_dbfs(project_dir=TEST_PROJECT_DIR,
                                              experiment_id=0)
    # Get expected tar
    local_tar_path = os.path.join(dbfs_root_mock, dbfs_uri.split("/dbfs/")[1])
    expected_tar_path = str(tmpdir.join("expected.tar.gz"))
    file_utils.make_tarfile(output_filename=expected_tar_path,
                            source_dir=TEST_PROJECT_DIR,
                            archive_name=databricks.DB_TARFILE_ARCHIVE_NAME)
    # Extract the tarred project, verify its contents
    assert filecmp.cmp(local_tar_path, expected_tar_path, shallow=False)
Esempio n. 8
0
def test_dbfs_path_exists_error_response_handling(response_mock):
    with mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds") \
            as get_databricks_host_creds_mock, \
            mock.patch("mlflow.utils.rest_utils.http_request") as http_request_mock:
        # given a well formed DatabricksJobRunner
        # note: databricks_profile is None needed because clients using profile are mocked
        job_runner = DatabricksJobRunner(databricks_profile=None)

        # when the http request to validate the dbfs path returns a 400 response with an
        # error message that is either well-formed JSON or not
        get_databricks_host_creds_mock.return_value = None
        http_request_mock.return_value = response_mock

        # then _dbfs_path_exists should return a MlflowException
        with pytest.raises(MlflowException):
            job_runner._dbfs_path_exists('some/path')
Esempio n. 9
0
def test_run_databricks_validations(
        tmpdir, cluster_spec_mock,  # pylint: disable=unused-argument
        tracking_uri_mock, dbfs_mocks):  # pylint: disable=unused-argument
    """
    Tests that running on Databricks fails before making any API requests if validations fail.
    """
    with mock.patch("mlflow.projects.databricks.DatabricksJobRunner._check_auth_available"),\
        mock.patch("mlflow.projects.databricks.DatabricksJobRunner.databricks_api_request")\
            as db_api_req_mock:
        # Test bad tracking URI
        tracking_uri_mock.return_value = tmpdir.strpath
        with pytest.raises(ExecutionException):
            run_databricks_project(cluster_spec_mock, block=True)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        tracking_uri_mock.return_value = "http://"
        # Test misspecified parameters
        with pytest.raises(ExecutionException):
            mlflow.projects.run(
                TEST_PROJECT_DIR, mode="databricks", entry_point="greeter",
                cluster_spec=cluster_spec_mock)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test bad cluster spec
        with pytest.raises(ExecutionException):
            mlflow.projects.run(TEST_PROJECT_DIR, mode="databricks", block=True, cluster_spec=None)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test that validations pass with good tracking URIs
        runner = DatabricksJobRunner(databricks_profile="DEFAULT")
        runner._before_run_validations("http://", cluster_spec_mock)
        runner._before_run_validations("databricks", cluster_spec_mock)