def on_request_created( api_call_stats: MultilevelCounter, account_id: str, region_name: str, service_name: str, readonly: bool, **kwargs: Any, ) -> None: """Called when a boto3 request is created. This handles api call statistics tracking. Args: api_call_stats: MultilevelCounter to increment account_id: request account id region_name: request region service_name: request service readonly: if True only allow readonly calls kwargs: kwargs which are passed through by the boto event callback. """ _, _, operation_name = kwargs["event_name"].split(".") if readonly: if not _PERMITTED_OPERATION_NAMES_RE.search(kwargs["operation_name"]): raise Exception( f"Operation name {operation_name} did not match {_PERMITTED_OPERATION_NAMES_STR}" ) api_call_stats.increment(account_id, region_name, service_name, operation_name)
def scan(self) -> GraphSet: """Perform a scan on all of the resource classes in this GraphSpec and return a GraphSet containing the scanned data. Returns: GraphSet representing results of scanning this GraphSpec's resource_spec_classes. """ resources: List[Resource] = [] errors: List[str] = [] stats = MultilevelCounter() start_time = int(time.time()) logger = Logger() for resource_spec_class in self.resource_spec_classes: with logger.bind(resource_type=str(resource_spec_class.type_name)): logger.debug(event=LogEvent.ScanResourceTypeStart) resource_scan_result = resource_spec_class.scan( scan_accessor=self.scan_accessor) resources += resource_scan_result.resources errors += resource_scan_result.errors stats.merge(resource_scan_result.stats) logger.debug(event=LogEvent.ScanResourceTypeEnd) end_time = int(time.time()) return GraphSet( name=self.name, version=self.version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=stats, )
def test_increment(self): ml_counter = MultilevelCounter() ml_counter.increment("foo", "boo", "goo") expected_data = { "count": 1, "foo": { "count": 1, "boo": { "count": 1, "goo": { "count": 1 } } } } self.assertDictEqual(expected_data, ml_counter.to_dict())
def from_dict(cls: Type["GraphSet"], data: Dict[str, Any]) -> "GraphSet": """Create a GraphSet from a dict. Args: data: dict of Resource data Returns: GraphSet object """ resources: List[Resource] = [] name = data["name"] start_time = data["start_time"] end_time = data["end_time"] version = data["version"] errors = data["errors"] stats = MultilevelCounter.from_dict(data["stats"]) for resource_id, resource_data in data["resources"].items(): resource = Resource.from_dict(resource_id, resource_data) resources.append(resource) return cls( name=name, version=version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=stats, )
def scan(cls: Type["TestResourceSpecB"], scan_accessor: Any) -> ResourceScanResult: resources = [ Resource(resource_id="abc", type_name=cls.type_name, links=[]), Resource(resource_id="def", type_name=cls.type_name, links=[]), ] return ResourceScanResult(resources=resources, stats=MultilevelCounter(), errors=[])
def __init__(self, session: boto3.Session, account_id: str, region_name: str, readonly: bool = True): self.session = session self.account_id = account_id self.region = region_name self.api_call_stats = MultilevelCounter() self.client_cache: Dict[str, Any] = {} self.readonly = readonly
def test_invalid_diff_versions(self): graph_set_1 = GraphSet( name="graph-1", version="1", start_time=10, end_time=20, resources=[], errors=[], stats=MultilevelCounter(), ) graph_set_2 = GraphSet( name="graph-1", version="2", start_time=15, end_time=25, resources=[], errors=[], stats=MultilevelCounter(), ) with self.assertRaises(UnmergableGraphSetsException): graph_set_1.merge(graph_set_2)
def test_unknown_type_name(self): resources = [ Resource(resource_id="xyz", type_name="test:a"), Resource(resource_id="xyz", type_name="test:c"), ] with self.assertRaises(ResourceSpecClassNotFoundException): GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=[], stats=MultilevelCounter(), )
def test_invalid_resources_dupes_same_class_conflicting_types_no_allow_clobber( self): resources = [ Resource(resource_id="123", type_name="test:a"), Resource(resource_id="123", type_name="test:b"), ] with self.assertRaises(UnmergableDuplicateResourceIdsFoundException): GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=[], stats=MultilevelCounter(), )
def test_orphaned_ref(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="456")]) resources = [resource_a1, resource_b1] graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], stats=MultilevelCounter(), ) with self.assertRaises(GraphSetOrphanedReferencesException): graph_set.validate()
def setUp(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_a2 = Resource(resource_id="456", type_name="test:a") resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="123")]) resource_b2 = Resource(resource_id="def", type_name="test:b", links=[SimpleLink(pred="name", obj="sue")]) resources = [resource_a1, resource_a2, resource_b1, resource_b2] self.graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], stats=MultilevelCounter(), )
def test_from_dict(self): data = { "count": 2, "foo": { "count": 1, "boo": { "count": 1, "goo": { "count": 1 } } }, "boo": { "count": 1, "goo": { "count": 1, "moo": { "count": 1 } } }, } ml_counter = MultilevelCounter.from_dict(data) self.assertDictEqual(ml_counter.to_dict(), data)
def test_valid_merge(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_a2 = Resource(resource_id="456", type_name="test:a") resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="123")]) resource_b2 = Resource(resource_id="def", type_name="test:b", links=[SimpleLink(pred="name", obj="sue")]) graph_set_1 = GraphSet( name="graph-1", version="1", start_time=10, end_time=20, resources=[resource_a1, resource_a2], errors=["errora1", "errora2"], stats=MultilevelCounter(), ) graph_set_2 = GraphSet( name="graph-1", version="1", start_time=15, end_time=25, resources=[resource_b1, resource_b2], errors=["errorb1", "errorb2"], stats=MultilevelCounter(), ) graph_set_1.merge(graph_set_2) self.assertEqual(graph_set_1.name, "graph-1") self.assertEqual(graph_set_1.version, "1") self.assertEqual(graph_set_1.start_time, 10) self.assertEqual(graph_set_1.end_time, 25) self.assertCountEqual(graph_set_1.errors, ["errora1", "errora2", "errorb1", "errorb2"]) expected_resource_dicts = [ { "type": "test:a", "links": [{ "pred": "has-foo", "obj": "goo", "type": "simple" }] }, { "type": "test:a" }, { "type": "test:b", "links": [{ "pred": "has-a", "obj": "123", "type": "resource_link" }] }, { "type": "test:b", "links": [{ "pred": "name", "obj": "sue", "type": "simple" }] }, ] resource_dicts = [ resource.to_dict() for resource in graph_set_1.resources ] self.assertCountEqual(expected_resource_dicts, resource_dicts)
def scan(self) -> Dict[str, Any]: """Scan an account and return a dict containing keys: * account_id: str * output_artifact: str * api_call_stats: Dict[str, Any] * errors: List[str] If errors is non-empty the results are incomplete for this account. output_artifact is a pointer to the actual scan data - either on the local fs or in s3. To scan an account we create a set of GraphSpecs, one for each region. Any ACCOUNT level granularity resources are only scanned in a single region (e.g. IAM Users) Returns: Dict of scan result, see above for details. """ logger = Logger() with logger.bind(account_id=self.account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) output_artifact = None stats = MultilevelCounter() errors: List[str] = [] now = int(time.time()) try: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=stats, ) # sanity check session = self.get_session() sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity()["Account"] if sts_account_id != self.account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {self.account_id}" ) if self.regions: scan_regions = tuple(self.regions) else: scan_regions = get_all_enabled_regions(session=session) # build graph specs. # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() resource_class_scan_granularity = resource_spec_class.scan_granularity if resource_class_scan_granularity == ScanGranularity.ACCOUNT: regions_services_resource_spec_classes[scan_regions[ 0]][client_name].append(resource_spec_class) elif resource_class_scan_granularity == ScanGranularity.REGION: for region in scan_regions: regions_services_resource_spec_classes[region][ client_name].append(resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_class_scan_granularity} not implemented" ) with ThreadPoolExecutor( max_workers=self.max_svc_threads) as executor: futures = [] for ( region, services_resource_spec_classes, ) in regions_services_resource_spec_classes.items(): for ( service, resource_spec_classes, ) in services_resource_spec_classes.items(): region_session = self.get_session(region=region) region_creds = region_session.get_credentials() scan_future = schedule_scan_services( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=self.account_id, region=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( resource_spec_classes), ) futures.append(scan_future) for future in as_completed(futures): graph_set_dict = future.result() graph_set = GraphSet.from_dict(graph_set_dict) errors += graph_set.errors account_graph_set.merge(graph_set) account_graph_set.validate() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) errors.append(" : ".join((error_str, trace_back))) unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=self.account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=stats, ) account_graph_set.validate() output_artifact = self.artifact_writer.write_artifact( name=self.account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() return { "account_id": self.account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }
def scan(self) -> List[Dict[str, Any]]: logger = Logger() scan_result_dicts = [] now = int(time.time()) prescan_account_ids_errors: DefaultDict[str, List[str]] = defaultdict(list) futures = [] with ThreadPoolExecutor(max_workers=self.max_threads) as executor: shuffled_account_ids = random.sample( self.account_scan_plan.account_ids, k=len(self.account_scan_plan.account_ids)) for account_id in shuffled_account_ids: with logger.bind(account_id=account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: scan_regions = tuple( self.account_scan_plan.regions) else: scan_regions = get_all_enabled_regions( session=session) account_gran_scan_region = random.choice( self.preferred_account_scan_regions) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT: if resource_spec_class.region_whitelist: account_resource_scan_region = resource_spec_class.region_whitelist[ 0] else: account_resource_scan_region = account_gran_scan_region regions_services_resource_spec_classes[ account_resource_scan_region][ client_name].append( resource_spec_class) elif resource_spec_class.scan_granularity == ScanGranularity.REGION: if resource_spec_class.region_whitelist: resource_scan_regions = tuple( region for region in scan_regions if region in resource_spec_class.region_whitelist) if not resource_scan_regions: resource_scan_regions = resource_spec_class.region_whitelist else: resource_scan_regions = scan_regions for region in resource_scan_regions: regions_services_resource_spec_classes[ region][client_name].append( resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented" ) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( svc_resource_spec_classes), ) futures.append(future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_account_ids_errors[account_id].append( f"{error_str}\n{trace_back}") account_ids_graph_set_dicts: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for future in as_completed(futures): account_id, graph_set_dict = future.result() account_ids_graph_set_dicts[account_id].append(graph_set_dict) # first make sure no account id appears both in account_ids_graph_set_dicts # and prescan_account_ids_errors - this should never happen doubled_accounts = set( account_ids_graph_set_dicts.keys()).intersection( set(prescan_account_ids_errors.keys())) if doubled_accounts: raise Exception(( f"BUG: Account(s) {doubled_accounts} in both " "account_ids_graph_set_dicts and prescan_account_ids_errors.")) # graph prescan error accounts for account_id, errors in prescan_account_ids_errors.items(): with logger.bind(account_id=account_id): unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter(), ) account_graph_set.validate() output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) # graph rest for account_id, graph_set_dicts in account_ids_graph_set_dicts.items(): with logger.bind(account_id=account_id): # if there are any errors whatsoever we generate an empty graph with # errors only errors = [] for graph_set_dict in graph_set_dicts: errors += graph_set_dict["errors"] if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter( ), # ENHANCHMENT: could technically get partial stats. ) account_graph_set.validate() else: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=MultilevelCounter(), ) for graph_set_dict in graph_set_dicts: graph_set = GraphSet.from_dict(graph_set_dict) account_graph_set.merge(graph_set) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) return scan_result_dicts
def scan(cls, scan_accessor: Any) -> ResourceScanResult: return ResourceScanResult(resources=[], stats=MultilevelCounter(), errors=[])
def scan( muxer: AWSScanMuxer, account_ids: List[str], regions: List[str], accessor: Accessor, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, scan_sub_accounts: bool = False, ) -> ScanManifest: if scan_sub_accounts: account_ids = get_sub_account_ids(account_ids, accessor) account_scan_plans = build_account_scan_plans(accessor=accessor, account_ids=account_ids, regions=regions) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) account_scan_manifests = muxer.scan(account_scan_plans=account_scan_plans) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[Dict[str, str]] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: List[Dict[str, str]] = [] stats = MultilevelCounter() for account_scan_manifest in account_scan_manifests: account_id = account_scan_manifest.account_id if account_scan_manifest.artifacts: artifacts += account_scan_manifest.artifacts if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.append(account_id) else: scanned_accounts.append(account_id) else: unscanned_accounts.append(account_id) account_stats = MultilevelCounter.from_dict( account_scan_manifest.api_call_stats) stats.merge(account_stats) graph_set = None for artifact_path in artifacts: artifact_dict = artifact_reader.read_artifact(artifact_path) artifact_graph_set = GraphSet.from_dict(artifact_dict) if graph_set is None: graph_set = artifact_graph_set else: graph_set.merge(artifact_graph_set) master_artifact_path = None if graph_set: master_artifact_path = artifact_writer.write_artifact( name="master", data=graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) if graph_set: start_time = graph_set.start_time end_time = graph_set.end_time else: start_time, end_time = None, None return ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=unscanned_accounts, api_call_stats=stats.to_dict(), start_time=start_time, end_time=end_time, )
def run_scan( muxer: AWSScanMuxer, config: Config, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, ) -> Tuple[ScanManifest, GraphSet]: if config.scan.scan_sub_accounts: account_ids = get_sub_account_ids(config.scan.accounts, config.access.accessor) else: account_ids = config.scan.accounts account_scan_plan = AccountScanPlan(account_ids=account_ids, regions=config.scan.regions, accessor=config.access.accessor) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[str] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: List[str] = [] stats = MultilevelCounter() graph_set = None for account_scan_manifest in muxer.scan( account_scan_plan=account_scan_plan): account_id = account_scan_manifest.account_id if account_scan_manifest.artifacts: for account_scan_artifact in account_scan_manifest.artifacts: artifacts.append(account_scan_artifact) artifact_graph_set_dict = artifact_reader.read_json( account_scan_artifact) artifact_graph_set = GraphSet.from_dict( artifact_graph_set_dict) if graph_set is None: graph_set = artifact_graph_set else: graph_set.merge(artifact_graph_set) if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.append(account_id) else: scanned_accounts.append(account_id) else: unscanned_accounts.append(account_id) account_stats = MultilevelCounter.from_dict( account_scan_manifest.api_call_stats) stats.merge(account_stats) if graph_set is None: raise Exception("BUG: No graph_set generated.") master_artifact_path = artifact_writer.write_json(name="master", data=graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) start_time = graph_set.start_time end_time = graph_set.end_time scan_manifest = ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=unscanned_accounts, api_call_stats=stats.to_dict(), start_time=start_time, end_time=end_time, ) artifact_writer.write_json("manifest", data=scan_manifest.to_dict()) return scan_manifest, graph_set
class TestScanAccessor: api_call_stats: MultilevelCounter = MultilevelCounter()
def test_merge_does_not_update_other(self): ml_counter_self = MultilevelCounter() ml_counter_self.increment("foo", "boo", "goo") ml_counter_other = MultilevelCounter() ml_counter_other.increment("boo", "goo", "moo") ml_counter_self.merge(ml_counter_other) expected_data = { "count": 1, "boo": { "count": 1, "goo": { "count": 1, "moo": { "count": 1 } } } } self.assertDictEqual(expected_data, ml_counter_other.to_dict())