def test_do_api_call_waits_between_retries(self, mock_sleep): retry_delay = 5 self.hook = DatabricksHook(retry_delay=retry_delay) for exception in [ requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError, ]: with mock.patch( 'airflow.providers.databricks.hooks.databricks.requests' ) as mock_requests: with mock.patch.object(self.hook.log, 'error'): mock_sleep.reset_mock() setup_mock_requests(mock_requests, exception) with pytest.raises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) assert len( mock_sleep.mock_calls) == self.hook.retry_limit - 1 calls = [mock.call(retry_delay), mock.call(retry_delay)] mock_sleep.assert_has_calls(calls)
def setUp(self, session=None): conn = session.query(Connection).filter( Connection.conn_id == DEFAULT_CONN_ID).first() conn.extra = json.dumps({'token': TOKEN, 'host': HOST}) session.commit() self.hook = DatabricksHook()
def setUp(self, session=None): conn = session.query(Connection).filter( Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.login = LOGIN conn.password = PASSWORD conn.extra = None session.commit() self.hook = DatabricksHook(retry_delay=0)
def _get_hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, )
class TestDatabricksHookToken(unittest.TestCase): """ Tests for DatabricksHook when auth is done with token. """ @provide_session def setUp(self, session=None): conn = session.query(Connection).filter( Connection.conn_id == DEFAULT_CONN_ID).first() conn.extra = json.dumps({'token': TOKEN, 'host': HOST}) session.commit() self.hook = DatabricksHook() @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_submit_run(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) assert run_id == '1' args = mock_requests.post.call_args kwargs = args[1] assert kwargs['auth'].token == TOKEN
def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, caller="DatabricksReposDeleteOperator", )
def _get_hook(self, caller: str) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, caller=caller, )
def test_init_bad_retry_limit(self): with pytest.raises(ValueError): DatabricksHook(retry_limit=0)
class TestDatabricksHook(unittest.TestCase): """ Tests for DatabricksHook. """ @provide_session def setUp(self, session=None): conn = session.query(Connection).filter( Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.login = LOGIN conn.password = PASSWORD conn.extra = None session.commit() self.hook = DatabricksHook(retry_delay=0) def test_parse_host_with_proper_host(self): host = self.hook._parse_host(HOST) assert host == HOST def test_parse_host_with_scheme(self): host = self.hook._parse_host(HOST_WITH_SCHEME) assert host == HOST def test_init_bad_retry_limit(self): with pytest.raises(ValueError): DatabricksHook(retry_limit=0) def test_do_api_call_retries_with_retryable_error(self): for exception in [ requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError, ]: with mock.patch( 'airflow.providers.databricks.hooks.databricks.requests' ) as mock_requests: with mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception) with pytest.raises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) assert mock_errors.call_count == self.hook.retry_limit @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_do_api_call_does_not_retry_with_non_retryable_error( self, mock_requests): setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=400) with mock.patch.object(self.hook.log, 'error') as mock_errors: with pytest.raises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) mock_errors.assert_not_called() def test_do_api_call_succeeds_after_retrying(self): for exception in [ requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError, ]: with mock.patch( 'airflow.providers.databricks.hooks.databricks.requests' ) as mock_requests: with mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception, error_count=2, response_content={'run_id': '1'}) response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) assert mock_errors.call_count == 2 assert response == {'run_id': '1'} @mock.patch('airflow.providers.databricks.hooks.databricks.sleep') def test_do_api_call_waits_between_retries(self, mock_sleep): retry_delay = 5 self.hook = DatabricksHook(retry_delay=retry_delay) for exception in [ requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError, ]: with mock.patch( 'airflow.providers.databricks.hooks.databricks.requests' ) as mock_requests: with mock.patch.object(self.hook.log, 'error'): mock_sleep.reset_mock() setup_mock_requests(mock_requests, exception) with pytest.raises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) assert len( mock_sleep.mock_calls) == self.hook.retry_limit - 1 calls = [mock.call(retry_delay), mock.call(retry_delay)] mock_sleep.assert_has_calls(calls) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_do_api_call_patch(self, mock_requests): mock_requests.patch.return_value.json.return_value = { 'cluster_name': 'new_name' } data = {'cluster_name': 'new_name'} patched_cluster_name = self.hook._do_api_call( ('PATCH', 'api/2.0/jobs/runs/submit'), data) assert patched_cluster_name['cluster_name'] == 'new_name' mock_requests.patch.assert_called_once_with( submit_run_endpoint(HOST), json={'cluster_name': 'new_name'}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) assert run_id == '1' mock_requests.post.assert_called_once_with( submit_run_endpoint(HOST), json={ 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER, }, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_spark_python_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} data = { 'spark_python_task': SPARK_PYTHON_TASK, 'new_cluster': NEW_CLUSTER } run_id = self.hook.submit_run(data) assert run_id == '1' mock_requests.post.assert_called_once_with( submit_run_endpoint(HOST), json={ 'spark_python_task': SPARK_PYTHON_TASK, 'new_cluster': NEW_CLUSTER, }, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_run_now(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = { 'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID } run_id = self.hook.run_now(data) assert run_id == '1' mock_requests.post.assert_called_once_with( run_now_endpoint(HOST), json={ 'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID }, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_run_page_url(self, mock_requests): mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE run_page_url = self.hook.get_run_page_url(RUN_ID) assert run_page_url == RUN_PAGE_URL mock_requests.get.assert_called_once_with( get_run_endpoint(HOST), json=None, params={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_job_id(self, mock_requests): mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE job_id = self.hook.get_job_id(RUN_ID) assert job_id == JOB_ID mock_requests.get.assert_called_once_with( get_run_endpoint(HOST), json=None, params={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_run_state(self, mock_requests): mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE run_state = self.hook.get_run_state(RUN_ID) assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) mock_requests.get.assert_called_once_with( get_run_endpoint(HOST), json=None, params={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_cancel_run(self, mock_requests): mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE self.hook.cancel_run(RUN_ID) mock_requests.post.assert_called_once_with( cancel_run_endpoint(HOST), json={'run_id': RUN_ID}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_start_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock self.hook.start_cluster({"cluster_id": CLUSTER_ID}) mock_requests.post.assert_called_once_with( start_cluster_endpoint(HOST), json={'cluster_id': CLUSTER_ID}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_restart_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock self.hook.restart_cluster({"cluster_id": CLUSTER_ID}) mock_requests.post.assert_called_once_with( restart_cluster_endpoint(HOST), json={'cluster_id': CLUSTER_ID}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_terminate_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock self.hook.terminate_cluster({"cluster_id": CLUSTER_ID}) mock_requests.post.assert_called_once_with( terminate_cluster_endpoint(HOST), json={'cluster_id': CLUSTER_ID}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_install_libs_on_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = {'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES} self.hook.install(data) mock_requests.post.assert_called_once_with( install_endpoint(HOST), json={ 'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES }, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_uninstall_libs_on_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = {'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES} self.hook.uninstall(data) mock_requests.post.assert_called_once_with( uninstall_endpoint(HOST), json={ 'cluster_id': CLUSTER_ID, 'libraries': LIBRARIES }, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds, )
def _get_hook(self): return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay)
def __init__(self, run_id: int, databricks_conn_id: str, polling_period_seconds: int = 30) -> None: super().__init__() self.run_id = run_id self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.hook = DatabricksHook(databricks_conn_id)
def test_init_bad_retry_limit(self): with self.assertRaises(ValueError): DatabricksHook(retry_limit=0)
class TestDatabricksHook(unittest.TestCase): """ Tests for DatabricksHook. """ @provide_session def setUp(self, session=None): conn = session.query(Connection) \ .filter(Connection.conn_id == DEFAULT_CONN_ID) \ .first() conn.host = HOST conn.login = LOGIN conn.password = PASSWORD conn.extra = None session.commit() self.hook = DatabricksHook(retry_delay=0) def test_parse_host_with_proper_host(self): host = self.hook._parse_host(HOST) self.assertEqual(host, HOST) def test_parse_host_with_scheme(self): host = self.hook._parse_host(HOST_WITH_SCHEME) self.assertEqual(host, HOST) def test_init_bad_retry_limit(self): with self.assertRaises(ValueError): DatabricksHook(retry_limit=0) def test_do_api_call_retries_with_retryable_error(self): for exception in [requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError]: with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests: with mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) self.assertEqual(mock_errors.call_count, self.hook.retry_limit) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): setup_mock_requests( mock_requests, requests_exceptions.HTTPError, status_code=400 ) with mock.patch.object(self.hook.log, 'error') as mock_errors: with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) mock_errors.assert_not_called() def test_do_api_call_succeeds_after_retrying(self): for exception in [requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError]: with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests: with mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests( mock_requests, exception, error_count=2, response_content={'run_id': '1'} ) response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) self.assertEqual(mock_errors.call_count, 2) self.assertEqual(response, {'run_id': '1'}) @mock.patch('airflow.providers.databricks.hooks.databricks.sleep') def test_do_api_call_waits_between_retries(self, mock_sleep): retry_delay = 5 self.hook = DatabricksHook(retry_delay=retry_delay) for exception in [requests_exceptions.ConnectionError, requests_exceptions.SSLError, requests_exceptions.Timeout, requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError]: with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests: with mock.patch.object(self.hook.log, 'error'): mock_sleep.reset_mock() setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) self.assertEqual(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) calls = [ mock.call(retry_delay), mock.call(retry_delay) ] mock_sleep.assert_has_calls(calls) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} data = { 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER } run_id = self.hook.submit_run(data) self.assertEqual(run_id, '1') mock_requests.post.assert_called_once_with( submit_run_endpoint(HOST), json={ 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER, }, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_run_now(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = { 'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID } run_id = self.hook.run_now(data) self.assertEqual(run_id, '1') mock_requests.post.assert_called_once_with( run_now_endpoint(HOST), json={ 'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID }, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_run_page_url(self, mock_requests): mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE run_page_url = self.hook.get_run_page_url(RUN_ID) self.assertEqual(run_page_url, RUN_PAGE_URL) mock_requests.get.assert_called_once_with( get_run_endpoint(HOST), json={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_run_state(self, mock_requests): mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE run_state = self.hook.get_run_state(RUN_ID) self.assertEqual(run_state, RunState( LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE)) mock_requests.get.assert_called_once_with( get_run_endpoint(HOST), json={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_cancel_run(self, mock_requests): mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE self.hook.cancel_run(RUN_ID) mock_requests.post.assert_called_once_with( cancel_run_endpoint(HOST), json={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_start_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock self.hook.start_cluster({"cluster_id": CLUSTER_ID}) mock_requests.post.assert_called_once_with( start_cluster_endpoint(HOST), json={'cluster_id': CLUSTER_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_restart_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock self.hook.restart_cluster({"cluster_id": CLUSTER_ID}) mock_requests.post.assert_called_once_with( restart_cluster_endpoint(HOST), json={'cluster_id': CLUSTER_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_terminate_cluster(self, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock self.hook.terminate_cluster({"cluster_id": CLUSTER_ID}) mock_requests.post.assert_called_once_with( terminate_cluster_endpoint(HOST), json={'cluster_id': CLUSTER_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds)