def lambda_handler(event: Dict[str, Any], context: Any) -> None: """Entrypoint""" root = logging.getLogger() if root.handlers: for handler in root.handlers: root.removeHandler(handler) account_scan_plan_dict = get_required_lambda_event_var( event, "account_scan_plan") account_scan_plan = AccountScanPlan.from_dict(account_scan_plan_dict) scan_id = get_required_lambda_event_var(event, "scan_id") artifact_path = get_required_lambda_event_var(event, "artifact_path") max_svc_scan_threads = get_required_lambda_event_var( event, "max_svc_scan_threads") preferred_account_scan_regions = get_required_lambda_event_var( event, "preferred_account_scan_regions") scan_sub_accounts = get_required_lambda_event_var(event, "scan_sub_accounts") artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=artifact_path, scan_id=scan_id) account_scanner = AccountScanner( account_scan_plan=account_scan_plan, artifact_writer=artifact_writer, max_svc_scan_threads=max_svc_scan_threads, preferred_account_scan_regions=preferred_account_scan_regions, scan_sub_accounts=scan_sub_accounts, ) scan_results_dict = account_scanner.scan() scan_results_str = json.dumps(scan_results_dict, default=json_encoder) json_results = json.loads(scan_results_str) return json_results
def lambda_handler(event: Dict[str, Any], __: Any) -> Dict[str, Any]: """AWS Lambda Handler. Depending on the input event and env vars either run the aws2n or account_scan processes""" root = logging.getLogger() if root.handlers: for handler in root.handlers: root.removeHandler(handler) try: aws2n_config = AWS2NConfig() config = AWSConfig.from_path(path=aws2n_config.config_path) scan_id = generate_scan_id() muxer = LambdaAWSScanMuxer( scan_id=scan_id, config=config, account_scan_lambda_name=aws2n_config.account_scan_lambda_name, account_scan_lambda_timeout=aws2n_config. account_scan_lambda_timeout, ) aws2n_result = aws2n(scan_id=scan_id, config=config, muxer=muxer, load_neptune=True) return asdict(aws2n_result) except ValidationError: pass try: account_scan_input = AccountScanLambdaEvent(**event) artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=account_scan_input.artifact_path, scan_id=account_scan_input.scan_id) account_scanner = AccountScanner( account_scan_plan=account_scan_input.account_scan_plan, artifact_writer=artifact_writer, max_svc_scan_threads=account_scan_input.max_svc_scan_threads, preferred_account_scan_regions=account_scan_input. preferred_account_scan_regions, scan_sub_accounts=account_scan_input.scan_sub_accounts, ) scan_results = account_scanner.scan() return scan_results.dict() except ValidationError: pass try: pruner_config = GraphPrunerConfig() prune_results = prune_graph(pruner_config) return prune_results.dict() except ValidationError: pass raise InvalidLambdaInputException( f"Invalid lambda input.\nENV: {os.environ}\nEvent: {event}")
def aws2neptune_lpg(scan_id: str, config: AWSConfig, muxer: AWSScanMuxer) -> None: """Scan AWS resources to json, convert to RDF and load into Neptune if config.neptune is defined""" artifact_reader = ArtifactReader.from_artifact_path(config.artifact_path) artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=config.artifact_path, scan_id=scan_id) aws_resource_region_mapping_repo = build_aws_resource_region_mapping_repo( global_region_whitelist=config.scan.regions, preferred_account_scan_regions=config.scan. preferred_account_scan_regions, services_regions_json_url=config.services_regions_json_url, ) logger.info( AWSLogEvents.ScanConfigured, config=str(config), reader=str(artifact_reader.__class__), writer=str(artifact_writer.__class__), ) print("Beginning AWS Account Scan") _, graph_set = run_scan( muxer=muxer, config=config, aws_resource_region_mapping_repo=aws_resource_region_mapping_repo, artifact_writer=artifact_writer, artifact_reader=artifact_reader, ) print("AWS Account Scan Complete. Beginning write to Amazon Neptune.") logger.info(LogEvent.NeptuneGremlinWriteStart) graph = graph_set.to_neptune_lpg(scan_id) if config.neptune is None: raise Exception( "Can not load to Neptune because config.neptune is empty.") endpoint = NeptuneEndpoint( host=config.neptune.host, port=config.neptune.port, region=config.neptune.region, ssl=bool(config.neptune.ssl), auth_mode=str(config.neptune.auth_mode), ) neptune_client = AltimeterNeptuneClient(max_age_min=1440, neptune_endpoint=endpoint) neptune_client.write_to_neptune_lpg(graph, scan_id) logger.info(LogEvent.NeptuneGremlinWriteEnd) print("Write to Amazon Neptune Complete")
def local_account_scan( scan_id: str, account_scan_plan: AccountScanPlan, config: AWSConfig, resource_spec_classes: Tuple[Type[AWSResourceSpec], ...], ) -> AccountScanResult: """Scan a set of accounts. Args: account_scan_plan_dict: AccountScanPlan defining the scan config: Config object """ artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=config.artifact_path, scan_id=scan_id) account_scanner = AccountScanner( account_scan_plan=account_scan_plan, artifact_writer=artifact_writer, max_svc_scan_threads=config.concurrency.max_svc_scan_threads, scan_sub_accounts=config.scan.scan_sub_accounts, resource_spec_classes=resource_spec_classes, ) return account_scanner.scan()
def aws2neptune_rdf(scan_id: str, config: Config, muxer: AWSScanMuxer) -> None: """Scan AWS resources to json, convert to RDF and load into Neptune if config.neptune is defined""" artifact_reader = ArtifactReader.from_artifact_path(config.artifact_path) artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=config.artifact_path, scan_id=scan_id ) logger.info( AWSLogEvents.ScanConfigured, config=str(config), reader=str(artifact_reader.__class__), writer=str(artifact_writer.__class__), ) print("Beginning AWS Account Scan") scan_manifest, graph_set = run_scan( muxer=muxer, config=config, artifact_writer=artifact_writer, artifact_reader=artifact_reader, ) print("AWS Account Scan Complete. Beginning write to Amazon Neptune.") logger.info(LogEvent.NeptuneRDFWriteStart) graph = graph_set.to_rdf() if config.neptune is None: raise Exception("Can not load to Neptune because config.neptune is empty.") endpoint = NeptuneEndpoint( host=config.neptune.host, port=config.neptune.port, region=config.neptune.region, ssl=bool(config.neptune.ssl), auth_mode=str(config.neptune.auth_mode), ) neptune_client = AltimeterNeptuneClient(max_age_min=1440, neptune_endpoint=endpoint) neptune_client.write_to_neptune_rdf(graph) logger.info(LogEvent.NeptuneRDFWriteEnd) print("Write to Amazon Neptune Complete")
def local_account_scan( scan_id: str, account_scan_plan_dict: Dict[str, Any], config: Config, ) -> List[Dict[str, Any]]: """Scan a set of accounts. Args: account_scan_plan_dict: AccountScanPlan defining the scan config: Config object """ artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=config.artifact_path, scan_id=scan_id) account_scan_plan = AccountScanPlan.from_dict( account_scan_plan_dict=account_scan_plan_dict) account_scanner = AccountScanner( account_scan_plan=account_scan_plan, artifact_writer=artifact_writer, max_svc_scan_threads=config.concurrency.max_svc_scan_threads, preferred_account_scan_regions=config.scan. preferred_account_scan_regions, scan_sub_accounts=config.scan.scan_sub_accounts, ) return account_scanner.scan()
def aws2n(scan_id: str, config: Config, muxer: AWSScanMuxer, load_neptune: bool) -> AWS2NResult: """Scan AWS resources to json, convert to RDF and load into Neptune if config.neptune is defined""" artifact_reader = ArtifactReader.from_artifact_path(config.artifact_path) artifact_writer = ArtifactWriter.from_artifact_path( artifact_path=config.artifact_path, scan_id=scan_id) logger = Logger() logger.info( AWSLogEvents.ScanConfigured, config=str(config), reader=str(artifact_reader.__class__), writer=str(artifact_writer.__class__), ) scan_manifest, graph_set = run_scan( muxer=muxer, config=config, artifact_writer=artifact_writer, artifact_reader=artifact_reader, ) json_path = scan_manifest.master_artifact rdf_path = artifact_writer.write_graph_set(name="master", graph_set=graph_set, compression=GZIP) graph_metadata = None if load_neptune: if config.neptune is None: raise Exception( "Can not load to Neptune because config.neptune is empty.") endpoint = NeptuneEndpoint(host=config.neptune.host, port=config.neptune.port, region=config.neptune.region) neptune_client = AltimeterNeptuneClient(max_age_min=1440, neptune_endpoint=endpoint) rdf_bucket, rdf_key = parse_s3_uri(rdf_path) if rdf_key is None: raise Exception(f"Invalid rdf s3 path {rdf_path}") graph_metadata = neptune_client.load_graph( bucket=rdf_bucket, key=rdf_key, load_iam_role_arn=str(config.neptune.iam_role_arn)) logger.info(event=LogEvent.GraphLoadedSNSNotificationStart) sns_client = boto3.client("sns") message_dict = { "uri": graph_metadata.uri, "name": graph_metadata.name, "version": graph_metadata.version, "start_time": graph_metadata.start_time, "end_time": graph_metadata.end_time, "neptune_endpoint": endpoint.get_endpoint_str(), } message_dict["default"] = json.dumps(message_dict) sns_client.publish( TopicArn=config.neptune.graph_load_sns_topic_arn, MessageStructure="json", Message=json.dumps(message_dict), ) logger.info(event=LogEvent.GraphLoadedSNSNotificationEnd) return AWS2NResult(json_path=json_path, rdf_path=rdf_path, graph_metadata=graph_metadata)
def run_scan( muxer: AWSScanMuxer, config: Config, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, ) -> Tuple[ScanManifest, GraphSet]: if config.scan.scan_sub_accounts: account_ids = get_sub_account_ids(config.scan.accounts, config.access.accessor) else: account_ids = config.scan.accounts account_scan_plan = AccountScanPlan(account_ids=account_ids, regions=config.scan.regions, accessor=config.access.accessor) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[str] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: List[str] = [] stats = MultilevelCounter() graph_set = None for account_scan_manifest in muxer.scan( account_scan_plan=account_scan_plan): account_id = account_scan_manifest.account_id if account_scan_manifest.artifacts: for account_scan_artifact in account_scan_manifest.artifacts: artifacts.append(account_scan_artifact) artifact_graph_set_dict = artifact_reader.read_json( account_scan_artifact) artifact_graph_set = GraphSet.from_dict( artifact_graph_set_dict) if graph_set is None: graph_set = artifact_graph_set else: graph_set.merge(artifact_graph_set) if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.append(account_id) else: scanned_accounts.append(account_id) else: unscanned_accounts.append(account_id) account_stats = MultilevelCounter.from_dict( account_scan_manifest.api_call_stats) stats.merge(account_stats) if graph_set is None: raise Exception("BUG: No graph_set generated.") master_artifact_path = artifact_writer.write_json(name="master", data=graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) start_time = graph_set.start_time end_time = graph_set.end_time scan_manifest = ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=unscanned_accounts, api_call_stats=stats.to_dict(), start_time=start_time, end_time=end_time, ) artifact_writer.write_json("manifest", data=scan_manifest.to_dict()) return scan_manifest, graph_set
def test_from_artifact_path_filepath(self): writer = ArtifactWriter.from_artifact_path(artifact_path="/file/path", scan_id="test-scan-id") self.assertIsInstance(writer, FileArtifactWriter)
def test_from_artifact_path_s3(self): writer = ArtifactWriter.from_artifact_path(artifact_path="s3://bucket", scan_id="test-scan-id") self.assertIsInstance(writer, S3ArtifactWriter)
def run_scan( muxer: AWSScanMuxer, config: AWSConfig, aws_resource_region_mapping_repo: AWSResourceRegionMappingRepository, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, ) -> Tuple[ScanManifest, ValidatedGraphSet]: if config.scan.accounts: scan_account_ids = config.scan.accounts else: sts_client = boto3.client("sts") scan_account_id = sts_client.get_caller_identity()["Account"] scan_account_ids = (scan_account_id,) if config.scan.scan_sub_accounts: account_ids = get_sub_account_ids(scan_account_ids, config.accessor) else: account_ids = scan_account_ids scan_plan = ScanPlan( account_ids=account_ids, regions=config.scan.regions, aws_resource_region_mapping_repo=aws_resource_region_mapping_repo, accessor=config.accessor, ) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[str] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: Set[str] = set() graph_sets: List[GraphSet] = [] for account_scan_manifest in muxer.scan(scan_plan=scan_plan): account_id = account_scan_manifest.account_id if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.add(account_id) if account_scan_manifest.artifacts: for account_scan_artifact in account_scan_manifest.artifacts: artifacts.append(account_scan_artifact) artifact_graph_set_dict = artifact_reader.read_json(account_scan_artifact) graph_sets.append(GraphSet.parse_obj(artifact_graph_set_dict)) scanned_accounts.append(account_id) else: unscanned_accounts.add(account_id) if not graph_sets: raise Exception("BUG: No graph_sets generated.") validated_graph_set = ValidatedGraphSet.from_graph_set(GraphSet.from_graph_sets(graph_sets)) master_artifact_path: Optional[str] = None if config.write_master_json: master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) start_time = validated_graph_set.start_time end_time = validated_graph_set.end_time scan_manifest = ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=list(unscanned_accounts), start_time=start_time, end_time=end_time, ) artifact_writer.write_json("manifest", data=scan_manifest) return scan_manifest, validated_graph_set
def scan( muxer: AWSScanMuxer, account_ids: List[str], regions: List[str], accessor: Accessor, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, scan_sub_accounts: bool = False, ) -> ScanManifest: if scan_sub_accounts: account_ids = get_sub_account_ids(account_ids, accessor) account_scan_plans = build_account_scan_plans(accessor=accessor, account_ids=account_ids, regions=regions) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) account_scan_manifests = muxer.scan(account_scan_plans=account_scan_plans) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[Dict[str, str]] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: List[Dict[str, str]] = [] stats = MultilevelCounter() for account_scan_manifest in account_scan_manifests: account_id = account_scan_manifest.account_id if account_scan_manifest.artifacts: artifacts += account_scan_manifest.artifacts if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.append(account_id) else: scanned_accounts.append(account_id) else: unscanned_accounts.append(account_id) account_stats = MultilevelCounter.from_dict( account_scan_manifest.api_call_stats) stats.merge(account_stats) graph_set = None for artifact_path in artifacts: artifact_dict = artifact_reader.read_artifact(artifact_path) artifact_graph_set = GraphSet.from_dict(artifact_dict) if graph_set is None: graph_set = artifact_graph_set else: graph_set.merge(artifact_graph_set) master_artifact_path = None if graph_set: master_artifact_path = artifact_writer.write_artifact( name="master", data=graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) if graph_set: start_time = graph_set.start_time end_time = graph_set.end_time else: start_time, end_time = None, None return ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=unscanned_accounts, api_call_stats=stats.to_dict(), start_time=start_time, end_time=end_time, )