Exemplo n.º 1
0
def lambda_handler(event: Dict[str, Any], context: Any) -> None:
    """Entrypoint"""
    root = logging.getLogger()
    if root.handlers:
        for handler in root.handlers:
            root.removeHandler(handler)

    account_scan_plan_dict = get_required_lambda_event_var(
        event, "account_scan_plan")
    account_scan_plan = AccountScanPlan.from_dict(account_scan_plan_dict)
    scan_id = get_required_lambda_event_var(event, "scan_id")
    artifact_path = get_required_lambda_event_var(event, "artifact_path")
    max_svc_scan_threads = get_required_lambda_event_var(
        event, "max_svc_scan_threads")
    preferred_account_scan_regions = get_required_lambda_event_var(
        event, "preferred_account_scan_regions")
    scan_sub_accounts = get_required_lambda_event_var(event,
                                                      "scan_sub_accounts")

    artifact_writer = ArtifactWriter.from_artifact_path(
        artifact_path=artifact_path, scan_id=scan_id)
    account_scanner = AccountScanner(
        account_scan_plan=account_scan_plan,
        artifact_writer=artifact_writer,
        max_svc_scan_threads=max_svc_scan_threads,
        preferred_account_scan_regions=preferred_account_scan_regions,
        scan_sub_accounts=scan_sub_accounts,
    )
    scan_results_dict = account_scanner.scan()
    scan_results_str = json.dumps(scan_results_dict, default=json_encoder)
    json_results = json.loads(scan_results_str)
    return json_results
Exemplo n.º 2
0
def lambda_handler(event: Dict[str, Any], __: Any) -> Dict[str, Any]:
    """AWS Lambda Handler. Depending on the input event and env vars
    either run the aws2n or account_scan processes"""
    root = logging.getLogger()
    if root.handlers:
        for handler in root.handlers:
            root.removeHandler(handler)
    try:
        aws2n_config = AWS2NConfig()
        config = AWSConfig.from_path(path=aws2n_config.config_path)
        scan_id = generate_scan_id()
        muxer = LambdaAWSScanMuxer(
            scan_id=scan_id,
            config=config,
            account_scan_lambda_name=aws2n_config.account_scan_lambda_name,
            account_scan_lambda_timeout=aws2n_config.
            account_scan_lambda_timeout,
        )
        aws2n_result = aws2n(scan_id=scan_id,
                             config=config,
                             muxer=muxer,
                             load_neptune=True)
        return asdict(aws2n_result)
    except ValidationError:
        pass
    try:
        account_scan_input = AccountScanLambdaEvent(**event)
        artifact_writer = ArtifactWriter.from_artifact_path(
            artifact_path=account_scan_input.artifact_path,
            scan_id=account_scan_input.scan_id)
        account_scanner = AccountScanner(
            account_scan_plan=account_scan_input.account_scan_plan,
            artifact_writer=artifact_writer,
            max_svc_scan_threads=account_scan_input.max_svc_scan_threads,
            preferred_account_scan_regions=account_scan_input.
            preferred_account_scan_regions,
            scan_sub_accounts=account_scan_input.scan_sub_accounts,
        )
        scan_results = account_scanner.scan()
        return scan_results.dict()
    except ValidationError:
        pass
    try:
        pruner_config = GraphPrunerConfig()
        prune_results = prune_graph(pruner_config)
        return prune_results.dict()
    except ValidationError:
        pass
    raise InvalidLambdaInputException(
        f"Invalid lambda input.\nENV: {os.environ}\nEvent: {event}")
def aws2neptune_lpg(scan_id: str, config: AWSConfig,
                    muxer: AWSScanMuxer) -> None:
    """Scan AWS resources to json, convert to RDF and load into Neptune
    if config.neptune is defined"""

    artifact_reader = ArtifactReader.from_artifact_path(config.artifact_path)
    artifact_writer = ArtifactWriter.from_artifact_path(
        artifact_path=config.artifact_path, scan_id=scan_id)

    aws_resource_region_mapping_repo = build_aws_resource_region_mapping_repo(
        global_region_whitelist=config.scan.regions,
        preferred_account_scan_regions=config.scan.
        preferred_account_scan_regions,
        services_regions_json_url=config.services_regions_json_url,
    )

    logger.info(
        AWSLogEvents.ScanConfigured,
        config=str(config),
        reader=str(artifact_reader.__class__),
        writer=str(artifact_writer.__class__),
    )
    print("Beginning AWS Account Scan")

    _, graph_set = run_scan(
        muxer=muxer,
        config=config,
        aws_resource_region_mapping_repo=aws_resource_region_mapping_repo,
        artifact_writer=artifact_writer,
        artifact_reader=artifact_reader,
    )
    print("AWS Account Scan Complete. Beginning write to Amazon Neptune.")
    logger.info(LogEvent.NeptuneGremlinWriteStart)
    graph = graph_set.to_neptune_lpg(scan_id)
    if config.neptune is None:
        raise Exception(
            "Can not load to Neptune because config.neptune is empty.")
    endpoint = NeptuneEndpoint(
        host=config.neptune.host,
        port=config.neptune.port,
        region=config.neptune.region,
        ssl=bool(config.neptune.ssl),
        auth_mode=str(config.neptune.auth_mode),
    )
    neptune_client = AltimeterNeptuneClient(max_age_min=1440,
                                            neptune_endpoint=endpoint)
    neptune_client.write_to_neptune_lpg(graph, scan_id)
    logger.info(LogEvent.NeptuneGremlinWriteEnd)
    print("Write to Amazon Neptune Complete")
Exemplo n.º 4
0
def local_account_scan(
    scan_id: str,
    account_scan_plan: AccountScanPlan,
    config: AWSConfig,
    resource_spec_classes: Tuple[Type[AWSResourceSpec], ...],
) -> AccountScanResult:
    """Scan a set of accounts.

    Args:
        account_scan_plan_dict: AccountScanPlan defining the scan
        config: Config object
    """
    artifact_writer = ArtifactWriter.from_artifact_path(
        artifact_path=config.artifact_path, scan_id=scan_id)
    account_scanner = AccountScanner(
        account_scan_plan=account_scan_plan,
        artifact_writer=artifact_writer,
        max_svc_scan_threads=config.concurrency.max_svc_scan_threads,
        scan_sub_accounts=config.scan.scan_sub_accounts,
        resource_spec_classes=resource_spec_classes,
    )
    return account_scanner.scan()
Exemplo n.º 5
0
def aws2neptune_rdf(scan_id: str, config: Config, muxer: AWSScanMuxer) -> None:
    """Scan AWS resources to json, convert to RDF and load into Neptune
    if config.neptune is defined"""

    artifact_reader = ArtifactReader.from_artifact_path(config.artifact_path)
    artifact_writer = ArtifactWriter.from_artifact_path(
        artifact_path=config.artifact_path, scan_id=scan_id
    )

    logger.info(
        AWSLogEvents.ScanConfigured,
        config=str(config),
        reader=str(artifact_reader.__class__),
        writer=str(artifact_writer.__class__),
    )
    print("Beginning AWS Account Scan")
    scan_manifest, graph_set = run_scan(
        muxer=muxer,
        config=config,
        artifact_writer=artifact_writer,
        artifact_reader=artifact_reader,
    )
    print("AWS Account Scan Complete. Beginning write to Amazon Neptune.")
    logger.info(LogEvent.NeptuneRDFWriteStart)
    graph = graph_set.to_rdf()
    if config.neptune is None:
        raise Exception("Can not load to Neptune because config.neptune is empty.")
    endpoint = NeptuneEndpoint(
        host=config.neptune.host,
        port=config.neptune.port,
        region=config.neptune.region,
        ssl=bool(config.neptune.ssl),
        auth_mode=str(config.neptune.auth_mode),
    )
    neptune_client = AltimeterNeptuneClient(max_age_min=1440, neptune_endpoint=endpoint)
    neptune_client.write_to_neptune_rdf(graph)
    logger.info(LogEvent.NeptuneRDFWriteEnd)
    print("Write to Amazon Neptune Complete")
Exemplo n.º 6
0
def local_account_scan(
    scan_id: str,
    account_scan_plan_dict: Dict[str, Any],
    config: Config,
) -> List[Dict[str, Any]]:
    """Scan a set of accounts.

    Args:
        account_scan_plan_dict: AccountScanPlan defining the scan
        config: Config object
    """
    artifact_writer = ArtifactWriter.from_artifact_path(
        artifact_path=config.artifact_path, scan_id=scan_id)
    account_scan_plan = AccountScanPlan.from_dict(
        account_scan_plan_dict=account_scan_plan_dict)
    account_scanner = AccountScanner(
        account_scan_plan=account_scan_plan,
        artifact_writer=artifact_writer,
        max_svc_scan_threads=config.concurrency.max_svc_scan_threads,
        preferred_account_scan_regions=config.scan.
        preferred_account_scan_regions,
        scan_sub_accounts=config.scan.scan_sub_accounts,
    )
    return account_scanner.scan()
Exemplo n.º 7
0
def aws2n(scan_id: str, config: Config, muxer: AWSScanMuxer,
          load_neptune: bool) -> AWS2NResult:
    """Scan AWS resources to json, convert to RDF and load into Neptune
    if config.neptune is defined"""
    artifact_reader = ArtifactReader.from_artifact_path(config.artifact_path)
    artifact_writer = ArtifactWriter.from_artifact_path(
        artifact_path=config.artifact_path, scan_id=scan_id)

    logger = Logger()
    logger.info(
        AWSLogEvents.ScanConfigured,
        config=str(config),
        reader=str(artifact_reader.__class__),
        writer=str(artifact_writer.__class__),
    )

    scan_manifest, graph_set = run_scan(
        muxer=muxer,
        config=config,
        artifact_writer=artifact_writer,
        artifact_reader=artifact_reader,
    )
    json_path = scan_manifest.master_artifact
    rdf_path = artifact_writer.write_graph_set(name="master",
                                               graph_set=graph_set,
                                               compression=GZIP)
    graph_metadata = None
    if load_neptune:
        if config.neptune is None:
            raise Exception(
                "Can not load to Neptune because config.neptune is empty.")
        endpoint = NeptuneEndpoint(host=config.neptune.host,
                                   port=config.neptune.port,
                                   region=config.neptune.region)
        neptune_client = AltimeterNeptuneClient(max_age_min=1440,
                                                neptune_endpoint=endpoint)
        rdf_bucket, rdf_key = parse_s3_uri(rdf_path)
        if rdf_key is None:
            raise Exception(f"Invalid rdf s3 path {rdf_path}")
        graph_metadata = neptune_client.load_graph(
            bucket=rdf_bucket,
            key=rdf_key,
            load_iam_role_arn=str(config.neptune.iam_role_arn))
        logger.info(event=LogEvent.GraphLoadedSNSNotificationStart)
        sns_client = boto3.client("sns")
        message_dict = {
            "uri": graph_metadata.uri,
            "name": graph_metadata.name,
            "version": graph_metadata.version,
            "start_time": graph_metadata.start_time,
            "end_time": graph_metadata.end_time,
            "neptune_endpoint": endpoint.get_endpoint_str(),
        }
        message_dict["default"] = json.dumps(message_dict)
        sns_client.publish(
            TopicArn=config.neptune.graph_load_sns_topic_arn,
            MessageStructure="json",
            Message=json.dumps(message_dict),
        )
        logger.info(event=LogEvent.GraphLoadedSNSNotificationEnd)
    return AWS2NResult(json_path=json_path,
                       rdf_path=rdf_path,
                       graph_metadata=graph_metadata)
Exemplo n.º 8
0
def run_scan(
    muxer: AWSScanMuxer,
    config: Config,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
) -> Tuple[ScanManifest, GraphSet]:
    if config.scan.scan_sub_accounts:
        account_ids = get_sub_account_ids(config.scan.accounts,
                                          config.access.accessor)
    else:
        account_ids = config.scan.accounts
    account_scan_plan = AccountScanPlan(account_ids=account_ids,
                                        regions=config.scan.regions,
                                        accessor=config.access.accessor)
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[str] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: List[str] = []
    stats = MultilevelCounter()
    graph_set = None

    for account_scan_manifest in muxer.scan(
            account_scan_plan=account_scan_plan):
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.artifacts:
            for account_scan_artifact in account_scan_manifest.artifacts:
                artifacts.append(account_scan_artifact)
                artifact_graph_set_dict = artifact_reader.read_json(
                    account_scan_artifact)
                artifact_graph_set = GraphSet.from_dict(
                    artifact_graph_set_dict)
                if graph_set is None:
                    graph_set = artifact_graph_set
                else:
                    graph_set.merge(artifact_graph_set)
            if account_scan_manifest.errors:
                errors[account_id] = account_scan_manifest.errors
                unscanned_accounts.append(account_id)
            else:
                scanned_accounts.append(account_id)
        else:
            unscanned_accounts.append(account_id)
        account_stats = MultilevelCounter.from_dict(
            account_scan_manifest.api_call_stats)
        stats.merge(account_stats)
    if graph_set is None:
        raise Exception("BUG: No graph_set generated.")
    master_artifact_path = artifact_writer.write_json(name="master",
                                                      data=graph_set.to_dict())
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    start_time = graph_set.start_time
    end_time = graph_set.end_time
    scan_manifest = ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=unscanned_accounts,
        api_call_stats=stats.to_dict(),
        start_time=start_time,
        end_time=end_time,
    )
    artifact_writer.write_json("manifest", data=scan_manifest.to_dict())
    return scan_manifest, graph_set
Exemplo n.º 9
0
 def test_from_artifact_path_filepath(self):
     writer = ArtifactWriter.from_artifact_path(artifact_path="/file/path",
                                                scan_id="test-scan-id")
     self.assertIsInstance(writer, FileArtifactWriter)
Exemplo n.º 10
0
 def test_from_artifact_path_s3(self):
     writer = ArtifactWriter.from_artifact_path(artifact_path="s3://bucket",
                                                scan_id="test-scan-id")
     self.assertIsInstance(writer, S3ArtifactWriter)
def run_scan(
    muxer: AWSScanMuxer,
    config: AWSConfig,
    aws_resource_region_mapping_repo: AWSResourceRegionMappingRepository,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
) -> Tuple[ScanManifest, ValidatedGraphSet]:
    if config.scan.accounts:
        scan_account_ids = config.scan.accounts
    else:
        sts_client = boto3.client("sts")
        scan_account_id = sts_client.get_caller_identity()["Account"]
        scan_account_ids = (scan_account_id,)
    if config.scan.scan_sub_accounts:
        account_ids = get_sub_account_ids(scan_account_ids, config.accessor)
    else:
        account_ids = scan_account_ids
    scan_plan = ScanPlan(
        account_ids=account_ids,
        regions=config.scan.regions,
        aws_resource_region_mapping_repo=aws_resource_region_mapping_repo,
        accessor=config.accessor,
    )
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[str] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: Set[str] = set()
    graph_sets: List[GraphSet] = []

    for account_scan_manifest in muxer.scan(scan_plan=scan_plan):
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.errors:
            errors[account_id] = account_scan_manifest.errors
            unscanned_accounts.add(account_id)
        if account_scan_manifest.artifacts:
            for account_scan_artifact in account_scan_manifest.artifacts:
                artifacts.append(account_scan_artifact)
                artifact_graph_set_dict = artifact_reader.read_json(account_scan_artifact)
                graph_sets.append(GraphSet.parse_obj(artifact_graph_set_dict))
            scanned_accounts.append(account_id)
        else:
            unscanned_accounts.add(account_id)
    if not graph_sets:
        raise Exception("BUG: No graph_sets generated.")
    validated_graph_set = ValidatedGraphSet.from_graph_set(GraphSet.from_graph_sets(graph_sets))
    master_artifact_path: Optional[str] = None
    if config.write_master_json:
        master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set)
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    start_time = validated_graph_set.start_time
    end_time = validated_graph_set.end_time
    scan_manifest = ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=list(unscanned_accounts),
        start_time=start_time,
        end_time=end_time,
    )
    artifact_writer.write_json("manifest", data=scan_manifest)
    return scan_manifest, validated_graph_set
Exemplo n.º 12
0
def scan(
    muxer: AWSScanMuxer,
    account_ids: List[str],
    regions: List[str],
    accessor: Accessor,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
    scan_sub_accounts: bool = False,
) -> ScanManifest:
    if scan_sub_accounts:
        account_ids = get_sub_account_ids(account_ids, accessor)
    account_scan_plans = build_account_scan_plans(accessor=accessor,
                                                  account_ids=account_ids,
                                                  regions=regions)
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    account_scan_manifests = muxer.scan(account_scan_plans=account_scan_plans)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[Dict[str, str]] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: List[Dict[str, str]] = []
    stats = MultilevelCounter()

    for account_scan_manifest in account_scan_manifests:
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.artifacts:
            artifacts += account_scan_manifest.artifacts
            if account_scan_manifest.errors:
                errors[account_id] = account_scan_manifest.errors
                unscanned_accounts.append(account_id)
            else:
                scanned_accounts.append(account_id)
        else:
            unscanned_accounts.append(account_id)
        account_stats = MultilevelCounter.from_dict(
            account_scan_manifest.api_call_stats)
        stats.merge(account_stats)
    graph_set = None
    for artifact_path in artifacts:
        artifact_dict = artifact_reader.read_artifact(artifact_path)
        artifact_graph_set = GraphSet.from_dict(artifact_dict)
        if graph_set is None:
            graph_set = artifact_graph_set
        else:
            graph_set.merge(artifact_graph_set)
    master_artifact_path = None
    if graph_set:
        master_artifact_path = artifact_writer.write_artifact(
            name="master", data=graph_set.to_dict())
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    if graph_set:
        start_time = graph_set.start_time
        end_time = graph_set.end_time
    else:
        start_time, end_time = None, None
    return ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=unscanned_accounts,
        api_call_stats=stats.to_dict(),
        start_time=start_time,
        end_time=end_time,
    )