def test(self):
        account_id = "012345678901"
        errors = ["foo", "boo"]
        unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
            account_id=account_id, errors=errors)
        resource = ResourceSpec.merge_resources("foo",
                                                [unscanned_account_resource])

        self.assertEqual(resource.resource_id, "foo")
        self.assertEqual(resource.type, "aws:unscanned-account")
        self.assertEqual(len(resource.link_collection.simple_links), 2)
        self.assertEqual(
            resource.link_collection.simple_links[0],
            SimpleLink(pred="account_id", obj="012345678901"),
        )
        self.assertEqual(resource.link_collection.simple_links[1].pred,
                         "error")
        self.assertTrue(
            resource.link_collection.simple_links[1].obj.startswith(
                "foo\nboo - "))
Example #2
0
    def test(self):
        account_id = "012345678901"
        errors = ["foo", "boo"]
        unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
            account_id=account_id, errors=errors)
        resource = ResourceSpec.merge_resources("foo",
                                                [unscanned_account_resource])

        resource_dict = resource.to_dict()
        self.assertEqual(resource_dict["type"], "aws:unscanned-account")
        self.assertEqual(len(resource_dict["links"]), 2)
        self.assertEqual(resource_dict["links"][0], {
            'pred': 'account_id',
            'obj': '012345678901',
            'type': 'simple'
        })
        self.assertEqual(resource_dict["links"][1]["pred"], "error")
        self.assertEqual(resource_dict["links"][1]["type"], "simple")
        self.assertTrue(
            resource_dict["links"][1]["obj"].startswith("foo\nboo - "))
Example #3
0
    def scan(self) -> Dict[str, Any]:
        """Scan an account and return a dict containing keys:

            * account_id: str
            * output_artifact: str
            * api_call_stats: Dict[str, Any]
            * errors: List[str]

        If errors is non-empty the results are incomplete for this account.
        output_artifact is a pointer to the actual scan data - either on the local fs or in s3.

        To scan an account we create a set of GraphSpecs, one for each region.  Any ACCOUNT
        level granularity resources are only scanned in a single region (e.g. IAM Users)

        Returns:
            Dict of scan result, see above for details.
        """
        logger = Logger()
        with logger.bind(account_id=self.account_id):
            logger.info(event=AWSLogEvents.ScanAWSAccountStart)
            output_artifact = None
            stats = MultilevelCounter()
            errors: List[str] = []
            now = int(time.time())
            try:
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[],
                    errors=[],
                    stats=stats,
                )
                # sanity check
                session = self.get_session()
                sts_client = session.client("sts")
                sts_account_id = sts_client.get_caller_identity()["Account"]
                if sts_account_id != self.account_id:
                    raise ValueError(
                        f"BUG: sts detected account_id {sts_account_id} != {self.account_id}"
                    )
                if self.regions:
                    scan_regions = tuple(self.regions)
                else:
                    scan_regions = get_all_enabled_regions(session=session)
                # build graph specs.
                # build a dict of regions -> services -> List[AWSResourceSpec]
                regions_services_resource_spec_classes: DefaultDict[
                    str,
                    DefaultDict[str,
                                List[Type[AWSResourceSpec]]]] = defaultdict(
                                    lambda: defaultdict(list))
                resource_spec_class: Type[AWSResourceSpec]
                for resource_spec_class in self.resource_spec_classes:
                    client_name = resource_spec_class.get_client_name()
                    resource_class_scan_granularity = resource_spec_class.scan_granularity
                    if resource_class_scan_granularity == ScanGranularity.ACCOUNT:
                        regions_services_resource_spec_classes[scan_regions[
                            0]][client_name].append(resource_spec_class)
                    elif resource_class_scan_granularity == ScanGranularity.REGION:
                        for region in scan_regions:
                            regions_services_resource_spec_classes[region][
                                client_name].append(resource_spec_class)
                    else:
                        raise NotImplementedError(
                            f"ScanGranularity {resource_class_scan_granularity} not implemented"
                        )
                with ThreadPoolExecutor(
                        max_workers=self.max_svc_threads) as executor:
                    futures = []
                    for (
                            region,
                            services_resource_spec_classes,
                    ) in regions_services_resource_spec_classes.items():
                        for (
                                service,
                                resource_spec_classes,
                        ) in services_resource_spec_classes.items():
                            region_session = self.get_session(region=region)
                            region_creds = region_session.get_credentials()
                            scan_future = schedule_scan_services(
                                executor=executor,
                                graph_name=self.graph_name,
                                graph_version=self.graph_version,
                                account_id=self.account_id,
                                region=region,
                                service=service,
                                access_key=region_creds.access_key,
                                secret_key=region_creds.secret_key,
                                token=region_creds.token,
                                resource_spec_classes=tuple(
                                    resource_spec_classes),
                            )
                            futures.append(scan_future)
                    for future in as_completed(futures):
                        graph_set_dict = future.result()
                        graph_set = GraphSet.from_dict(graph_set_dict)
                        errors += graph_set.errors
                        account_graph_set.merge(graph_set)
                    account_graph_set.validate()
            except Exception as ex:
                error_str = str(ex)
                trace_back = traceback.format_exc()
                logger.error(event=AWSLogEvents.ScanAWSAccountError,
                             error=error_str,
                             trace_back=trace_back)
                errors.append(" : ".join((error_str, trace_back)))
                unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                    account_id=self.account_id, errors=errors)
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[unscanned_account_resource],
                    errors=errors,
                    stats=stats,
                )
                account_graph_set.validate()
        output_artifact = self.artifact_writer.write_artifact(
            name=self.account_id, data=account_graph_set.to_dict())
        logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
        api_call_stats = account_graph_set.stats.to_dict()
        return {
            "account_id": self.account_id,
            "output_artifact": output_artifact,
            "errors": errors,
            "api_call_stats": api_call_stats,
        }
Example #4
0
 def scan(self) -> List[Dict[str, Any]]:
     logger = Logger()
     scan_result_dicts = []
     now = int(time.time())
     prescan_account_ids_errors: DefaultDict[str,
                                             List[str]] = defaultdict(list)
     futures = []
     with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
         shuffled_account_ids = random.sample(
             self.account_scan_plan.account_ids,
             k=len(self.account_scan_plan.account_ids))
         for account_id in shuffled_account_ids:
             with logger.bind(account_id=account_id):
                 logger.info(event=AWSLogEvents.ScanAWSAccountStart)
                 try:
                     session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id)
                     # sanity check
                     sts_client = session.client("sts")
                     sts_account_id = sts_client.get_caller_identity(
                     )["Account"]
                     if sts_account_id != account_id:
                         raise ValueError(
                             f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                         )
                     if self.account_scan_plan.regions:
                         scan_regions = tuple(
                             self.account_scan_plan.regions)
                     else:
                         scan_regions = get_all_enabled_regions(
                             session=session)
                     account_gran_scan_region = random.choice(
                         self.preferred_account_scan_regions)
                     # build a dict of regions -> services -> List[AWSResourceSpec]
                     regions_services_resource_spec_classes: DefaultDict[
                         str, DefaultDict[
                             str,
                             List[Type[AWSResourceSpec]]]] = defaultdict(
                                 lambda: defaultdict(list))
                     resource_spec_class: Type[AWSResourceSpec]
                     for resource_spec_class in self.resource_spec_classes:
                         client_name = resource_spec_class.get_client_name()
                         if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT:
                             if resource_spec_class.region_whitelist:
                                 account_resource_scan_region = resource_spec_class.region_whitelist[
                                     0]
                             else:
                                 account_resource_scan_region = account_gran_scan_region
                             regions_services_resource_spec_classes[
                                 account_resource_scan_region][
                                     client_name].append(
                                         resource_spec_class)
                         elif resource_spec_class.scan_granularity == ScanGranularity.REGION:
                             if resource_spec_class.region_whitelist:
                                 resource_scan_regions = tuple(
                                     region for region in scan_regions
                                     if region in
                                     resource_spec_class.region_whitelist)
                                 if not resource_scan_regions:
                                     resource_scan_regions = resource_spec_class.region_whitelist
                             else:
                                 resource_scan_regions = scan_regions
                             for region in resource_scan_regions:
                                 regions_services_resource_spec_classes[
                                     region][client_name].append(
                                         resource_spec_class)
                         else:
                             raise NotImplementedError(
                                 f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented"
                             )
                     # Build and submit ScanUnits
                     shuffed_regions_services_resource_spec_classes = random.sample(
                         regions_services_resource_spec_classes.items(),
                         len(regions_services_resource_spec_classes),
                     )
                     for (
                             region,
                             services_resource_spec_classes,
                     ) in shuffed_regions_services_resource_spec_classes:
                         region_session = self.account_scan_plan.accessor.get_session(
                             account_id=account_id, region_name=region)
                         region_creds = region_session.get_credentials()
                         shuffled_services_resource_spec_classes = random.sample(
                             services_resource_spec_classes.items(),
                             len(services_resource_spec_classes),
                         )
                         for (
                                 service,
                                 svc_resource_spec_classes,
                         ) in shuffled_services_resource_spec_classes:
                             future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=tuple(
                                     svc_resource_spec_classes),
                             )
                             futures.append(future)
                 except Exception as ex:
                     error_str = str(ex)
                     trace_back = traceback.format_exc()
                     logger.error(
                         event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back,
                     )
                     prescan_account_ids_errors[account_id].append(
                         f"{error_str}\n{trace_back}")
     account_ids_graph_set_dicts: Dict[str,
                                       List[Dict[str,
                                                 Any]]] = defaultdict(list)
     for future in as_completed(futures):
         account_id, graph_set_dict = future.result()
         account_ids_graph_set_dicts[account_id].append(graph_set_dict)
     # first make sure no account id appears both in account_ids_graph_set_dicts
     # and prescan_account_ids_errors - this should never happen
     doubled_accounts = set(
         account_ids_graph_set_dicts.keys()).intersection(
             set(prescan_account_ids_errors.keys()))
     if doubled_accounts:
         raise Exception((
             f"BUG: Account(s) {doubled_accounts} in both "
             "account_ids_graph_set_dicts and prescan_account_ids_errors."))
     # graph prescan error accounts
     for account_id, errors in prescan_account_ids_errors.items():
         with logger.bind(account_id=account_id):
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
                 stats=MultilevelCounter(),
             )
             account_graph_set.validate()
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     # graph rest
     for account_id, graph_set_dicts in account_ids_graph_set_dicts.items():
         with logger.bind(account_id=account_id):
             # if there are any errors whatsoever we generate an empty graph with
             # errors only
             errors = []
             for graph_set_dict in graph_set_dicts:
                 errors += graph_set_dict["errors"]
             if errors:
                 unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                     account_id=account_id, errors=errors)
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[unscanned_account_resource],
                     errors=errors,
                     stats=MultilevelCounter(
                     ),  # ENHANCHMENT: could technically get partial stats.
                 )
                 account_graph_set.validate()
             else:
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[],
                     errors=[],
                     stats=MultilevelCounter(),
                 )
                 for graph_set_dict in graph_set_dicts:
                     graph_set = GraphSet.from_dict(graph_set_dict)
                     account_graph_set.merge(graph_set)
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     return scan_result_dicts
Example #5
0
 def scan(self) -> AccountScanResult:
     logger = Logger()
     now = int(time.time())
     prescan_errors: List[str] = []
     futures: List[Future] = []
     account_id = self.account_scan_plan.account_id
     with logger.bind(account_id=account_id):
         with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
             logger.info(event=AWSLogEvents.ScanAWSAccountStart)
             try:
                 session = self.account_scan_plan.accessor.get_session(
                     account_id=account_id)
                 # sanity check
                 sts_client = session.client("sts")
                 sts_account_id = sts_client.get_caller_identity(
                 )["Account"]
                 if sts_account_id != account_id:
                     raise ValueError(
                         f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                     )
                 if self.account_scan_plan.regions:
                     account_scan_regions = tuple(
                         self.account_scan_plan.regions)
                 else:
                     account_scan_regions = get_all_enabled_regions(
                         session=session)
                 # build a dict of regions -> services -> List[AWSResourceSpec]
                 regions_services_resource_spec_classes: DefaultDict[
                     str, DefaultDict[
                         str, List[Type[AWSResourceSpec]]]] = defaultdict(
                             lambda: defaultdict(list))
                 for resource_spec_class in self.resource_spec_classes:
                     resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions(
                         resource_spec_class=resource_spec_class,
                         region_whitelist=account_scan_regions,
                     )
                     for region in resource_regions:
                         regions_services_resource_spec_classes[region][
                             resource_spec_class.service_name].append(
                                 resource_spec_class)
                 # Build and submit ScanUnits
                 shuffed_regions_services_resource_spec_classes = random.sample(
                     regions_services_resource_spec_classes.items(),
                     len(regions_services_resource_spec_classes),
                 )
                 for (
                         region,
                         services_resource_spec_classes,
                 ) in shuffed_regions_services_resource_spec_classes:
                     region_session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id, region_name=region)
                     region_creds = region_session.get_credentials()
                     shuffled_services_resource_spec_classes = random.sample(
                         services_resource_spec_classes.items(),
                         len(services_resource_spec_classes),
                     )
                     for (
                             service,
                             svc_resource_spec_classes,
                     ) in shuffled_services_resource_spec_classes:
                         parallel_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if svc_resource_spec_class.parallel_scan
                         ]
                         serial_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if not svc_resource_spec_class.parallel_scan
                         ]
                         for (parallel_svc_resource_spec_class
                              ) in parallel_svc_resource_spec_classes:
                             parallel_future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=(
                                     parallel_svc_resource_spec_class, ),
                             )
                             futures.append(parallel_future)
                         serial_future = schedule_scan(
                             executor=executor,
                             graph_name=self.graph_name,
                             graph_version=self.graph_version,
                             account_id=account_id,
                             region_name=region,
                             service=service,
                             access_key=region_creds.access_key,
                             secret_key=region_creds.secret_key,
                             token=region_creds.token,
                             resource_spec_classes=tuple(
                                 serial_svc_resource_spec_classes),
                         )
                         futures.append(serial_future)
             except Exception as ex:
                 error_str = str(ex)
                 trace_back = traceback.format_exc()
                 logger.error(
                     event=AWSLogEvents.ScanAWSAccountError,
                     error=error_str,
                     trace_back=trace_back,
                 )
                 prescan_errors.append(f"{error_str}\n{trace_back}")
         graph_sets: List[GraphSet] = []
         for future in as_completed(futures):
             graph_set = future.result()
             graph_sets.append(graph_set)
         # if there was a prescan error graph it and return the result
         if prescan_errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=prescan_errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=prescan_errors,
             )
             output_artifact = self.artifact_writer.write_json(
                 name=account_id,
                 data=account_graph_set,
             )
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             return AccountScanResult(
                 account_id=account_id,
                 artifacts=[output_artifact],
                 errors=prescan_errors,
             )
         # if there are any errors whatsoever we generate an empty graph with errors only
         errors = []
         for graph_set in graph_sets:
             errors += graph_set.errors
         if errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
             )
         else:
             account_graph_set = GraphSet.from_graph_sets(graph_sets)
         output_artifact = self.artifact_writer.write_json(
             name=account_id,
             data=account_graph_set,
         )
         logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
         return AccountScanResult(
             account_id=account_id,
             artifacts=[output_artifact],
             errors=errors,
         )