예제 #1
0
 def lambda_handler(cls, event: Dict[str, Any], _: Any) -> None:
     """lambda entrypoint"""
     config = Config()
     result = Result(**event)
     logger = Logger()
     errors: List[str] = []
     with logger.bind(result=result):
         logger.info(event=QJLogEvents.ResultRemediationStart)
         try:
             session = get_assumed_session(
                 account_id=result.account_id,
                 role_name=config.remediator_target_role_name,
                 external_id=config.remediator_target_role_external_id,
             )
             cls.remediate(session=session,
                           result=result.result,
                           dry_run=config.dry_run)
             logger.info(event=QJLogEvents.ResultRemediationSuccessful)
         except Exception as ex:
             logger.error(event=QJLogEvents.ResultRemediationFailed,
                          error=str(ex))
             errors.append(str(ex))
     if errors:
         raise RemediationError(
             f"Errors found during remediation: {errors}")
예제 #2
0
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]:
    logger = Logger()
    with logger.bind(
            account_id=scan_unit.account_id,
            region=scan_unit.region_name,
            service=scan_unit.service,
            resource_classes=sorted([
                resource_spec_class.__name__
                for resource_spec_class in scan_unit.resource_spec_classes
            ]),
    ):
        start_t = time.time()
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart)
        session = boto3.Session(
            aws_access_key_id=scan_unit.access_key,
            aws_secret_access_key=scan_unit.secret_key,
            aws_session_token=scan_unit.token,
            region_name=scan_unit.region_name,
        )
        scan_accessor = AWSAccessor(session=session,
                                    account_id=scan_unit.account_id,
                                    region_name=scan_unit.region_name)
        graph_spec = GraphSpec(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            resource_spec_classes=scan_unit.resource_spec_classes,
            scan_accessor=scan_accessor,
        )
        start_time = int(time.time())
        resources: List[Resource] = []
        errors = []
        try:
            resources = graph_spec.scan()
        except Exception as ex:
            error_str = str(ex)
            trace_back = traceback.format_exc()
            logger.error(event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back)
            error = f"{str(ex)}\n{trace_back}"
            errors.append(error)
        end_time = int(time.time())
        graph_set = GraphSet(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=scan_accessor.api_call_stats,
        )
        end_t = time.time()
        elapsed_sec = end_t - start_t
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd,
                    elapsed_sec=elapsed_sec)
        return (scan_unit.account_id, graph_set.to_dict())
예제 #3
0
def lambda_handler(event, context):
    host = get_required_lambda_env_var("NEPTUNE_HOST")
    port = get_required_lambda_env_var("NEPTUNE_PORT")
    region = get_required_lambda_env_var("NEPTUNE_REGION")
    max_age_min = get_required_lambda_env_var("MAX_AGE_MIN")
    graph_name = get_required_lambda_env_var("GRAPH_NAME")
    try:
        max_age_min = int(max_age_min)
    except ValueError as ve:
        raise Exception(f"env var MAX_AGE_MIN must be an int: {ve}")
    now = int(datetime.now().timestamp())
    oldest_acceptable_graph_epoch = now - max_age_min * 60

    endpoint = NeptuneEndpoint(host=host, port=port, region=region)
    client = AltimeterNeptuneClient(max_age_min=max_age_min,
                                    neptune_endpoint=endpoint)
    logger = Logger()

    uncleared = []

    # first prune metadata - if clears below are partial we want to make sure no clients
    # consider this a valid graph still.
    logger.info(event=LogEvent.PruneNeptuneMetadataGraphStart)
    client.clear_old_graph_metadata(name=graph_name, max_age_min=max_age_min)
    logger.info(event=LogEvent.PruneNeptuneMetadataGraphEnd)
    # now clear actual graphs
    with logger.bind(neptune_endpoint=endpoint):
        logger.info(event=LogEvent.PruneNeptuneGraphsStart)
        for graph_metadata in client.get_graph_metadatas(name=graph_name):
            assert graph_metadata.name == graph_name
            graph_epoch = graph_metadata.end_time
            with logger.bind(graph_uri=graph_metadata.uri,
                             graph_epoch=graph_epoch):
                if graph_epoch < oldest_acceptable_graph_epoch:
                    logger.info(event=LogEvent.PruneNeptuneGraphStart)
                    try:
                        client.clear_graph(graph_uri=graph_metadata.uri)
                        logger.info(event=LogEvent.PruneNeptuneGraphEnd)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=
                            f"Error pruning graph {graph_metadata.uri}: {ex}",
                        )
                        uncleared.append(graph_metadata.uri)
                        continue
                else:
                    logger.info(event=LogEvent.PruneNeptuneGraphSkip)
        logger.info(event=LogEvent.PruneNeptuneGraphsEnd)
        if uncleared:
            msg = f"Errors were found pruning {uncleared}."
            logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg)
            raise Exception(msg)
예제 #4
0
def remediator(event: Dict[str, Any]) -> None:
    """Run the remediation lambda for a QJ result set"""
    config = RemediatorConfig()
    logger = Logger()
    remediation = Remediation(**event)
    with logger.bind(remediation=remediation):
        logger.info(event=QJLogEvents.RemediationInit)
        qj_api_client = QJAPIClient(host=config.qj_api_host)
        latest_result_set = qj_api_client.get_job_latest_result_set(
            job_name=remediation.job_name)
        if not latest_result_set:
            msg = f"No latest_result_set present for {remediation.job_name}"
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if latest_result_set.result_set_id != remediation.result_set_id:
            msg = (
                f"Remediation result_set_id {remediation.result_set_id} does not match the "
                f"latest result_set_id {latest_result_set.result_set_id}")
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if not latest_result_set.job.remediate_sqs_queue:
            msg = f"Job {latest_result_set.job.name} has no remediator"
            logger.error(QJLogEvents.JobHasNoRemediator, detail=msg)
            raise RemediationError(msg)
        num_threads = 10  # TODO env var
        errors = []
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for result in latest_result_set.results:
                logger.info(event=QJLogEvents.ProcessResult, result=result)
                future = _schedule_result_remediation(
                    executor=executor,
                    lambda_name=latest_result_set.job.remediate_sqs_queue,
                    lambda_timeout=300,  # TODO env var?
                    result=result,
                )
                futures.append(future)
            for future in as_completed(futures):
                try:
                    lambda_result = future.result()
                    logger.info(QJLogEvents.ResultRemediationSuccessful,
                                lambda_result=lambda_result)
                except Exception as ex:
                    logger.info(
                        event=QJLogEvents.ResultSetRemediationFailed,
                        error=str(ex),
                    )
                    errors.append(str(ex))
        if errors:
            logger.error(event=QJLogEvents.ResultSetRemediationFailed,
                         errors=errors)
            raise RemediationError(
                f"Errors encountered during remediation of {latest_result_set.job.name} "
                f"/ {latest_result_set.result_set_id}: {errors}")
예제 #5
0
    def scan(self) -> Dict[str, Any]:
        """Scan an account and return a dict containing keys:

            * account_id: str
            * output_artifact: str
            * api_call_stats: Dict[str, Any]
            * errors: List[str]

        If errors is non-empty the results are incomplete for this account.
        output_artifact is a pointer to the actual scan data - either on the local fs or in s3.

        To scan an account we create a set of GraphSpecs, one for each region.  Any ACCOUNT
        level granularity resources are only scanned in a single region (e.g. IAM Users)

        Returns:
            Dict of scan result, see above for details.
        """
        logger = Logger()
        with logger.bind(account_id=self.account_id):
            logger.info(event=AWSLogEvents.ScanAWSAccountStart)
            output_artifact = None
            stats = MultilevelCounter()
            errors: List[str] = []
            now = int(time.time())
            try:
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[],
                    errors=[],
                    stats=stats,
                )
                # sanity check
                session = self.get_session()
                sts_client = session.client("sts")
                sts_account_id = sts_client.get_caller_identity()["Account"]
                if sts_account_id != self.account_id:
                    raise ValueError(
                        f"BUG: sts detected account_id {sts_account_id} != {self.account_id}"
                    )
                if self.regions:
                    scan_regions = tuple(self.regions)
                else:
                    scan_regions = get_all_enabled_regions(session=session)
                # build graph specs.
                # build a dict of regions -> services -> List[AWSResourceSpec]
                regions_services_resource_spec_classes: DefaultDict[
                    str,
                    DefaultDict[str,
                                List[Type[AWSResourceSpec]]]] = defaultdict(
                                    lambda: defaultdict(list))
                resource_spec_class: Type[AWSResourceSpec]
                for resource_spec_class in self.resource_spec_classes:
                    client_name = resource_spec_class.get_client_name()
                    resource_class_scan_granularity = resource_spec_class.scan_granularity
                    if resource_class_scan_granularity == ScanGranularity.ACCOUNT:
                        regions_services_resource_spec_classes[scan_regions[
                            0]][client_name].append(resource_spec_class)
                    elif resource_class_scan_granularity == ScanGranularity.REGION:
                        for region in scan_regions:
                            regions_services_resource_spec_classes[region][
                                client_name].append(resource_spec_class)
                    else:
                        raise NotImplementedError(
                            f"ScanGranularity {resource_class_scan_granularity} not implemented"
                        )
                with ThreadPoolExecutor(
                        max_workers=self.max_svc_threads) as executor:
                    futures = []
                    for (
                            region,
                            services_resource_spec_classes,
                    ) in regions_services_resource_spec_classes.items():
                        for (
                                service,
                                resource_spec_classes,
                        ) in services_resource_spec_classes.items():
                            region_session = self.get_session(region=region)
                            region_creds = region_session.get_credentials()
                            scan_future = schedule_scan_services(
                                executor=executor,
                                graph_name=self.graph_name,
                                graph_version=self.graph_version,
                                account_id=self.account_id,
                                region=region,
                                service=service,
                                access_key=region_creds.access_key,
                                secret_key=region_creds.secret_key,
                                token=region_creds.token,
                                resource_spec_classes=tuple(
                                    resource_spec_classes),
                            )
                            futures.append(scan_future)
                    for future in as_completed(futures):
                        graph_set_dict = future.result()
                        graph_set = GraphSet.from_dict(graph_set_dict)
                        errors += graph_set.errors
                        account_graph_set.merge(graph_set)
                    account_graph_set.validate()
            except Exception as ex:
                error_str = str(ex)
                trace_back = traceback.format_exc()
                logger.error(event=AWSLogEvents.ScanAWSAccountError,
                             error=error_str,
                             trace_back=trace_back)
                errors.append(" : ".join((error_str, trace_back)))
                unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                    account_id=self.account_id, errors=errors)
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[unscanned_account_resource],
                    errors=errors,
                    stats=stats,
                )
                account_graph_set.validate()
        output_artifact = self.artifact_writer.write_artifact(
            name=self.account_id, data=account_graph_set.to_dict())
        logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
        api_call_stats = account_graph_set.stats.to_dict()
        return {
            "account_id": self.account_id,
            "output_artifact": output_artifact,
            "errors": errors,
            "api_call_stats": api_call_stats,
        }
예제 #6
0
 def scan(self) -> List[Dict[str, Any]]:
     logger = Logger()
     scan_result_dicts = []
     now = int(time.time())
     prescan_account_ids_errors: DefaultDict[str,
                                             List[str]] = defaultdict(list)
     futures = []
     with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
         shuffled_account_ids = random.sample(
             self.account_scan_plan.account_ids,
             k=len(self.account_scan_plan.account_ids))
         for account_id in shuffled_account_ids:
             with logger.bind(account_id=account_id):
                 logger.info(event=AWSLogEvents.ScanAWSAccountStart)
                 try:
                     session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id)
                     # sanity check
                     sts_client = session.client("sts")
                     sts_account_id = sts_client.get_caller_identity(
                     )["Account"]
                     if sts_account_id != account_id:
                         raise ValueError(
                             f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                         )
                     if self.account_scan_plan.regions:
                         scan_regions = tuple(
                             self.account_scan_plan.regions)
                     else:
                         scan_regions = get_all_enabled_regions(
                             session=session)
                     account_gran_scan_region = random.choice(
                         self.preferred_account_scan_regions)
                     # build a dict of regions -> services -> List[AWSResourceSpec]
                     regions_services_resource_spec_classes: DefaultDict[
                         str, DefaultDict[
                             str,
                             List[Type[AWSResourceSpec]]]] = defaultdict(
                                 lambda: defaultdict(list))
                     resource_spec_class: Type[AWSResourceSpec]
                     for resource_spec_class in self.resource_spec_classes:
                         client_name = resource_spec_class.get_client_name()
                         if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT:
                             if resource_spec_class.region_whitelist:
                                 account_resource_scan_region = resource_spec_class.region_whitelist[
                                     0]
                             else:
                                 account_resource_scan_region = account_gran_scan_region
                             regions_services_resource_spec_classes[
                                 account_resource_scan_region][
                                     client_name].append(
                                         resource_spec_class)
                         elif resource_spec_class.scan_granularity == ScanGranularity.REGION:
                             if resource_spec_class.region_whitelist:
                                 resource_scan_regions = tuple(
                                     region for region in scan_regions
                                     if region in
                                     resource_spec_class.region_whitelist)
                                 if not resource_scan_regions:
                                     resource_scan_regions = resource_spec_class.region_whitelist
                             else:
                                 resource_scan_regions = scan_regions
                             for region in resource_scan_regions:
                                 regions_services_resource_spec_classes[
                                     region][client_name].append(
                                         resource_spec_class)
                         else:
                             raise NotImplementedError(
                                 f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented"
                             )
                     # Build and submit ScanUnits
                     shuffed_regions_services_resource_spec_classes = random.sample(
                         regions_services_resource_spec_classes.items(),
                         len(regions_services_resource_spec_classes),
                     )
                     for (
                             region,
                             services_resource_spec_classes,
                     ) in shuffed_regions_services_resource_spec_classes:
                         region_session = self.account_scan_plan.accessor.get_session(
                             account_id=account_id, region_name=region)
                         region_creds = region_session.get_credentials()
                         shuffled_services_resource_spec_classes = random.sample(
                             services_resource_spec_classes.items(),
                             len(services_resource_spec_classes),
                         )
                         for (
                                 service,
                                 svc_resource_spec_classes,
                         ) in shuffled_services_resource_spec_classes:
                             future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=tuple(
                                     svc_resource_spec_classes),
                             )
                             futures.append(future)
                 except Exception as ex:
                     error_str = str(ex)
                     trace_back = traceback.format_exc()
                     logger.error(
                         event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back,
                     )
                     prescan_account_ids_errors[account_id].append(
                         f"{error_str}\n{trace_back}")
     account_ids_graph_set_dicts: Dict[str,
                                       List[Dict[str,
                                                 Any]]] = defaultdict(list)
     for future in as_completed(futures):
         account_id, graph_set_dict = future.result()
         account_ids_graph_set_dicts[account_id].append(graph_set_dict)
     # first make sure no account id appears both in account_ids_graph_set_dicts
     # and prescan_account_ids_errors - this should never happen
     doubled_accounts = set(
         account_ids_graph_set_dicts.keys()).intersection(
             set(prescan_account_ids_errors.keys()))
     if doubled_accounts:
         raise Exception((
             f"BUG: Account(s) {doubled_accounts} in both "
             "account_ids_graph_set_dicts and prescan_account_ids_errors."))
     # graph prescan error accounts
     for account_id, errors in prescan_account_ids_errors.items():
         with logger.bind(account_id=account_id):
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
                 stats=MultilevelCounter(),
             )
             account_graph_set.validate()
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     # graph rest
     for account_id, graph_set_dicts in account_ids_graph_set_dicts.items():
         with logger.bind(account_id=account_id):
             # if there are any errors whatsoever we generate an empty graph with
             # errors only
             errors = []
             for graph_set_dict in graph_set_dicts:
                 errors += graph_set_dict["errors"]
             if errors:
                 unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                     account_id=account_id, errors=errors)
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[unscanned_account_resource],
                     errors=errors,
                     stats=MultilevelCounter(
                     ),  # ENHANCHMENT: could technically get partial stats.
                 )
                 account_graph_set.validate()
             else:
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[],
                     errors=[],
                     stats=MultilevelCounter(),
                 )
                 for graph_set_dict in graph_set_dicts:
                     graph_set = GraphSet.from_dict(graph_set_dict)
                     account_graph_set.merge(graph_set)
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     return scan_result_dicts
예제 #7
0
 def scan(self) -> AccountScanResult:
     logger = Logger()
     now = int(time.time())
     prescan_errors: List[str] = []
     futures: List[Future] = []
     account_id = self.account_scan_plan.account_id
     with logger.bind(account_id=account_id):
         with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
             logger.info(event=AWSLogEvents.ScanAWSAccountStart)
             try:
                 session = self.account_scan_plan.accessor.get_session(
                     account_id=account_id)
                 # sanity check
                 sts_client = session.client("sts")
                 sts_account_id = sts_client.get_caller_identity(
                 )["Account"]
                 if sts_account_id != account_id:
                     raise ValueError(
                         f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                     )
                 if self.account_scan_plan.regions:
                     account_scan_regions = tuple(
                         self.account_scan_plan.regions)
                 else:
                     account_scan_regions = get_all_enabled_regions(
                         session=session)
                 # build a dict of regions -> services -> List[AWSResourceSpec]
                 regions_services_resource_spec_classes: DefaultDict[
                     str, DefaultDict[
                         str, List[Type[AWSResourceSpec]]]] = defaultdict(
                             lambda: defaultdict(list))
                 for resource_spec_class in self.resource_spec_classes:
                     resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions(
                         resource_spec_class=resource_spec_class,
                         region_whitelist=account_scan_regions,
                     )
                     for region in resource_regions:
                         regions_services_resource_spec_classes[region][
                             resource_spec_class.service_name].append(
                                 resource_spec_class)
                 # Build and submit ScanUnits
                 shuffed_regions_services_resource_spec_classes = random.sample(
                     regions_services_resource_spec_classes.items(),
                     len(regions_services_resource_spec_classes),
                 )
                 for (
                         region,
                         services_resource_spec_classes,
                 ) in shuffed_regions_services_resource_spec_classes:
                     region_session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id, region_name=region)
                     region_creds = region_session.get_credentials()
                     shuffled_services_resource_spec_classes = random.sample(
                         services_resource_spec_classes.items(),
                         len(services_resource_spec_classes),
                     )
                     for (
                             service,
                             svc_resource_spec_classes,
                     ) in shuffled_services_resource_spec_classes:
                         parallel_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if svc_resource_spec_class.parallel_scan
                         ]
                         serial_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if not svc_resource_spec_class.parallel_scan
                         ]
                         for (parallel_svc_resource_spec_class
                              ) in parallel_svc_resource_spec_classes:
                             parallel_future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=(
                                     parallel_svc_resource_spec_class, ),
                             )
                             futures.append(parallel_future)
                         serial_future = schedule_scan(
                             executor=executor,
                             graph_name=self.graph_name,
                             graph_version=self.graph_version,
                             account_id=account_id,
                             region_name=region,
                             service=service,
                             access_key=region_creds.access_key,
                             secret_key=region_creds.secret_key,
                             token=region_creds.token,
                             resource_spec_classes=tuple(
                                 serial_svc_resource_spec_classes),
                         )
                         futures.append(serial_future)
             except Exception as ex:
                 error_str = str(ex)
                 trace_back = traceback.format_exc()
                 logger.error(
                     event=AWSLogEvents.ScanAWSAccountError,
                     error=error_str,
                     trace_back=trace_back,
                 )
                 prescan_errors.append(f"{error_str}\n{trace_back}")
         graph_sets: List[GraphSet] = []
         for future in as_completed(futures):
             graph_set = future.result()
             graph_sets.append(graph_set)
         # if there was a prescan error graph it and return the result
         if prescan_errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=prescan_errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=prescan_errors,
             )
             output_artifact = self.artifact_writer.write_json(
                 name=account_id,
                 data=account_graph_set,
             )
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             return AccountScanResult(
                 account_id=account_id,
                 artifacts=[output_artifact],
                 errors=prescan_errors,
             )
         # if there are any errors whatsoever we generate an empty graph with errors only
         errors = []
         for graph_set in graph_sets:
             errors += graph_set.errors
         if errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
             )
         else:
             account_graph_set = GraphSet.from_graph_sets(graph_sets)
         output_artifact = self.artifact_writer.write_json(
             name=account_id,
             data=account_graph_set,
         )
         logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
         return AccountScanResult(
             account_id=account_id,
             artifacts=[output_artifact],
             errors=errors,
         )
class AltimeterNeptuneClient:
    """Client to run sparql queries against a neptune instance using graph name conventions to
    determine most recent graph.

    Args:
        max_age_min: maximum acceptable age in minutes of graphs.  Only graphs which are found
                     that meet this critera will be queried.
        neptune_endpoint: NeptuneEndpoint object for this client
    """

    def __init__(self, max_age_min: int, neptune_endpoint: NeptuneEndpoint):
        self._neptune_endpoint = neptune_endpoint
        self._max_age_min = max_age_min
        self._auth = None
        # initially set this to a time in the past such that _get_auth's logic is simpler
        # regarding first run.
        self._auth_expiration = datetime.now() - timedelta(hours=24)
        self.logger = Logger()

    def run_query(self, graph_names: Set[str], query: str) -> QueryResult:
        """Runs a SPARQL query against the latest available graphs given a list of
        graph names.

        Args:
            graph_names: list of graph names to query
            query: query string. This query string should not include any 'from' clause;
                   the graph_names param will be used to inject the correct graph uris
                   by locating the latest acceptable (based on `max_age_min`) graph.

        Returns:
            QueryResult object
        """
        graph_uris_load_times: Dict[str, int] = {}
        for graph_name in graph_names:
            graph_metadata = self._get_latest_graph_metadata(name=graph_name)
            graph_uris_load_times[graph_metadata.uri] = graph_metadata.end_time
        finalized_query = finalize_query(query, graph_uris=list(graph_uris_load_times.keys()))
        query_result_set = self.run_raw_query(finalized_query)
        return QueryResult(graph_uris_load_times, query_result_set)

    def load_graph(self, bucket: str, key: str, load_iam_role_arn: str) -> GraphMetadata:
        """Load a graph into Neptune.
        Args:
             bucket: s3 bucket of graph rdf
             key: s3 key of graph rdf
             load_iam_role_arn: arn of iam role used to load the graph

        Returns:
            GraphMetadata object describing loaded graph

        Raises:
            NeptuneLoadGraphException if errors occur during graph load
        """
        session = boto3.Session(region_name=self._neptune_endpoint.region)
        s3_client = session.client("s3")
        rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket, Key=key)
        tag_set = rdf_object_tagging["TagSet"]
        graph_name = get_required_tag_value(tag_set, "name")
        graph_version = get_required_tag_value(tag_set, "version")
        graph_start_time = int(get_required_tag_value(tag_set, "start_time"))
        graph_end_time = int(get_required_tag_value(tag_set, "end_time"))
        graph_metadata = GraphMetadata(
            uri=f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}",
            name=graph_name,
            version=graph_version,
            start_time=graph_start_time,
            end_time=graph_end_time,
        )
        logger = self.logger
        with logger.bind(
            rdf_bucket=bucket,
            rdf_key=key,
            graph_uri=graph_metadata.uri,
            neptune_endpoint=self._neptune_endpoint.get_endpoint_str(),
        ):
            session = boto3.Session(region_name=self._neptune_endpoint.region)
            credentials = session.get_credentials()
            auth = AWSRequestsAuth(
                aws_access_key=credentials.access_key,
                aws_secret_access_key=credentials.secret_key,
                aws_token=credentials.token,
                aws_host=self._neptune_endpoint.get_endpoint_str(),
                aws_region=self._neptune_endpoint.region,
                aws_service="neptune-db",
            )
            post_body = {
                "source": f"s3://{bucket}/{key}",
                "format": "rdfxml",
                "iamRoleArn": load_iam_role_arn,
                "region": self._neptune_endpoint.region,
                "failOnError": "TRUE",
                "parallelism": "MEDIUM",
                "parserConfiguration": {
                    "baseUri": GRAPH_BASE_URI,
                    "namedGraphUri": graph_metadata.uri,
                },
            }
            logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body)
            submit_resp = requests.post(
                self._neptune_endpoint.get_loader_endpoint(), json=post_body, auth=auth
            )
            if submit_resp.status_code != 200:
                raise NeptuneLoadGraphException(
                    f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}"
                )
            submit_resp_json = submit_resp.json()
            load_id = submit_resp_json["payload"]["loadId"]
            with logger.bind(load_id=load_id):
                logger.info(event=LogEvent.NeptuneLoadPolling)
                while True:
                    time.sleep(10)
                    status_resp = requests.get(
                        f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}",
                        params={"details": "true", "errors": "true"},
                        auth=auth,
                    )
                    if status_resp.status_code != 200:
                        raise NeptuneLoadGraphException(
                            f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}"
                        )
                    status_resp_json = status_resp.json()
                    status = status_resp_json["payload"]["overallStatus"]["status"]
                    logger.info(event=LogEvent.NeptuneLoadPolling, status=status)
                    if status == "LOAD_COMPLETED":
                        break
                    if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"):
                        logger.error(event=LogEvent.NeptuneLoadError, status=status)
                        raise NeptuneLoadGraphException(f"Error loading graph: {status_resp_json}")
                logger.info(event=LogEvent.NeptuneLoadEnd)

                logger.info(event=LogEvent.MetadataGraphUpdateStart)
                self._register_graph(graph_metadata=graph_metadata)
                logger.info(event=LogEvent.MetadataGraphUpdateEnd)

                return graph_metadata

    def _get_auth(self) -> AWSRequestsAuth:
        """Generate an AWSRequestsAuth object using a boto session for the current/local account.

        Returns:
             AWSRequestsAuth object
        """
        if datetime.now() >= self._auth_expiration:
            session = boto3.Session()
            credentials = session.get_credentials()
            region = (
                session.region_name
                if self._neptune_endpoint.region is None
                else self._neptune_endpoint.region
            )
            auth = AWSRequestsAuth(
                aws_access_key=credentials.access_key,
                aws_secret_access_key=credentials.secret_key,
                aws_token=credentials.token,
                aws_host=f"{self._neptune_endpoint.host}:{self._neptune_endpoint.port}",
                aws_region=region,
                aws_service="neptune-db",
            )
            self._auth = auth
            self._auth_expiration = datetime.now() + timedelta(minutes=SESSION_LIFETIME_MINUTES)
        return self._auth

    def _register_graph(self, graph_metadata: GraphMetadata) -> None:
        """Registers a GraphMetadata object into the metadata graph
        The meta graph keeps track of graph uris and metadata.
        Run this after a graph is completely loaded and then use
        _get_latest_graph_metadata to query this graph to find the latest graph.

        Args:
            graph_metadata: GraphMetadata to load into the metadata graph.

        Raises:
            NeptuneUpdateGraphException if an error occurred during metadata graph update
        """

        auth = self._get_auth()
        neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint()
        update_stmt = (
            f"INSERT DATA {{\n"
            f"    GRAPH <{META_GRAPH_NAME}>\n"
            f"        {{ <alti:graph:{graph_metadata.uri}> "
            f'            <alti:uri>         "{graph_metadata.uri}" ;\n'
            f'            <alti:name>        "{graph_metadata.name}" ;\n'
            f'            <alti:version>     "{graph_metadata.version}" ;\n'
            f"            <alti:start_time>  {graph_metadata.start_time} ;\n"
            f"            <alti:end_time>    {graph_metadata.end_time} ;\n"
            f"}}\n"
            "}\n"
        )
        resp = requests.post(neptune_sparql_url, data={"update": update_stmt}, auth=auth)
        if resp.status_code != 200:
            raise NeptuneUpdateGraphException(
                (f"Error updating graph {META_GRAPH_NAME} " f"with {update_stmt} : {resp.text}")
            )

    def get_graph_uris(self, name: str) -> List[str]:
        """Return all graph uris regardless of whether they have corresponding metadata entries

        Args:
            name: graph name

        Returns:
            list of graph uris
        """
        query = "SELECT ?graph_uri WHERE { GRAPH ?graph_uri { } }"
        results = self.run_raw_query(query=query)
        results_list = results.to_list()
        all_graph_uris = [result["graph_uri"] for result in results_list]
        graph_prefix = f"{GRAPH_BASE_URI}/{name}/"
        graph_uris = [uri for uri in all_graph_uris if uri.startswith(graph_prefix)]
        return graph_uris

    def get_graph_metadatas(self, name: str, version: Optional[str] = None) -> List[GraphMetadata]:
        """Return all graph metadatas for a given name/version. These represent fully loaded
        graphs in the Neptune database.

        Args:
            name: graph name
            version: graph version

        Returns:
            list of GraphMetadata objects for the given graph name/version
        """
        if version is None:
            get_graph_metadatas_query = (
                "SELECT ?uri ?name ?version ?start_time ?end_time\n"
                f"FROM <{META_GRAPH_NAME}>\n"
                f"WHERE {{ ?graph_metadata <alti:uri> ?uri ;\n"
                f'                         <alti:name> "{name}" ;\n'
                f"                         <alti:name> ?name ;\n"
                f"                         <alti:version> ?version ;\n"
                f"                         <alti:start_time> ?start_time ;\n"
                f"                         <alti:end_time> ?end_time }}\n"
                f"ORDER BY DESC(?end_time)\n"
            )
        else:
            get_graph_metadatas_query = (
                "SELECT ?uri ?name ?version ?start_time ?end_time\n"
                f"FROM <{META_GRAPH_NAME}>\n"
                f"WHERE {{ ?graph_metadata <alti:uri> ?uri ;\n"
                f'                         <alti:name> "{name}" ;\n'
                f"                         <alti:name> ?name ;\n"
                f'                         <alti:version> "{version}" ;\n'
                f"                         <alti:version> ?version ;\n"
                f"                         <alti:start_time> ?start_time ;\n"
                f"                         <alti:end_time> ?end_time }}\n"
                f"ORDER BY DESC(?end_time)\n"
            )
        results = self.run_raw_query(query=get_graph_metadatas_query)
        results_list = results.to_list()
        graph_metadatas: List[GraphMetadata] = []
        for result in results_list:
            graph_metadata = GraphMetadata(
                uri=result["uri"],
                name=result["name"],
                version=result["version"],
                start_time=int(result["start_time"]),
                end_time=int(result["end_time"]),
            )
            graph_metadatas.append(graph_metadata)
        return graph_metadatas

    def _get_latest_graph_metadata(self, name: str, version: Optional[str] = None) -> GraphMetadata:
        """Return a GraphMetadata object representing the most recently successfully loaded graph
        for a given name / version.

        Args:
            name: graph name
            version: graph version

        Returns:
            GraphMetadata of the latest graph for the given name/version

        Raises:
            NeptuneNoGraphsFoundException if no matching graphs were found
            NeptuneNoFreshGraphFoundException if no graphs could be found within ax_age_min
        """
        if version is None:
            get_graph_metadatas_query = (
                f"SELECT ?uri ?version ?start_time ?end_time\n"
                f"FROM <{META_GRAPH_NAME}>\n"
                f"WHERE {{\n"
                f"    ?graph_metadata <alti:uri> ?uri ;\n"
                f'                    <alti:name> "{name}" ;\n'
                f"                    <alti:version> ?version ;\n"
                f"                    <alti:start_time> ?start_time ;\n"
                f"                    <alti:end_time> ?end_time }}\n"
                f"ORDER BY DESC(?version) DESC(?end_time)\n"
                f"LIMIT 1"
            )
        else:
            get_graph_metadatas_query = (
                f"SELECT ?uri ?version ?start_time ?end_time\n"
                f"FROM <{META_GRAPH_NAME}>\n"
                f"WHERE {{\n"
                f"    ?graph_metadata <alti:uri> ?uri ;\n"
                f'                    <alti:name> "{name}" ;\n'
                f"                    <alti:version> {version} ;\n"
                f"                    <alti:version> ?version ;\n"
                f"                    <alti:start_time> ?start_time ;\n"
                f"                    <alti:end_time> ?end_time }}\n"
                f"ORDER BY DESC(?end_time)\n"
                f"LIMIT 1"
            )
        results = self.run_raw_query(query=get_graph_metadatas_query)
        results_list = results.to_list()
        if not results_list:
            raise NeptuneNoGraphsFoundException(f"No graphs found for graph name '{name}'")
        if len(results_list) != 1:
            raise RuntimeError("Logic error - more than one graph returned.")
        result = results_list[0]
        latest_uri = result["uri"]
        latest_version = result["version"]
        latest_start_time = int(result["start_time"])
        latest_end_time = int(result["end_time"])
        now = int(datetime.now().timestamp())
        oldest_acceptable_graph_end_time = now - self._max_age_min * 60
        if latest_end_time < oldest_acceptable_graph_end_time:
            raise NeptuneNoFreshGraphFoundException(
                (
                    f"Could not find a graph named '{name}' younger "
                    f"than {self._max_age_min} "
                    f"minutes old.  Found: {results_list}"
                )
            )
        return GraphMetadata(
            uri=latest_uri,
            name=name,
            version=latest_version,
            start_time=latest_start_time,
            end_time=latest_end_time,
        )

    def clear_registered_graph(self, name: str, uri: str) -> None:
        """Remove data and metadata for a graph by uri

        Args:
            name: graph name
            uri: graph uri

        Raises:
            NeptuneUpdateGraphException if an error occurred during clearing
        """
        # clear metadata first such that clients will not use this graph if
        # data clear fails
        self.clear_graph_metadata(name=name, uri=uri)
        # then clear data
        self.clear_graph_data(uri=uri)

    def clear_graph_metadata(self, name: str, uri: str) -> None:
        """Clear a graph metadata entry"""
        auth = self._get_auth()
        neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint()
        delete_stmt = (
            f"WITH <{META_GRAPH_NAME}>\n"
            f"DELETE\n"
            f'  {{ ?graph <alti:uri>         "{uri}" ;\n'
            f'            <alti:name>        "{name}" ;\n'
            f"            <alti:version>     ?version ;\n"
            f"            <alti:start_time>  ?start_time ;\n"
            f"            <alti:end_time>    ?end_time }}\n"
            f"WHERE\n"
            f'  {{ ?graph <alti:uri>         "{uri}" ;\n'
            f'            <alti:name>        "{name}" ;\n'
            f"            <alti:version>     ?version ;\n"
            f"            <alti:start_time>  ?start_time ;\n"
            f"            <alti:end_time>    ?end_time }}\n"
        )
        resp = requests.post(neptune_sparql_url, data={"update": delete_stmt}, auth=auth)
        if resp.status_code != 200:
            raise NeptuneUpdateGraphException(
                (f"Error updating graph {META_GRAPH_NAME} " f"with {delete_stmt} : {resp.text}")
            )

    def clear_graph_data(self, uri: str) -> None:
        """Clear a graph in Neptune"""
        auth = self._get_auth()
        neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint()
        update_stmt = f"clear graph <{uri}>"
        resp = requests.post(neptune_sparql_url, data={"update": update_stmt}, auth=auth)
        if resp.status_code != 200:
            raise NeptuneClearGraphException(
                (f"Error clearing graph {uri} " f"with {update_stmt} : {resp.text}")
            )

    def run_raw_query(self, query: str) -> QueryResultSet:
        """Run a query against a neptune instance, return a dict of results. Generally this
        should be called from `run_query`

        Args:
            query: complete query to run

        Returns:
            QueryResultSet object

        Raises:
            NeptuneQueryException if an error occurred running the query
        """
        neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint()
        auth = self._get_auth()
        resp = requests.post(
            neptune_sparql_url, headers={"te": "trailers"}, data={"query": query}, auth=auth
        )
        if resp.status_code != 200:
            raise NeptuneQueryException(f"Error running query {query}: {resp.text}")
        try:
            results_json = resp.json()
        except json.decoder.JSONDecodeError as jde:
            neptune_status = resp.headers.get("X-Neptune-Status", "unknown")
            neptune_detail = resp.headers.get("X-Neptune-Detail", "unknown")
            raise NeptuneQueryException(
                f"Error running query {query}: {neptune_status}: {neptune_detail}"
            ) from jde
        return QueryResultSet.from_sparql_endpoint_json(results_json)

    @staticmethod
    def __normalize_query_string(query: str) -> str:
        """Normalize the query string"""
        kv = (list(map(str.strip, s.split("="))) for s in query.split("&") if len(s) > 0)

        normalized = "&".join("%s=%s" % (p[0], p[1] if len(p) > 1 else "") for p in sorted(kv))
        return normalized

    def __get_signature_key(
        self, key: str, datestamp: str, regionname: str, servicename: str
    ) -> bytes:
        """Get the signed signature key
        :return: The signed key
        """
        key_date = self.__sign(("AWS4" + key).encode("utf-8"), datestamp)
        key_region = self.__sign(key_date, regionname)
        key_service = self.__sign(key_region, servicename)
        key_signing = self.__sign(key_service, "aws4_request")
        return key_signing

    @staticmethod
    def __sign(key: bytes, msg: str) -> bytes:
        """ Sign the msg with the key """
        return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()

    def prepare_request(
        self, method: str = "GET", payload: str = "", querystring: Dict = {}
    ) -> RequestParameters:
        """
        This prepares the request for sigv4signing.  This is heavily influenced by the code here:
        https://github.com/awslabs/amazon-neptune-tools/tree/master/neptune-python-utils
        :param method: The method name
        :param payload: The request payload
        :param querystring: The request querystring
        :return: The request parameters
        """
        session = boto3.Session()
        credentials = session.get_credentials()
        access_key = credentials.access_key
        secret_key = credentials.secret_key
        session_token = credentials.token

        service = "neptune-db"
        algorithm = "AWS4-HMAC-SHA256"

        request_parameters = parse.urlencode(querystring).replace("%27", "%22")
        canonical_querystring = self.__normalize_query_string(request_parameters)

        t = datetime.utcnow()
        amz_date = t.strftime("%Y%m%dT%H%M%SZ")
        datestamp = t.strftime("%Y%m%d")
        canonical_headers = "host:{}:{}\nx-amz-date:{}\n".format(
            self._neptune_endpoint.host, self._neptune_endpoint.port, amz_date
        )
        signed_headers = "host;x-amz-date"
        payload_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest()
        canonical_request = "{}\n/{}\n{}\n{}\n{}\n{}".format(
            method,
            "gremlin",
            canonical_querystring,
            canonical_headers,
            signed_headers,
            payload_hash,
        )
        credential_scope = "{}/{}/{}/aws4_request".format(
            datestamp, self._neptune_endpoint.region, service
        )
        string_to_sign = "{}\n{}\n{}\n{}".format(
            algorithm,
            amz_date,
            credential_scope,
            hashlib.sha256(canonical_request.encode("utf-8")).hexdigest(),
        )
        signing_key = self.__get_signature_key(
            secret_key, datestamp, self._neptune_endpoint.region, service
        )
        signature = hmac.new(
            signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
        ).hexdigest()
        authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format(
            algorithm, access_key, credential_scope, signed_headers, signature
        )
        headers = {"x-amz-date": amz_date, "Authorization": authorization_header}
        if session_token:
            headers["x-amz-security-token"] = session_token
        return RequestParameters(
            "{}?{}".format(
                self._neptune_endpoint.get_gremlin_endpoint(self._neptune_endpoint.ssl),
                canonical_querystring,
            )
            if canonical_querystring
            else self._neptune_endpoint.get_gremlin_endpoint(),
            canonical_querystring,
            headers,
        )

    def connect_to_gremlin(self) -> Tuple[traversal, DriverRemoteConnection]:
        """
        Get the Gremlin traversal and connection for the Neptune endpoint
        :return: The Traversal object
        """
        if self._neptune_endpoint.auth_mode.lower() == "default":
            gremlin_connection = DriverRemoteConnection(
                self._neptune_endpoint.get_gremlin_endpoint(self._neptune_endpoint.ssl), "g"
            )
        else:
            request_parameters = self.prepare_request()
            signed_ws_request = httpclient.HTTPRequest(
                request_parameters.uri, headers=request_parameters.headers
            )
            gremlin_connection = DriverRemoteConnection(signed_ws_request, "g")
        graph_traversal_source = traversal().withRemote(gremlin_connection)

        return graph_traversal_source, gremlin_connection

    def __write_vertices(self, g: traversal, vertices: List[Dict], scan_id: str) -> None:
        """
        Writes the vertices to the labeled property graph
        :param g: The graph traversal source
        :param vertices: A list of dictionaries for each vertex
        :return: None
        """
        cnt = 0
        t = g
        for r in vertices:
            vertex_id = f'{r["~id"]}_{scan_id}'
            t = (
                t.V(vertex_id)
                .fold()
                .coalesce(
                    __.unfold(),
                    __.addV(self.parse_arn(r["~label"])["resource"]).property(T.id, vertex_id),
                )
            )
            for k in r.keys():
                # Need to handle numbers that are bigger than a Long in Java, for now we stringify it
                if isinstance(r[k], int) and (
                    r[k] > 9223372036854775807 or r[k] < -9223372036854775807
                ):
                    r[k] = str(r[k])
                if k not in ["~id", "~label"]:
                    t = t.property(k, r[k])
            cnt += 1
            if cnt % 100 == 0 or cnt == len(vertices):
                try:
                    self.logger.info(
                        event=LogEvent.NeptunePeriodicWrite,
                        msg=f"Writing vertices {cnt} of {len(vertices)}",
                    )
                    t.next()
                    t = g
                except Exception as err:
                    print(str(err))
                    raise NeptuneLoadGraphException(
                        f"Error loading vertex {r} " f"with {str(t.bytecode)}"
                    ) from err

    def __write_edges(self, g: traversal, edges: List[Dict], scan_id: str) -> None:
        """
        Writes the edges to the labeled property graph
        :param g: The graph traversal source
        :param edges: A list of dictionaries for each edge
        :return: None
        """
        cnt = 0
        t = g
        for r in edges:
            to_id = f'{r["~to"]}_{scan_id}'
            from_id = f'{r["~from"]}_{scan_id}'
            t = (
                t.addE(r["~label"])
                .property(T.id, str(r["~id"]))
                .from_(
                    __.V(from_id)
                    .fold()
                    .coalesce(
                        __.unfold(),
                        __.addV(self.parse_arn(r["~from"])["resource"])
                        .property(T.id, from_id)
                        .property("scan_id", scan_id)
                        .property("arn", r["~from"]),
                    )
                )
                .to(
                    __.V(to_id)
                    .fold()
                    .coalesce(
                        __.unfold(),
                        __.addV(self.parse_arn(r["~to"])["resource"])
                        .property(T.id, to_id)
                        .property("scan_id", scan_id)
                        .property("arn", r["~to"]),
                    )
                )
            )
            cnt += 1
            if cnt % 100 == 0 or cnt == len(edges):
                try:
                    self.logger.info(
                        event=LogEvent.NeptunePeriodicWrite,
                        msg=f"Writing edges {cnt} of {len(edges)}",
                    )
                    t.next()
                    t = g
                except Exception as err:
                    self.logger.error(event=LogEvent.NeptuneLoadError, msg=str(err))
                    raise NeptuneLoadGraphException(
                        f"Error loading edge {r} " f"with {str(t.bytecode)}"
                    ) from err

    @staticmethod
    def parse_arn(arn: str) -> Dict:
        """
        Parses an ARN into the component pieces
        :param arn: The arn to parse
        :return: A dictionary of the arn pieces
        """
        # http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
        elements = str(arn).split(":", 5)
        result = {}
        if len(elements) == 6:
            result = {
                "arn": elements[0],
                "partition": elements[1],
                "service": elements[2],
                "region": elements[3],
                "account": elements[4],
                "resource": elements[5],
                "resource_type": None,
            }
        else:
            result["resource"] = str(arn)
        if "/" in str(result["resource"]):
            result["resource_type"], result["resource"] = str(result["resource"]).split("/", 1)
        elif ":" in str(result["resource"]):
            result["resource_type"], result["resource"] = str(result["resource"]).split(":", 1)

        if str(result["resource"]).startswith("ami-"):
            result["resource"] = result["resource_type"]

        return result

    def write_to_neptune_lpg(self, graph: Dict, scan_id: str) -> None:
        """
        Writes the graph to a labeled property graph
        :param scan_id: The unique string representing the scan
        :param graph: The graph to write
        :return: None
        """
        if "vertices" in graph and "edges" in graph and len(graph["vertices"]) > 0:
            g, conn = self.connect_to_gremlin()
            self.__write_vertices(g, graph["vertices"], scan_id)
            self.__write_edges(g, graph["edges"], scan_id)
            conn.close()
        else:
            raise NeptuneNoGraphsFoundException

    def write_to_neptune_rdf(self, graph: Dict) -> None:
        """
        Writes the graph to an RDF graph
        :param graph: The graph to write
        :return: None
        """

        auth = self._get_auth()
        neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint()
        triples = ""
        for subject, predicate, obj in graph:
            triples = triples + subject.n3() + " " + predicate.n3() + " " + obj.n3() + " . \n"

        insert_stmt = (
            "INSERT DATA {\n" + f"    GRAPH <{META_GRAPH_NAME}>\n" "{\n" f"{triples}" "}\n" "}\n"
        )

        resp = requests.post(neptune_sparql_url, data={"update": insert_stmt}, auth=auth)
        if resp.status_code != 200:
            raise NeptuneUpdateGraphException(
                f"Error updating graph {META_GRAPH_NAME} " f"with {insert_stmt} : {resp.text}"
            )
예제 #9
0
    def load_graph(self, bucket: str, key: str,
                   load_iam_role_arn: str) -> GraphMetadata:
        """Load a graph into Neptune.
        Args:
             bucket: s3 bucket of graph rdf
             key: s3 key of graph rdf
             load_iam_role_arn: arn of iam role used to load the graph

        Returns:
            GraphMetadata object describing loaded graph

        Raises:
            NeptuneLoadGraphException if errors occur during graph load
        """
        session = boto3.Session(region_name=self._neptune_endpoint.region)
        s3_client = session.client("s3")
        rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket,
                                                          Key=key)
        tag_set = rdf_object_tagging["TagSet"]
        graph_name = get_required_tag_value(tag_set, "name")
        graph_version = get_required_tag_value(tag_set, "version")
        graph_start_time = int(get_required_tag_value(tag_set, "start_time"))
        graph_end_time = int(get_required_tag_value(tag_set, "end_time"))
        graph_metadata = GraphMetadata(
            uri=
            f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}",
            name=graph_name,
            version=graph_version,
            start_time=graph_start_time,
            end_time=graph_end_time,
        )
        logger = Logger()
        with logger.bind(
                rdf_bucket=bucket,
                rdf_key=key,
                graph_uri=graph_metadata.uri,
                neptune_endpoint=self._neptune_endpoint.get_endpoint_str(),
        ):
            session = boto3.Session(region_name=self._neptune_endpoint.region)
            credentials = session.get_credentials()
            auth = AWSRequestsAuth(
                aws_access_key=credentials.access_key,
                aws_secret_access_key=credentials.secret_key,
                aws_token=credentials.token,
                aws_host=self._neptune_endpoint.get_endpoint_str(),
                aws_region=self._neptune_endpoint.region,
                aws_service="neptune-db",
            )
            post_body = {
                "source": f"s3://{bucket}/{key}",
                "format": "rdfxml",
                "iamRoleArn": load_iam_role_arn,
                "region": self._neptune_endpoint.region,
                "failOnError": "TRUE",
                "parallelism": "MEDIUM",
                "parserConfiguration": {
                    "baseUri": GRAPH_BASE_URI,
                    "namedGraphUri": graph_metadata.uri,
                },
            }
            logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body)
            submit_resp = requests.post(
                self._neptune_endpoint.get_loader_endpoint(),
                json=post_body,
                auth=auth)
            if submit_resp.status_code != 200:
                raise NeptuneLoadGraphException(
                    f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}"
                )
            submit_resp_json = submit_resp.json()
            load_id = submit_resp_json["payload"]["loadId"]
            with logger.bind(load_id=load_id):
                logger.info(event=LogEvent.NeptuneLoadPolling)
                while True:
                    time.sleep(10)
                    status_resp = requests.get(
                        f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}",
                        params={
                            "details": "true",
                            "errors": "true"
                        },
                        auth=auth,
                    )
                    if status_resp.status_code != 200:
                        raise NeptuneLoadGraphException(
                            f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}"
                        )
                    status_resp_json = status_resp.json()
                    status = status_resp_json["payload"]["overallStatus"][
                        "status"]
                    logger.info(event=LogEvent.NeptuneLoadPolling,
                                status=status)
                    if status == "LOAD_COMPLETED":
                        break
                    if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"):
                        logger.error(event=LogEvent.NeptuneLoadError,
                                     status=status)
                        raise NeptuneLoadGraphException(
                            f"Error loading graph: {status_resp_json}")
                logger.info(event=LogEvent.NeptuneLoadEnd)

                logger.info(event=LogEvent.MetadataGraphUpdateStart)
                self._register_graph(graph_metadata=graph_metadata)
                logger.info(event=LogEvent.MetadataGraphUpdateEnd)

                return graph_metadata
예제 #10
0
def prune_graph(graph_pruner_config: GraphPrunerConfig) -> GraphPrunerResults:
    config = Config.from_path(path=graph_pruner_config.config_path)
    if config.neptune is None:
        raise InvalidConfigException("Configuration missing neptune section.")
    now = int(datetime.now().timestamp())
    oldest_acceptable_graph_epoch = now - config.pruner_max_age_min * 60
    endpoint = NeptuneEndpoint(
        host=config.neptune.host, port=config.neptune.port, region=config.neptune.region
    )
    client = AltimeterNeptuneClient(
        max_age_min=config.pruner_max_age_min, neptune_endpoint=endpoint
    )
    logger = Logger()
    uncleared = []
    pruned_graph_uris = []
    skipped_graph_uris = []
    logger.info(event=LogEvent.PruneNeptuneGraphsStart)
    all_graph_metadatas = client.get_graph_metadatas(name=config.graph_name)
    with logger.bind(neptune_endpoint=str(endpoint)):
        for graph_metadata in all_graph_metadatas:
            assert graph_metadata.name == config.graph_name
            graph_epoch = graph_metadata.end_time
            with logger.bind(graph_uri=graph_metadata.uri, graph_epoch=graph_epoch):
                if graph_epoch < oldest_acceptable_graph_epoch:
                    logger.info(event=LogEvent.PruneNeptuneGraphStart)
                    try:
                        client.clear_registered_graph(
                            name=config.graph_name, uri=graph_metadata.uri
                        )
                        logger.info(event=LogEvent.PruneNeptuneGraphEnd)
                        pruned_graph_uris.append(graph_metadata.uri)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=f"Error pruning graph {graph_metadata.uri}: {ex}",
                        )
                        uncleared.append(graph_metadata.uri)
                        continue
                else:
                    logger.info(event=LogEvent.PruneNeptuneGraphSkip)
                    skipped_graph_uris.append(graph_metadata.uri)
        # now find orphaned graphs - these are in neptune but have no metadata
        registered_graph_uris = [g_m.uri for g_m in all_graph_metadatas]
        all_graph_uris = client.get_graph_uris(name=config.graph_name)
        orphaned_graphs = set(all_graph_uris) - set(registered_graph_uris)
        if orphaned_graphs:
            for orphaned_graph_uri in orphaned_graphs:
                with logger.bind(graph_uri=orphaned_graph_uri):
                    logger.info(event=LogEvent.PruneOrphanedNeptuneGraphStart)
                    try:
                        client.clear_graph_data(uri=orphaned_graph_uri)
                        logger.info(event=LogEvent.PruneOrphanedNeptuneGraphEnd)
                        pruned_graph_uris.append(orphaned_graph_uri)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=f"Error pruning graph {orphaned_graph_uri}: {ex}",
                        )
                        uncleared.append(orphaned_graph_uri)
                        continue
        logger.info(event=LogEvent.PruneNeptuneGraphsEnd)
        if uncleared:
            msg = f"Errors were found pruning {uncleared}."
            logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg)
            raise Exception(msg)
    return GraphPrunerResults(
        pruned_graph_uris=pruned_graph_uris, skipped_graph_uris=skipped_graph_uris,
    )
예제 #11
0
def lambda_handler(event: Dict[str, Any], context: Any) -> None:
    """Entrypoint"""
    root = logging.getLogger()
    if root.handlers:
        for handler in root.handlers:
            root.removeHandler(handler)

    config_path = get_required_str_env_var("CONFIG_PATH")
    config = Config.from_path(path=config_path)

    if config.neptune is None:
        raise InvalidConfigException("Configuration missing neptune section.")

    now = int(datetime.now().timestamp())
    oldest_acceptable_graph_epoch = now - config.pruner_max_age_min * 60

    endpoint = NeptuneEndpoint(host=config.neptune.host,
                               port=config.neptune.port,
                               region=config.neptune.region)
    client = AltimeterNeptuneClient(max_age_min=config.pruner_max_age_min,
                                    neptune_endpoint=endpoint)

    logger = Logger()

    uncleared = []

    logger.info(event=LogEvent.PruneNeptuneGraphsStart)
    all_graph_metadatas = client.get_graph_metadatas(name=config.graph_name)
    with logger.bind(neptune_endpoint=str(endpoint)):
        for graph_metadata in all_graph_metadatas:
            assert graph_metadata.name == config.graph_name
            graph_epoch = graph_metadata.end_time
            with logger.bind(graph_uri=graph_metadata.uri,
                             graph_epoch=graph_epoch):
                if graph_epoch < oldest_acceptable_graph_epoch:
                    logger.info(event=LogEvent.PruneNeptuneGraphStart)
                    try:
                        client.clear_registered_graph(name=config.graph_name,
                                                      uri=graph_metadata.uri)
                        logger.info(event=LogEvent.PruneNeptuneGraphEnd)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=
                            f"Error pruning graph {graph_metadata.uri}: {ex}",
                        )
                        uncleared.append(graph_metadata.uri)
                        continue
                else:
                    logger.info(event=LogEvent.PruneNeptuneGraphSkip)
        # now find orphaned graphs - these are in neptune but have no metadata
        registered_graph_uris = [g_m.uri for g_m in all_graph_metadatas]
        all_graph_uris = client.get_graph_uris(name=config.graph_name)
        orphaned_graphs = set(all_graph_uris) - set(registered_graph_uris)
        if orphaned_graphs:
            for orphaned_graph_uri in orphaned_graphs:
                with logger.bind(graph_uri=orphaned_graph_uri):
                    logger.info(event=LogEvent.PruneOrphanedNeptuneGraphStart)
                    try:
                        client.clear_graph_data(uri=orphaned_graph_uri)
                        logger.info(
                            event=LogEvent.PruneOrphanedNeptuneGraphEnd)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=
                            f"Error pruning graph {orphaned_graph_uri}: {ex}",
                        )
                        uncleared.append(orphaned_graph_uri)
                        continue
        logger.info(event=LogEvent.PruneNeptuneGraphsEnd)
        if uncleared:
            msg = f"Errors were found pruning {uncleared}."
            logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg)
            raise Exception(msg)