def lambda_handler(cls, event: Dict[str, Any], _: Any) -> None: """lambda entrypoint""" config = Config() result = Result(**event) logger = Logger() errors: List[str] = [] with logger.bind(result=result): logger.info(event=QJLogEvents.ResultRemediationStart) try: session = get_assumed_session( account_id=result.account_id, role_name=config.remediator_target_role_name, external_id=config.remediator_target_role_external_id, ) cls.remediate(session=session, result=result.result, dry_run=config.dry_run) logger.info(event=QJLogEvents.ResultRemediationSuccessful) except Exception as ex: logger.error(event=QJLogEvents.ResultRemediationFailed, error=str(ex)) errors.append(str(ex)) if errors: raise RemediationError( f"Errors found during remediation: {errors}")
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]: logger = Logger() with logger.bind( account_id=scan_unit.account_id, region=scan_unit.region_name, service=scan_unit.service, resource_classes=sorted([ resource_spec_class.__name__ for resource_spec_class in scan_unit.resource_spec_classes ]), ): start_t = time.time() logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart) session = boto3.Session( aws_access_key_id=scan_unit.access_key, aws_secret_access_key=scan_unit.secret_key, aws_session_token=scan_unit.token, region_name=scan_unit.region_name, ) scan_accessor = AWSAccessor(session=session, account_id=scan_unit.account_id, region_name=scan_unit.region_name) graph_spec = GraphSpec( name=scan_unit.graph_name, version=scan_unit.graph_version, resource_spec_classes=scan_unit.resource_spec_classes, scan_accessor=scan_accessor, ) start_time = int(time.time()) resources: List[Resource] = [] errors = [] try: resources = graph_spec.scan() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) error = f"{str(ex)}\n{trace_back}" errors.append(error) end_time = int(time.time()) graph_set = GraphSet( name=scan_unit.graph_name, version=scan_unit.graph_version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=scan_accessor.api_call_stats, ) end_t = time.time() elapsed_sec = end_t - start_t logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd, elapsed_sec=elapsed_sec) return (scan_unit.account_id, graph_set.to_dict())
def lambda_handler(event, context): host = get_required_lambda_env_var("NEPTUNE_HOST") port = get_required_lambda_env_var("NEPTUNE_PORT") region = get_required_lambda_env_var("NEPTUNE_REGION") max_age_min = get_required_lambda_env_var("MAX_AGE_MIN") graph_name = get_required_lambda_env_var("GRAPH_NAME") try: max_age_min = int(max_age_min) except ValueError as ve: raise Exception(f"env var MAX_AGE_MIN must be an int: {ve}") now = int(datetime.now().timestamp()) oldest_acceptable_graph_epoch = now - max_age_min * 60 endpoint = NeptuneEndpoint(host=host, port=port, region=region) client = AltimeterNeptuneClient(max_age_min=max_age_min, neptune_endpoint=endpoint) logger = Logger() uncleared = [] # first prune metadata - if clears below are partial we want to make sure no clients # consider this a valid graph still. logger.info(event=LogEvent.PruneNeptuneMetadataGraphStart) client.clear_old_graph_metadata(name=graph_name, max_age_min=max_age_min) logger.info(event=LogEvent.PruneNeptuneMetadataGraphEnd) # now clear actual graphs with logger.bind(neptune_endpoint=endpoint): logger.info(event=LogEvent.PruneNeptuneGraphsStart) for graph_metadata in client.get_graph_metadatas(name=graph_name): assert graph_metadata.name == graph_name graph_epoch = graph_metadata.end_time with logger.bind(graph_uri=graph_metadata.uri, graph_epoch=graph_epoch): if graph_epoch < oldest_acceptable_graph_epoch: logger.info(event=LogEvent.PruneNeptuneGraphStart) try: client.clear_graph(graph_uri=graph_metadata.uri) logger.info(event=LogEvent.PruneNeptuneGraphEnd) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg= f"Error pruning graph {graph_metadata.uri}: {ex}", ) uncleared.append(graph_metadata.uri) continue else: logger.info(event=LogEvent.PruneNeptuneGraphSkip) logger.info(event=LogEvent.PruneNeptuneGraphsEnd) if uncleared: msg = f"Errors were found pruning {uncleared}." logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg) raise Exception(msg)
def remediator(event: Dict[str, Any]) -> None: """Run the remediation lambda for a QJ result set""" config = RemediatorConfig() logger = Logger() remediation = Remediation(**event) with logger.bind(remediation=remediation): logger.info(event=QJLogEvents.RemediationInit) qj_api_client = QJAPIClient(host=config.qj_api_host) latest_result_set = qj_api_client.get_job_latest_result_set( job_name=remediation.job_name) if not latest_result_set: msg = f"No latest_result_set present for {remediation.job_name}" logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if latest_result_set.result_set_id != remediation.result_set_id: msg = ( f"Remediation result_set_id {remediation.result_set_id} does not match the " f"latest result_set_id {latest_result_set.result_set_id}") logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if not latest_result_set.job.remediate_sqs_queue: msg = f"Job {latest_result_set.job.name} has no remediator" logger.error(QJLogEvents.JobHasNoRemediator, detail=msg) raise RemediationError(msg) num_threads = 10 # TODO env var errors = [] with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] for result in latest_result_set.results: logger.info(event=QJLogEvents.ProcessResult, result=result) future = _schedule_result_remediation( executor=executor, lambda_name=latest_result_set.job.remediate_sqs_queue, lambda_timeout=300, # TODO env var? result=result, ) futures.append(future) for future in as_completed(futures): try: lambda_result = future.result() logger.info(QJLogEvents.ResultRemediationSuccessful, lambda_result=lambda_result) except Exception as ex: logger.info( event=QJLogEvents.ResultSetRemediationFailed, error=str(ex), ) errors.append(str(ex)) if errors: logger.error(event=QJLogEvents.ResultSetRemediationFailed, errors=errors) raise RemediationError( f"Errors encountered during remediation of {latest_result_set.job.name} " f"/ {latest_result_set.result_set_id}: {errors}")
def scan(self) -> Dict[str, Any]: """Scan an account and return a dict containing keys: * account_id: str * output_artifact: str * api_call_stats: Dict[str, Any] * errors: List[str] If errors is non-empty the results are incomplete for this account. output_artifact is a pointer to the actual scan data - either on the local fs or in s3. To scan an account we create a set of GraphSpecs, one for each region. Any ACCOUNT level granularity resources are only scanned in a single region (e.g. IAM Users) Returns: Dict of scan result, see above for details. """ logger = Logger() with logger.bind(account_id=self.account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) output_artifact = None stats = MultilevelCounter() errors: List[str] = [] now = int(time.time()) try: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=stats, ) # sanity check session = self.get_session() sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity()["Account"] if sts_account_id != self.account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {self.account_id}" ) if self.regions: scan_regions = tuple(self.regions) else: scan_regions = get_all_enabled_regions(session=session) # build graph specs. # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() resource_class_scan_granularity = resource_spec_class.scan_granularity if resource_class_scan_granularity == ScanGranularity.ACCOUNT: regions_services_resource_spec_classes[scan_regions[ 0]][client_name].append(resource_spec_class) elif resource_class_scan_granularity == ScanGranularity.REGION: for region in scan_regions: regions_services_resource_spec_classes[region][ client_name].append(resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_class_scan_granularity} not implemented" ) with ThreadPoolExecutor( max_workers=self.max_svc_threads) as executor: futures = [] for ( region, services_resource_spec_classes, ) in regions_services_resource_spec_classes.items(): for ( service, resource_spec_classes, ) in services_resource_spec_classes.items(): region_session = self.get_session(region=region) region_creds = region_session.get_credentials() scan_future = schedule_scan_services( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=self.account_id, region=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( resource_spec_classes), ) futures.append(scan_future) for future in as_completed(futures): graph_set_dict = future.result() graph_set = GraphSet.from_dict(graph_set_dict) errors += graph_set.errors account_graph_set.merge(graph_set) account_graph_set.validate() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) errors.append(" : ".join((error_str, trace_back))) unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=self.account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=stats, ) account_graph_set.validate() output_artifact = self.artifact_writer.write_artifact( name=self.account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() return { "account_id": self.account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }
def scan(self) -> List[Dict[str, Any]]: logger = Logger() scan_result_dicts = [] now = int(time.time()) prescan_account_ids_errors: DefaultDict[str, List[str]] = defaultdict(list) futures = [] with ThreadPoolExecutor(max_workers=self.max_threads) as executor: shuffled_account_ids = random.sample( self.account_scan_plan.account_ids, k=len(self.account_scan_plan.account_ids)) for account_id in shuffled_account_ids: with logger.bind(account_id=account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: scan_regions = tuple( self.account_scan_plan.regions) else: scan_regions = get_all_enabled_regions( session=session) account_gran_scan_region = random.choice( self.preferred_account_scan_regions) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT: if resource_spec_class.region_whitelist: account_resource_scan_region = resource_spec_class.region_whitelist[ 0] else: account_resource_scan_region = account_gran_scan_region regions_services_resource_spec_classes[ account_resource_scan_region][ client_name].append( resource_spec_class) elif resource_spec_class.scan_granularity == ScanGranularity.REGION: if resource_spec_class.region_whitelist: resource_scan_regions = tuple( region for region in scan_regions if region in resource_spec_class.region_whitelist) if not resource_scan_regions: resource_scan_regions = resource_spec_class.region_whitelist else: resource_scan_regions = scan_regions for region in resource_scan_regions: regions_services_resource_spec_classes[ region][client_name].append( resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented" ) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( svc_resource_spec_classes), ) futures.append(future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_account_ids_errors[account_id].append( f"{error_str}\n{trace_back}") account_ids_graph_set_dicts: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for future in as_completed(futures): account_id, graph_set_dict = future.result() account_ids_graph_set_dicts[account_id].append(graph_set_dict) # first make sure no account id appears both in account_ids_graph_set_dicts # and prescan_account_ids_errors - this should never happen doubled_accounts = set( account_ids_graph_set_dicts.keys()).intersection( set(prescan_account_ids_errors.keys())) if doubled_accounts: raise Exception(( f"BUG: Account(s) {doubled_accounts} in both " "account_ids_graph_set_dicts and prescan_account_ids_errors.")) # graph prescan error accounts for account_id, errors in prescan_account_ids_errors.items(): with logger.bind(account_id=account_id): unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter(), ) account_graph_set.validate() output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) # graph rest for account_id, graph_set_dicts in account_ids_graph_set_dicts.items(): with logger.bind(account_id=account_id): # if there are any errors whatsoever we generate an empty graph with # errors only errors = [] for graph_set_dict in graph_set_dicts: errors += graph_set_dict["errors"] if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter( ), # ENHANCHMENT: could technically get partial stats. ) account_graph_set.validate() else: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=MultilevelCounter(), ) for graph_set_dict in graph_set_dicts: graph_set = GraphSet.from_dict(graph_set_dict) account_graph_set.merge(graph_set) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) return scan_result_dicts
def scan(self) -> AccountScanResult: logger = Logger() now = int(time.time()) prescan_errors: List[str] = [] futures: List[Future] = [] account_id = self.account_scan_plan.account_id with logger.bind(account_id=account_id): with ThreadPoolExecutor(max_workers=self.max_threads) as executor: logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: account_scan_regions = tuple( self.account_scan_plan.regions) else: account_scan_regions = get_all_enabled_regions( session=session) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) for resource_spec_class in self.resource_spec_classes: resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions( resource_spec_class=resource_spec_class, region_whitelist=account_scan_regions, ) for region in resource_regions: regions_services_resource_spec_classes[region][ resource_spec_class.service_name].append( resource_spec_class) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: parallel_svc_resource_spec_classes = [ svc_resource_spec_class for svc_resource_spec_class in svc_resource_spec_classes if svc_resource_spec_class.parallel_scan ] serial_svc_resource_spec_classes = [ svc_resource_spec_class for svc_resource_spec_class in svc_resource_spec_classes if not svc_resource_spec_class.parallel_scan ] for (parallel_svc_resource_spec_class ) in parallel_svc_resource_spec_classes: parallel_future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=( parallel_svc_resource_spec_class, ), ) futures.append(parallel_future) serial_future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( serial_svc_resource_spec_classes), ) futures.append(serial_future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_errors.append(f"{error_str}\n{trace_back}") graph_sets: List[GraphSet] = [] for future in as_completed(futures): graph_set = future.result() graph_sets.append(graph_set) # if there was a prescan error graph it and return the result if prescan_errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=prescan_errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=prescan_errors, ) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set, ) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) return AccountScanResult( account_id=account_id, artifacts=[output_artifact], errors=prescan_errors, ) # if there are any errors whatsoever we generate an empty graph with errors only errors = [] for graph_set in graph_sets: errors += graph_set.errors if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, ) else: account_graph_set = GraphSet.from_graph_sets(graph_sets) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set, ) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) return AccountScanResult( account_id=account_id, artifacts=[output_artifact], errors=errors, )
class AltimeterNeptuneClient: """Client to run sparql queries against a neptune instance using graph name conventions to determine most recent graph. Args: max_age_min: maximum acceptable age in minutes of graphs. Only graphs which are found that meet this critera will be queried. neptune_endpoint: NeptuneEndpoint object for this client """ def __init__(self, max_age_min: int, neptune_endpoint: NeptuneEndpoint): self._neptune_endpoint = neptune_endpoint self._max_age_min = max_age_min self._auth = None # initially set this to a time in the past such that _get_auth's logic is simpler # regarding first run. self._auth_expiration = datetime.now() - timedelta(hours=24) self.logger = Logger() def run_query(self, graph_names: Set[str], query: str) -> QueryResult: """Runs a SPARQL query against the latest available graphs given a list of graph names. Args: graph_names: list of graph names to query query: query string. This query string should not include any 'from' clause; the graph_names param will be used to inject the correct graph uris by locating the latest acceptable (based on `max_age_min`) graph. Returns: QueryResult object """ graph_uris_load_times: Dict[str, int] = {} for graph_name in graph_names: graph_metadata = self._get_latest_graph_metadata(name=graph_name) graph_uris_load_times[graph_metadata.uri] = graph_metadata.end_time finalized_query = finalize_query(query, graph_uris=list(graph_uris_load_times.keys())) query_result_set = self.run_raw_query(finalized_query) return QueryResult(graph_uris_load_times, query_result_set) def load_graph(self, bucket: str, key: str, load_iam_role_arn: str) -> GraphMetadata: """Load a graph into Neptune. Args: bucket: s3 bucket of graph rdf key: s3 key of graph rdf load_iam_role_arn: arn of iam role used to load the graph Returns: GraphMetadata object describing loaded graph Raises: NeptuneLoadGraphException if errors occur during graph load """ session = boto3.Session(region_name=self._neptune_endpoint.region) s3_client = session.client("s3") rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket, Key=key) tag_set = rdf_object_tagging["TagSet"] graph_name = get_required_tag_value(tag_set, "name") graph_version = get_required_tag_value(tag_set, "version") graph_start_time = int(get_required_tag_value(tag_set, "start_time")) graph_end_time = int(get_required_tag_value(tag_set, "end_time")) graph_metadata = GraphMetadata( uri=f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}", name=graph_name, version=graph_version, start_time=graph_start_time, end_time=graph_end_time, ) logger = self.logger with logger.bind( rdf_bucket=bucket, rdf_key=key, graph_uri=graph_metadata.uri, neptune_endpoint=self._neptune_endpoint.get_endpoint_str(), ): session = boto3.Session(region_name=self._neptune_endpoint.region) credentials = session.get_credentials() auth = AWSRequestsAuth( aws_access_key=credentials.access_key, aws_secret_access_key=credentials.secret_key, aws_token=credentials.token, aws_host=self._neptune_endpoint.get_endpoint_str(), aws_region=self._neptune_endpoint.region, aws_service="neptune-db", ) post_body = { "source": f"s3://{bucket}/{key}", "format": "rdfxml", "iamRoleArn": load_iam_role_arn, "region": self._neptune_endpoint.region, "failOnError": "TRUE", "parallelism": "MEDIUM", "parserConfiguration": { "baseUri": GRAPH_BASE_URI, "namedGraphUri": graph_metadata.uri, }, } logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body) submit_resp = requests.post( self._neptune_endpoint.get_loader_endpoint(), json=post_body, auth=auth ) if submit_resp.status_code != 200: raise NeptuneLoadGraphException( f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}" ) submit_resp_json = submit_resp.json() load_id = submit_resp_json["payload"]["loadId"] with logger.bind(load_id=load_id): logger.info(event=LogEvent.NeptuneLoadPolling) while True: time.sleep(10) status_resp = requests.get( f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}", params={"details": "true", "errors": "true"}, auth=auth, ) if status_resp.status_code != 200: raise NeptuneLoadGraphException( f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}" ) status_resp_json = status_resp.json() status = status_resp_json["payload"]["overallStatus"]["status"] logger.info(event=LogEvent.NeptuneLoadPolling, status=status) if status == "LOAD_COMPLETED": break if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"): logger.error(event=LogEvent.NeptuneLoadError, status=status) raise NeptuneLoadGraphException(f"Error loading graph: {status_resp_json}") logger.info(event=LogEvent.NeptuneLoadEnd) logger.info(event=LogEvent.MetadataGraphUpdateStart) self._register_graph(graph_metadata=graph_metadata) logger.info(event=LogEvent.MetadataGraphUpdateEnd) return graph_metadata def _get_auth(self) -> AWSRequestsAuth: """Generate an AWSRequestsAuth object using a boto session for the current/local account. Returns: AWSRequestsAuth object """ if datetime.now() >= self._auth_expiration: session = boto3.Session() credentials = session.get_credentials() region = ( session.region_name if self._neptune_endpoint.region is None else self._neptune_endpoint.region ) auth = AWSRequestsAuth( aws_access_key=credentials.access_key, aws_secret_access_key=credentials.secret_key, aws_token=credentials.token, aws_host=f"{self._neptune_endpoint.host}:{self._neptune_endpoint.port}", aws_region=region, aws_service="neptune-db", ) self._auth = auth self._auth_expiration = datetime.now() + timedelta(minutes=SESSION_LIFETIME_MINUTES) return self._auth def _register_graph(self, graph_metadata: GraphMetadata) -> None: """Registers a GraphMetadata object into the metadata graph The meta graph keeps track of graph uris and metadata. Run this after a graph is completely loaded and then use _get_latest_graph_metadata to query this graph to find the latest graph. Args: graph_metadata: GraphMetadata to load into the metadata graph. Raises: NeptuneUpdateGraphException if an error occurred during metadata graph update """ auth = self._get_auth() neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint() update_stmt = ( f"INSERT DATA {{\n" f" GRAPH <{META_GRAPH_NAME}>\n" f" {{ <alti:graph:{graph_metadata.uri}> " f' <alti:uri> "{graph_metadata.uri}" ;\n' f' <alti:name> "{graph_metadata.name}" ;\n' f' <alti:version> "{graph_metadata.version}" ;\n' f" <alti:start_time> {graph_metadata.start_time} ;\n" f" <alti:end_time> {graph_metadata.end_time} ;\n" f"}}\n" "}\n" ) resp = requests.post(neptune_sparql_url, data={"update": update_stmt}, auth=auth) if resp.status_code != 200: raise NeptuneUpdateGraphException( (f"Error updating graph {META_GRAPH_NAME} " f"with {update_stmt} : {resp.text}") ) def get_graph_uris(self, name: str) -> List[str]: """Return all graph uris regardless of whether they have corresponding metadata entries Args: name: graph name Returns: list of graph uris """ query = "SELECT ?graph_uri WHERE { GRAPH ?graph_uri { } }" results = self.run_raw_query(query=query) results_list = results.to_list() all_graph_uris = [result["graph_uri"] for result in results_list] graph_prefix = f"{GRAPH_BASE_URI}/{name}/" graph_uris = [uri for uri in all_graph_uris if uri.startswith(graph_prefix)] return graph_uris def get_graph_metadatas(self, name: str, version: Optional[str] = None) -> List[GraphMetadata]: """Return all graph metadatas for a given name/version. These represent fully loaded graphs in the Neptune database. Args: name: graph name version: graph version Returns: list of GraphMetadata objects for the given graph name/version """ if version is None: get_graph_metadatas_query = ( "SELECT ?uri ?name ?version ?start_time ?end_time\n" f"FROM <{META_GRAPH_NAME}>\n" f"WHERE {{ ?graph_metadata <alti:uri> ?uri ;\n" f' <alti:name> "{name}" ;\n' f" <alti:name> ?name ;\n" f" <alti:version> ?version ;\n" f" <alti:start_time> ?start_time ;\n" f" <alti:end_time> ?end_time }}\n" f"ORDER BY DESC(?end_time)\n" ) else: get_graph_metadatas_query = ( "SELECT ?uri ?name ?version ?start_time ?end_time\n" f"FROM <{META_GRAPH_NAME}>\n" f"WHERE {{ ?graph_metadata <alti:uri> ?uri ;\n" f' <alti:name> "{name}" ;\n' f" <alti:name> ?name ;\n" f' <alti:version> "{version}" ;\n' f" <alti:version> ?version ;\n" f" <alti:start_time> ?start_time ;\n" f" <alti:end_time> ?end_time }}\n" f"ORDER BY DESC(?end_time)\n" ) results = self.run_raw_query(query=get_graph_metadatas_query) results_list = results.to_list() graph_metadatas: List[GraphMetadata] = [] for result in results_list: graph_metadata = GraphMetadata( uri=result["uri"], name=result["name"], version=result["version"], start_time=int(result["start_time"]), end_time=int(result["end_time"]), ) graph_metadatas.append(graph_metadata) return graph_metadatas def _get_latest_graph_metadata(self, name: str, version: Optional[str] = None) -> GraphMetadata: """Return a GraphMetadata object representing the most recently successfully loaded graph for a given name / version. Args: name: graph name version: graph version Returns: GraphMetadata of the latest graph for the given name/version Raises: NeptuneNoGraphsFoundException if no matching graphs were found NeptuneNoFreshGraphFoundException if no graphs could be found within ax_age_min """ if version is None: get_graph_metadatas_query = ( f"SELECT ?uri ?version ?start_time ?end_time\n" f"FROM <{META_GRAPH_NAME}>\n" f"WHERE {{\n" f" ?graph_metadata <alti:uri> ?uri ;\n" f' <alti:name> "{name}" ;\n' f" <alti:version> ?version ;\n" f" <alti:start_time> ?start_time ;\n" f" <alti:end_time> ?end_time }}\n" f"ORDER BY DESC(?version) DESC(?end_time)\n" f"LIMIT 1" ) else: get_graph_metadatas_query = ( f"SELECT ?uri ?version ?start_time ?end_time\n" f"FROM <{META_GRAPH_NAME}>\n" f"WHERE {{\n" f" ?graph_metadata <alti:uri> ?uri ;\n" f' <alti:name> "{name}" ;\n' f" <alti:version> {version} ;\n" f" <alti:version> ?version ;\n" f" <alti:start_time> ?start_time ;\n" f" <alti:end_time> ?end_time }}\n" f"ORDER BY DESC(?end_time)\n" f"LIMIT 1" ) results = self.run_raw_query(query=get_graph_metadatas_query) results_list = results.to_list() if not results_list: raise NeptuneNoGraphsFoundException(f"No graphs found for graph name '{name}'") if len(results_list) != 1: raise RuntimeError("Logic error - more than one graph returned.") result = results_list[0] latest_uri = result["uri"] latest_version = result["version"] latest_start_time = int(result["start_time"]) latest_end_time = int(result["end_time"]) now = int(datetime.now().timestamp()) oldest_acceptable_graph_end_time = now - self._max_age_min * 60 if latest_end_time < oldest_acceptable_graph_end_time: raise NeptuneNoFreshGraphFoundException( ( f"Could not find a graph named '{name}' younger " f"than {self._max_age_min} " f"minutes old. Found: {results_list}" ) ) return GraphMetadata( uri=latest_uri, name=name, version=latest_version, start_time=latest_start_time, end_time=latest_end_time, ) def clear_registered_graph(self, name: str, uri: str) -> None: """Remove data and metadata for a graph by uri Args: name: graph name uri: graph uri Raises: NeptuneUpdateGraphException if an error occurred during clearing """ # clear metadata first such that clients will not use this graph if # data clear fails self.clear_graph_metadata(name=name, uri=uri) # then clear data self.clear_graph_data(uri=uri) def clear_graph_metadata(self, name: str, uri: str) -> None: """Clear a graph metadata entry""" auth = self._get_auth() neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint() delete_stmt = ( f"WITH <{META_GRAPH_NAME}>\n" f"DELETE\n" f' {{ ?graph <alti:uri> "{uri}" ;\n' f' <alti:name> "{name}" ;\n' f" <alti:version> ?version ;\n" f" <alti:start_time> ?start_time ;\n" f" <alti:end_time> ?end_time }}\n" f"WHERE\n" f' {{ ?graph <alti:uri> "{uri}" ;\n' f' <alti:name> "{name}" ;\n' f" <alti:version> ?version ;\n" f" <alti:start_time> ?start_time ;\n" f" <alti:end_time> ?end_time }}\n" ) resp = requests.post(neptune_sparql_url, data={"update": delete_stmt}, auth=auth) if resp.status_code != 200: raise NeptuneUpdateGraphException( (f"Error updating graph {META_GRAPH_NAME} " f"with {delete_stmt} : {resp.text}") ) def clear_graph_data(self, uri: str) -> None: """Clear a graph in Neptune""" auth = self._get_auth() neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint() update_stmt = f"clear graph <{uri}>" resp = requests.post(neptune_sparql_url, data={"update": update_stmt}, auth=auth) if resp.status_code != 200: raise NeptuneClearGraphException( (f"Error clearing graph {uri} " f"with {update_stmt} : {resp.text}") ) def run_raw_query(self, query: str) -> QueryResultSet: """Run a query against a neptune instance, return a dict of results. Generally this should be called from `run_query` Args: query: complete query to run Returns: QueryResultSet object Raises: NeptuneQueryException if an error occurred running the query """ neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint() auth = self._get_auth() resp = requests.post( neptune_sparql_url, headers={"te": "trailers"}, data={"query": query}, auth=auth ) if resp.status_code != 200: raise NeptuneQueryException(f"Error running query {query}: {resp.text}") try: results_json = resp.json() except json.decoder.JSONDecodeError as jde: neptune_status = resp.headers.get("X-Neptune-Status", "unknown") neptune_detail = resp.headers.get("X-Neptune-Detail", "unknown") raise NeptuneQueryException( f"Error running query {query}: {neptune_status}: {neptune_detail}" ) from jde return QueryResultSet.from_sparql_endpoint_json(results_json) @staticmethod def __normalize_query_string(query: str) -> str: """Normalize the query string""" kv = (list(map(str.strip, s.split("="))) for s in query.split("&") if len(s) > 0) normalized = "&".join("%s=%s" % (p[0], p[1] if len(p) > 1 else "") for p in sorted(kv)) return normalized def __get_signature_key( self, key: str, datestamp: str, regionname: str, servicename: str ) -> bytes: """Get the signed signature key :return: The signed key """ key_date = self.__sign(("AWS4" + key).encode("utf-8"), datestamp) key_region = self.__sign(key_date, regionname) key_service = self.__sign(key_region, servicename) key_signing = self.__sign(key_service, "aws4_request") return key_signing @staticmethod def __sign(key: bytes, msg: str) -> bytes: """ Sign the msg with the key """ return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() def prepare_request( self, method: str = "GET", payload: str = "", querystring: Dict = {} ) -> RequestParameters: """ This prepares the request for sigv4signing. This is heavily influenced by the code here: https://github.com/awslabs/amazon-neptune-tools/tree/master/neptune-python-utils :param method: The method name :param payload: The request payload :param querystring: The request querystring :return: The request parameters """ session = boto3.Session() credentials = session.get_credentials() access_key = credentials.access_key secret_key = credentials.secret_key session_token = credentials.token service = "neptune-db" algorithm = "AWS4-HMAC-SHA256" request_parameters = parse.urlencode(querystring).replace("%27", "%22") canonical_querystring = self.__normalize_query_string(request_parameters) t = datetime.utcnow() amz_date = t.strftime("%Y%m%dT%H%M%SZ") datestamp = t.strftime("%Y%m%d") canonical_headers = "host:{}:{}\nx-amz-date:{}\n".format( self._neptune_endpoint.host, self._neptune_endpoint.port, amz_date ) signed_headers = "host;x-amz-date" payload_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest() canonical_request = "{}\n/{}\n{}\n{}\n{}\n{}".format( method, "gremlin", canonical_querystring, canonical_headers, signed_headers, payload_hash, ) credential_scope = "{}/{}/{}/aws4_request".format( datestamp, self._neptune_endpoint.region, service ) string_to_sign = "{}\n{}\n{}\n{}".format( algorithm, amz_date, credential_scope, hashlib.sha256(canonical_request.encode("utf-8")).hexdigest(), ) signing_key = self.__get_signature_key( secret_key, datestamp, self._neptune_endpoint.region, service ) signature = hmac.new( signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 ).hexdigest() authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format( algorithm, access_key, credential_scope, signed_headers, signature ) headers = {"x-amz-date": amz_date, "Authorization": authorization_header} if session_token: headers["x-amz-security-token"] = session_token return RequestParameters( "{}?{}".format( self._neptune_endpoint.get_gremlin_endpoint(self._neptune_endpoint.ssl), canonical_querystring, ) if canonical_querystring else self._neptune_endpoint.get_gremlin_endpoint(), canonical_querystring, headers, ) def connect_to_gremlin(self) -> Tuple[traversal, DriverRemoteConnection]: """ Get the Gremlin traversal and connection for the Neptune endpoint :return: The Traversal object """ if self._neptune_endpoint.auth_mode.lower() == "default": gremlin_connection = DriverRemoteConnection( self._neptune_endpoint.get_gremlin_endpoint(self._neptune_endpoint.ssl), "g" ) else: request_parameters = self.prepare_request() signed_ws_request = httpclient.HTTPRequest( request_parameters.uri, headers=request_parameters.headers ) gremlin_connection = DriverRemoteConnection(signed_ws_request, "g") graph_traversal_source = traversal().withRemote(gremlin_connection) return graph_traversal_source, gremlin_connection def __write_vertices(self, g: traversal, vertices: List[Dict], scan_id: str) -> None: """ Writes the vertices to the labeled property graph :param g: The graph traversal source :param vertices: A list of dictionaries for each vertex :return: None """ cnt = 0 t = g for r in vertices: vertex_id = f'{r["~id"]}_{scan_id}' t = ( t.V(vertex_id) .fold() .coalesce( __.unfold(), __.addV(self.parse_arn(r["~label"])["resource"]).property(T.id, vertex_id), ) ) for k in r.keys(): # Need to handle numbers that are bigger than a Long in Java, for now we stringify it if isinstance(r[k], int) and ( r[k] > 9223372036854775807 or r[k] < -9223372036854775807 ): r[k] = str(r[k]) if k not in ["~id", "~label"]: t = t.property(k, r[k]) cnt += 1 if cnt % 100 == 0 or cnt == len(vertices): try: self.logger.info( event=LogEvent.NeptunePeriodicWrite, msg=f"Writing vertices {cnt} of {len(vertices)}", ) t.next() t = g except Exception as err: print(str(err)) raise NeptuneLoadGraphException( f"Error loading vertex {r} " f"with {str(t.bytecode)}" ) from err def __write_edges(self, g: traversal, edges: List[Dict], scan_id: str) -> None: """ Writes the edges to the labeled property graph :param g: The graph traversal source :param edges: A list of dictionaries for each edge :return: None """ cnt = 0 t = g for r in edges: to_id = f'{r["~to"]}_{scan_id}' from_id = f'{r["~from"]}_{scan_id}' t = ( t.addE(r["~label"]) .property(T.id, str(r["~id"])) .from_( __.V(from_id) .fold() .coalesce( __.unfold(), __.addV(self.parse_arn(r["~from"])["resource"]) .property(T.id, from_id) .property("scan_id", scan_id) .property("arn", r["~from"]), ) ) .to( __.V(to_id) .fold() .coalesce( __.unfold(), __.addV(self.parse_arn(r["~to"])["resource"]) .property(T.id, to_id) .property("scan_id", scan_id) .property("arn", r["~to"]), ) ) ) cnt += 1 if cnt % 100 == 0 or cnt == len(edges): try: self.logger.info( event=LogEvent.NeptunePeriodicWrite, msg=f"Writing edges {cnt} of {len(edges)}", ) t.next() t = g except Exception as err: self.logger.error(event=LogEvent.NeptuneLoadError, msg=str(err)) raise NeptuneLoadGraphException( f"Error loading edge {r} " f"with {str(t.bytecode)}" ) from err @staticmethod def parse_arn(arn: str) -> Dict: """ Parses an ARN into the component pieces :param arn: The arn to parse :return: A dictionary of the arn pieces """ # http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html elements = str(arn).split(":", 5) result = {} if len(elements) == 6: result = { "arn": elements[0], "partition": elements[1], "service": elements[2], "region": elements[3], "account": elements[4], "resource": elements[5], "resource_type": None, } else: result["resource"] = str(arn) if "/" in str(result["resource"]): result["resource_type"], result["resource"] = str(result["resource"]).split("/", 1) elif ":" in str(result["resource"]): result["resource_type"], result["resource"] = str(result["resource"]).split(":", 1) if str(result["resource"]).startswith("ami-"): result["resource"] = result["resource_type"] return result def write_to_neptune_lpg(self, graph: Dict, scan_id: str) -> None: """ Writes the graph to a labeled property graph :param scan_id: The unique string representing the scan :param graph: The graph to write :return: None """ if "vertices" in graph and "edges" in graph and len(graph["vertices"]) > 0: g, conn = self.connect_to_gremlin() self.__write_vertices(g, graph["vertices"], scan_id) self.__write_edges(g, graph["edges"], scan_id) conn.close() else: raise NeptuneNoGraphsFoundException def write_to_neptune_rdf(self, graph: Dict) -> None: """ Writes the graph to an RDF graph :param graph: The graph to write :return: None """ auth = self._get_auth() neptune_sparql_url = self._neptune_endpoint.get_sparql_endpoint() triples = "" for subject, predicate, obj in graph: triples = triples + subject.n3() + " " + predicate.n3() + " " + obj.n3() + " . \n" insert_stmt = ( "INSERT DATA {\n" + f" GRAPH <{META_GRAPH_NAME}>\n" "{\n" f"{triples}" "}\n" "}\n" ) resp = requests.post(neptune_sparql_url, data={"update": insert_stmt}, auth=auth) if resp.status_code != 200: raise NeptuneUpdateGraphException( f"Error updating graph {META_GRAPH_NAME} " f"with {insert_stmt} : {resp.text}" )
def load_graph(self, bucket: str, key: str, load_iam_role_arn: str) -> GraphMetadata: """Load a graph into Neptune. Args: bucket: s3 bucket of graph rdf key: s3 key of graph rdf load_iam_role_arn: arn of iam role used to load the graph Returns: GraphMetadata object describing loaded graph Raises: NeptuneLoadGraphException if errors occur during graph load """ session = boto3.Session(region_name=self._neptune_endpoint.region) s3_client = session.client("s3") rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket, Key=key) tag_set = rdf_object_tagging["TagSet"] graph_name = get_required_tag_value(tag_set, "name") graph_version = get_required_tag_value(tag_set, "version") graph_start_time = int(get_required_tag_value(tag_set, "start_time")) graph_end_time = int(get_required_tag_value(tag_set, "end_time")) graph_metadata = GraphMetadata( uri= f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}", name=graph_name, version=graph_version, start_time=graph_start_time, end_time=graph_end_time, ) logger = Logger() with logger.bind( rdf_bucket=bucket, rdf_key=key, graph_uri=graph_metadata.uri, neptune_endpoint=self._neptune_endpoint.get_endpoint_str(), ): session = boto3.Session(region_name=self._neptune_endpoint.region) credentials = session.get_credentials() auth = AWSRequestsAuth( aws_access_key=credentials.access_key, aws_secret_access_key=credentials.secret_key, aws_token=credentials.token, aws_host=self._neptune_endpoint.get_endpoint_str(), aws_region=self._neptune_endpoint.region, aws_service="neptune-db", ) post_body = { "source": f"s3://{bucket}/{key}", "format": "rdfxml", "iamRoleArn": load_iam_role_arn, "region": self._neptune_endpoint.region, "failOnError": "TRUE", "parallelism": "MEDIUM", "parserConfiguration": { "baseUri": GRAPH_BASE_URI, "namedGraphUri": graph_metadata.uri, }, } logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body) submit_resp = requests.post( self._neptune_endpoint.get_loader_endpoint(), json=post_body, auth=auth) if submit_resp.status_code != 200: raise NeptuneLoadGraphException( f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}" ) submit_resp_json = submit_resp.json() load_id = submit_resp_json["payload"]["loadId"] with logger.bind(load_id=load_id): logger.info(event=LogEvent.NeptuneLoadPolling) while True: time.sleep(10) status_resp = requests.get( f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}", params={ "details": "true", "errors": "true" }, auth=auth, ) if status_resp.status_code != 200: raise NeptuneLoadGraphException( f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}" ) status_resp_json = status_resp.json() status = status_resp_json["payload"]["overallStatus"][ "status"] logger.info(event=LogEvent.NeptuneLoadPolling, status=status) if status == "LOAD_COMPLETED": break if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"): logger.error(event=LogEvent.NeptuneLoadError, status=status) raise NeptuneLoadGraphException( f"Error loading graph: {status_resp_json}") logger.info(event=LogEvent.NeptuneLoadEnd) logger.info(event=LogEvent.MetadataGraphUpdateStart) self._register_graph(graph_metadata=graph_metadata) logger.info(event=LogEvent.MetadataGraphUpdateEnd) return graph_metadata
def prune_graph(graph_pruner_config: GraphPrunerConfig) -> GraphPrunerResults: config = Config.from_path(path=graph_pruner_config.config_path) if config.neptune is None: raise InvalidConfigException("Configuration missing neptune section.") now = int(datetime.now().timestamp()) oldest_acceptable_graph_epoch = now - config.pruner_max_age_min * 60 endpoint = NeptuneEndpoint( host=config.neptune.host, port=config.neptune.port, region=config.neptune.region ) client = AltimeterNeptuneClient( max_age_min=config.pruner_max_age_min, neptune_endpoint=endpoint ) logger = Logger() uncleared = [] pruned_graph_uris = [] skipped_graph_uris = [] logger.info(event=LogEvent.PruneNeptuneGraphsStart) all_graph_metadatas = client.get_graph_metadatas(name=config.graph_name) with logger.bind(neptune_endpoint=str(endpoint)): for graph_metadata in all_graph_metadatas: assert graph_metadata.name == config.graph_name graph_epoch = graph_metadata.end_time with logger.bind(graph_uri=graph_metadata.uri, graph_epoch=graph_epoch): if graph_epoch < oldest_acceptable_graph_epoch: logger.info(event=LogEvent.PruneNeptuneGraphStart) try: client.clear_registered_graph( name=config.graph_name, uri=graph_metadata.uri ) logger.info(event=LogEvent.PruneNeptuneGraphEnd) pruned_graph_uris.append(graph_metadata.uri) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg=f"Error pruning graph {graph_metadata.uri}: {ex}", ) uncleared.append(graph_metadata.uri) continue else: logger.info(event=LogEvent.PruneNeptuneGraphSkip) skipped_graph_uris.append(graph_metadata.uri) # now find orphaned graphs - these are in neptune but have no metadata registered_graph_uris = [g_m.uri for g_m in all_graph_metadatas] all_graph_uris = client.get_graph_uris(name=config.graph_name) orphaned_graphs = set(all_graph_uris) - set(registered_graph_uris) if orphaned_graphs: for orphaned_graph_uri in orphaned_graphs: with logger.bind(graph_uri=orphaned_graph_uri): logger.info(event=LogEvent.PruneOrphanedNeptuneGraphStart) try: client.clear_graph_data(uri=orphaned_graph_uri) logger.info(event=LogEvent.PruneOrphanedNeptuneGraphEnd) pruned_graph_uris.append(orphaned_graph_uri) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg=f"Error pruning graph {orphaned_graph_uri}: {ex}", ) uncleared.append(orphaned_graph_uri) continue logger.info(event=LogEvent.PruneNeptuneGraphsEnd) if uncleared: msg = f"Errors were found pruning {uncleared}." logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg) raise Exception(msg) return GraphPrunerResults( pruned_graph_uris=pruned_graph_uris, skipped_graph_uris=skipped_graph_uris, )
def lambda_handler(event: Dict[str, Any], context: Any) -> None: """Entrypoint""" root = logging.getLogger() if root.handlers: for handler in root.handlers: root.removeHandler(handler) config_path = get_required_str_env_var("CONFIG_PATH") config = Config.from_path(path=config_path) if config.neptune is None: raise InvalidConfigException("Configuration missing neptune section.") now = int(datetime.now().timestamp()) oldest_acceptable_graph_epoch = now - config.pruner_max_age_min * 60 endpoint = NeptuneEndpoint(host=config.neptune.host, port=config.neptune.port, region=config.neptune.region) client = AltimeterNeptuneClient(max_age_min=config.pruner_max_age_min, neptune_endpoint=endpoint) logger = Logger() uncleared = [] logger.info(event=LogEvent.PruneNeptuneGraphsStart) all_graph_metadatas = client.get_graph_metadatas(name=config.graph_name) with logger.bind(neptune_endpoint=str(endpoint)): for graph_metadata in all_graph_metadatas: assert graph_metadata.name == config.graph_name graph_epoch = graph_metadata.end_time with logger.bind(graph_uri=graph_metadata.uri, graph_epoch=graph_epoch): if graph_epoch < oldest_acceptable_graph_epoch: logger.info(event=LogEvent.PruneNeptuneGraphStart) try: client.clear_registered_graph(name=config.graph_name, uri=graph_metadata.uri) logger.info(event=LogEvent.PruneNeptuneGraphEnd) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg= f"Error pruning graph {graph_metadata.uri}: {ex}", ) uncleared.append(graph_metadata.uri) continue else: logger.info(event=LogEvent.PruneNeptuneGraphSkip) # now find orphaned graphs - these are in neptune but have no metadata registered_graph_uris = [g_m.uri for g_m in all_graph_metadatas] all_graph_uris = client.get_graph_uris(name=config.graph_name) orphaned_graphs = set(all_graph_uris) - set(registered_graph_uris) if orphaned_graphs: for orphaned_graph_uri in orphaned_graphs: with logger.bind(graph_uri=orphaned_graph_uri): logger.info(event=LogEvent.PruneOrphanedNeptuneGraphStart) try: client.clear_graph_data(uri=orphaned_graph_uri) logger.info( event=LogEvent.PruneOrphanedNeptuneGraphEnd) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg= f"Error pruning graph {orphaned_graph_uri}: {ex}", ) uncleared.append(orphaned_graph_uri) continue logger.info(event=LogEvent.PruneNeptuneGraphsEnd) if uncleared: msg = f"Errors were found pruning {uncleared}." logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg) raise Exception(msg)