Esempio n. 1
0
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.extra = json.dumps({'token': TOKEN})
        session.commit()

        self.hook = DatabricksHook()
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        session.commit()

        self.hook = DatabricksHook()
Esempio n. 3
0
class DatabricksHookTokenTest(unittest.TestCase):
    """
    Tests for DatabricksHook when auth is done with token.
    """
    @db.provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.extra = json.dumps({'token': TOKEN})
        session.commit()

        self.hook = DatabricksHook()

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        json = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
        run_id = self.hook.submit_run(json)

        self.assertEquals(run_id, '1')
        args = mock_requests.post.call_args
        kwargs = args[1]
        self.assertEquals(kwargs['auth'].token, TOKEN)
class DatabricksHookTokenTest(unittest.TestCase):
    """
    Tests for DatabricksHook when auth is done with token.
    """

    @db.provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.extra = json.dumps({'token': TOKEN})
        session.commit()

        self.hook = DatabricksHook()

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        json = {
            'notebook_task': NOTEBOOK_TASK,
            'new_cluster': NEW_CLUSTER
        }
        run_id = self.hook.submit_run(json)

        self.assertEqual(run_id, '1')
        args = mock_requests.post.call_args
        kwargs = args[1]
        self.assertEqual(kwargs['auth'].token, TOKEN)
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.extra = json.dumps({'token': TOKEN})
        session.commit()

        self.hook = DatabricksHook()
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        session.commit()

        self.hook = DatabricksHook(retry_delay=0)
Esempio n. 7
0
    def _run_spark_submit(self, databricks_json):
        task = self.task  # type: SparkTask
        _config = task.spark_engine

        from airflow.contrib.hooks.databricks_hook import DatabricksHook

        self.hook = DatabricksHook(
            _config.conn_id,
            _config.connection_retry_limit,
            retry_delay=_config.connection_retry_delay,
        )
        try:
            logging.debug("posted JSON:" + str(databricks_json))
            self.current_run_id = self.hook.submit_run(databricks_json)
            self.hook.log.setLevel(logging.WARNING)
            self._handle_databricks_operator_execution(self.current_run_id,
                                                       self.hook,
                                                       _config.task_id)
            self.hook.log.setLevel(logging.INFO)
        except AirflowException as e:
            raise failed_to_submit_databricks_job(e)
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError,
                requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError]:
            with mock.patch(
                'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
                    mock.patch.object(self.hook.log, 'error'):
                mock_sleep.reset_mock()
                setup_mock_requests(mock_requests, exception)

                with self.assertRaises(AirflowException):
                    self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
                mock_sleep.assert_called_with(retry_delay)
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error'):
                    mock_sleep.reset_mock()
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
                    mock_sleep.assert_called_with(retry_delay)
Esempio n. 10
0
 def get_hook(self):
     return DatabricksHook(
         self.databricks_conn_id,
         retry_limit=self.databricks_retry_limit,
         retry_delay=self.databricks_retry_delay)
Esempio n. 11
0
 def test_init_bad_retry_limit(self):
     with self.assertRaises(ValueError):
         DatabricksHook(retry_limit=0)
Esempio n. 12
0
class DatabricksHookTest(unittest.TestCase):
    """
    Tests for DatabricksHook.
    """
    @db.provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        conn.extra = None
        session.commit()

        self.hook = DatabricksHook(retry_delay=0)

    def test_parse_host_with_proper_host(self):
        host = self.hook._parse_host(HOST)
        self.assertEquals(host, HOST)

    def test_parse_host_with_scheme(self):
        host = self.hook._parse_host(HOST_WITH_SCHEME)
        self.assertEquals(host, HOST)

    def test_init_bad_retry_limit(self):
        with self.assertRaises(ValueError):
            DatabricksHook(retry_limit=0)

    def test_do_api_call_retries_with_retryable_error(self):
        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError, requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError
        ]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests'
                            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEquals(mock_errors.call_count,
                                      self.hook.retry_limit)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_do_api_call_does_not_retry_with_non_retryable_error(
            self, mock_requests):
        setup_mock_requests(mock_requests,
                            requests_exceptions.HTTPError,
                            status_code=400)

        with mock.patch.object(self.hook.log, 'error') as mock_errors:
            with self.assertRaises(AirflowException):
                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

            mock_errors.assert_not_called()

    def test_do_api_call_succeeds_after_retrying(self):
        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError, requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError
        ]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests'
                            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(mock_requests,
                                        exception,
                                        error_count=2,
                                        response_content={'run_id': '1'})

                    response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEquals(mock_errors.call_count, 2)
                    self.assertEquals(response, {'run_id': '1'})

    @mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.SSLError, requests_exceptions.Timeout,
                requests_exceptions.ConnectTimeout,
                requests_exceptions.HTTPError
        ]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests'
                            ) as mock_requests:
                with mock.patch.object(self.hook.log, 'error'):
                    mock_sleep.reset_mock()
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEquals(len(mock_sleep.mock_calls),
                                      self.hook.retry_limit - 1)
                    mock_sleep.assert_called_with(retry_delay)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        json = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
        run_id = self.hook.submit_run(json)

        self.assertEquals(run_id, '1')
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'notebook_task': NOTEBOOK_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_run_now(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        json = {
            'notebook_params': NOTEBOOK_PARAMS,
            'jar_params': JAR_PARAMS,
            'job_id': JOB_ID
        }
        run_id = self.hook.run_now(json)

        self.assertEquals(run_id, '1')

        mock_requests.post.assert_called_once_with(
            run_now_endpoint(HOST),
            json={
                'notebook_params': NOTEBOOK_PARAMS,
                'jar_params': JAR_PARAMS,
                'job_id': JOB_ID
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_page_url(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_page_url = self.hook.get_run_page_url(RUN_ID)

        self.assertEquals(run_page_url, RUN_PAGE_URL)
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_state(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_state = self.hook.get_run_state(RUN_ID)

        self.assertEquals(
            run_state, RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE))
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_cancel_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE

        self.hook.cancel_run(RUN_ID)

        mock_requests.post.assert_called_once_with(
            cancel_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_start_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.start_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            start_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_restart_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.restart_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            restart_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_terminate_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            terminate_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)
class DatabricksHookTest(unittest.TestCase):
    """
    Tests for DatabricksHook.
    """
    @db.provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        session.commit()

        self.hook = DatabricksHook()

    def test_parse_host_with_proper_host(self):
        host = self.hook._parse_host(HOST)
        self.assertEquals(host, HOST)

    def test_parse_host_with_scheme(self):
        host = self.hook._parse_host(HOST_WITH_SCHEME)
        self.assertEquals(host, HOST)

    def test_init_bad_retry_limit(self):
        with self.assertRaises(AssertionError):
            DatabricksHook(retry_limit=0)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_do_api_call_with_error_retry(self, mock_requests):
        for exception in [
                requests_exceptions.ConnectionError,
                requests_exceptions.Timeout
        ]:
            with mock.patch.object(self.hook.logger, 'error') as mock_errors:
                mock_requests.reset_mock()
                mock_requests.post.side_effect = exception()

                with self.assertRaises(AirflowException):
                    self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                self.assertEquals(len(mock_errors.mock_calls),
                                  self.hook.retry_limit)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_do_api_call_with_bad_status_code(self, mock_requests):
        mock_requests.codes.ok = 200
        status_code_mock = mock.PropertyMock(return_value=500)
        type(mock_requests.post.return_value).status_code = status_code_mock
        with self.assertRaises(AirflowException):
            self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        json = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
        run_id = self.hook.submit_run(json)

        self.assertEquals(run_id, '1')
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'notebook_task': NOTEBOOK_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_page_url(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.get.return_value).status_code = status_code_mock

        run_page_url = self.hook.get_run_page_url(RUN_ID)

        self.assertEquals(run_page_url, RUN_PAGE_URL)
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_state(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.get.return_value).status_code = status_code_mock

        run_state = self.hook.get_run_state(RUN_ID)

        self.assertEquals(
            run_state, RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE))
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_cancel_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.cancel_run(RUN_ID)

        mock_requests.post.assert_called_once_with(
            cancel_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)
Esempio n. 14
0
class DatabricksHookTest(unittest.TestCase):
    """
    Tests for DatabricksHook.
    """

    @db.provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        conn.extra = None
        session.commit()

        self.hook = DatabricksHook(retry_delay=0)

    def test_parse_host_with_proper_host(self):
        host = self.hook._parse_host(HOST)
        self.assertEqual(host, HOST)

    def test_parse_host_with_scheme(self):
        host = self.hook._parse_host(HOST_WITH_SCHEME)
        self.assertEqual(host, HOST)

    def test_init_bad_retry_limit(self):
        with self.assertRaises(ValueError):
            DatabricksHook(retry_limit=0)

    def test_do_api_call_retries_with_retryable_error(self):
        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(mock_errors.call_count, self.hook.retry_limit)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests):
        setup_mock_requests(
            mock_requests, requests_exceptions.HTTPError, status_code=400
        )

        with mock.patch.object(self.hook.log, 'error') as mock_errors:
            with self.assertRaises(AirflowException):
                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

            mock_errors.assert_not_called()

    def test_do_api_call_succeeds_after_retrying(self):
        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error') as mock_errors:
                    setup_mock_requests(
                        mock_requests,
                        exception,
                        error_count=2,
                        response_content={'run_id': '1'}
                    )

                    response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(mock_errors.call_count, 2)
                    self.assertEqual(response, {'run_id': '1'})

    @mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
    def test_do_api_call_waits_between_retries(self, mock_sleep):
        retry_delay = 5
        self.hook = DatabricksHook(retry_delay=retry_delay)

        for exception in [requests_exceptions.ConnectionError,
                          requests_exceptions.SSLError,
                          requests_exceptions.Timeout,
                          requests_exceptions.ConnectTimeout,
                          requests_exceptions.HTTPError]:
            with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests:
                with mock.patch.object(self.hook.log, 'error'):
                    mock_sleep.reset_mock()
                    setup_mock_requests(mock_requests, exception)

                    with self.assertRaises(AirflowException):
                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                    self.assertEqual(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
                    mock_sleep.assert_called_with(retry_delay)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        json = {
            'notebook_task': NOTEBOOK_TASK,
            'new_cluster': NEW_CLUSTER
        }
        run_id = self.hook.submit_run(json)

        self.assertEqual(run_id, '1')
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'notebook_task': NOTEBOOK_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_run_now(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        json = {
            'notebook_params': NOTEBOOK_PARAMS,
            'jar_params': JAR_PARAMS,
            'job_id': JOB_ID
        }
        run_id = self.hook.run_now(json)

        self.assertEqual(run_id, '1')

        mock_requests.post.assert_called_once_with(
            run_now_endpoint(HOST),
            json={
                'notebook_params': NOTEBOOK_PARAMS,
                'jar_params': JAR_PARAMS,
                'job_id': JOB_ID
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_page_url(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_page_url = self.hook.get_run_page_url(RUN_ID)

        self.assertEqual(run_page_url, RUN_PAGE_URL)
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_state(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_state = self.hook.get_run_state(RUN_ID)

        self.assertEqual(run_state, RunState(
            LIFE_CYCLE_STATE,
            RESULT_STATE,
            STATE_MESSAGE))
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_cancel_run(self, mock_requests):
        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE

        self.hook.cancel_run(RUN_ID)

        mock_requests.post.assert_called_once_with(
            cancel_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_start_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.start_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            start_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_restart_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.restart_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            restart_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_terminate_cluster(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})

        mock_requests.post.assert_called_once_with(
            terminate_cluster_endpoint(HOST),
            json={'cluster_id': CLUSTER_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)
Esempio n. 15
0
class DatabricksCtrl(SparkCtrl):
    def __init__(self, task_run):
        super(DatabricksCtrl, self).__init__(task_run=task_run)
        self.databricks_config = task_run.task.spark_engine  # type: DatabricksConfig

        self.local_dbfs_mount = None
        if self.databricks_config.cloud_type == DatabricksCloud.azure:
            assert_plugin_enabled(
                "dbnd-azure",
                "Databricks on azure requires dbnd-azure module.")

            self.local_dbfs_mount = DatabricksAzureConfig().local_dbfs_mount

        self.current_run_id = None
        self.hook = None

    def _dbfs_scheme_to_local(self, path):
        if self.databricks_config.cloud_type == DatabricksCloud.aws:
            return path
        elif self.databricks_config.cloud_type == DatabricksCloud.azure:
            return path.replace("dbfs://", "/dbfs")

    def _dbfs_scheme_to_mount(self, path):
        if self.databricks_config.cloud_type != DatabricksCloud.azure:
            return path

        from dbnd_azure.fs.azure_blob import AzureBlobStorageClient

        (
            storage_account,
            container_name,
            blob_name,
        ) = AzureBlobStorageClient._path_to_account_container_and_blob(path)
        return "dbfs://%s" % (os.path.join(self.local_dbfs_mount, blob_name))

    def _handle_databricks_operator_execution(self, run_id, hook, task_id):
        """
        Handles the Airflow + Databricks lifecycle logic for a Databricks operator
        :param run_id: Databricks run_id
        :param hook: Airflow databricks hook
        :param task_id: Databand Task Id.

        """
        b = TextBanner("Spark task %s is submitted to Databricks cluster:" %
                       task_id,
                       color="cyan")
        url = hook.get_run_page_url(run_id)
        self.task_run.set_external_resource_urls({"databricks url": url})
        b.column("URL", url)
        logger.info(b.get_banner_str())
        while True:
            b = TextBanner(
                "Spark task %s is submitted to Databricks cluster:" % task_id,
                color="cyan",
            )
            b.column("URL", url)
            run_state = hook.get_run_state(run_id)
            if run_state.is_terminal:
                if run_state.is_successful:
                    b.column("Task completed successfully", task_id)
                    b.column("State:", run_state.life_cycle_state)
                    b.column("Message:", run_state.state_message)
                    break
                else:
                    b.column("State", run_state.result_state)
                    b.column("Error Message:", run_state.state_message)
                    logger.info(b.get_banner_str())
                    raise failed_to_run_databricks_job(run_state.result_state,
                                                       run_state.state_message,
                                                       url)
            else:
                b.column("State:", run_state.life_cycle_state)
                b.column("Message:", run_state.state_message)
                time.sleep(
                    self.databricks_config.status_polling_interval_seconds)
            logger.info(b.get_banner_str())

    def _create_spark_submit_json(self, spark_submit_parameters):
        spark_config = self.task.spark_config
        new_cluster = {
            "num_workers":
            self.databricks_config.num_workers,
            "spark_version":
            self.databricks_config.spark_version,
            # spark_conf not supported, instead uses 'conf'
            "conf":
            self.databricks_config.spark_conf,
            "node_type_id":
            self.databricks_config.node_type_id,
            "init_scripts":
            self.databricks_config.init_scripts,
            "cluster_log_conf":
            self.databricks_config.cluster_log_conf,
            "spark_env_vars":
            self._get_env_vars(self.databricks_config.spark_env_vars),
            "py_files":
            self.deploy.arg_files(self.task.get_py_files()),
            "files":
            self.deploy.arg_files(spark_config.files),
            "jars":
            self.deploy.arg_files(spark_config.jars),
        }
        if self.databricks_config.cloud_type == DatabricksCloud.aws:
            attributes = DatabricksAwsConfig()
            new_cluster["aws_attributes"] = {
                "instance_profile_arn": attributes.aws_instance_profile_arn,
                "ebs_volume_type": attributes.aws_ebs_volume_type,
                "ebs_volume_count": attributes.aws_ebs_volume_count,
                "ebs_volume_size": attributes.aws_ebs_volume_size,
            }
        else:
            # need to see if there are any relevant setting for azure or other databricks envs.
            pass

        # since airflow connector for now() does not support spark_submit_task, it is implemented this way.
        return {
            "spark_submit_task": {
                "parameters": spark_submit_parameters
            },
            "new_cluster": new_cluster,
            "run_name": self.task.task_id,
        }

    def _create_pyspark_submit_json(self, python_file, parameters):
        spark_python_task_json = {
            "python_file": python_file,
            "parameters": parameters
        }
        # since airflow connector for now() does not support spark_submit_task, it is implemented this way.
        return {
            "spark_python_task": spark_python_task_json,
            "existing_cluster_id": self.databricks_config.cluster_id,
            "run_name": self.task.task_id,
        }

    def _run_spark_submit(self, databricks_json):
        task = self.task  # type: SparkTask
        _config = task.spark_engine

        from airflow.contrib.hooks.databricks_hook import DatabricksHook

        self.hook = DatabricksHook(
            _config.conn_id,
            _config.connection_retry_limit,
            retry_delay=_config.connection_retry_delay,
        )
        try:
            logging.debug("posted JSON:" + str(databricks_json))
            self.current_run_id = self.hook.submit_run(databricks_json)
            self.hook.log.setLevel(logging.WARNING)
            self._handle_databricks_operator_execution(self.current_run_id,
                                                       self.hook,
                                                       _config.task_id)
            self.hook.log.setLevel(logging.INFO)
        except AirflowException as e:
            raise failed_to_submit_databricks_job(e)

    def run_pyspark(self, pyspark_script):
        # should be reimplemented using SparkSubmitHook (maybe from airflow)
        # note that config jars are not supported.
        if not self.databricks_config.cluster_id:
            spark_submit_parameters = [self.sync(pyspark_script)] + (
                list_of_strings(self.task.application_args()))
            databricks_json = self._create_spark_submit_json(
                spark_submit_parameters)
        else:
            pyspark_script = self.sync(pyspark_script)
            parameters = [
                self._dbfs_scheme_to_local(e)
                for e in list_of_strings(self.task.application_args())
            ]
            databricks_json = self._create_pyspark_submit_json(
                python_file=pyspark_script, parameters=parameters)

        return self._run_spark_submit(databricks_json)

    def run_spark(self, main_class):
        jars_list = []
        jars = self.config.jars
        if jars:
            jars_list = ["--jars"] + jars
        # should be reimplemented using SparkSubmitHook (maybe from airflow)
        spark_submit_parameters = [
            "--class",
            main_class,
            self.sync(self.config.main_jar),
        ] + (list_of_strings(self.task.application_args()) + jars_list)
        databricks_json = self._create_spark_submit_json(
            spark_submit_parameters)
        return self._run_spark_submit(databricks_json)

    def _report_step_status(self, step):
        logger.info(self._get_step_banner(step))

    def _get_step_banner(self, step):
        """
        {
          'id': 6,
          'state': 'success',
        }
        """
        t = self.task
        b = TextBanner("Spark Task %s is running at Emr:" % t.task_id,
                       color="yellow")

        b.column("TASK", t.task_id)
        b.column("EMR STEP STATE", step["Step"]["Status"]["State"])

        tracker_url = current_task_run().task_tracker_url
        if tracker_url:
            b.column("DATABAND LOG", tracker_url)

        b.new_line()
        b.column("EMR STEP ID", step["Step"]["Id"])
        b.new_section()
        return b.getvalue()

    def sync(self, local_file):
        synced = self.deploy.sync(local_file)
        if self.databricks_config.cloud_type == DatabricksCloud.azure:
            return self._dbfs_scheme_to_mount(synced)
        return synced

    def on_kill(self):
        if self.hook and self.current_run_id:
            self.hook.cancel_run(self.current_run_id)
class DatabricksHookTest(unittest.TestCase):
    """
    Tests for DatabricksHook.
    """
    @db.provide_session
    def setUp(self, session=None):
        conn = session.query(Connection) \
            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
            .first()
        conn.host = HOST
        conn.login = LOGIN
        conn.password = PASSWORD
        session.commit()

        self.hook = DatabricksHook()

    def test_parse_host_with_proper_host(self):
        host = self.hook._parse_host(HOST)
        self.assertEquals(host, HOST)

    def test_parse_host_with_scheme(self):
        host = self.hook._parse_host(HOST_WITH_SCHEME)
        self.assertEquals(host, HOST)

    def test_init_bad_retry_limit(self):
        with self.assertRaises(AssertionError):
            DatabricksHook(retry_limit = 0)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_do_api_call_with_error_retry(self, mock_requests):
        for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
            with mock.patch.object(self.hook.log, 'error') as mock_errors:
                mock_requests.reset_mock()
                mock_requests.post.side_effect = exception()

                with self.assertRaises(AirflowException):
                    self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

                self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_do_api_call_with_bad_status_code(self, mock_requests):
        mock_requests.codes.ok = 200
        status_code_mock = mock.PropertyMock(return_value=500)
        type(mock_requests.post.return_value).status_code = status_code_mock
        with self.assertRaises(AirflowException):
            self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_submit_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock
        json = {
          'notebook_task': NOTEBOOK_TASK,
          'new_cluster': NEW_CLUSTER
        }
        run_id = self.hook.submit_run(json)

        self.assertEquals(run_id, '1')
        mock_requests.post.assert_called_once_with(
            submit_run_endpoint(HOST),
            json={
                'notebook_task': NOTEBOOK_TASK,
                'new_cluster': NEW_CLUSTER,
            },
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_page_url(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.get.return_value).status_code = status_code_mock

        run_page_url = self.hook.get_run_page_url(RUN_ID)

        self.assertEquals(run_page_url, RUN_PAGE_URL)
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_get_run_state(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.get.return_value).status_code = status_code_mock

        run_state = self.hook.get_run_state(RUN_ID)

        self.assertEquals(run_state, RunState(
            LIFE_CYCLE_STATE,
            RESULT_STATE,
            STATE_MESSAGE))
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)

    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
    def test_cancel_run(self, mock_requests):
        mock_requests.codes.ok = 200
        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
        status_code_mock = mock.PropertyMock(return_value=200)
        type(mock_requests.post.return_value).status_code = status_code_mock

        self.hook.cancel_run(RUN_ID)

        mock_requests.post.assert_called_once_with(
            cancel_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)