def test_run_query_botocore_error(self) -> None: report = 'Unknown parameter in QueryExecutionContext: "banana", must be one of: Database, Catalog' mock_athena = Mock(start_query_execution=Mock(side_effect=ParamValidationError(report=report))) with self.assertRaises(exception.RunQueryException) as ex: AwsAthenaAsyncClient(mock_athena).run_query("some query") self.assertIn(report, ex.exception.args[0])
def test_has_query_succeeded(self) -> None: query_id = "9876-6543" with patch( "src.clients.aws_athena_async_client.AwsAthenaAsyncClient._is_query_state_in" ) as mock_is_query_state_in: AwsAthenaAsyncClient(Mock()).has_query_succeeded(query_id) mock_is_query_state_in.assert_called_once_with(query_id, states.SUCCESS_STATES)
def test_has_query_completed(self) -> None: query_id = "4321-8765" with patch( "src.clients.aws_athena_async_client.AwsAthenaAsyncClient._is_query_state_in" ) as mock_is_query_state_in: AwsAthenaAsyncClient(Mock()).has_query_completed(query_id) mock_is_query_state_in.assert_called_once_with(query_id, states.COMPLETED_STATES)
def test_query_state_is_not_in_completed_states(self) -> None: query_id = "6da1f9be-772e-4a19-a542-f9f13e037707" for state in [states.QUERY_QUEUED, states.QUERY_RUNNING]: mock_athena = Mock(get_query_execution=Mock(return_value={"QueryExecution": {"Status": {"State": state}}})) self.assertFalse( AwsAthenaAsyncClient(mock_athena)._is_query_state_in(query_id, states.COMPLETED_STATES), f"query should not be considered completed when status is {state}", ) mock_athena.get_query_execution.assert_called_once_with(QueryExecutionId=query_id)
def test_query_state_is_in_completed_states(self) -> None: query_id = "48068afb-edde-4e9c-bcef-6bfa29987b1a" for state in [states.QUERY_SUCCEEDED, states.QUERY_FAILED, states.QUERY_CANCELLED]: mock_athena = Mock(get_query_execution=Mock(return_value={"QueryExecution": {"Status": {"State": state}}})) self.assertTrue( AwsAthenaAsyncClient(mock_athena)._is_query_state_in(query_id, states.COMPLETED_STATES), f"query should be considered completed when status is {state}", ) mock_athena.get_query_execution.assert_called_once_with(QueryExecutionId=query_id)
def test_run_query_success(self) -> None: query = "CREATE DATABASE something" mock_athena = Mock(start_query_execution=Mock(return_value={"QueryExecutionId": "1234"})) query_id = AwsAthenaAsyncClient(mock_athena).run_query(query) self.assertEqual("1234", query_id) mock_athena.start_query_execution.assert_called_once_with( QueryString=query, QueryExecutionContext={"Catalog": "AwsDataCatalog"}, ResultConfiguration={"OutputLocation": "s3://query-results-bucket"}, )
def test_get_query_results_has_results(self) -> None: query_id = "48068afb-edde-4e9c-bcef-6bfa29987b1a" mock_athena = Mock(get_query_results=Mock(return_value=queries_results.GET_EVENT_USAGE_COUNT_RESULTS)) self.assertEqual( [ {"Data": [{"VarCharValue": "GetParameter"}, {"VarCharValue": "274"}]}, {"Data": [{"VarCharValue": "DescribeInstanceInformation"}, {"VarCharValue": "1"}]}, {"Data": [{"VarCharValue": "GetParameters"}, {"VarCharValue": "570"}]}, {"Data": [{"VarCharValue": "ListAssociations"}, {"VarCharValue": "1"}]}, ], AwsAthenaAsyncClient(mock_athena).get_query_results(query_id), ) mock_athena.get_query_results.assert_called_once_with(QueryExecutionId=query_id)
def test_run_query_in_db_success(self) -> None: query = "SELECT something FROM somewhere WHERE other_thing = some_value" database = "some_database" mock_athena = Mock(start_query_execution=Mock(return_value={"QueryExecutionId": "1234"})) query_id = AwsAthenaAsyncClient(mock_athena).run_query(query, database) self.assertEqual("1234", query_id) mock_athena.start_query_execution.assert_called_once_with( QueryString=query, QueryExecutionContext={"Catalog": "AwsDataCatalog", "Database": database}, ResultConfiguration={"OutputLocation": "s3://query-results-bucket"}, )
def test_run_query_client_error(self) -> None: error_message = "MultiFactorAuthentication failed with invalid MFA one time pass code." mock_athena = Mock( start_query_execution=Mock( side_effect=ClientError( operation_name="AssumeRole", error_response={ "Error": { "Code": "AccessDenied", "Message": error_message, } }, ) ) ) with self.assertRaises(exception.RunQueryException) as ex: AwsAthenaAsyncClient(mock_athena).run_query("some query") self.assertIn(error_message, ex.exception.args[0])
def test_query_state_is_unknown(self) -> None: query_id = "6da1f9be-772e-4a19-a542-f9f13e037707" error_message = "some client error" mock_athena = Mock( get_query_execution=Mock( side_effect=ClientError( operation_name="AssumeRole", error_response={ "Error": { "Code": "AccessDenied", "Message": error_message, } }, ) ) ) with self.assertRaises(exception.UnknownQueryStateException) as ex: AwsAthenaAsyncClient(mock_athena)._is_query_state_in(query_id, states.COMPLETED_STATES) self.assertIn(error_message, ex.exception.args[0])
def assert_success_query_run( test: AwsScannerTestCase, method_under_test: str, method_args: Dict[str, str], query: str, raise_on_failure: Type[Exception], ) -> None: with patch( "src.clients.aws_athena_async_client.AwsAthenaAsyncClient.run_query", return_value="1234-5678-9012" ) as mock_run_query: query_exec_response = getattr(AwsAthenaAsyncClient(Mock()), method_under_test)(**method_args) test.assertEqual("1234-5678-9012", query_exec_response) if "database" in method_args: mock_run_query.assert_called_once_with( query=query, database=method_args["database"], raise_on_failure=raise_on_failure ) else: mock_run_query.assert_called_once_with(query=query, raise_on_failure=raise_on_failure)
def test_get_query_results_failure(self) -> None: query_id = "6da1f9be-772e-4a19-a542-f9f13e037707" error_message = "Query has not yet finished. Current state: QUEUED" mock_athena = Mock( get_query_results=Mock( side_effect=ClientError( operation_name="GetQueryResults", error_response={ "Error": { "Code": "InvalidRequestException", "Message": error_message, } }, ) ) ) with self.assertRaises(exception.GetQueryResultsException) as ex: AwsAthenaAsyncClient(mock_athena).get_query_results(query_id) self.assertIn(error_message, ex.exception.args[0])
def get_client(self) -> AwsAthenaAsyncClient: return AwsAthenaAsyncClient(Mock(list_table_metadata=Mock(side_effect=self.list_table_metadata)))
def test_get_query_error(self) -> None: query_id = "5789-3472-6589" error = "FAILED: SemanticException [Error 10072]: Database does not exist: 1234" mock_athena = Mock(get_query_execution=Mock(return_value=queries_results.DROP_DATABASE_EXECUTION_FAILURE)) self.assertEqual(error, AwsAthenaAsyncClient(mock_athena).get_query_error(query_id)) mock_athena.get_query_execution.assert_called_once_with(QueryExecutionId=query_id)
def test_get_query_results_has_empty_results(self) -> None: query_id = "48068afb-edde-4e9c-bcef-6bfa29987b1a" mock_athena = Mock(get_query_results=Mock(return_value=queries_results.GET_EVENT_USAGE_COUNT_EMPTY_RESULTS)) self.assertEqual([], AwsAthenaAsyncClient(mock_athena).get_query_results(query_id)) mock_athena.get_query_results.assert_called_once_with(QueryExecutionId=query_id)
def get_client(self, catalog: str) -> AwsAthenaAsyncClient: client = AwsAthenaAsyncClient(Mock(list_databases=Mock(side_effect=self.list_databases))) client._catalog = catalog return client
def __init__(self, boto_athena: BaseClient): self._athena_async = AwsAthenaAsyncClient(boto_athena)
def assert_failure_query_run( test: AwsScannerTestCase, method_under_test: str, method_args: Dict[str, str], raise_on_failure: Type[Exception] ) -> None: mock_athena = Mock(start_query_execution=Mock(side_effect=ParamValidationError(report="boom"))) with test.assertRaises(raise_on_failure): getattr(AwsAthenaAsyncClient(mock_athena), method_under_test)(**method_args)
class AwsAthenaClient: def __init__(self, boto_athena: BaseClient): self._athena_async = AwsAthenaAsyncClient(boto_athena) def create_database(self, database_name: str) -> None: self._wait_for_success( query_id=self._athena_async.create_database( database_name=database_name), timeout_seconds=60, raise_on_failure=exceptions.CreateDatabaseException, ) def drop_database(self, database_name: str) -> None: self._wait_for_success( query_id=self._athena_async.drop_database( database_name=database_name), timeout_seconds=60, raise_on_failure=exceptions.DropDatabaseException, ) def create_table(self, database: str, account: Account) -> None: self._wait_for_success( query_id=self._athena_async.create_table(database=database, account=account), timeout_seconds=60, raise_on_failure=exceptions.CreateTableException, ) def drop_table(self, database: str, table: str) -> None: self._wait_for_success( query_id=self._athena_async.drop_table(database=database, table=table), timeout_seconds=60, raise_on_failure=exceptions.DropTableException, ) def add_partition(self, database: str, account: Account, partition: AwsAthenaDataPartition) -> None: self._wait_for_success( query_id=self._athena_async.add_partition(database=database, account=account, partition=partition), timeout_seconds=120, raise_on_failure=exceptions.AddPartitionException, ) def list_databases(self) -> List[str]: return self._athena_async.list_databases() def list_tables(self, database: str) -> List[str]: return self._athena_async.list_tables(database) def run_query(self, database: str, query: str) -> List[Any]: return self._wait_for_success( query_id=self._athena_async.run_query(query=query, database=database), timeout_seconds=300, raise_on_failure=exceptions.RunQueryException, ) def _wait_for_completion(self, query_id: str, timeout_seconds: int) -> None: for _ in range(timeout_seconds): if self._athena_async.has_query_completed(query_id): return sleep(self._get_default_delay()) raise exceptions.TimeoutException(f"query execution id: {query_id}") def _wait_for_success(self, query_id: str, timeout_seconds: int, raise_on_failure: Type[Exception]) -> List[Any]: self._wait_for_completion(query_id, timeout_seconds) if self._athena_async.has_query_succeeded(query_id): return self._athena_async.get_query_results(query_id) raise raise_on_failure(self._athena_async.get_query_error(query_id)) @staticmethod def _get_default_delay() -> int: return 1