Beispiel #1
0
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError,
                requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError,
        ]:
            with mock.patch(
                    'airflow.providers.databricks.hooks.databricks.requests'
            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error'):
                    mock_sleep.reset_mock()
                    setup_mock_requests(mock_requests, exception)

                    with pytest.raises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    assert len(
                        mock_sleep.mock_calls) == self.hook.retry_limit - 1
                    calls = [mock.call(retry_delay), mock.call(retry_delay)]
                    mock_sleep.assert_has_calls(calls)
Beispiel #2
0
    def setUp(self, session=None):
        conn = session.query(Connection).filter(
            Connection.conn_id == DEFAULT_CONN_ID).first()
        conn.extra = json.dumps({'token': TOKEN, 'host': HOST})

        session.commit()

        self.hook = DatabricksHook()
Beispiel #3
0
    def setUp(self, session=None):
        conn = session.query(Connection).filter(
            Connection.conn_id == DEFAULT_CONN_ID).first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        conn.extra = None
        session.commit()

        self.hook = DatabricksHook(retry_delay=0)
Beispiel #4
0
 def _get_hook(self) -> DatabricksHook:
     return DatabricksHook(
         self.databricks_conn_id,
         retry_limit=self.databricks_retry_limit,
         retry_delay=self.databricks_retry_delay,
         retry_args=self.databricks_retry_args,
     )
Beispiel #5
0
class TestDatabricksHookToken(unittest.TestCase):
    """
    Tests for DatabricksHook when auth is done with token.
    """
    @provide_session
    def setUp(self, session=None):
        conn = session.query(Connection).filter(
            Connection.conn_id == DEFAULT_CONN_ID).first()
        conn.extra = json.dumps({'token': TOKEN, 'host': HOST})

        session.commit()

        self.hook = DatabricksHook()

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
        run_id = self.hook.submit_run(data)

        assert run_id == '1'
        args = mock_requests.post.call_args
        kwargs = args[1]
        assert kwargs['auth'].token == TOKEN
 def _hook(self) -> DatabricksHook:
     return DatabricksHook(
         self.databricks_conn_id,
         retry_limit=self.databricks_retry_limit,
         retry_delay=self.databricks_retry_delay,
         caller="DatabricksReposDeleteOperator",
     )
Beispiel #7
0
 def _get_hook(self, caller: str) -> DatabricksHook:
     return DatabricksHook(
         self.databricks_conn_id,
         retry_limit=self.databricks_retry_limit,
         retry_delay=self.databricks_retry_delay,
         retry_args=self.databricks_retry_args,
         caller=caller,
     )
Beispiel #8
0
 def test_init_bad_retry_limit(self):
     with pytest.raises(ValueError):
         DatabricksHook(retry_limit=0)
Beispiel #9
0
class TestDatabricksHook(unittest.TestCase):
    """
    Tests for DatabricksHook.
    """
    @provide_session
    def setUp(self, session=None):
        conn = session.query(Connection).filter(
            Connection.conn_id == DEFAULT_CONN_ID).first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        conn.extra = None
        session.commit()

        self.hook = DatabricksHook(retry_delay=0)

    def test_parse_host_with_proper_host(self):
        host = self.hook._parse_host(HOST)
        assert host == HOST

    def test_parse_host_with_scheme(self):
        host = self.hook._parse_host(HOST_WITH_SCHEME)
        assert host == HOST

    def test_init_bad_retry_limit(self):
        with pytest.raises(ValueError):
            DatabricksHook(retry_limit=0)

    def test_do_api_call_retries_with_retryable_error(self):
        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError,
                requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError,
        ]:
            with mock.patch(
                    'airflow.providers.databricks.hooks.databricks.requests'
            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(mock_requests, exception)

                    with pytest.raises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    assert mock_errors.call_count == self.hook.retry_limit

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_do_api_call_does_not_retry_with_non_retryable_error(
            self, mock_requests):
        setup_mock_requests(mock_requests,
                            requests_exceptions.HTTPError,
                            status_code=400)

        with mock.patch.object(self.hook.log, 'error') as mock_errors:
            with pytest.raises(AirflowException):
                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

            mock_errors.assert_not_called()

    def test_do_api_call_succeeds_after_retrying(self):
        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError,
                requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError,
        ]:
            with mock.patch(
                    'airflow.providers.databricks.hooks.databricks.requests'
            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(mock_requests,
                                        exception,
                                        error_count=2,
                                        response_content={'run_id': '1'})

                    response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    assert mock_errors.call_count == 2
                    assert response == {'run_id': '1'}

    @mock.patch('airflow.providers.databricks.hooks.databricks.sleep')
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError,
                requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError,
        ]:
            with mock.patch(
                    'airflow.providers.databricks.hooks.databricks.requests'
            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error'):
                    mock_sleep.reset_mock()
                    setup_mock_requests(mock_requests, exception)

                    with pytest.raises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    assert len(
                        mock_sleep.mock_calls) == self.hook.retry_limit - 1
                    calls = [mock.call(retry_delay), mock.call(retry_delay)]
                    mock_sleep.assert_has_calls(calls)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_do_api_call_patch(self, mock_requests):
        mock_requests.patch.return_value.json.return_value = {
            'cluster_name': 'new_name'
        }
        data = {'cluster_name': 'new_name'}
        patched_cluster_name = self.hook._do_api_call(
            ('PATCH', 'api/2.0/jobs/runs/submit'), data)

        assert patched_cluster_name['cluster_name'] == 'new_name'
        mock_requests.patch.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={'cluster_name': 'new_name'},
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
        run_id = self.hook.submit_run(data)

        assert run_id == '1'
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'notebook_task': NOTEBOOK_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_spark_python_submit_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        data = {
            'spark_python_task': SPARK_PYTHON_TASK,
            'new_cluster': NEW_CLUSTER
        }
        run_id = self.hook.submit_run(data)

        assert run_id == '1'
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'spark_python_task': SPARK_PYTHON_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_run_now(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        data = {
            'notebook_params': NOTEBOOK_PARAMS,
            'jar_params': JAR_PARAMS,
            'job_id': JOB_ID
        }
        run_id = self.hook.run_now(data)

        assert run_id == '1'

        mock_requests.post.assert_called_once_with(
            run_now_endpoint(HOST),
            json={
                'notebook_params': NOTEBOOK_PARAMS,
                'jar_params': JAR_PARAMS,
                'job_id': JOB_ID
            },
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_get_run_page_url(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_page_url = self.hook.get_run_page_url(RUN_ID)

        assert run_page_url == RUN_PAGE_URL
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json=None,
            params={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_get_job_id(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        job_id = self.hook.get_job_id(RUN_ID)

        assert job_id == JOB_ID
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json=None,
            params={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_get_run_state(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_state = self.hook.get_run_state(RUN_ID)

        assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE,
                                     STATE_MESSAGE)
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json=None,
            params={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_cancel_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE

        self.hook.cancel_run(RUN_ID)

        mock_requests.post.assert_called_once_with(
            cancel_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_start_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.start_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            start_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_restart_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.restart_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            restart_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_terminate_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            terminate_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_install_libs_on_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        data = {'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES}
        self.hook.install(data)

        mock_requests.post.assert_called_once_with(
            install_endpoint(HOST),
            json={
                'cluster_id': CLUSTER_ID,
                'libraries': LIBRARIES
            },
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_uninstall_libs_on_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        data = {'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES}
        self.hook.uninstall(data)

        mock_requests.post.assert_called_once_with(
            uninstall_endpoint(HOST),
            json={
                'cluster_id': CLUSTER_ID,
                'libraries': LIBRARIES
            },
            params=None,
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds,
        )
Beispiel #10
0
 def _get_hook(self):
     return DatabricksHook(
         self.databricks_conn_id,
         retry_limit=self.databricks_retry_limit,
         retry_delay=self.databricks_retry_delay)
Beispiel #11
0
 def __init__(self, run_id: int, databricks_conn_id: str, polling_period_seconds: int = 30) -> None:
     super().__init__()
     self.run_id = run_id
     self.databricks_conn_id = databricks_conn_id
     self.polling_period_seconds = polling_period_seconds
     self.hook = DatabricksHook(databricks_conn_id)
Beispiel #12
0
 def test_init_bad_retry_limit(self):
     with self.assertRaises(ValueError):
         DatabricksHook(retry_limit=0)
class TestDatabricksHook(unittest.TestCase):
    """
    Tests for DatabricksHook.
    """

    @provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        conn.extra = None
        session.commit()

        self.hook = DatabricksHook(retry_delay=0)

    def test_parse_host_with_proper_host(self):
        host = self.hook._parse_host(HOST)
        self.assertEqual(host, HOST)

    def test_parse_host_with_scheme(self):
        host = self.hook._parse_host(HOST_WITH_SCHEME)
        self.assertEqual(host, HOST)

    def test_init_bad_retry_limit(self):
        with self.assertRaises(ValueError):
            DatabricksHook(retry_limit=0)

    def test_do_api_call_retries_with_retryable_error(self):
        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(mock_errors.call_count, self.hook.retry_limit)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests):
        setup_mock_requests(
            mock_requests, requests_exceptions.HTTPError, status_code=400
        )

        with mock.patch.object(self.hook.log, 'error') as mock_errors:
            with self.assertRaises(AirflowException):
                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

            mock_errors.assert_not_called()

    def test_do_api_call_succeeds_after_retrying(self):
        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(
                        mock_requests,
                        exception,
                        error_count=2,
                        response_content={'run_id': '1'}
                    )

                    response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(mock_errors.call_count, 2)
                    self.assertEqual(response, {'run_id': '1'})

    @mock.patch('airflow.providers.databricks.hooks.databricks.sleep')
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error'):
                    mock_sleep.reset_mock()
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
                    calls = [
                        mock.call(retry_delay),
                        mock.call(retry_delay)
                    ]
                    mock_sleep.assert_has_calls(calls)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        data = {
            'notebook_task': NOTEBOOK_TASK,
            'new_cluster': NEW_CLUSTER
        }
        run_id = self.hook.submit_run(data)

        self.assertEqual(run_id, '1')
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'notebook_task': NOTEBOOK_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_run_now(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        data = {
            'notebook_params': NOTEBOOK_PARAMS,
            'jar_params': JAR_PARAMS,
            'job_id': JOB_ID
        }
        run_id = self.hook.run_now(data)

        self.assertEqual(run_id, '1')

        mock_requests.post.assert_called_once_with(
            run_now_endpoint(HOST),
            json={
                'notebook_params': NOTEBOOK_PARAMS,
                'jar_params': JAR_PARAMS,
                'job_id': JOB_ID
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_get_run_page_url(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_page_url = self.hook.get_run_page_url(RUN_ID)

        self.assertEqual(run_page_url, RUN_PAGE_URL)
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_get_run_state(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_state = self.hook.get_run_state(RUN_ID)

        self.assertEqual(run_state, RunState(
            LIFE_CYCLE_STATE,
            RESULT_STATE,
            STATE_MESSAGE))
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_cancel_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE

        self.hook.cancel_run(RUN_ID)

        mock_requests.post.assert_called_once_with(
            cancel_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_start_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.start_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            start_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_restart_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.restart_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            restart_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
    def test_terminate_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            terminate_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)