Example #1
0
    def test_create_athena_table_tasks(self) -> None:
        mock_orgs = Mock(find_account_by_ids=Mock(return_value=[(account("8", "eight")), (account("3", "three"))]))
        year, month = date.today().year, date.today().month
        tasks = AwsTaskBuilder(mock_orgs, ["8", "3"]).create_athena_table_tasks(year, month)

        self.assert_cloudtrail_tasks_equal(AwsCreateAthenaTableTask(account("8", "eight"), partition()), tasks[0])
        self.assert_cloudtrail_tasks_equal(AwsCreateAthenaTableTask(account("3", "three"), partition()), tasks[1])
Example #2
0
    def test_principal_by_ip_finder_tasks(self) -> None:
        mock_orgs = Mock(get_target_accounts=Mock(return_value=[(account("1", "one")), (account("2", "two"))]))
        year, month, source_ip = date.today().year, date.today().month, "127.0.0.1"
        tasks = AwsTaskBuilder(mock_orgs).principal_by_ip_finder_tasks(year, month, source_ip)

        self.assert_ip_tasks_equal(AwsPrincipalByIPFinderTask(account("1", "one"), partition(), "127.0.0.1"), tasks[0])
        self.assert_ip_tasks_equal(AwsPrincipalByIPFinderTask(account("2", "two"), partition(), "127.0.0.1"), tasks[1])
    def test_get_client(self) -> None:
        mock_client = Mock()
        mock_boto = Mock(client=Mock(return_value=mock_client))
        mock_assume_role = Mock(
            return_value=AwsCredentials("access", "secret", "session"))

        with patch(
                "src.clients.aws_client_factory.AwsClientFactory._get_session_token"
        ):
            with patch(
                    "src.clients.aws_client_factory.AwsClientFactory._assume_role",
                    mock_assume_role):
                with patch("src.clients.aws_client_factory.boto3", mock_boto):
                    client = AwsClientFactory(self.mfa,
                                              self.username)._get_client(
                                                  self.service_name, account(),
                                                  self.role)
        self.assertEqual(mock_client, client)
        mock_assume_role.assert_called_once_with(account(), self.role)
        mock_boto.client.assert_called_once_with(
            service_name="some_service",
            aws_access_key_id="access",
            aws_secret_access_key="secret",
            aws_session_token="session",
        )
Example #4
0
    def test_role_usage_scanner_tasks(self) -> None:
        mock_orgs = Mock(get_target_accounts=Mock(return_value=[(account("1", "one")), (account("2", "two"))]))
        year, month, role = date.today().year, date.today().month, "SomeRole"
        tasks = AwsTaskBuilder(mock_orgs).role_usage_scanner_tasks(year, month, role)

        self.assert_role_tasks_equal(AwsRoleUsageScannerTask(account("1", "one"), partition(), "SomeRole"), tasks[0])
        self.assert_role_tasks_equal(AwsRoleUsageScannerTask(account("2", "two"), partition(), "SomeRole"), tasks[1])
 def test_setup(self) -> None:
     mock_athena = Mock()
     t = cloudtrail_task()
     t._setup(mock_athena)
     mock_athena.assert_has_calls([
         call.create_database(t._database),
         call.create_table(t._database, account()),
         call.add_partition(t._database, account(), partition()),
     ])
 def test_get_ssm_client(self, _: Mock) -> None:
     ssm_boto_client = Mock()
     with patch(
             f"{self.factory_path}.get_ssm_boto_client",
             side_effect=lambda acc: ssm_boto_client
             if acc == account() else None,
     ):
         ssm_client = AwsClientFactory(
             self.mfa, self.username).get_ssm_client(account())
         self.assertEqual(ssm_client._ssm, ssm_boto_client)
    def test_assume_role(self) -> None:
        assumed_role_creds = {
            "Credentials": {
                "AccessKeyId": "some_access_key",
                "SecretAccessKey": "some_secret_access_key",
                "SessionToken": "some_session_token",
            }
        }
        mock_sts_client = Mock(assume_role=Mock(
            return_value=assumed_role_creds))
        mock_boto = Mock(client=Mock(return_value=mock_sts_client))
        session_token = AwsCredentials("session_access_key",
                                       "session_secret_key", "session_token")

        with patch(
                "src.clients.aws_client_factory.AwsClientFactory._get_session_token",
                return_value=session_token):
            with patch("src.clients.aws_client_factory.boto3", mock_boto):
                creds = AwsClientFactory(self.mfa, self.username)._assume_role(
                    account(), self.role)
        self.assertEqual(
            creds,
            AwsCredentials("some_access_key", "some_secret_access_key",
                           "some_session_token"))
        mock_boto.client.assert_called_once_with(
            service_name="sts",
            aws_access_key_id="session_access_key",
            aws_secret_access_key="session_secret_key",
            aws_session_token="session_token",
        )
        mock_sts_client.assume_role.assert_called_once_with(
            DurationSeconds=3600,
            RoleArn="arn:aws:iam::account_id:role/some_role",
            RoleSessionName="boto3_assuming_some_role",
        )
    def test_run_tasks(self) -> None:
        succeeding_task_1 = athena_task(description="some task")
        succeeding_task_1._run_task = lambda _: {"outcome_1": "success_1"}

        succeeding_task_2 = athena_task(description="other task")
        succeeding_task_2._run_task = lambda _: {"outcome_2": "success_2"}

        failing_task = athena_task(account=account("5678", "wrong account"),
                                   description="boom")
        failing_task._run_task = lambda _: _raise(AwsScannerException("oops"))

        tasks = [succeeding_task_1, failing_task, succeeding_task_2]

        with self.assertLogs("AwsParallelTaskRunner",
                             level="ERROR") as error_log:
            reports = AwsParallelTaskRunner(Mock()).run(tasks)

        self.assertEqual(2, len(reports),
                         "there should only be two task reports")
        self.assertIn(
            task_report(description="some task",
                        results={"outcome_1": "success_1"},
                        partition=None), reports)
        self.assertIn(
            task_report(description="other task",
                        results={"outcome_2": "success_2"},
                        partition=None), reports)

        self.assertEqual(
            [
                "ERROR:AwsParallelTaskRunner:task 'boom' for 'wrong account (5678)' failed with: 'oops'"
            ],
            error_log.output,
        )
 def test_get_organizations_boto_client(self) -> None:
     self.assert_get_client(
         method_under_test="get_organizations_boto_client",
         service="organizations",
         target_account=account(identifier="999888777666", name="root"),
         role="orgs_role",
     )
Example #10
0
 def test_create_table(self) -> None:
     assert_query_run(
         test=self,
         method_under_test="create_table",
         method_args={"database": "some_database", "account": account("908173625490", "some_account")},
         query=queries.CREATE_TABLE,
         raise_on_failure=exception.CreateTableException,
     )
 def test_get_athena_boto_client(self) -> None:
     self.assert_get_client(
         method_under_test="get_athena_boto_client",
         service="athena",
         target_account=account(identifier="555666777888",
                                name="cloudtrail"),
         role="cloudtrail_role",
     )
 def test_get_s3_boto_client(self) -> None:
     s3_account = account(identifier="122344566788", name="some_s3_account")
     self.assert_get_client(
         method_under_test="get_s3_boto_client",
         method_args={"account": s3_account},
         service="s3",
         target_account=s3_account,
         role="s3_role",
     )
 def test_teardown(self) -> None:
     mock_athena = Mock()
     t = cloudtrail_task()
     t._teardown(mock_athena)
     mock_athena.assert_has_calls([
         call.drop_table(t._database,
                         account().identifier),
         call.drop_database(t._database),
     ])
 def test_run_task(self) -> None:
     buckets = [
         bucket("bucket-1"),
         bucket("bucket-2"),
         bucket("another-bucket")
     ]
     s3_client = Mock(list_buckets=Mock(return_value=buckets))
     task_report = AwsAuditS3Task(account())._run_task(s3_client)
     self.assertEqual({"buckets": buckets}, task_report)
Example #15
0
 def test_find_account_by_id(self) -> None:
     account_id = "123456789012"
     mock_boto_orgs = Mock(describe_account=Mock(
         side_effect=lambda AccountId: responses.DESCRIBE_ACCOUNT
         if AccountId == account_id else None))
     self.assertEqual(
         account(account_id, "some test account"),
         AwsOrganizationsClient(mock_boto_orgs).find_account_by_id(
             account_id),
     )
Example #16
0
 def test_create_table(self, mock_wait_for_success: Mock) -> None:
     self.assert_wait_for_success(
         mock_wait_for_success=mock_wait_for_success,
         method_under_test="create_table",
         method_args={
             "database": "some_db",
             "account": account()
         },
         timeout_seconds=60,
         raise_on_failure=exceptions.CreateTableException,
     )
Example #17
0
    def test_aws_create_athena_table_task(self) -> None:
        task = AwsCreateAthenaTableTask(account(), partition())
        self.assertIsInstance(task, AwsCloudTrailTask)

        mock_athena = Mock()
        run_results = task._run_task(mock_athena)

        mock_athena.assert_not_called()
        self.assertEqual("account_id", run_results["table"])
        self.assertTrue(
            run_results["database"].startswith("some_prefix_account_id_"))
Example #18
0
 def test_add_partition(self, mock_wait_for_success: Mock) -> None:
     self.assert_wait_for_success(
         mock_wait_for_success=mock_wait_for_success,
         method_under_test="add_partition",
         method_args={
             "database": "some_db",
             "account": account(),
             "partition": partition(2019, 8),
         },
         timeout_seconds=120,
         raise_on_failure=exceptions.AddPartitionException,
     )
Example #19
0
 def test_add_partition(self) -> None:
     assert_query_run(
         test=self,
         method_under_test="add_partition",
         method_args={
             "database": "some_database",
             "account": account("908173625490", "some_account"),
             "partition": partition(2020, 7),
         },
         query=queries.ADD_PARTITION_YEAR_MONTH,
         raise_on_failure=exception.AddPartitionException,
     )
Example #20
0
class TestAwsAthenaCleanerTask(AwsScannerTestCase):
    database_mappings = {
        "db_1": ["table_1", "table_2", "table_3"],
        "some_prefix_db_2": ["table_1", "table_2"],
        "db_3": ["table_1"],
        "some_prefix_db_4": ["table_1", "table_2", "table_3"],
        "some_prefix_db_5": [],
    }
    expected_report = task_report(
        account=account("555666777888", "cloudtrail"),
        description="clean scanner leftovers",
        partition=None,
        results={
            "dropped_tables": [
                "some_prefix_db_2.table_1",
                "some_prefix_db_2.table_2",
                "some_prefix_db_4.table_1",
                "some_prefix_db_4.table_2",
                "some_prefix_db_4.table_3",
            ],
            "dropped_databases":
            ["some_prefix_db_2", "some_prefix_db_4", "some_prefix_db_5"],
        },
    )

    def test_clean_task_databases(self) -> None:
        mock_athena = Mock(
            list_databases=Mock(
                return_value=list(self.database_mappings.keys())),
            list_tables=Mock(
                side_effect=lambda db: self.database_mappings.get(db)),
        )
        self.assertEqual(self.expected_report,
                         AwsAthenaCleanerTask().run(mock_athena))
        mock_athena.assert_has_calls([
            call.list_databases(),
            call.list_tables("some_prefix_db_2"),
            call.drop_table("some_prefix_db_2", "table_1"),
            call.drop_table("some_prefix_db_2", "table_2"),
            call.list_tables("some_prefix_db_4"),
            call.drop_table("some_prefix_db_4", "table_1"),
            call.drop_table("some_prefix_db_4", "table_2"),
            call.drop_table("some_prefix_db_4", "table_3"),
            call.list_tables("some_prefix_db_5"),
            call.drop_database("some_prefix_db_2"),
            call.drop_database("some_prefix_db_4"),
            call.drop_database("some_prefix_db_5"),
        ])
 def test_run_task(self) -> None:
     parameters = [
         secure_string_parameter("secure_1"),
         string_list_parameter("list"),
         string_parameter("string_1"),
         secure_string_parameter("secure_2"),
         string_parameter("string_2"),
         string_parameter("string_3"),
     ]
     ssm_client = Mock(list_parameters=Mock(return_value=parameters))
     task_report = AwsListSSMParametersTask(account())._run_task(ssm_client)
     self.assertEqual(
         {
             "ssm_parameters": parameters,
             "type_count": {
                 "SecureString": 2,
                 "StringList": 1,
                 "String": 3
             },
             "total_count": 6,
         },
         task_report,
     )
Example #22
0
 def test_find_account_by_ids(self) -> None:
     with patch(
             "src.clients.aws_organizations_client.AwsOrganizationsClient.find_account_by_id",
             side_effect=lambda acc_id: {
                 "3": account("3", "account 3"),
                 "8": account("8", "account 8"),
                 "2": None,
                 "5": account("5", "account 5"),
             }.get(acc_id),
     ):
         self.assertEqual(
             [
                 account("8", "account 8"),
                 account("5", "account 5"),
                 account("3", "account 3")
             ],
             AwsOrganizationsClient(Mock()).find_account_by_ids(
                 ["8", "2", "5", "3"]),
         )
Example #23
0
 def test_teardown_do_nothing(self) -> None:
     task = AwsCreateAthenaTableTask(account(), partition())
     mock_athena = Mock()
     task._teardown(mock_athena)
     mock_athena.assert_not_called()
 def __build_task_under_test(task_type, task_args):
     task = task_type(account=account(), partition=partition(), **task_args)
     task._database = "some_db"
     return task
Example #25
0
    def test_audit_s3_tasks(self) -> None:
        mock_orgs = Mock(find_account_by_ids=Mock(return_value=[(account("3", "three")), (account("5", "five"))]))
        tasks = AwsTaskBuilder(mock_orgs, ["5", "3"]).audit_s3_tasks()

        self.assert_tasks_equal(AwsAuditS3Task(account("3", "three")), tasks[0])
        self.assert_tasks_equal(AwsAuditS3Task(account("5", "five")), tasks[1])
 def test_run_s3_task(self) -> None:
     client = Mock()
     client_factory = Mock(get_s3_client=Mock(side_effect=lambda acc: client if acc == account() else None))
     task = s3_task()
     task.run = Mock(side_effect=lambda c: task_report() if c == client else None)  # type: ignore
     self.assertEqual(task_report(), AwsTaskRunner(client_factory)._run_task(task))
 def test_run_task(self) -> None:
     accounts = [account(), account()]
     mock_orgs_client = Mock(get_all_accounts=Mock(return_value=accounts))
     self.assertEqual({"accounts": accounts}, AwsListAccountsTask()._run_task(mock_orgs_client))
Example #28
0
    def test_list_ssm_parameters_tasks(self) -> None:
        mock_orgs = Mock(find_account_by_ids=Mock(return_value=[(account("2", "two")), (account("4", "four"))]))
        tasks = AwsTaskBuilder(mock_orgs, ["4", "2"]).list_ssm_parameters_tasks()

        self.assert_tasks_equal(AwsListSSMParametersTask(account("2", "two")), tasks[0])
        self.assert_tasks_equal(AwsListSSMParametersTask(account("4", "four")), tasks[1])
Example #29
0
 def test_get_account(self) -> None:
     task = aws_task()
     self.assertEqual(task.account, account())
            "Name": "Root 1 > Org Unit 2 > Org Unit 2 > Account 1",
        },
        {
            "Id": "242243167582",
            "Name": "Root 1 > Org Unit 2 > Org Unit 2 > Account 2",
        },
    ]
}
EMPTY_ACCOUNTS = {"Accounts": []}
EXPECTED_ORGANIZATION_TREE = [
    organizational_unit(
        identifier="r-root1",
        name="Root 1",
        root=True,
        accounts=[
            account(identifier="987654655432", name="Root 1 > Account 1")
        ],
        org_units=[
            organizational_unit(
                identifier="ou-root1-1",
                name="Root 1 > Org Unit 1",
                accounts=[
                    account(identifier="987651321565",
                            name="Root 1 > Org Unit 1 > Account 1"),
                    account(identifier="643758194672",
                            name="Root 1 > Org Unit 1 > Account 2"),
                    account(identifier="594678488453",
                            name="Root 1 > Org Unit 1 > Account 3"),
                ],
                org_units=[],
            ),