Exemple #1
0
    def test_run_query_botocore_error(self) -> None:
        report = 'Unknown parameter in QueryExecutionContext: "banana", must be one of: Database, Catalog'
        mock_athena = Mock(start_query_execution=Mock(side_effect=ParamValidationError(report=report)))

        with self.assertRaises(exception.RunQueryException) as ex:
            AwsAthenaAsyncClient(mock_athena).run_query("some query")
        self.assertIn(report, ex.exception.args[0])
Exemple #2
0
 def test_has_query_succeeded(self) -> None:
     query_id = "9876-6543"
     with patch(
         "src.clients.aws_athena_async_client.AwsAthenaAsyncClient._is_query_state_in"
     ) as mock_is_query_state_in:
         AwsAthenaAsyncClient(Mock()).has_query_succeeded(query_id)
     mock_is_query_state_in.assert_called_once_with(query_id, states.SUCCESS_STATES)
Exemple #3
0
 def test_has_query_completed(self) -> None:
     query_id = "4321-8765"
     with patch(
         "src.clients.aws_athena_async_client.AwsAthenaAsyncClient._is_query_state_in"
     ) as mock_is_query_state_in:
         AwsAthenaAsyncClient(Mock()).has_query_completed(query_id)
     mock_is_query_state_in.assert_called_once_with(query_id, states.COMPLETED_STATES)
Exemple #4
0
 def test_query_state_is_not_in_completed_states(self) -> None:
     query_id = "6da1f9be-772e-4a19-a542-f9f13e037707"
     for state in [states.QUERY_QUEUED, states.QUERY_RUNNING]:
         mock_athena = Mock(get_query_execution=Mock(return_value={"QueryExecution": {"Status": {"State": state}}}))
         self.assertFalse(
             AwsAthenaAsyncClient(mock_athena)._is_query_state_in(query_id, states.COMPLETED_STATES),
             f"query should not be considered completed when status is {state}",
         )
         mock_athena.get_query_execution.assert_called_once_with(QueryExecutionId=query_id)
Exemple #5
0
 def test_query_state_is_in_completed_states(self) -> None:
     query_id = "48068afb-edde-4e9c-bcef-6bfa29987b1a"
     for state in [states.QUERY_SUCCEEDED, states.QUERY_FAILED, states.QUERY_CANCELLED]:
         mock_athena = Mock(get_query_execution=Mock(return_value={"QueryExecution": {"Status": {"State": state}}}))
         self.assertTrue(
             AwsAthenaAsyncClient(mock_athena)._is_query_state_in(query_id, states.COMPLETED_STATES),
             f"query should be considered completed when status is {state}",
         )
         mock_athena.get_query_execution.assert_called_once_with(QueryExecutionId=query_id)
Exemple #6
0
    def test_run_query_success(self) -> None:
        query = "CREATE DATABASE something"
        mock_athena = Mock(start_query_execution=Mock(return_value={"QueryExecutionId": "1234"}))

        query_id = AwsAthenaAsyncClient(mock_athena).run_query(query)

        self.assertEqual("1234", query_id)
        mock_athena.start_query_execution.assert_called_once_with(
            QueryString=query,
            QueryExecutionContext={"Catalog": "AwsDataCatalog"},
            ResultConfiguration={"OutputLocation": "s3://query-results-bucket"},
        )
Exemple #7
0
 def test_get_query_results_has_results(self) -> None:
     query_id = "48068afb-edde-4e9c-bcef-6bfa29987b1a"
     mock_athena = Mock(get_query_results=Mock(return_value=queries_results.GET_EVENT_USAGE_COUNT_RESULTS))
     self.assertEqual(
         [
             {"Data": [{"VarCharValue": "GetParameter"}, {"VarCharValue": "274"}]},
             {"Data": [{"VarCharValue": "DescribeInstanceInformation"}, {"VarCharValue": "1"}]},
             {"Data": [{"VarCharValue": "GetParameters"}, {"VarCharValue": "570"}]},
             {"Data": [{"VarCharValue": "ListAssociations"}, {"VarCharValue": "1"}]},
         ],
         AwsAthenaAsyncClient(mock_athena).get_query_results(query_id),
     )
     mock_athena.get_query_results.assert_called_once_with(QueryExecutionId=query_id)
Exemple #8
0
    def test_run_query_in_db_success(self) -> None:
        query = "SELECT something FROM somewhere WHERE other_thing = some_value"
        database = "some_database"
        mock_athena = Mock(start_query_execution=Mock(return_value={"QueryExecutionId": "1234"}))

        query_id = AwsAthenaAsyncClient(mock_athena).run_query(query, database)

        self.assertEqual("1234", query_id)
        mock_athena.start_query_execution.assert_called_once_with(
            QueryString=query,
            QueryExecutionContext={"Catalog": "AwsDataCatalog", "Database": database},
            ResultConfiguration={"OutputLocation": "s3://query-results-bucket"},
        )
Exemple #9
0
 def test_run_query_client_error(self) -> None:
     error_message = "MultiFactorAuthentication failed with invalid MFA one time pass code."
     mock_athena = Mock(
         start_query_execution=Mock(
             side_effect=ClientError(
                 operation_name="AssumeRole",
                 error_response={
                     "Error": {
                         "Code": "AccessDenied",
                         "Message": error_message,
                     }
                 },
             )
         )
     )
     with self.assertRaises(exception.RunQueryException) as ex:
         AwsAthenaAsyncClient(mock_athena).run_query("some query")
     self.assertIn(error_message, ex.exception.args[0])
Exemple #10
0
 def test_query_state_is_unknown(self) -> None:
     query_id = "6da1f9be-772e-4a19-a542-f9f13e037707"
     error_message = "some client error"
     mock_athena = Mock(
         get_query_execution=Mock(
             side_effect=ClientError(
                 operation_name="AssumeRole",
                 error_response={
                     "Error": {
                         "Code": "AccessDenied",
                         "Message": error_message,
                     }
                 },
             )
         )
     )
     with self.assertRaises(exception.UnknownQueryStateException) as ex:
         AwsAthenaAsyncClient(mock_athena)._is_query_state_in(query_id, states.COMPLETED_STATES)
     self.assertIn(error_message, ex.exception.args[0])
Exemple #11
0
def assert_success_query_run(
    test: AwsScannerTestCase,
    method_under_test: str,
    method_args: Dict[str, str],
    query: str,
    raise_on_failure: Type[Exception],
) -> None:
    with patch(
        "src.clients.aws_athena_async_client.AwsAthenaAsyncClient.run_query", return_value="1234-5678-9012"
    ) as mock_run_query:
        query_exec_response = getattr(AwsAthenaAsyncClient(Mock()), method_under_test)(**method_args)

    test.assertEqual("1234-5678-9012", query_exec_response)
    if "database" in method_args:
        mock_run_query.assert_called_once_with(
            query=query, database=method_args["database"], raise_on_failure=raise_on_failure
        )
    else:
        mock_run_query.assert_called_once_with(query=query, raise_on_failure=raise_on_failure)
Exemple #12
0
 def test_get_query_results_failure(self) -> None:
     query_id = "6da1f9be-772e-4a19-a542-f9f13e037707"
     error_message = "Query has not yet finished. Current state: QUEUED"
     mock_athena = Mock(
         get_query_results=Mock(
             side_effect=ClientError(
                 operation_name="GetQueryResults",
                 error_response={
                     "Error": {
                         "Code": "InvalidRequestException",
                         "Message": error_message,
                     }
                 },
             )
         )
     )
     with self.assertRaises(exception.GetQueryResultsException) as ex:
         AwsAthenaAsyncClient(mock_athena).get_query_results(query_id)
     self.assertIn(error_message, ex.exception.args[0])
Exemple #13
0
 def get_client(self) -> AwsAthenaAsyncClient:
     return AwsAthenaAsyncClient(Mock(list_table_metadata=Mock(side_effect=self.list_table_metadata)))
Exemple #14
0
 def test_get_query_error(self) -> None:
     query_id = "5789-3472-6589"
     error = "FAILED: SemanticException [Error 10072]: Database does not exist: 1234"
     mock_athena = Mock(get_query_execution=Mock(return_value=queries_results.DROP_DATABASE_EXECUTION_FAILURE))
     self.assertEqual(error, AwsAthenaAsyncClient(mock_athena).get_query_error(query_id))
     mock_athena.get_query_execution.assert_called_once_with(QueryExecutionId=query_id)
Exemple #15
0
 def test_get_query_results_has_empty_results(self) -> None:
     query_id = "48068afb-edde-4e9c-bcef-6bfa29987b1a"
     mock_athena = Mock(get_query_results=Mock(return_value=queries_results.GET_EVENT_USAGE_COUNT_EMPTY_RESULTS))
     self.assertEqual([], AwsAthenaAsyncClient(mock_athena).get_query_results(query_id))
     mock_athena.get_query_results.assert_called_once_with(QueryExecutionId=query_id)
Exemple #16
0
 def get_client(self, catalog: str) -> AwsAthenaAsyncClient:
     client = AwsAthenaAsyncClient(Mock(list_databases=Mock(side_effect=self.list_databases)))
     client._catalog = catalog
     return client
 def __init__(self, boto_athena: BaseClient):
     self._athena_async = AwsAthenaAsyncClient(boto_athena)
Exemple #18
0
def assert_failure_query_run(
    test: AwsScannerTestCase, method_under_test: str, method_args: Dict[str, str], raise_on_failure: Type[Exception]
) -> None:
    mock_athena = Mock(start_query_execution=Mock(side_effect=ParamValidationError(report="boom")))
    with test.assertRaises(raise_on_failure):
        getattr(AwsAthenaAsyncClient(mock_athena), method_under_test)(**method_args)
class AwsAthenaClient:
    def __init__(self, boto_athena: BaseClient):
        self._athena_async = AwsAthenaAsyncClient(boto_athena)

    def create_database(self, database_name: str) -> None:
        self._wait_for_success(
            query_id=self._athena_async.create_database(
                database_name=database_name),
            timeout_seconds=60,
            raise_on_failure=exceptions.CreateDatabaseException,
        )

    def drop_database(self, database_name: str) -> None:
        self._wait_for_success(
            query_id=self._athena_async.drop_database(
                database_name=database_name),
            timeout_seconds=60,
            raise_on_failure=exceptions.DropDatabaseException,
        )

    def create_table(self, database: str, account: Account) -> None:
        self._wait_for_success(
            query_id=self._athena_async.create_table(database=database,
                                                     account=account),
            timeout_seconds=60,
            raise_on_failure=exceptions.CreateTableException,
        )

    def drop_table(self, database: str, table: str) -> None:
        self._wait_for_success(
            query_id=self._athena_async.drop_table(database=database,
                                                   table=table),
            timeout_seconds=60,
            raise_on_failure=exceptions.DropTableException,
        )

    def add_partition(self, database: str, account: Account,
                      partition: AwsAthenaDataPartition) -> None:
        self._wait_for_success(
            query_id=self._athena_async.add_partition(database=database,
                                                      account=account,
                                                      partition=partition),
            timeout_seconds=120,
            raise_on_failure=exceptions.AddPartitionException,
        )

    def list_databases(self) -> List[str]:
        return self._athena_async.list_databases()

    def list_tables(self, database: str) -> List[str]:
        return self._athena_async.list_tables(database)

    def run_query(self, database: str, query: str) -> List[Any]:
        return self._wait_for_success(
            query_id=self._athena_async.run_query(query=query,
                                                  database=database),
            timeout_seconds=300,
            raise_on_failure=exceptions.RunQueryException,
        )

    def _wait_for_completion(self, query_id: str,
                             timeout_seconds: int) -> None:
        for _ in range(timeout_seconds):
            if self._athena_async.has_query_completed(query_id):
                return
            sleep(self._get_default_delay())
        raise exceptions.TimeoutException(f"query execution id: {query_id}")

    def _wait_for_success(self, query_id: str, timeout_seconds: int,
                          raise_on_failure: Type[Exception]) -> List[Any]:
        self._wait_for_completion(query_id, timeout_seconds)
        if self._athena_async.has_query_succeeded(query_id):
            return self._athena_async.get_query_results(query_id)
        raise raise_on_failure(self._athena_async.get_query_error(query_id))

    @staticmethod
    def _get_default_delay() -> int:
        return 1