예제 #1
0
 def get_hook(self):
     """Create and return an AWSAthenaHook."""
     return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
예제 #2
0
 def setUp(self):
     self.athena = AWSAthenaHook(sleep_time=0)
예제 #3
0
class TestAWSAthenaHook(unittest.TestCase):
    def setUp(self):
        self.athena = AWSAthenaHook(sleep_time=0)

    def test_init(self):
        self.assertEqual(self.athena.aws_conn_id, 'aws_default')
        self.assertEqual(self.athena.sleep_time, 0)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_query_without_token(self, mock_conn):
        mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
        result = self.athena.run_query(
            query=MOCK_DATA['query'],
            query_context=mock_query_context,
            result_configuration=mock_result_configuration,
        )
        expected_call_params = {
            'QueryString': MOCK_DATA['query'],
            'QueryExecutionContext': mock_query_context,
            'ResultConfiguration': mock_result_configuration,
            'WorkGroup': MOCK_DATA['workgroup'],
        }
        mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
        self.assertEqual(result, MOCK_DATA['query_execution_id'])

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_query_with_token(self, mock_conn):
        mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
        result = self.athena.run_query(
            query=MOCK_DATA['query'],
            query_context=mock_query_context,
            result_configuration=mock_result_configuration,
            client_request_token=MOCK_DATA['client_request_token'],
        )
        expected_call_params = {
            'QueryString': MOCK_DATA['query'],
            'QueryExecutionContext': mock_query_context,
            'ResultConfiguration': mock_result_configuration,
            'ClientRequestToken': MOCK_DATA['client_request_token'],
            'WorkGroup': MOCK_DATA['workgroup'],
        }
        mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
        self.assertEqual(result, MOCK_DATA['query_execution_id'])

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
        result = self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
        self.assertIsNone(result)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_get_query_results_with_default_params(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
        self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
        expected_call_params = {'QueryExecutionId': MOCK_DATA['query_execution_id'], 'MaxResults': 1000}
        mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_get_query_results_with_next_token(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
        self.athena.get_query_results(
            query_execution_id=MOCK_DATA['query_execution_id'], next_token_id=MOCK_DATA['next_token_id']
        )
        expected_call_params = {
            'QueryExecutionId': MOCK_DATA['query_execution_id'],
            'NextToken': MOCK_DATA['next_token_id'],
            'MaxResults': 1000,
        }
        mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_get_paginator_with_non_succeeded_query(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
        result = self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
        self.assertIsNone(result)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_get_paginator_with_default_params(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
        self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
        expected_call_params = {
            'QueryExecutionId': MOCK_DATA['query_execution_id'],
            'PaginationConfig': {'MaxItems': None, 'PageSize': None, 'StartingToken': None},
        }
        mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_get_paginator_with_pagination_config(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
        self.athena.get_query_results_paginator(
            query_execution_id=MOCK_DATA['query_execution_id'],
            max_items=MOCK_DATA['max_items'],
            page_size=MOCK_DATA['max_items'],
            starting_token=MOCK_DATA['next_token_id'],
        )
        expected_call_params = {
            'QueryExecutionId': MOCK_DATA['query_execution_id'],
            'PaginationConfig': {
                'MaxItems': MOCK_DATA['max_items'],
                'PageSize': MOCK_DATA['max_items'],
                'StartingToken': MOCK_DATA['next_token_id'],
            },
        }
        mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params)

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_poll_query_when_final(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
        result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id'])
        mock_conn.return_value.get_query_execution.assert_called_once()
        self.assertEqual(result, 'SUCCEEDED')

    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_poll_query_with_timeout(self, mock_conn):
        mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
        result = self.athena.poll_query_status(
            query_execution_id=MOCK_DATA['query_execution_id'], max_tries=1
        )
        mock_conn.return_value.get_query_execution.assert_called_once()
        self.assertEqual(result, 'RUNNING')
 def get_hook(self):
     return AWSAthenaHook(self.aws_conn_id, self.sleep_time)
예제 #5
0
파일: athena.py 프로젝트: zjkanjie/airflow
 def hook(self) -> AWSAthenaHook:
     """Create and return an AWSAthenaHook"""
     return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
예제 #6
0
 def get_hook(self):
     """Create and return an AWSAthenaHook"""
     if not self.hook:
         self.hook = AWSAthenaHook(self.aws_conn_id, self.sleep_time)
     return self.hook