def test_create_athena_table_tasks(self) -> None: mock_orgs = Mock(find_account_by_ids=Mock(return_value=[(account("8", "eight")), (account("3", "three"))])) year, month = date.today().year, date.today().month tasks = AwsTaskBuilder(mock_orgs, ["8", "3"]).create_athena_table_tasks(year, month) self.assert_cloudtrail_tasks_equal(AwsCreateAthenaTableTask(account("8", "eight"), partition()), tasks[0]) self.assert_cloudtrail_tasks_equal(AwsCreateAthenaTableTask(account("3", "three"), partition()), tasks[1])
def test_principal_by_ip_finder_tasks(self) -> None: mock_orgs = Mock(get_target_accounts=Mock(return_value=[(account("1", "one")), (account("2", "two"))])) year, month, source_ip = date.today().year, date.today().month, "127.0.0.1" tasks = AwsTaskBuilder(mock_orgs).principal_by_ip_finder_tasks(year, month, source_ip) self.assert_ip_tasks_equal(AwsPrincipalByIPFinderTask(account("1", "one"), partition(), "127.0.0.1"), tasks[0]) self.assert_ip_tasks_equal(AwsPrincipalByIPFinderTask(account("2", "two"), partition(), "127.0.0.1"), tasks[1])
def test_get_client(self) -> None: mock_client = Mock() mock_boto = Mock(client=Mock(return_value=mock_client)) mock_assume_role = Mock( return_value=AwsCredentials("access", "secret", "session")) with patch( "src.clients.aws_client_factory.AwsClientFactory._get_session_token" ): with patch( "src.clients.aws_client_factory.AwsClientFactory._assume_role", mock_assume_role): with patch("src.clients.aws_client_factory.boto3", mock_boto): client = AwsClientFactory(self.mfa, self.username)._get_client( self.service_name, account(), self.role) self.assertEqual(mock_client, client) mock_assume_role.assert_called_once_with(account(), self.role) mock_boto.client.assert_called_once_with( service_name="some_service", aws_access_key_id="access", aws_secret_access_key="secret", aws_session_token="session", )
def test_role_usage_scanner_tasks(self) -> None: mock_orgs = Mock(get_target_accounts=Mock(return_value=[(account("1", "one")), (account("2", "two"))])) year, month, role = date.today().year, date.today().month, "SomeRole" tasks = AwsTaskBuilder(mock_orgs).role_usage_scanner_tasks(year, month, role) self.assert_role_tasks_equal(AwsRoleUsageScannerTask(account("1", "one"), partition(), "SomeRole"), tasks[0]) self.assert_role_tasks_equal(AwsRoleUsageScannerTask(account("2", "two"), partition(), "SomeRole"), tasks[1])
def test_setup(self) -> None: mock_athena = Mock() t = cloudtrail_task() t._setup(mock_athena) mock_athena.assert_has_calls([ call.create_database(t._database), call.create_table(t._database, account()), call.add_partition(t._database, account(), partition()), ])
def test_get_ssm_client(self, _: Mock) -> None: ssm_boto_client = Mock() with patch( f"{self.factory_path}.get_ssm_boto_client", side_effect=lambda acc: ssm_boto_client if acc == account() else None, ): ssm_client = AwsClientFactory( self.mfa, self.username).get_ssm_client(account()) self.assertEqual(ssm_client._ssm, ssm_boto_client)
def test_assume_role(self) -> None: assumed_role_creds = { "Credentials": { "AccessKeyId": "some_access_key", "SecretAccessKey": "some_secret_access_key", "SessionToken": "some_session_token", } } mock_sts_client = Mock(assume_role=Mock( return_value=assumed_role_creds)) mock_boto = Mock(client=Mock(return_value=mock_sts_client)) session_token = AwsCredentials("session_access_key", "session_secret_key", "session_token") with patch( "src.clients.aws_client_factory.AwsClientFactory._get_session_token", return_value=session_token): with patch("src.clients.aws_client_factory.boto3", mock_boto): creds = AwsClientFactory(self.mfa, self.username)._assume_role( account(), self.role) self.assertEqual( creds, AwsCredentials("some_access_key", "some_secret_access_key", "some_session_token")) mock_boto.client.assert_called_once_with( service_name="sts", aws_access_key_id="session_access_key", aws_secret_access_key="session_secret_key", aws_session_token="session_token", ) mock_sts_client.assume_role.assert_called_once_with( DurationSeconds=3600, RoleArn="arn:aws:iam::account_id:role/some_role", RoleSessionName="boto3_assuming_some_role", )
def test_run_tasks(self) -> None: succeeding_task_1 = athena_task(description="some task") succeeding_task_1._run_task = lambda _: {"outcome_1": "success_1"} succeeding_task_2 = athena_task(description="other task") succeeding_task_2._run_task = lambda _: {"outcome_2": "success_2"} failing_task = athena_task(account=account("5678", "wrong account"), description="boom") failing_task._run_task = lambda _: _raise(AwsScannerException("oops")) tasks = [succeeding_task_1, failing_task, succeeding_task_2] with self.assertLogs("AwsParallelTaskRunner", level="ERROR") as error_log: reports = AwsParallelTaskRunner(Mock()).run(tasks) self.assertEqual(2, len(reports), "there should only be two task reports") self.assertIn( task_report(description="some task", results={"outcome_1": "success_1"}, partition=None), reports) self.assertIn( task_report(description="other task", results={"outcome_2": "success_2"}, partition=None), reports) self.assertEqual( [ "ERROR:AwsParallelTaskRunner:task 'boom' for 'wrong account (5678)' failed with: 'oops'" ], error_log.output, )
def test_get_organizations_boto_client(self) -> None: self.assert_get_client( method_under_test="get_organizations_boto_client", service="organizations", target_account=account(identifier="999888777666", name="root"), role="orgs_role", )
def test_create_table(self) -> None: assert_query_run( test=self, method_under_test="create_table", method_args={"database": "some_database", "account": account("908173625490", "some_account")}, query=queries.CREATE_TABLE, raise_on_failure=exception.CreateTableException, )
def test_get_athena_boto_client(self) -> None: self.assert_get_client( method_under_test="get_athena_boto_client", service="athena", target_account=account(identifier="555666777888", name="cloudtrail"), role="cloudtrail_role", )
def test_get_s3_boto_client(self) -> None: s3_account = account(identifier="122344566788", name="some_s3_account") self.assert_get_client( method_under_test="get_s3_boto_client", method_args={"account": s3_account}, service="s3", target_account=s3_account, role="s3_role", )
def test_teardown(self) -> None: mock_athena = Mock() t = cloudtrail_task() t._teardown(mock_athena) mock_athena.assert_has_calls([ call.drop_table(t._database, account().identifier), call.drop_database(t._database), ])
def test_run_task(self) -> None: buckets = [ bucket("bucket-1"), bucket("bucket-2"), bucket("another-bucket") ] s3_client = Mock(list_buckets=Mock(return_value=buckets)) task_report = AwsAuditS3Task(account())._run_task(s3_client) self.assertEqual({"buckets": buckets}, task_report)
def test_find_account_by_id(self) -> None: account_id = "123456789012" mock_boto_orgs = Mock(describe_account=Mock( side_effect=lambda AccountId: responses.DESCRIBE_ACCOUNT if AccountId == account_id else None)) self.assertEqual( account(account_id, "some test account"), AwsOrganizationsClient(mock_boto_orgs).find_account_by_id( account_id), )
def test_create_table(self, mock_wait_for_success: Mock) -> None: self.assert_wait_for_success( mock_wait_for_success=mock_wait_for_success, method_under_test="create_table", method_args={ "database": "some_db", "account": account() }, timeout_seconds=60, raise_on_failure=exceptions.CreateTableException, )
def test_aws_create_athena_table_task(self) -> None: task = AwsCreateAthenaTableTask(account(), partition()) self.assertIsInstance(task, AwsCloudTrailTask) mock_athena = Mock() run_results = task._run_task(mock_athena) mock_athena.assert_not_called() self.assertEqual("account_id", run_results["table"]) self.assertTrue( run_results["database"].startswith("some_prefix_account_id_"))
def test_add_partition(self, mock_wait_for_success: Mock) -> None: self.assert_wait_for_success( mock_wait_for_success=mock_wait_for_success, method_under_test="add_partition", method_args={ "database": "some_db", "account": account(), "partition": partition(2019, 8), }, timeout_seconds=120, raise_on_failure=exceptions.AddPartitionException, )
def test_add_partition(self) -> None: assert_query_run( test=self, method_under_test="add_partition", method_args={ "database": "some_database", "account": account("908173625490", "some_account"), "partition": partition(2020, 7), }, query=queries.ADD_PARTITION_YEAR_MONTH, raise_on_failure=exception.AddPartitionException, )
class TestAwsAthenaCleanerTask(AwsScannerTestCase): database_mappings = { "db_1": ["table_1", "table_2", "table_3"], "some_prefix_db_2": ["table_1", "table_2"], "db_3": ["table_1"], "some_prefix_db_4": ["table_1", "table_2", "table_3"], "some_prefix_db_5": [], } expected_report = task_report( account=account("555666777888", "cloudtrail"), description="clean scanner leftovers", partition=None, results={ "dropped_tables": [ "some_prefix_db_2.table_1", "some_prefix_db_2.table_2", "some_prefix_db_4.table_1", "some_prefix_db_4.table_2", "some_prefix_db_4.table_3", ], "dropped_databases": ["some_prefix_db_2", "some_prefix_db_4", "some_prefix_db_5"], }, ) def test_clean_task_databases(self) -> None: mock_athena = Mock( list_databases=Mock( return_value=list(self.database_mappings.keys())), list_tables=Mock( side_effect=lambda db: self.database_mappings.get(db)), ) self.assertEqual(self.expected_report, AwsAthenaCleanerTask().run(mock_athena)) mock_athena.assert_has_calls([ call.list_databases(), call.list_tables("some_prefix_db_2"), call.drop_table("some_prefix_db_2", "table_1"), call.drop_table("some_prefix_db_2", "table_2"), call.list_tables("some_prefix_db_4"), call.drop_table("some_prefix_db_4", "table_1"), call.drop_table("some_prefix_db_4", "table_2"), call.drop_table("some_prefix_db_4", "table_3"), call.list_tables("some_prefix_db_5"), call.drop_database("some_prefix_db_2"), call.drop_database("some_prefix_db_4"), call.drop_database("some_prefix_db_5"), ])
def test_run_task(self) -> None: parameters = [ secure_string_parameter("secure_1"), string_list_parameter("list"), string_parameter("string_1"), secure_string_parameter("secure_2"), string_parameter("string_2"), string_parameter("string_3"), ] ssm_client = Mock(list_parameters=Mock(return_value=parameters)) task_report = AwsListSSMParametersTask(account())._run_task(ssm_client) self.assertEqual( { "ssm_parameters": parameters, "type_count": { "SecureString": 2, "StringList": 1, "String": 3 }, "total_count": 6, }, task_report, )
def test_find_account_by_ids(self) -> None: with patch( "src.clients.aws_organizations_client.AwsOrganizationsClient.find_account_by_id", side_effect=lambda acc_id: { "3": account("3", "account 3"), "8": account("8", "account 8"), "2": None, "5": account("5", "account 5"), }.get(acc_id), ): self.assertEqual( [ account("8", "account 8"), account("5", "account 5"), account("3", "account 3") ], AwsOrganizationsClient(Mock()).find_account_by_ids( ["8", "2", "5", "3"]), )
def test_teardown_do_nothing(self) -> None: task = AwsCreateAthenaTableTask(account(), partition()) mock_athena = Mock() task._teardown(mock_athena) mock_athena.assert_not_called()
def __build_task_under_test(task_type, task_args): task = task_type(account=account(), partition=partition(), **task_args) task._database = "some_db" return task
def test_audit_s3_tasks(self) -> None: mock_orgs = Mock(find_account_by_ids=Mock(return_value=[(account("3", "three")), (account("5", "five"))])) tasks = AwsTaskBuilder(mock_orgs, ["5", "3"]).audit_s3_tasks() self.assert_tasks_equal(AwsAuditS3Task(account("3", "three")), tasks[0]) self.assert_tasks_equal(AwsAuditS3Task(account("5", "five")), tasks[1])
def test_run_s3_task(self) -> None: client = Mock() client_factory = Mock(get_s3_client=Mock(side_effect=lambda acc: client if acc == account() else None)) task = s3_task() task.run = Mock(side_effect=lambda c: task_report() if c == client else None) # type: ignore self.assertEqual(task_report(), AwsTaskRunner(client_factory)._run_task(task))
def test_run_task(self) -> None: accounts = [account(), account()] mock_orgs_client = Mock(get_all_accounts=Mock(return_value=accounts)) self.assertEqual({"accounts": accounts}, AwsListAccountsTask()._run_task(mock_orgs_client))
def test_list_ssm_parameters_tasks(self) -> None: mock_orgs = Mock(find_account_by_ids=Mock(return_value=[(account("2", "two")), (account("4", "four"))])) tasks = AwsTaskBuilder(mock_orgs, ["4", "2"]).list_ssm_parameters_tasks() self.assert_tasks_equal(AwsListSSMParametersTask(account("2", "two")), tasks[0]) self.assert_tasks_equal(AwsListSSMParametersTask(account("4", "four")), tasks[1])
def test_get_account(self) -> None: task = aws_task() self.assertEqual(task.account, account())
"Name": "Root 1 > Org Unit 2 > Org Unit 2 > Account 1", }, { "Id": "242243167582", "Name": "Root 1 > Org Unit 2 > Org Unit 2 > Account 2", }, ] } EMPTY_ACCOUNTS = {"Accounts": []} EXPECTED_ORGANIZATION_TREE = [ organizational_unit( identifier="r-root1", name="Root 1", root=True, accounts=[ account(identifier="987654655432", name="Root 1 > Account 1") ], org_units=[ organizational_unit( identifier="ou-root1-1", name="Root 1 > Org Unit 1", accounts=[ account(identifier="987651321565", name="Root 1 > Org Unit 1 > Account 1"), account(identifier="643758194672", name="Root 1 > Org Unit 1 > Account 2"), account(identifier="594678488453", name="Root 1 > Org Unit 1 > Account 3"), ], org_units=[], ),