示例#1
0
def test_run_databricks_failed(_):
    with mock.patch('mlflow.utils.rest_utils.http_request') as m:
        text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
        m.return_value = mock.Mock(text=text, status_code=400)
        runner = DatabricksJobRunner(construct_db_uri_from_profile('profile'))
        with pytest.raises(MlflowException):
            runner._run_shell_command_job('/project', 'command', {}, {})
示例#2
0
def test_databricks_http_request_integration(get_config, request):
    """Confirms that the databricks http request params can in fact be used as an HTTP request"""
    def confirm_request_params(**kwargs):
        headers = dict(_DEFAULT_HEADERS)
        headers['Authorization'] = 'Basic dXNlcjpwYXNz'
        assert kwargs == {
            'method': 'PUT',
            'url': 'host/clusters/list',
            'headers': headers,
            'verify': True,
            'json': {
                'a': 'b'
            }
        }
        http_response = mock.MagicMock()
        http_response.status_code = 200
        http_response.text = '{"OK": "woo"}'
        return http_response

    request.side_effect = confirm_request_params
    get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=False)

    response = DatabricksJobRunner(
        databricks_profile_uri=None)._databricks_api_request('/clusters/list',
                                                             'PUT',
                                                             json={'a': 'b'})
    assert json.loads(response.text) == {'OK': 'woo'}
    get_config.reset_mock()
    response = DatabricksJobRunner(
        databricks_profile_uri=construct_db_uri_from_profile("my-profile"))\
        ._databricks_api_request('/clusters/list', 'PUT', json={'a': 'b'})
    assert json.loads(response.text) == {'OK': 'woo'}
    assert get_config.call_count == 0
示例#3
0
def test_databricks_params_custom_profile(ProfileConfigProvider):
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=True)
    ProfileConfigProvider.return_value = mock_provider
    params = databricks_utils.get_databricks_host_creds(
        construct_db_uri_from_profile("profile"))
    assert params.ignore_tls_verification
    ProfileConfigProvider.assert_called_with("profile")
示例#4
0
def test_upload_existing_project_to_dbfs(dbfs_path_exists_mock):  # pylint: disable=unused-argument
    # Check that we don't upload the project if it already exists on DBFS
    with mock.patch("mlflow.projects.databricks.DatabricksJobRunner._upload_to_dbfs") \
            as upload_to_dbfs_mock:
        dbfs_path_exists_mock.return_value = True
        runner = DatabricksJobRunner(
            databricks_profile_uri=construct_db_uri_from_profile("DEFAULT"))
        runner._upload_project_to_dbfs(
            project_dir=TEST_PROJECT_DIR,
            experiment_id=FileStore.DEFAULT_EXPERIMENT_ID)
        assert upload_to_dbfs_mock.call_count == 0
示例#5
0
def test_upload_project_to_dbfs(dbfs_root_mock, tmpdir, dbfs_path_exists_mock,
                                upload_to_dbfs_mock):  # pylint: disable=unused-argument
    # Upload project to a mock directory
    dbfs_path_exists_mock.return_value = False
    runner = DatabricksJobRunner(
        databricks_profile_uri=construct_db_uri_from_profile("DEFAULT"))
    dbfs_uri = runner._upload_project_to_dbfs(
        project_dir=TEST_PROJECT_DIR,
        experiment_id=FileStore.DEFAULT_EXPERIMENT_ID)
    # Get expected tar
    local_tar_path = os.path.join(dbfs_root_mock, dbfs_uri.split("/dbfs/")[1])
    expected_tar_path = str(tmpdir.join("expected.tar.gz"))
    file_utils.make_tarfile(output_filename=expected_tar_path,
                            source_dir=TEST_PROJECT_DIR,
                            archive_name=databricks.DB_TARFILE_ARCHIVE_NAME)
    # Extract the tarred project, verify its contents
    assert filecmp.cmp(local_tar_path, expected_tar_path, shallow=False)