def get_session(self, account_id: str, region: Optional[str] = None) -> boto3.Session: """Get a boto3 session for a given account. Args: account_id: target account id region: session region Returns: boto3.Session object """ logger = Logger() with logger.bind(auth_account_id=account_id): if self.multi_hop_accessors: for mha in self.multi_hop_accessors: # pylint: disable=not-an-iterable with logger.bind(auth_accessor=str(mha)): try: session = mha.get_session(account_id=account_id, region=region) return session except Exception as ex: logger.debug(event=LogEvent.AuthToAccountFailure, exception=str(ex)) raise AccountAuthException( f"Unable to access {account_id} using {str(self)}") # local run mode session = boto3.Session(region_name=region) sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity()["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts_account_id {sts_account_id} != {account_id}") return session
def lambda_handler(event, context): host = get_required_lambda_env_var("NEPTUNE_HOST") port = get_required_lambda_env_var("NEPTUNE_PORT") region = get_required_lambda_env_var("NEPTUNE_REGION") max_age_min = get_required_lambda_env_var("MAX_AGE_MIN") graph_name = get_required_lambda_env_var("GRAPH_NAME") try: max_age_min = int(max_age_min) except ValueError as ve: raise Exception(f"env var MAX_AGE_MIN must be an int: {ve}") now = int(datetime.now().timestamp()) oldest_acceptable_graph_epoch = now - max_age_min * 60 endpoint = NeptuneEndpoint(host=host, port=port, region=region) client = AltimeterNeptuneClient(max_age_min=max_age_min, neptune_endpoint=endpoint) logger = Logger() uncleared = [] # first prune metadata - if clears below are partial we want to make sure no clients # consider this a valid graph still. logger.info(event=LogEvent.PruneNeptuneMetadataGraphStart) client.clear_old_graph_metadata(name=graph_name, max_age_min=max_age_min) logger.info(event=LogEvent.PruneNeptuneMetadataGraphEnd) # now clear actual graphs with logger.bind(neptune_endpoint=endpoint): logger.info(event=LogEvent.PruneNeptuneGraphsStart) for graph_metadata in client.get_graph_metadatas(name=graph_name): assert graph_metadata.name == graph_name graph_epoch = graph_metadata.end_time with logger.bind(graph_uri=graph_metadata.uri, graph_epoch=graph_epoch): if graph_epoch < oldest_acceptable_graph_epoch: logger.info(event=LogEvent.PruneNeptuneGraphStart) try: client.clear_graph(graph_uri=graph_metadata.uri) logger.info(event=LogEvent.PruneNeptuneGraphEnd) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg= f"Error pruning graph {graph_metadata.uri}: {ex}", ) uncleared.append(graph_metadata.uri) continue else: logger.info(event=LogEvent.PruneNeptuneGraphSkip) logger.info(event=LogEvent.PruneNeptuneGraphsEnd) if uncleared: msg = f"Errors were found pruning {uncleared}." logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg) raise Exception(msg)
def write_graph_set(self, name: str, graph_set: GraphSet, compression: Optional[str] = None) -> str: """Write a graph artifact Args: name: name graph_set: GraphSet object to write Returns: path to written artifact """ logger = Logger() os.makedirs(self.output_dir, exist_ok=True) if compression is None: artifact_path = os.path.join(self.output_dir, f"{name}.rdf") elif compression == GZIP: artifact_path = os.path.join(self.output_dir, f"{name}.rdf.gz") else: raise ValueError(f"Unknown compression arg {compression}") graph = graph_set.to_rdf() with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.WriteToFSStart) with open(artifact_path, "wb") as fp: if compression is None: graph.serialize(fp) elif compression == GZIP: with gzip.GzipFile(fileobj=fp, mode="wb") as gz: graph.serialize(gz) else: raise ValueError(f"Unknown compression arg {compression}") logger.info(event=LogEvent.WriteToFSEnd) return artifact_path
def scan_services( graph_name: str, graph_version: str, account_id: str, region: str, service: str, access_key: str, secret_key: str, token: str, resource_spec_classes: Tuple[Type[AWSResourceSpec], ...], ) -> Dict[str, Any]: logger = Logger() with logger.bind(region=region, service=service): logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart) session = boto3.Session( aws_access_key_id=access_key, aws_secret_access_key=secret_key, aws_session_token=token, region_name=region, ) aws_accessor = AWSAccessor(session=session, account_id=account_id, region_name=region) graph_spec = GraphSpec( name=graph_name, version=graph_version, resource_spec_classes=resource_spec_classes, scan_accessor=aws_accessor, ) graph_set = graph_spec.scan() graph_set_dict = graph_set.to_dict() logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd) return graph_set_dict
def lambda_handler(cls, event: Dict[str, Any], _: Any) -> None: """lambda entrypoint""" config = Config() result = Result(**event) logger = Logger() errors: List[str] = [] with logger.bind(result=result): logger.info(event=QJLogEvents.ResultRemediationStart) try: session = get_assumed_session( account_id=result.account_id, role_name=config.remediator_target_role_name, external_id=config.remediator_target_role_external_id, ) cls.remediate(session=session, result=result.result, dry_run=config.dry_run) logger.info(event=QJLogEvents.ResultRemediationSuccessful) except Exception as ex: logger.error(event=QJLogEvents.ResultRemediationFailed, error=str(ex)) errors.append(str(ex)) if errors: raise RemediationError( f"Errors found during remediation: {errors}")
def scan(self) -> GraphSet: """Perform a scan on all of the resource classes in this GraphSpec and return a GraphSet containing the scanned data. Returns: GraphSet representing results of scanning this GraphSpec's resource_spec_classes. """ resources: List[Resource] = [] errors: List[str] = [] stats = MultilevelCounter() start_time = int(time.time()) logger = Logger() for resource_spec_class in self.resource_spec_classes: with logger.bind(resource_type=str(resource_spec_class.type_name)): logger.debug(event=LogEvent.ScanResourceTypeStart) resource_scan_result = resource_spec_class.scan( scan_accessor=self.scan_accessor) resources += resource_scan_result.resources errors += resource_scan_result.errors stats.merge(resource_scan_result.stats) logger.debug(event=LogEvent.ScanResourceTypeEnd) end_time = int(time.time()) return GraphSet( name=self.name, version=self.version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=stats, )
def invoke_lambda(lambda_name: str, lambda_timeout: int, event: Dict[str, Any]) -> Dict[str, Any]: """Invoke an AWS Lambda function Args: lambda_name: name of lambda lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait at least this long for a response before timing out. event: event data to send to the lambda Returns: lambda response payload Raises: Exception if there was an error invoking the lambda. """ logger = Logger() with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, event=event): logger.info(event=AWSLogEvents.RunAccountScanLambdaStart) boto_config = botocore.config.Config( read_timeout=lambda_timeout + 10, retries={"max_attempts": 0} ) session = boto3.Session() lambda_client = session.client("lambda", config=boto_config) resp = lambda_client.invoke( FunctionName=lambda_name, Payload=json.dumps(event).encode("utf-8") ) payload: bytes = resp["Payload"].read() if resp.get("FunctionError", None): raise Exception(f"Error invoking {lambda_name} with event {event}: {payload}") payload_dict = json.loads(payload) logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd) return payload_dict
def read_json(self, path: str) -> Dict[str, Any]: """Read a json artifact Args: path: s3 uri to artifact. s3://bucket/key/path Returns: artifact content """ bucket, key = parse_s3_uri(path) if key is None: raise ValueError(f"Unable to read from s3 uri missing key: {path}") session = boto3.Session() s3_client = session.client("s3") logger = Logger() with logger.bind(bucket=bucket, key=key): with io.BytesIO() as artifact_bytes_buf: logger.info(event=LogEvent.ReadFromS3Start) s3_client.download_fileobj(bucket, key, artifact_bytes_buf) artifact_bytes_buf.flush() artifact_bytes_buf.seek(0) artifact_bytes = artifact_bytes_buf.read() logger.info(event=LogEvent.ReadFromS3End) artifact_str = artifact_bytes.decode("utf-8") artifact_dict = json.loads(artifact_str) return artifact_dict
def lambda_handler(event, context): json_bucket = event["Records"][0]["s3"]["bucket"]["name"] json_key = urllib.parse.unquote(event["Records"][0]["s3"]["object"]["key"]) rdf_bucket = get_required_lambda_env_var("RDF_BUCKET") rdf_key = ".".join(json_key.split(".")[:-1]) + ".rdf.gz" session = boto3.Session() s3_client = session.client("s3") logger = Logger() with logger.bind(json_bucket=json_bucket, json_key=json_key): graph_pkg = graph_pkg_from_s3(s3_client=s3_client, json_bucket=json_bucket, json_key=json_key) with logger.bind(rdf_bucket=rdf_bucket, rdf_key=rdf_key): logger.info(event=LogEvent.WriteToS3Start) with io.BytesIO() as rdf_bytes_buf: with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz: graph_pkg.graph.serialize(gz) rdf_bytes_buf.flush() rdf_bytes_buf.seek(0) s3_client.upload_fileobj(rdf_bytes_buf, rdf_bucket, rdf_key) s3_client.put_object_tagging( Bucket=rdf_bucket, Key=rdf_key, Tagging={ "TagSet": [ { "Key": "name", "Value": graph_pkg.name }, { "Key": "version", "Value": graph_pkg.version }, { "Key": "start_time", "Value": str(graph_pkg.start_time) }, { "Key": "end_time", "Value": str(graph_pkg.end_time) }, ] }, ) logger.info(event=LogEvent.WriteToS3End)
def scan( self, account_scan_plan: AccountScanPlan ) -> Generator[AccountScanManifest, None, None]: """Scan accounts. Return a list of AccountScanManifest objects. Args: account_scan_plan: AccountScanPlan defining this scan op Yields: AccountScanManifest objects """ num_total_accounts = len(account_scan_plan.account_ids) account_scan_plans = account_scan_plan.to_batches( max_accounts=self.config.concurrency.max_accounts_per_thread) num_account_batches = len(account_scan_plans) num_threads = min(num_account_batches, self.config.concurrency.max_account_scan_threads) logger = Logger() with logger.bind( num_total_accounts=num_total_accounts, num_account_batches=num_account_batches, muxer=self.__class__.__name__, num_muxer_threads=num_threads, ): logger.info(event=AWSLogEvents.MuxerStart) with ThreadPoolExecutor(max_workers=num_threads) as executor: processed_accounts = 0 futures = [] for sub_account_scan_plan in account_scan_plans: account_scan_future = self._schedule_account_scan( executor, sub_account_scan_plan) futures.append(account_scan_future) logger.info( event=AWSLogEvents.MuxerQueueScan, account_ids=",".join( sub_account_scan_plan.account_ids), ) for future in as_completed(futures): scan_results_dicts = future.result() for scan_results_dict in scan_results_dicts: account_id = scan_results_dict["account_id"] output_artifact = scan_results_dict["output_artifact"] account_errors = scan_results_dict["errors"] api_call_stats = scan_results_dict["api_call_stats"] artifacts = [output_artifact ] if output_artifact else [] account_scan_result = AccountScanManifest( account_id=account_id, artifacts=artifacts, errors=account_errors, api_call_stats=api_call_stats, ) yield account_scan_result processed_accounts += 1 logger.info(event=AWSLogEvents.MuxerStat, num_scanned=processed_accounts) logger.info(event=AWSLogEvents.MuxerEnd)
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]: logger = Logger() with logger.bind( account_id=scan_unit.account_id, region=scan_unit.region_name, service=scan_unit.service, resource_classes=sorted([ resource_spec_class.__name__ for resource_spec_class in scan_unit.resource_spec_classes ]), ): start_t = time.time() logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart) session = boto3.Session( aws_access_key_id=scan_unit.access_key, aws_secret_access_key=scan_unit.secret_key, aws_session_token=scan_unit.token, region_name=scan_unit.region_name, ) scan_accessor = AWSAccessor(session=session, account_id=scan_unit.account_id, region_name=scan_unit.region_name) graph_spec = GraphSpec( name=scan_unit.graph_name, version=scan_unit.graph_version, resource_spec_classes=scan_unit.resource_spec_classes, scan_accessor=scan_accessor, ) start_time = int(time.time()) resources: List[Resource] = [] errors = [] try: resources = graph_spec.scan() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) error = f"{str(ex)}\n{trace_back}" errors.append(error) end_time = int(time.time()) graph_set = GraphSet( name=scan_unit.graph_name, version=scan_unit.graph_version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=scan_accessor.api_call_stats, ) end_t = time.time() elapsed_sec = end_t - start_t logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd, elapsed_sec=elapsed_sec) return (scan_unit.account_id, graph_set.to_dict())
def remediator(event: Dict[str, Any]) -> None: """Run the remediation lambda for a QJ result set""" config = RemediatorConfig() logger = Logger() remediation = Remediation(**event) with logger.bind(remediation=remediation): logger.info(event=QJLogEvents.RemediationInit) qj_api_client = QJAPIClient(host=config.qj_api_host) latest_result_set = qj_api_client.get_job_latest_result_set( job_name=remediation.job_name) if not latest_result_set: msg = f"No latest_result_set present for {remediation.job_name}" logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if latest_result_set.result_set_id != remediation.result_set_id: msg = ( f"Remediation result_set_id {remediation.result_set_id} does not match the " f"latest result_set_id {latest_result_set.result_set_id}") logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if not latest_result_set.job.remediate_sqs_queue: msg = f"Job {latest_result_set.job.name} has no remediator" logger.error(QJLogEvents.JobHasNoRemediator, detail=msg) raise RemediationError(msg) num_threads = 10 # TODO env var errors = [] with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] for result in latest_result_set.results: logger.info(event=QJLogEvents.ProcessResult, result=result) future = _schedule_result_remediation( executor=executor, lambda_name=latest_result_set.job.remediate_sqs_queue, lambda_timeout=300, # TODO env var? result=result, ) futures.append(future) for future in as_completed(futures): try: lambda_result = future.result() logger.info(QJLogEvents.ResultRemediationSuccessful, lambda_result=lambda_result) except Exception as ex: logger.info( event=QJLogEvents.ResultSetRemediationFailed, error=str(ex), ) errors.append(str(ex)) if errors: logger.error(event=QJLogEvents.ResultSetRemediationFailed, errors=errors) raise RemediationError( f"Errors encountered during remediation of {latest_result_set.job.name} " f"/ {latest_result_set.result_set_id}: {errors}")
def notify(self, notification: schemas.ResultSetNotification) -> None: logger = Logger() with logger.bind(notification=notification): logger.info(event=QJLogEvents.NotifyNewResultsStart) session = boto3.Session(region_name=self.region_name) sns_client = session.client("sns", region_name=self.region_name) sns_client.publish( TopicArn=self.sns_topic_arn, Message=json.dumps({"default": notification.json()}), MessageStructure="json", ) logger.info(event=QJLogEvents.NotifyNewResultsEnd)
def read_artifact(self, artifact_path: str) -> Dict[str, Any]: """Read an artifact Args: artifact_path: filesystem path to artifact Returns: artifact content """ logger = Logger() with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.ReadFromFSStart) with open(artifact_path, "r") as artifact_fp: data = json.load(artifact_fp) logger.info(event=LogEvent.ReadFromFSEnd) return data
def invoke_lambda( lambda_name: str, lambda_timeout: int, account_scan_lambda_event: AccountScanLambdaEvent ) -> AccountScanResult: """Invoke the AccountScan AWS Lambda function Args: lambda_name: name of lambda lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait at least this long for a response before timing out. account_scan_lambda_event: AccountScanLambdaEvent object to serialize to json and send to the lambda Returns: AccountScanResult Raises: Exception if there was an error invoking the lambda. """ logger = Logger() account_id = account_scan_lambda_event.account_scan_plan.account_id with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, account_id=account_id): logger.info(event=AWSLogEvents.RunAccountScanLambdaStart) boto_config = botocore.config.Config( read_timeout=lambda_timeout + 10, retries={"max_attempts": 0}, ) session = boto3.Session() lambda_client = session.client("lambda", config=boto_config) try: resp = lambda_client.invoke( FunctionName=lambda_name, Payload=account_scan_lambda_event.json().encode("utf-8") ) except Exception as invoke_ex: error = str(invoke_ex) logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=error) raise Exception( f"Error while invoking {lambda_name} with event {account_scan_lambda_event.json()}: {error}" ) from invoke_ex payload: bytes = resp["Payload"].read() if resp.get("FunctionError", None): function_error = payload.decode() logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=function_error) raise Exception( f"Function error in {lambda_name} with event {account_scan_lambda_event.json()}: {function_error}" ) payload_dict = json.loads(payload) account_scan_result = AccountScanResult(**payload_dict) logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd) return account_scan_result
def scan(self, account_scan_plans: List[AccountScanPlan]) -> List[AccountScanManifest]: """Scan accounts. Return a list of AccountScanManifest objects. Args: account_scan_plans: list of AccountScanPlan objects defining this scan op Returns: list of AccountScanManifest objects describing the output of the scan. """ account_scan_results: List[AccountScanManifest] = [] num_total_accounts = len(account_scan_plans) num_threads = min(num_total_accounts, self.max_threads) logger = Logger() with logger.bind( num_total_accounts=num_total_accounts, muxer=self.__class__.__name__, num_threads=num_threads, ): logger.info(event=AWSLogEvents.MuxerStart) with ThreadPoolExecutor(max_workers=num_threads) as executor: processed_accounts = 0 futures = [] for account_scan_plan in account_scan_plans: account_scan_future = self._schedule_account_scan(executor, account_scan_plan) futures.append(account_scan_future) logger.info( event=AWSLogEvents.MuxerQueueScan, account_id=account_scan_plan.account_id ) for future in as_completed(futures): scan_results_dict = future.result() account_id = scan_results_dict["account_id"] output_artifact = scan_results_dict["output_artifact"] account_errors = scan_results_dict["errors"] api_call_stats = scan_results_dict["api_call_stats"] artifacts = [output_artifact] if output_artifact else [] account_scan_result = AccountScanManifest( account_id=account_id, artifacts=artifacts, errors=account_errors, api_call_stats=api_call_stats, ) account_scan_results.append(account_scan_result) processed_accounts += 1 logger.info(event=AWSLogEvents.MuxerStat, num_scanned=processed_accounts) logger.info(event=AWSLogEvents.MuxerEnd) return account_scan_results
def scan(self) -> List[Resource]: """Perform a scan on all of the resource classes in this GraphSpec and return a list of Resource objects. Returns: List of Resource objects """ resources: List[Resource] = [] logger = Logger() for resource_spec_class in self.resource_spec_classes: with logger.bind(resource_type=str(resource_spec_class.type_name)): logger.debug(event=LogEvent.ScanResourceTypeStart) scanned_resources = resource_spec_class.scan( scan_accessor=self.scan_accessor) resources += scanned_resources logger.debug(event=LogEvent.ScanResourceTypeEnd) return resources
def get_sub_account_ids(account_ids: Tuple[str, ...], accessor: Accessor) -> Tuple[str, ...]: logger = Logger() logger.info(event=AWSLogEvents.GetSubAccountsStart) sub_account_ids: Set[str] = set(account_ids) for master_account_id in account_ids: with logger.bind(master_account_id=master_account_id): account_session = accessor.get_session(master_account_id) orgs_client = account_session.client("organizations") resp = orgs_client.describe_organization() if resp["Organization"]["MasterAccountId"] == master_account_id: accounts_paginator = orgs_client.get_paginator("list_accounts") for accounts_resp in accounts_paginator.paginate(): for account_resp in accounts_resp["Accounts"]: if account_resp["Status"].lower() == "active": account_id = account_resp["Id"] sub_account_ids.add(account_id) logger.info(event=AWSLogEvents.GetSubAccountsEnd) return tuple(sub_account_ids)
def write_json(self, name: str, data: BaseModel) -> str: """Write artifact data to self.output_dir/name.json Args: name: filename data: data Returns: Full filesystem path of artifact file """ logger = Logger() os.makedirs(self.output_dir, exist_ok=True) artifact_path = os.path.join(self.output_dir, f"{name}.json") with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.WriteToFSStart) with open(artifact_path, "w") as artifact_fp: artifact_fp.write(data.json(exclude_unset=True)) logger.info(event=LogEvent.WriteToFSEnd) return artifact_path
def write_artifact(self, name: str, data: Dict[str, Any]) -> str: """Write artifact data to self.output_dir/name.json Args: name: filename data: artifact data Returns: Full filesystem path of artifact file """ logger = Logger() os.makedirs(self.output_dir, exist_ok=True) artifact_path = os.path.join(self.output_dir, f"{name}.json") with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.WriteToFSStart) with open(artifact_path, "w") as artifact_fp: json.dump(data, artifact_fp, default=json_encoder) logger.info(event=LogEvent.WriteToFSEnd) return artifact_path
def _invoke_lambda( lambda_name: str, lambda_timeout: int, result: Result, ) -> Any: """Invoke a QJ's remediator function""" logger = Logger() with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, result=result): logger.info(event=QJLogEvents.InvokeResultRemediationLambdaStart) boto_config = botocore.config.Config( read_timeout=lambda_timeout + 10, retries={"max_attempts": 0}, ) session = boto3.Session() lambda_client = session.client("lambda", config=boto_config) event = result.json().encode("utf-8") try: resp = lambda_client.invoke( FunctionName=lambda_name, Payload=event, ) except Exception as invoke_ex: error = str(invoke_ex) logger.info(event=QJLogEvents.InvokeResultRemediationLambdaError, error=error) raise Exception( f"Error while invoking {lambda_name} with event: {str(event)}: {error}" ) from invoke_ex lambda_result: bytes = resp["Payload"].read() if resp.get("FunctionError", None): error = lambda_result.decode() logger.info(event=QJLogEvents.ResultRemediationLambdaRunError, error=error) raise Exception( f"Function error in {lambda_name} with event {str(event)}: {error}" ) logger.info(event=QJLogEvents.InvokeResultRemediationLambdaEnd) return json.loads(lambda_result)
def write_json(self, name: str, data: BaseModel) -> str: """Write artifact data to s3://self.bucket/self.key_prefix/name.json Args: name: s3 key name data: data Returns: S3 uri (s3://bucket/key/path) to artifact """ output_key = "/".join((self.key_prefix, f"{name}.json")) logger = Logger() with logger.bind(bucket=self.bucket, key=output_key): logger.info(event=LogEvent.WriteToS3Start) s3_client = boto3.Session().client("s3") results_str = data.json(exclude_unset=True) results_bytes = results_str.encode("utf-8") with io.BytesIO(results_bytes) as results_bytes_stream: s3_client.upload_fileobj(results_bytes_stream, self.bucket, output_key) logger.info(event=LogEvent.WriteToS3End) return f"s3://{self.bucket}/{output_key}"
def write_artifact(self, name: str, data: Dict[str, Any]) -> str: """Write artifact data to s3://self.bucket/self.key_prefix/name.json Args: name: s3 key name data: artifact data Returns: S3 uri (s3://bucket/key/path) to artifact """ output_key = "/".join((self.key_prefix, f"{name}.json")) logger = Logger() with logger.bind(bucket=self.bucket, key=output_key): logger.info(event=LogEvent.WriteToS3Start) s3_client = boto3.client("s3") results_str = json.dumps(data, default=json_encoder) results_bytes = results_str.encode("utf-8") with io.BytesIO(results_bytes) as results_bytes_stream: s3_client.upload_fileobj(results_bytes_stream, self.bucket, output_key) logger.info(event=LogEvent.WriteToS3End) return f"s3://{self.bucket}/{output_key}"
def read_artifact(self, artifact_path: str) -> Dict[str, Any]: """Read an artifact Args: artifact_path: s3 uri to artifact. s3://bucket/key/path Returns: artifact content """ bucket, key = parse_s3_uri(artifact_path) session = boto3.Session() s3_client = session.client("s3") logger = Logger() with io.BytesIO() as artifact_bytes_buf: with logger.bind(bucket=bucket, key=key): logger.info(event=LogEvent.ReadFromS3Start) s3_client.download_fileobj(bucket, key, artifact_bytes_buf) artifact_bytes_buf.flush() artifact_bytes_buf.seek(0) artifact_bytes = artifact_bytes_buf.read() logger.info(event=LogEvent.ReadFromS3End) artifact_str = artifact_bytes.decode("utf-8") artifact_dict = json.loads(artifact_str) return artifact_dict
def enqueue_queries(jobs: List[schemas.Job], queue_url: str, execution_hash: str, region: str) -> None: """Enqueue querys by sending a message for each job key to queue_url""" sqs_client = boto3.client("sqs", region_name=region) logger = Logger() with logger.bind(queue_url=queue_url, execution_hash=execution_hash): for job in jobs: job_hash = hashlib.sha256() job_hash.update(json.dumps(job.json()).encode()) message_group_id = job_hash.hexdigest() job_hash.update(execution_hash.encode()) message_dedupe_id = job_hash.hexdigest() logger.info( QJLogEvents.ScheduleJob, job=job, message_group_id=message_group_id, message_dedupe_id=message_dedupe_id, ) sqs_client.send_message( QueueUrl=queue_url, MessageBody=job.json(), MessageGroupId=message_group_id, MessageDeduplicationId=message_dedupe_id, )
def scan(self) -> Dict[str, Any]: """Scan an account and return a dict containing keys: * account_id: str * output_artifact: str * api_call_stats: Dict[str, Any] * errors: List[str] If errors is non-empty the results are incomplete for this account. output_artifact is a pointer to the actual scan data - either on the local fs or in s3. To scan an account we create a set of GraphSpecs, one for each region. Any ACCOUNT level granularity resources are only scanned in a single region (e.g. IAM Users) Returns: Dict of scan result, see above for details. """ logger = Logger() with logger.bind(account_id=self.account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) output_artifact = None stats = MultilevelCounter() errors: List[str] = [] now = int(time.time()) try: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=stats, ) # sanity check session = self.get_session() sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity()["Account"] if sts_account_id != self.account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {self.account_id}" ) if self.regions: scan_regions = tuple(self.regions) else: scan_regions = get_all_enabled_regions(session=session) # build graph specs. # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() resource_class_scan_granularity = resource_spec_class.scan_granularity if resource_class_scan_granularity == ScanGranularity.ACCOUNT: regions_services_resource_spec_classes[scan_regions[ 0]][client_name].append(resource_spec_class) elif resource_class_scan_granularity == ScanGranularity.REGION: for region in scan_regions: regions_services_resource_spec_classes[region][ client_name].append(resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_class_scan_granularity} not implemented" ) with ThreadPoolExecutor( max_workers=self.max_svc_threads) as executor: futures = [] for ( region, services_resource_spec_classes, ) in regions_services_resource_spec_classes.items(): for ( service, resource_spec_classes, ) in services_resource_spec_classes.items(): region_session = self.get_session(region=region) region_creds = region_session.get_credentials() scan_future = schedule_scan_services( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=self.account_id, region=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( resource_spec_classes), ) futures.append(scan_future) for future in as_completed(futures): graph_set_dict = future.result() graph_set = GraphSet.from_dict(graph_set_dict) errors += graph_set.errors account_graph_set.merge(graph_set) account_graph_set.validate() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) errors.append(" : ".join((error_str, trace_back))) unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=self.account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=stats, ) account_graph_set.validate() output_artifact = self.artifact_writer.write_artifact( name=self.account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() return { "account_id": self.account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }
def scan(self) -> List[Dict[str, Any]]: logger = Logger() scan_result_dicts = [] now = int(time.time()) prescan_account_ids_errors: DefaultDict[str, List[str]] = defaultdict(list) futures = [] with ThreadPoolExecutor(max_workers=self.max_threads) as executor: shuffled_account_ids = random.sample( self.account_scan_plan.account_ids, k=len(self.account_scan_plan.account_ids)) for account_id in shuffled_account_ids: with logger.bind(account_id=account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: scan_regions = tuple( self.account_scan_plan.regions) else: scan_regions = get_all_enabled_regions( session=session) account_gran_scan_region = random.choice( self.preferred_account_scan_regions) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT: if resource_spec_class.region_whitelist: account_resource_scan_region = resource_spec_class.region_whitelist[ 0] else: account_resource_scan_region = account_gran_scan_region regions_services_resource_spec_classes[ account_resource_scan_region][ client_name].append( resource_spec_class) elif resource_spec_class.scan_granularity == ScanGranularity.REGION: if resource_spec_class.region_whitelist: resource_scan_regions = tuple( region for region in scan_regions if region in resource_spec_class.region_whitelist) if not resource_scan_regions: resource_scan_regions = resource_spec_class.region_whitelist else: resource_scan_regions = scan_regions for region in resource_scan_regions: regions_services_resource_spec_classes[ region][client_name].append( resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented" ) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( svc_resource_spec_classes), ) futures.append(future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_account_ids_errors[account_id].append( f"{error_str}\n{trace_back}") account_ids_graph_set_dicts: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for future in as_completed(futures): account_id, graph_set_dict = future.result() account_ids_graph_set_dicts[account_id].append(graph_set_dict) # first make sure no account id appears both in account_ids_graph_set_dicts # and prescan_account_ids_errors - this should never happen doubled_accounts = set( account_ids_graph_set_dicts.keys()).intersection( set(prescan_account_ids_errors.keys())) if doubled_accounts: raise Exception(( f"BUG: Account(s) {doubled_accounts} in both " "account_ids_graph_set_dicts and prescan_account_ids_errors.")) # graph prescan error accounts for account_id, errors in prescan_account_ids_errors.items(): with logger.bind(account_id=account_id): unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter(), ) account_graph_set.validate() output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) # graph rest for account_id, graph_set_dicts in account_ids_graph_set_dicts.items(): with logger.bind(account_id=account_id): # if there are any errors whatsoever we generate an empty graph with # errors only errors = [] for graph_set_dict in graph_set_dicts: errors += graph_set_dict["errors"] if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter( ), # ENHANCHMENT: could technically get partial stats. ) account_graph_set.validate() else: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=MultilevelCounter(), ) for graph_set_dict in graph_set_dicts: graph_set = GraphSet.from_dict(graph_set_dict) account_graph_set.merge(graph_set) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) return scan_result_dicts
def scan(self) -> AccountScanResult: logger = Logger() now = int(time.time()) prescan_errors: List[str] = [] futures: List[Future] = [] account_id = self.account_scan_plan.account_id with logger.bind(account_id=account_id): with ThreadPoolExecutor(max_workers=self.max_threads) as executor: logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: account_scan_regions = tuple( self.account_scan_plan.regions) else: account_scan_regions = get_all_enabled_regions( session=session) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) for resource_spec_class in self.resource_spec_classes: resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions( resource_spec_class=resource_spec_class, region_whitelist=account_scan_regions, ) for region in resource_regions: regions_services_resource_spec_classes[region][ resource_spec_class.service_name].append( resource_spec_class) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: parallel_svc_resource_spec_classes = [ svc_resource_spec_class for svc_resource_spec_class in svc_resource_spec_classes if svc_resource_spec_class.parallel_scan ] serial_svc_resource_spec_classes = [ svc_resource_spec_class for svc_resource_spec_class in svc_resource_spec_classes if not svc_resource_spec_class.parallel_scan ] for (parallel_svc_resource_spec_class ) in parallel_svc_resource_spec_classes: parallel_future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=( parallel_svc_resource_spec_class, ), ) futures.append(parallel_future) serial_future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( serial_svc_resource_spec_classes), ) futures.append(serial_future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_errors.append(f"{error_str}\n{trace_back}") graph_sets: List[GraphSet] = [] for future in as_completed(futures): graph_set = future.result() graph_sets.append(graph_set) # if there was a prescan error graph it and return the result if prescan_errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=prescan_errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=prescan_errors, ) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set, ) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) return AccountScanResult( account_id=account_id, artifacts=[output_artifact], errors=prescan_errors, ) # if there are any errors whatsoever we generate an empty graph with errors only errors = [] for graph_set in graph_sets: errors += graph_set.errors if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, ) else: account_graph_set = GraphSet.from_graph_sets(graph_sets) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set, ) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) return AccountScanResult( account_id=account_id, artifacts=[output_artifact], errors=errors, )
def load_graph(self, bucket: str, key: str, load_iam_role_arn: str) -> GraphMetadata: """Load a graph into Neptune. Args: bucket: s3 bucket of graph rdf key: s3 key of graph rdf load_iam_role_arn: arn of iam role used to load the graph Returns: GraphMetadata object describing loaded graph Raises: NeptuneLoadGraphException if errors occur during graph load """ session = boto3.Session(region_name=self._neptune_endpoint.region) s3_client = session.client("s3") rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket, Key=key) tag_set = rdf_object_tagging["TagSet"] graph_name = get_required_tag_value(tag_set, "name") graph_version = get_required_tag_value(tag_set, "version") graph_start_time = int(get_required_tag_value(tag_set, "start_time")) graph_end_time = int(get_required_tag_value(tag_set, "end_time")) graph_metadata = GraphMetadata( uri= f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}", name=graph_name, version=graph_version, start_time=graph_start_time, end_time=graph_end_time, ) logger = Logger() with logger.bind( rdf_bucket=bucket, rdf_key=key, graph_uri=graph_metadata.uri, neptune_endpoint=self._neptune_endpoint.get_endpoint_str(), ): session = boto3.Session(region_name=self._neptune_endpoint.region) credentials = session.get_credentials() auth = AWSRequestsAuth( aws_access_key=credentials.access_key, aws_secret_access_key=credentials.secret_key, aws_token=credentials.token, aws_host=self._neptune_endpoint.get_endpoint_str(), aws_region=self._neptune_endpoint.region, aws_service="neptune-db", ) post_body = { "source": f"s3://{bucket}/{key}", "format": "rdfxml", "iamRoleArn": load_iam_role_arn, "region": self._neptune_endpoint.region, "failOnError": "TRUE", "parallelism": "MEDIUM", "parserConfiguration": { "baseUri": GRAPH_BASE_URI, "namedGraphUri": graph_metadata.uri, }, } logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body) submit_resp = requests.post( self._neptune_endpoint.get_loader_endpoint(), json=post_body, auth=auth) if submit_resp.status_code != 200: raise NeptuneLoadGraphException( f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}" ) submit_resp_json = submit_resp.json() load_id = submit_resp_json["payload"]["loadId"] with logger.bind(load_id=load_id): logger.info(event=LogEvent.NeptuneLoadPolling) while True: time.sleep(10) status_resp = requests.get( f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}", params={ "details": "true", "errors": "true" }, auth=auth, ) if status_resp.status_code != 200: raise NeptuneLoadGraphException( f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}" ) status_resp_json = status_resp.json() status = status_resp_json["payload"]["overallStatus"][ "status"] logger.info(event=LogEvent.NeptuneLoadPolling, status=status) if status == "LOAD_COMPLETED": break if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"): logger.error(event=LogEvent.NeptuneLoadError, status=status) raise NeptuneLoadGraphException( f"Error loading graph: {status_resp_json}") logger.info(event=LogEvent.NeptuneLoadEnd) logger.info(event=LogEvent.MetadataGraphUpdateStart) self._register_graph(graph_metadata=graph_metadata) logger.info(event=LogEvent.MetadataGraphUpdateEnd) return graph_metadata
def write_graph_set(self, name: str, graph_set: GraphSet, compression: Optional[str] = None) -> str: """Write a graph artifact Args: name: name graph_set: GraphSet to write Returns: path to written artifact """ logger = Logger() if compression is None: key = f"{name}.rdf" elif compression == GZIP: key = f"{name}.rdf.gz" else: raise ValueError(f"Unknown compression arg {compression}") output_key = "/".join((self.key_prefix, key)) graph = graph_set.to_rdf() with logger.bind(bucket=self.bucket, key_prefix=self.key_prefix, key=key): logger.info(event=LogEvent.WriteToS3Start) with io.BytesIO() as rdf_bytes_buf: if compression is None: graph.serialize(rdf_bytes_buf) elif compression == GZIP: with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz: graph.serialize(gz) else: raise ValueError(f"Unknown compression arg {compression}") rdf_bytes_buf.flush() rdf_bytes_buf.seek(0) session = boto3.Session() s3_client = session.client("s3") s3_client.upload_fileobj(rdf_bytes_buf, self.bucket, output_key) s3_client.put_object_tagging( Bucket=self.bucket, Key=output_key, Tagging={ "TagSet": [ { "Key": "name", "Value": graph_set.name }, { "Key": "version", "Value": graph_set.version }, { "Key": "start_time", "Value": str(graph_set.start_time) }, { "Key": "end_time", "Value": str(graph_set.end_time) }, ] }, ) logger.info(event=LogEvent.WriteToS3End) return f"s3://{self.bucket}/{output_key}"