Exemplo n.º 1
0
def on_request_created(
    api_call_stats: MultilevelCounter,
    account_id: str,
    region_name: str,
    service_name: str,
    readonly: bool,
    **kwargs: Any,
) -> None:
    """Called when a boto3 request is created. This handles api call statistics tracking.

    Args:
        api_call_stats: MultilevelCounter to increment
        account_id: request account id
        region_name: request region
        service_name: request service
        readonly: if True only allow readonly calls
        kwargs: kwargs which are passed through by the boto event callback.
    """
    _, _, operation_name = kwargs["event_name"].split(".")
    if readonly:
        if not _PERMITTED_OPERATION_NAMES_RE.search(kwargs["operation_name"]):
            raise Exception(
                f"Operation name {operation_name} did not match {_PERMITTED_OPERATION_NAMES_STR}"
            )
    api_call_stats.increment(account_id, region_name, service_name,
                             operation_name)
Exemplo n.º 2
0
    def scan(self) -> GraphSet:
        """Perform a scan on all of the resource classes in this GraphSpec and return
        a GraphSet containing the scanned data.

        Returns:
            GraphSet representing results of scanning this GraphSpec's resource_spec_classes.
        """

        resources: List[Resource] = []
        errors: List[str] = []
        stats = MultilevelCounter()
        start_time = int(time.time())
        logger = Logger()
        for resource_spec_class in self.resource_spec_classes:
            with logger.bind(resource_type=str(resource_spec_class.type_name)):
                logger.debug(event=LogEvent.ScanResourceTypeStart)
                resource_scan_result = resource_spec_class.scan(
                    scan_accessor=self.scan_accessor)
                resources += resource_scan_result.resources
                errors += resource_scan_result.errors
                stats.merge(resource_scan_result.stats)
                logger.debug(event=LogEvent.ScanResourceTypeEnd)
        end_time = int(time.time())
        return GraphSet(
            name=self.name,
            version=self.version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=stats,
        )
Exemplo n.º 3
0
 def test_increment(self):
     ml_counter = MultilevelCounter()
     ml_counter.increment("foo", "boo", "goo")
     expected_data = {
         "count": 1,
         "foo": {
             "count": 1,
             "boo": {
                 "count": 1,
                 "goo": {
                     "count": 1
                 }
             }
         }
     }
     self.assertDictEqual(expected_data, ml_counter.to_dict())
Exemplo n.º 4
0
    def from_dict(cls: Type["GraphSet"], data: Dict[str, Any]) -> "GraphSet":
        """Create a GraphSet from a dict.

        Args:
            data: dict of Resource data

        Returns:
            GraphSet object
        """
        resources: List[Resource] = []
        name = data["name"]
        start_time = data["start_time"]
        end_time = data["end_time"]
        version = data["version"]
        errors = data["errors"]
        stats = MultilevelCounter.from_dict(data["stats"])
        for resource_id, resource_data in data["resources"].items():
            resource = Resource.from_dict(resource_id, resource_data)
            resources.append(resource)
        return cls(
            name=name,
            version=version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=stats,
        )
Exemplo n.º 5
0
 def scan(cls: Type["TestResourceSpecB"],
          scan_accessor: Any) -> ResourceScanResult:
     resources = [
         Resource(resource_id="abc", type_name=cls.type_name, links=[]),
         Resource(resource_id="def", type_name=cls.type_name, links=[]),
     ]
     return ResourceScanResult(resources=resources,
                               stats=MultilevelCounter(),
                               errors=[])
Exemplo n.º 6
0
 def __init__(self,
              session: boto3.Session,
              account_id: str,
              region_name: str,
              readonly: bool = True):
     self.session = session
     self.account_id = account_id
     self.region = region_name
     self.api_call_stats = MultilevelCounter()
     self.client_cache: Dict[str, Any] = {}
     self.readonly = readonly
Exemplo n.º 7
0
 def test_invalid_diff_versions(self):
     graph_set_1 = GraphSet(
         name="graph-1",
         version="1",
         start_time=10,
         end_time=20,
         resources=[],
         errors=[],
         stats=MultilevelCounter(),
     )
     graph_set_2 = GraphSet(
         name="graph-1",
         version="2",
         start_time=15,
         end_time=25,
         resources=[],
         errors=[],
         stats=MultilevelCounter(),
     )
     with self.assertRaises(UnmergableGraphSetsException):
         graph_set_1.merge(graph_set_2)
Exemplo n.º 8
0
 def test_unknown_type_name(self):
     resources = [
         Resource(resource_id="xyz", type_name="test:a"),
         Resource(resource_id="xyz", type_name="test:c"),
     ]
     with self.assertRaises(ResourceSpecClassNotFoundException):
         GraphSet(
             name="test-name",
             version="1",
             start_time=1234,
             end_time=4567,
             resources=resources,
             errors=[],
             stats=MultilevelCounter(),
         )
Exemplo n.º 9
0
 def test_invalid_resources_dupes_same_class_conflicting_types_no_allow_clobber(
         self):
     resources = [
         Resource(resource_id="123", type_name="test:a"),
         Resource(resource_id="123", type_name="test:b"),
     ]
     with self.assertRaises(UnmergableDuplicateResourceIdsFoundException):
         GraphSet(
             name="test-name",
             version="1",
             start_time=1234,
             end_time=4567,
             resources=resources,
             errors=[],
             stats=MultilevelCounter(),
         )
Exemplo n.º 10
0
 def test_orphaned_ref(self):
     resource_a1 = Resource(resource_id="123",
                            type_name="test:a",
                            links=[SimpleLink(pred="has-foo", obj="goo")])
     resource_b1 = Resource(
         resource_id="abc",
         type_name="test:b",
         links=[ResourceLinkLink(pred="has-a", obj="456")])
     resources = [resource_a1, resource_b1]
     graph_set = GraphSet(
         name="test-name",
         version="1",
         start_time=1234,
         end_time=4567,
         resources=resources,
         errors=["test err 1", "test err 2"],
         stats=MultilevelCounter(),
     )
     with self.assertRaises(GraphSetOrphanedReferencesException):
         graph_set.validate()
Exemplo n.º 11
0
 def setUp(self):
     resource_a1 = Resource(resource_id="123",
                            type_name="test:a",
                            links=[SimpleLink(pred="has-foo", obj="goo")])
     resource_a2 = Resource(resource_id="456", type_name="test:a")
     resource_b1 = Resource(
         resource_id="abc",
         type_name="test:b",
         links=[ResourceLinkLink(pred="has-a", obj="123")])
     resource_b2 = Resource(resource_id="def",
                            type_name="test:b",
                            links=[SimpleLink(pred="name", obj="sue")])
     resources = [resource_a1, resource_a2, resource_b1, resource_b2]
     self.graph_set = GraphSet(
         name="test-name",
         version="1",
         start_time=1234,
         end_time=4567,
         resources=resources,
         errors=["test err 1", "test err 2"],
         stats=MultilevelCounter(),
     )
Exemplo n.º 12
0
 def test_from_dict(self):
     data = {
         "count": 2,
         "foo": {
             "count": 1,
             "boo": {
                 "count": 1,
                 "goo": {
                     "count": 1
                 }
             }
         },
         "boo": {
             "count": 1,
             "goo": {
                 "count": 1,
                 "moo": {
                     "count": 1
                 }
             }
         },
     }
     ml_counter = MultilevelCounter.from_dict(data)
     self.assertDictEqual(ml_counter.to_dict(), data)
Exemplo n.º 13
0
    def test_valid_merge(self):
        resource_a1 = Resource(resource_id="123",
                               type_name="test:a",
                               links=[SimpleLink(pred="has-foo", obj="goo")])
        resource_a2 = Resource(resource_id="456", type_name="test:a")
        resource_b1 = Resource(
            resource_id="abc",
            type_name="test:b",
            links=[ResourceLinkLink(pred="has-a", obj="123")])
        resource_b2 = Resource(resource_id="def",
                               type_name="test:b",
                               links=[SimpleLink(pred="name", obj="sue")])
        graph_set_1 = GraphSet(
            name="graph-1",
            version="1",
            start_time=10,
            end_time=20,
            resources=[resource_a1, resource_a2],
            errors=["errora1", "errora2"],
            stats=MultilevelCounter(),
        )
        graph_set_2 = GraphSet(
            name="graph-1",
            version="1",
            start_time=15,
            end_time=25,
            resources=[resource_b1, resource_b2],
            errors=["errorb1", "errorb2"],
            stats=MultilevelCounter(),
        )
        graph_set_1.merge(graph_set_2)

        self.assertEqual(graph_set_1.name, "graph-1")
        self.assertEqual(graph_set_1.version, "1")
        self.assertEqual(graph_set_1.start_time, 10)
        self.assertEqual(graph_set_1.end_time, 25)
        self.assertCountEqual(graph_set_1.errors,
                              ["errora1", "errora2", "errorb1", "errorb2"])
        expected_resource_dicts = [
            {
                "type": "test:a",
                "links": [{
                    "pred": "has-foo",
                    "obj": "goo",
                    "type": "simple"
                }]
            },
            {
                "type": "test:a"
            },
            {
                "type": "test:b",
                "links": [{
                    "pred": "has-a",
                    "obj": "123",
                    "type": "resource_link"
                }]
            },
            {
                "type": "test:b",
                "links": [{
                    "pred": "name",
                    "obj": "sue",
                    "type": "simple"
                }]
            },
        ]
        resource_dicts = [
            resource.to_dict() for resource in graph_set_1.resources
        ]
        self.assertCountEqual(expected_resource_dicts, resource_dicts)
Exemplo n.º 14
0
    def scan(self) -> Dict[str, Any]:
        """Scan an account and return a dict containing keys:

            * account_id: str
            * output_artifact: str
            * api_call_stats: Dict[str, Any]
            * errors: List[str]

        If errors is non-empty the results are incomplete for this account.
        output_artifact is a pointer to the actual scan data - either on the local fs or in s3.

        To scan an account we create a set of GraphSpecs, one for each region.  Any ACCOUNT
        level granularity resources are only scanned in a single region (e.g. IAM Users)

        Returns:
            Dict of scan result, see above for details.
        """
        logger = Logger()
        with logger.bind(account_id=self.account_id):
            logger.info(event=AWSLogEvents.ScanAWSAccountStart)
            output_artifact = None
            stats = MultilevelCounter()
            errors: List[str] = []
            now = int(time.time())
            try:
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[],
                    errors=[],
                    stats=stats,
                )
                # sanity check
                session = self.get_session()
                sts_client = session.client("sts")
                sts_account_id = sts_client.get_caller_identity()["Account"]
                if sts_account_id != self.account_id:
                    raise ValueError(
                        f"BUG: sts detected account_id {sts_account_id} != {self.account_id}"
                    )
                if self.regions:
                    scan_regions = tuple(self.regions)
                else:
                    scan_regions = get_all_enabled_regions(session=session)
                # build graph specs.
                # build a dict of regions -> services -> List[AWSResourceSpec]
                regions_services_resource_spec_classes: DefaultDict[
                    str,
                    DefaultDict[str,
                                List[Type[AWSResourceSpec]]]] = defaultdict(
                                    lambda: defaultdict(list))
                resource_spec_class: Type[AWSResourceSpec]
                for resource_spec_class in self.resource_spec_classes:
                    client_name = resource_spec_class.get_client_name()
                    resource_class_scan_granularity = resource_spec_class.scan_granularity
                    if resource_class_scan_granularity == ScanGranularity.ACCOUNT:
                        regions_services_resource_spec_classes[scan_regions[
                            0]][client_name].append(resource_spec_class)
                    elif resource_class_scan_granularity == ScanGranularity.REGION:
                        for region in scan_regions:
                            regions_services_resource_spec_classes[region][
                                client_name].append(resource_spec_class)
                    else:
                        raise NotImplementedError(
                            f"ScanGranularity {resource_class_scan_granularity} not implemented"
                        )
                with ThreadPoolExecutor(
                        max_workers=self.max_svc_threads) as executor:
                    futures = []
                    for (
                            region,
                            services_resource_spec_classes,
                    ) in regions_services_resource_spec_classes.items():
                        for (
                                service,
                                resource_spec_classes,
                        ) in services_resource_spec_classes.items():
                            region_session = self.get_session(region=region)
                            region_creds = region_session.get_credentials()
                            scan_future = schedule_scan_services(
                                executor=executor,
                                graph_name=self.graph_name,
                                graph_version=self.graph_version,
                                account_id=self.account_id,
                                region=region,
                                service=service,
                                access_key=region_creds.access_key,
                                secret_key=region_creds.secret_key,
                                token=region_creds.token,
                                resource_spec_classes=tuple(
                                    resource_spec_classes),
                            )
                            futures.append(scan_future)
                    for future in as_completed(futures):
                        graph_set_dict = future.result()
                        graph_set = GraphSet.from_dict(graph_set_dict)
                        errors += graph_set.errors
                        account_graph_set.merge(graph_set)
                    account_graph_set.validate()
            except Exception as ex:
                error_str = str(ex)
                trace_back = traceback.format_exc()
                logger.error(event=AWSLogEvents.ScanAWSAccountError,
                             error=error_str,
                             trace_back=trace_back)
                errors.append(" : ".join((error_str, trace_back)))
                unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                    account_id=self.account_id, errors=errors)
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[unscanned_account_resource],
                    errors=errors,
                    stats=stats,
                )
                account_graph_set.validate()
        output_artifact = self.artifact_writer.write_artifact(
            name=self.account_id, data=account_graph_set.to_dict())
        logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
        api_call_stats = account_graph_set.stats.to_dict()
        return {
            "account_id": self.account_id,
            "output_artifact": output_artifact,
            "errors": errors,
            "api_call_stats": api_call_stats,
        }
Exemplo n.º 15
0
 def scan(self) -> List[Dict[str, Any]]:
     logger = Logger()
     scan_result_dicts = []
     now = int(time.time())
     prescan_account_ids_errors: DefaultDict[str,
                                             List[str]] = defaultdict(list)
     futures = []
     with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
         shuffled_account_ids = random.sample(
             self.account_scan_plan.account_ids,
             k=len(self.account_scan_plan.account_ids))
         for account_id in shuffled_account_ids:
             with logger.bind(account_id=account_id):
                 logger.info(event=AWSLogEvents.ScanAWSAccountStart)
                 try:
                     session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id)
                     # sanity check
                     sts_client = session.client("sts")
                     sts_account_id = sts_client.get_caller_identity(
                     )["Account"]
                     if sts_account_id != account_id:
                         raise ValueError(
                             f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                         )
                     if self.account_scan_plan.regions:
                         scan_regions = tuple(
                             self.account_scan_plan.regions)
                     else:
                         scan_regions = get_all_enabled_regions(
                             session=session)
                     account_gran_scan_region = random.choice(
                         self.preferred_account_scan_regions)
                     # build a dict of regions -> services -> List[AWSResourceSpec]
                     regions_services_resource_spec_classes: DefaultDict[
                         str, DefaultDict[
                             str,
                             List[Type[AWSResourceSpec]]]] = defaultdict(
                                 lambda: defaultdict(list))
                     resource_spec_class: Type[AWSResourceSpec]
                     for resource_spec_class in self.resource_spec_classes:
                         client_name = resource_spec_class.get_client_name()
                         if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT:
                             if resource_spec_class.region_whitelist:
                                 account_resource_scan_region = resource_spec_class.region_whitelist[
                                     0]
                             else:
                                 account_resource_scan_region = account_gran_scan_region
                             regions_services_resource_spec_classes[
                                 account_resource_scan_region][
                                     client_name].append(
                                         resource_spec_class)
                         elif resource_spec_class.scan_granularity == ScanGranularity.REGION:
                             if resource_spec_class.region_whitelist:
                                 resource_scan_regions = tuple(
                                     region for region in scan_regions
                                     if region in
                                     resource_spec_class.region_whitelist)
                                 if not resource_scan_regions:
                                     resource_scan_regions = resource_spec_class.region_whitelist
                             else:
                                 resource_scan_regions = scan_regions
                             for region in resource_scan_regions:
                                 regions_services_resource_spec_classes[
                                     region][client_name].append(
                                         resource_spec_class)
                         else:
                             raise NotImplementedError(
                                 f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented"
                             )
                     # Build and submit ScanUnits
                     shuffed_regions_services_resource_spec_classes = random.sample(
                         regions_services_resource_spec_classes.items(),
                         len(regions_services_resource_spec_classes),
                     )
                     for (
                             region,
                             services_resource_spec_classes,
                     ) in shuffed_regions_services_resource_spec_classes:
                         region_session = self.account_scan_plan.accessor.get_session(
                             account_id=account_id, region_name=region)
                         region_creds = region_session.get_credentials()
                         shuffled_services_resource_spec_classes = random.sample(
                             services_resource_spec_classes.items(),
                             len(services_resource_spec_classes),
                         )
                         for (
                                 service,
                                 svc_resource_spec_classes,
                         ) in shuffled_services_resource_spec_classes:
                             future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=tuple(
                                     svc_resource_spec_classes),
                             )
                             futures.append(future)
                 except Exception as ex:
                     error_str = str(ex)
                     trace_back = traceback.format_exc()
                     logger.error(
                         event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back,
                     )
                     prescan_account_ids_errors[account_id].append(
                         f"{error_str}\n{trace_back}")
     account_ids_graph_set_dicts: Dict[str,
                                       List[Dict[str,
                                                 Any]]] = defaultdict(list)
     for future in as_completed(futures):
         account_id, graph_set_dict = future.result()
         account_ids_graph_set_dicts[account_id].append(graph_set_dict)
     # first make sure no account id appears both in account_ids_graph_set_dicts
     # and prescan_account_ids_errors - this should never happen
     doubled_accounts = set(
         account_ids_graph_set_dicts.keys()).intersection(
             set(prescan_account_ids_errors.keys()))
     if doubled_accounts:
         raise Exception((
             f"BUG: Account(s) {doubled_accounts} in both "
             "account_ids_graph_set_dicts and prescan_account_ids_errors."))
     # graph prescan error accounts
     for account_id, errors in prescan_account_ids_errors.items():
         with logger.bind(account_id=account_id):
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
                 stats=MultilevelCounter(),
             )
             account_graph_set.validate()
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     # graph rest
     for account_id, graph_set_dicts in account_ids_graph_set_dicts.items():
         with logger.bind(account_id=account_id):
             # if there are any errors whatsoever we generate an empty graph with
             # errors only
             errors = []
             for graph_set_dict in graph_set_dicts:
                 errors += graph_set_dict["errors"]
             if errors:
                 unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                     account_id=account_id, errors=errors)
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[unscanned_account_resource],
                     errors=errors,
                     stats=MultilevelCounter(
                     ),  # ENHANCHMENT: could technically get partial stats.
                 )
                 account_graph_set.validate()
             else:
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[],
                     errors=[],
                     stats=MultilevelCounter(),
                 )
                 for graph_set_dict in graph_set_dicts:
                     graph_set = GraphSet.from_dict(graph_set_dict)
                     account_graph_set.merge(graph_set)
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     return scan_result_dicts
Exemplo n.º 16
0
 def scan(cls, scan_accessor: Any) -> ResourceScanResult:
     return ResourceScanResult(resources=[],
                               stats=MultilevelCounter(),
                               errors=[])
Exemplo n.º 17
0
def scan(
    muxer: AWSScanMuxer,
    account_ids: List[str],
    regions: List[str],
    accessor: Accessor,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
    scan_sub_accounts: bool = False,
) -> ScanManifest:
    if scan_sub_accounts:
        account_ids = get_sub_account_ids(account_ids, accessor)
    account_scan_plans = build_account_scan_plans(accessor=accessor,
                                                  account_ids=account_ids,
                                                  regions=regions)
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    account_scan_manifests = muxer.scan(account_scan_plans=account_scan_plans)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[Dict[str, str]] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: List[Dict[str, str]] = []
    stats = MultilevelCounter()

    for account_scan_manifest in account_scan_manifests:
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.artifacts:
            artifacts += account_scan_manifest.artifacts
            if account_scan_manifest.errors:
                errors[account_id] = account_scan_manifest.errors
                unscanned_accounts.append(account_id)
            else:
                scanned_accounts.append(account_id)
        else:
            unscanned_accounts.append(account_id)
        account_stats = MultilevelCounter.from_dict(
            account_scan_manifest.api_call_stats)
        stats.merge(account_stats)
    graph_set = None
    for artifact_path in artifacts:
        artifact_dict = artifact_reader.read_artifact(artifact_path)
        artifact_graph_set = GraphSet.from_dict(artifact_dict)
        if graph_set is None:
            graph_set = artifact_graph_set
        else:
            graph_set.merge(artifact_graph_set)
    master_artifact_path = None
    if graph_set:
        master_artifact_path = artifact_writer.write_artifact(
            name="master", data=graph_set.to_dict())
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    if graph_set:
        start_time = graph_set.start_time
        end_time = graph_set.end_time
    else:
        start_time, end_time = None, None
    return ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=unscanned_accounts,
        api_call_stats=stats.to_dict(),
        start_time=start_time,
        end_time=end_time,
    )
Exemplo n.º 18
0
def run_scan(
    muxer: AWSScanMuxer,
    config: Config,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
) -> Tuple[ScanManifest, GraphSet]:
    if config.scan.scan_sub_accounts:
        account_ids = get_sub_account_ids(config.scan.accounts,
                                          config.access.accessor)
    else:
        account_ids = config.scan.accounts
    account_scan_plan = AccountScanPlan(account_ids=account_ids,
                                        regions=config.scan.regions,
                                        accessor=config.access.accessor)
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[str] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: List[str] = []
    stats = MultilevelCounter()
    graph_set = None

    for account_scan_manifest in muxer.scan(
            account_scan_plan=account_scan_plan):
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.artifacts:
            for account_scan_artifact in account_scan_manifest.artifacts:
                artifacts.append(account_scan_artifact)
                artifact_graph_set_dict = artifact_reader.read_json(
                    account_scan_artifact)
                artifact_graph_set = GraphSet.from_dict(
                    artifact_graph_set_dict)
                if graph_set is None:
                    graph_set = artifact_graph_set
                else:
                    graph_set.merge(artifact_graph_set)
            if account_scan_manifest.errors:
                errors[account_id] = account_scan_manifest.errors
                unscanned_accounts.append(account_id)
            else:
                scanned_accounts.append(account_id)
        else:
            unscanned_accounts.append(account_id)
        account_stats = MultilevelCounter.from_dict(
            account_scan_manifest.api_call_stats)
        stats.merge(account_stats)
    if graph_set is None:
        raise Exception("BUG: No graph_set generated.")
    master_artifact_path = artifact_writer.write_json(name="master",
                                                      data=graph_set.to_dict())
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    start_time = graph_set.start_time
    end_time = graph_set.end_time
    scan_manifest = ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=unscanned_accounts,
        api_call_stats=stats.to_dict(),
        start_time=start_time,
        end_time=end_time,
    )
    artifact_writer.write_json("manifest", data=scan_manifest.to_dict())
    return scan_manifest, graph_set
Exemplo n.º 19
0
class TestScanAccessor:
    api_call_stats: MultilevelCounter = MultilevelCounter()
Exemplo n.º 20
0
    def test_merge_does_not_update_other(self):
        ml_counter_self = MultilevelCounter()
        ml_counter_self.increment("foo", "boo", "goo")

        ml_counter_other = MultilevelCounter()
        ml_counter_other.increment("boo", "goo", "moo")

        ml_counter_self.merge(ml_counter_other)

        expected_data = {
            "count": 1,
            "boo": {
                "count": 1,
                "goo": {
                    "count": 1,
                    "moo": {
                        "count": 1
                    }
                }
            }
        }

        self.assertDictEqual(expected_data, ml_counter_other.to_dict())