def read_json(self, path: str) -> Dict[str, Any]: """Read a json artifact Args: path: filesystem path to artifact Returns: artifact content """ logger = Logger() with logger.bind(artifact_path=path): logger.info(event=LogEvent.ReadFromFSStart) with open(path, "r") as artifact_fp: data = json.load(artifact_fp) logger.info(event=LogEvent.ReadFromFSEnd) return data
def get_version(self, db_session: Session, job_name: str, created: datetime) -> Job: """Get a specific version of a Job by created timestamp""" logger = Logger() logger.info(event=QJLogEvents.GetJobVersion, job_name=job_name) query = db_session.query(Job).filter(Job.name == job_name).filter( Job.created == created) results = query.all() num_results = len(results) if num_results: if num_results > 1: raise Exception( f"More than one job found for {job_name} with version {created}" ) return results[0] raise JobVersionNotFound(f"Could not find job {job_name} / {created}")
def scan( self, account_scan_plan: AccountScanPlan ) -> Generator[AccountScanManifest, None, None]: """Scan accounts. Return a list of AccountScanManifest objects. Args: account_scan_plan: AccountScanPlan defining this scan op Yields: AccountScanManifest objects """ num_total_accounts = len(account_scan_plan.account_ids) account_scan_plans = account_scan_plan.to_batches( max_accounts=self.config.concurrency.max_accounts_per_thread) num_account_batches = len(account_scan_plans) num_threads = min(num_account_batches, self.config.concurrency.max_account_scan_threads) logger = Logger() with logger.bind( num_total_accounts=num_total_accounts, num_account_batches=num_account_batches, muxer=self.__class__.__name__, num_muxer_threads=num_threads, ): logger.info(event=AWSLogEvents.MuxerStart) with ThreadPoolExecutor(max_workers=num_threads) as executor: processed_accounts = 0 futures = [] for sub_account_scan_plan in account_scan_plans: account_scan_future = self._schedule_account_scan( executor, sub_account_scan_plan) futures.append(account_scan_future) logger.info( event=AWSLogEvents.MuxerQueueScan, account_ids=",".join( sub_account_scan_plan.account_ids), ) for future in as_completed(futures): scan_results_dicts = future.result() for scan_results_dict in scan_results_dicts: account_id = scan_results_dict["account_id"] output_artifact = scan_results_dict["output_artifact"] account_errors = scan_results_dict["errors"] api_call_stats = scan_results_dict["api_call_stats"] artifacts = [output_artifact ] if output_artifact else [] account_scan_result = AccountScanManifest( account_id=account_id, artifacts=artifacts, errors=account_errors, api_call_stats=api_call_stats, ) yield account_scan_result processed_accounts += 1 logger.info(event=AWSLogEvents.MuxerStat, num_scanned=processed_accounts) logger.info(event=AWSLogEvents.MuxerEnd)
def lambda_handler(event, context): json_bucket = event["Records"][0]["s3"]["bucket"]["name"] json_key = urllib.parse.unquote(event["Records"][0]["s3"]["object"]["key"]) rdf_bucket = get_required_lambda_env_var("RDF_BUCKET") rdf_key = ".".join(json_key.split(".")[:-1]) + ".rdf.gz" session = boto3.Session() s3_client = session.client("s3") logger = Logger() with logger.bind(json_bucket=json_bucket, json_key=json_key): graph_pkg = graph_pkg_from_s3(s3_client=s3_client, json_bucket=json_bucket, json_key=json_key) with logger.bind(rdf_bucket=rdf_bucket, rdf_key=rdf_key): logger.info(event=LogEvent.WriteToS3Start) with io.BytesIO() as rdf_bytes_buf: with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz: graph_pkg.graph.serialize(gz) rdf_bytes_buf.flush() rdf_bytes_buf.seek(0) s3_client.upload_fileobj(rdf_bytes_buf, rdf_bucket, rdf_key) s3_client.put_object_tagging( Bucket=rdf_bucket, Key=rdf_key, Tagging={ "TagSet": [ { "Key": "name", "Value": graph_pkg.name }, { "Key": "version", "Value": graph_pkg.version }, { "Key": "start_time", "Value": str(graph_pkg.start_time) }, { "Key": "end_time", "Value": str(graph_pkg.end_time) }, ] }, ) logger.info(event=LogEvent.WriteToS3End)
def create(self, db_session: Session, obj_in: schemas.ResultSetCreate) -> ResultSet: """Create a ResultSet""" logger = Logger() num_results = len(obj_in.results) logger.info( event=QJLogEvents.CreateResultSet, job=obj_in.job, created=obj_in.created, graph_spec=obj_in.graph_spec, num_results=num_results, ) job = self._job_crud.get_version(db_session=db_session, job_name=obj_in.job.name, created=obj_in.job.created) if not job: raise JobVersionNotFound( f"Could not find job {obj_in.job.name} / {obj_in.job.created}") # create result_set db object if num_results > self._max_result_set_results: raise ResultSetResultsLimitExceeded( f"Result set has {num_results} results, limit is {self._max_result_set_results}" ) result_set = ResultSet(job=job, created=obj_in.created, graph_spec=json.loads(obj_in.graph_spec.json())) db_session.add(result_set) # create result db objects for obj_in_result in obj_in.results: result_size = len(json.dumps(obj_in_result.result)) if result_size > self._max_result_size_bytes: raise ResultSizeExceeded(( f"Result size {result_size} exceeds max {self._max_result_size_bytes}: " f"{json.dumps(obj_in_result.result)[:self._max_result_size_bytes]}..." )) result = Result( result_set=result_set, account_id=obj_in_result.account_id, result=obj_in_result.result, ) db_session.add(result) db_session.commit() db_session.refresh(result_set) return result_set
def set_automated_backups(cls, client: BaseClient, dbinstances: Dict[str, Dict[str, Any]]) -> None: logger = Logger() backup_paginator = client.get_paginator( "describe_db_instance_automated_backups") for resp in backup_paginator.paginate(): for backup in resp.get("DBInstanceAutomatedBackups", []): if backup["DBInstanceArn"] in dbinstances: dbinstances[backup["DBInstanceArn"]]["Backup"].append( backup) else: logger.info( event=AWSLogEvents.ScanAWSResourcesNonFatalError, msg= (f'Unable to find matching DB Instance {backup["DBInstanceArn"]} ' "(Possible Deletion)"), )
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]: logger = Logger() with logger.bind(account_id=scan_unit.account_id, region=scan_unit.region_name, service=scan_unit.service): logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart) session = boto3.Session( aws_access_key_id=scan_unit.access_key, aws_secret_access_key=scan_unit.secret_key, aws_session_token=scan_unit.token, region_name=scan_unit.region_name, ) scan_accessor = AWSAccessor(session=session, account_id=scan_unit.account_id, region_name=scan_unit.region_name) graph_spec = GraphSpec( name=scan_unit.graph_name, version=scan_unit.graph_version, resource_spec_classes=scan_unit.resource_spec_classes, scan_accessor=scan_accessor, ) start_time = int(time.time()) resources: List[Resource] = [] errors = [] try: resources = graph_spec.scan() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) error = f"{str(ex)}\n{trace_back}" errors.append(error) end_time = int(time.time()) graph_set = GraphSet( name=scan_unit.graph_name, version=scan_unit.graph_version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=scan_accessor.api_call_stats, ) logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd) return (scan_unit.account_id, graph_set.to_dict())
def get_sub_account_ids(account_ids: Tuple[str, ...], accessor: Accessor) -> Tuple[str, ...]: logger = Logger() logger.info(event=AWSLogEvents.GetSubAccountsStart) sub_account_ids: Set[str] = set(account_ids) for master_account_id in account_ids: with logger.bind(master_account_id=master_account_id): account_session = accessor.get_session(master_account_id) orgs_client = account_session.client("organizations") resp = orgs_client.describe_organization() if resp["Organization"]["MasterAccountId"] == master_account_id: accounts_paginator = orgs_client.get_paginator("list_accounts") for accounts_resp in accounts_paginator.paginate(): for account_resp in accounts_resp["Accounts"]: if account_resp["Status"].lower() == "active": account_id = account_resp["Id"] sub_account_ids.add(account_id) logger.info(event=AWSLogEvents.GetSubAccountsEnd) return tuple(sub_account_ids)
def remediator(event: Dict[str, Any]) -> None: """Run the remediation lambda for a QJ result set""" config = RemediatorConfig() logger = Logger() remediation = Remediation(**event) with logger.bind(remediation=remediation): logger.info(event=QJLogEvents.RemediationInit) qj_api_client = QJAPIClient(host=config.qj_api_host) latest_result_set = qj_api_client.get_job_latest_result_set( job_name=remediation.job_name) if not latest_result_set: msg = f"No latest_result_set present for {remediation.job_name}" logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if latest_result_set.result_set_id != remediation.result_set_id: msg = ( f"Remediation result_set_id {remediation.result_set_id} does not match the " f"latest result_set_id {latest_result_set.result_set_id}") logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if not latest_result_set.job.remediate_sqs_queue: msg = f"Job {latest_result_set.job.name} has no remediator" logger.error(QJLogEvents.JobHasNoRemediator, detail=msg) raise RemediationError(msg) num_threads = 10 # TODO env var errors = [] with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] for result in latest_result_set.results: logger.info(event=QJLogEvents.ProcessResult, result=result) future = _schedule_result_remediation( executor=executor, lambda_name=latest_result_set.job.remediate_sqs_queue, lambda_timeout=300, # TODO env var? result=result, ) futures.append(future) for future in as_completed(futures): try: lambda_result = future.result() logger.info(QJLogEvents.ResultRemediationSuccessful, lambda_result=lambda_result) except Exception as ex: logger.info( event=QJLogEvents.ResultSetRemediationFailed, error=str(ex), ) errors.append(str(ex)) if errors: logger.error(event=QJLogEvents.ResultSetRemediationFailed, errors=errors) raise RemediationError( f"Errors encountered during remediation of {latest_result_set.job.name} " f"/ {latest_result_set.result_set_id}: {errors}")
def get( self, db_session: Session, result_set_id: str, ) -> ResultSet: """Get a ResultSet by id""" logger = Logger() logger.info(event=QJLogEvents.GetResultSet, result_set_id=result_set_id) query = db_session.query(ResultSet).filter( ResultSet.result_set_id == result_set_id) result_sets = query.all() num_result_sets = len(result_sets) if num_result_sets: if num_result_sets > 1: raise Exception( f"More than one result_set found for {result_set_id}") return result_sets[0] raise ResultSetNotFound(f"No result set {result_set_id} found")
def get_active( self, db_session: Session, job_name: str, ) -> Job: """Get the active version of a Job""" logger = Logger() query = db_session.query(Job).filter( Job.active).filter(Job.name == job_name) results = query.all() num_results = len(results) logger.info(event=QJLogEvents.GetActiveJob, job_name=job_name, num_results=num_results) if num_results: assert num_results == 1, f"More than one active job found for {job_name}" return results[0] raise ActiveJobVersionNotFound( f"No active job version found for {job_name}")
def write_json(self, name: str, data: BaseModel) -> str: """Write artifact data to self.output_dir/name.json Args: name: filename data: data Returns: Full filesystem path of artifact file """ logger = Logger() os.makedirs(self.output_dir, exist_ok=True) artifact_path = os.path.join(self.output_dir, f"{name}.json") with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.WriteToFSStart) with open(artifact_path, "w") as artifact_fp: artifact_fp.write(data.json(exclude_unset=True)) logger.info(event=LogEvent.WriteToFSEnd) return artifact_path
def create(self, db_session: Session, job_create_in: schemas.JobCreate) -> Job: """Create a Job""" logger = Logger() logger.info(event=QJLogEvents.CreateJob, job_create=job_create_in) try: query = rdflib.Graph().query(job_create_in.query) except Exception as ex: raise JobQueryInvalid( f"Invalid query {job_create_in.query}: {str(ex)}") from ex query_fields = [str(query_var) for query_var in query.vars] if self._account_id_key not in query_fields: raise JobQueryMissingAccountId( f"Query {job_create_in.query} missing '{self._account_id_key}' field" ) if job_create_in.result_expiration_sec is None: job_create_in.result_expiration_sec = self._result_expiration_sec_default if job_create_in.result_expiration_sec > self._result_expiration_sec_limit: raise JobInvalid( f"Field result_expiration_sec value {job_create_in.result_expiration_sec} " f"must be <= {self._result_expiration_sec_limit}") if job_create_in.max_graph_age_sec is None: job_create_in.max_graph_age_sec = self._max_graph_age_sec_default else: if job_create_in.max_graph_age_sec > self._max_graph_age_sec_limit: raise JobInvalid( f"Field max_graph_age_sec value {job_create_in.max_graph_age_sec} must be " f"<= {self._max_graph_age_sec_limit}") if job_create_in.max_result_age_sec is None: job_create_in.max_result_age_sec = self._max_result_age_sec_default else: if job_create_in.max_result_age_sec > self._max_result_age_sec_limit: raise JobInvalid( f"Field max_result_age_sec value {job_create_in.max_result_age_sec} must be " f"<= {self._max_result_age_sec_limit}") obj_in_data = job_create_in.dict() obj_in_data["query_fields"] = query_fields job_create = schemas.Job(**obj_in_data) db_obj = Job(**job_create.dict()) # type: ignore db_session.add(db_obj) db_session.commit() db_session.refresh(db_obj) return db_obj
def write_artifact(self, name: str, data: Dict[str, Any]) -> str: """Write artifact data to self.output_dir/name.json Args: name: filename data: artifact data Returns: Full filesystem path of artifact file """ logger = Logger() os.makedirs(self.output_dir, exist_ok=True) artifact_path = os.path.join(self.output_dir, f"{name}.json") with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.WriteToFSStart) with open(artifact_path, "w") as artifact_fp: json.dump(data, artifact_fp, default=json_encoder) logger.info(event=LogEvent.WriteToFSEnd) return artifact_path
def update_version( self, db_session: Session, job_name: str, created: datetime, job_update: schemas.JobUpdate, ) -> Job: """Update a Job""" logger = Logger() logger.info(QJLogEvents.UpdateJob, job_name=job_name, created=created, job_update=job_update) job_version = self.get_version(db_session=db_session, job_name=job_name, created=created) if job_update.description is not None: job_version.description = job_update.description if job_update.category is not None: job_version.category = job_update.category if job_update.result_expiration_sec is not None: job_version.result_expiration_sec = job_update.result_expiration_sec if job_update.max_graph_age_sec is not None: job_version.max_graph_age_sec = job_update.max_graph_age_sec if job_update.max_result_age_sec is not None: job_version.max_result_age_sec = job_update.max_result_age_sec if job_update.active is not None: if job_update.active: query = db_session.query(Job).filter(Job.name == job_name) job_versions = query.all() for _job_version in job_versions: if _job_version != job_version: _job_version.active = False job_version.active = job_update.active if job_version.active: self._create_views(db_session=db_session, job_version=job_version) if job_update.notify_if_results is not None: job_version.notify_if_results = job_update.notify_if_results db_session.commit() db_session.refresh(job_version) return job_version
def invoke_lambda(lambda_name: str, lambda_timeout: int, event: Dict[str, Any]) -> List[Dict[str, Any]]: """Invoke an AWS Lambda function Args: lambda_name: name of lambda lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait at least this long for a response before timing out. event: event data to send to the lambda Returns: lambda response payload Raises: Exception if there was an error invoking the lambda. """ logger = Logger() account_ids = [account_id for account_id in event["account_ids"]] with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, account_ids=account_ids): logger.info(event=AWSLogEvents.RunAccountScanLambdaStart) boto_config = botocore.config.Config( read_timeout=lambda_timeout + 10, retries={"max_attempts": 0}, ) session = boto3.Session() lambda_client = session.client("lambda", config=boto_config) try: resp = lambda_client.invoke( FunctionName=lambda_name, Payload=json.dumps(event).encode("utf-8")) except Exception as invoke_ex: error = str(invoke_ex) logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=error) raise Exception( f"Error while invoking {lambda_name} with event {event}: {error}" ) payload: bytes = resp["Payload"].read() if resp.get("FunctionError", None): function_error = payload.decode() logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=function_error) raise Exception( f"Function error in {lambda_name} with event {event}: {function_error}" ) payload_dict = json.loads(payload) logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd) return payload_dict
def lambda_handler(event, context): rdf_bucket = event["Records"][0]["s3"]["bucket"]["name"] rdf_key = urllib.parse.unquote(event["Records"][0]["s3"]["object"]["key"]) neptune_host = get_required_lambda_env_var("NEPTUNE_HOST") neptune_port = get_required_lambda_env_var("NEPTUNE_PORT") neptune_region = get_required_lambda_env_var("NEPTUNE_REGION") neptune_load_iam_role_arn = get_required_lambda_env_var( "NEPTUNE_LOAD_IAM_ROLE_ARN") on_success_sns_topic_arn = get_required_lambda_env_var( "ON_SUCCESS_SNS_TOPIC_ARN") endpoint = NeptuneEndpoint(host=neptune_host, port=neptune_port, region=neptune_region) neptune_client = AltimeterNeptuneClient(max_age_min=1440, neptune_endpoint=endpoint) graph_metadata = neptune_client.load_graph( bucket=rdf_bucket, key=rdf_key, load_iam_role_arn=neptune_load_iam_role_arn) logger = Logger() logger.info(event=LogEvent.GraphLoadedSNSNotificationStart) sns_client = boto3.client("sns") message_dict = { "uri": graph_metadata.uri, "name": graph_metadata.name, "version": graph_metadata.version, "start_time": graph_metadata.start_time, "end_time": graph_metadata.end_time, "neptune_endpoint": endpoint.get_endpoint_str(), } message_dict["default"] = json.dumps(message_dict) sns_client.publish(TopicArn=on_success_sns_topic_arn, MessageStructure="json", Message=json.dumps(message_dict)) logger.info(event=LogEvent.GraphLoadedSNSNotificationEnd)
def write_artifact(self, name: str, data: Dict[str, Any]) -> str: """Write artifact data to s3://self.bucket/self.key_prefix/name.json Args: name: s3 key name data: artifact data Returns: S3 uri (s3://bucket/key/path) to artifact """ output_key = "/".join((self.key_prefix, f"{name}.json")) logger = Logger() with logger.bind(bucket=self.bucket, key=output_key): logger.info(event=LogEvent.WriteToS3Start) s3_client = boto3.client("s3") results_str = json.dumps(data, default=json_encoder) results_bytes = results_str.encode("utf-8") with io.BytesIO(results_bytes) as results_bytes_stream: s3_client.upload_fileobj(results_bytes_stream, self.bucket, output_key) logger.info(event=LogEvent.WriteToS3End) return f"s3://{self.bucket}/{output_key}"
def write_json(self, name: str, data: BaseModel) -> str: """Write artifact data to s3://self.bucket/self.key_prefix/name.json Args: name: s3 key name data: data Returns: S3 uri (s3://bucket/key/path) to artifact """ output_key = "/".join((self.key_prefix, f"{name}.json")) logger = Logger() with logger.bind(bucket=self.bucket, key=output_key): logger.info(event=LogEvent.WriteToS3Start) s3_client = boto3.Session().client("s3") results_str = data.json(exclude_unset=True) results_bytes = results_str.encode("utf-8") with io.BytesIO(results_bytes) as results_bytes_stream: s3_client.upload_fileobj(results_bytes_stream, self.bucket, output_key) logger.info(event=LogEvent.WriteToS3End) return f"s3://{self.bucket}/{output_key}"
def invoke_lambda( lambda_name: str, lambda_timeout: int, account_scan_lambda_event: AccountScanLambdaEvent ) -> AccountScanResult: """Invoke the AccountScan AWS Lambda function Args: lambda_name: name of lambda lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait at least this long for a response before timing out. account_scan_lambda_event: AccountScanLambdaEvent object to serialize to json and send to the lambda Returns: AccountScanResult Raises: Exception if there was an error invoking the lambda. """ logger = Logger() account_id = account_scan_lambda_event.account_scan_plan.account_id with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, account_id=account_id): logger.info(event=AWSLogEvents.RunAccountScanLambdaStart) boto_config = botocore.config.Config( read_timeout=lambda_timeout + 10, retries={"max_attempts": 0}, ) session = boto3.Session() lambda_client = session.client("lambda", config=boto_config) try: resp = lambda_client.invoke( FunctionName=lambda_name, Payload=account_scan_lambda_event.json().encode("utf-8") ) except Exception as invoke_ex: error = str(invoke_ex) logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=error) raise Exception( f"Error while invoking {lambda_name} with event {account_scan_lambda_event.json()}: {error}" ) from invoke_ex payload: bytes = resp["Payload"].read() if resp.get("FunctionError", None): function_error = payload.decode() logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=function_error) raise Exception( f"Function error in {lambda_name} with event {account_scan_lambda_event.json()}: {function_error}" ) payload_dict = json.loads(payload) account_scan_result = AccountScanResult(**payload_dict) logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd) return account_scan_result
def _create_latest_view(self, db_session: Session, job_version: Job) -> None: """Create the _latest view for a Job""" logger = Logger() latest_view_name = self._get_latest_view_name( job_name=job_version.name) self._drop_view(db_session=db_session, view_name=latest_view_name) create_sql = ( f"CREATE VIEW {latest_view_name} AS\n" f"SELECT result_created, {', '.join(job_version.query_fields)}\n" f"FROM\n" "(\n" f" SELECT\n" " rs.created as result_created,\n" f" lpad(r.account_id::text, 12, '0') as {self._account_id_key},\n" ) for query_field in job_version.query_fields: if query_field != self._account_id_key: create_sql += f" result->>'{query_field}' as {query_field},\n" create_sql += ( f" RANK () OVER (PARTITION BY r.account_id ORDER BY rs.created DESC) as rank_number\n" " FROM\n" " result r\n" " INNER JOIN result_set rs ON r.result_set_id = rs.id\n" " INNER JOIN job j ON rs.job_id = j.id\n" " WHERE\n" f" j.name = '{job_version.name}' AND\n" " j.active = true AND\n" f" rs.created > CURRENT_TIMESTAMP - INTERVAL '{job_version.max_result_age_sec} seconds'\n" ") ranked_query\n" "WHERE rank_number = 1\n" f"ORDER BY {self._account_id_key};\n") logger.info(event=QJLogEvents.CreateView, view_name=latest_view_name) db_session.execute(create_sql) grant_sql = f"GRANT SELECT ON {latest_view_name} TO {self._db_ro_user};\n" db_session.execute(grant_sql)
def read_artifact(self, artifact_path: str) -> Dict[str, Any]: """Read an artifact Args: artifact_path: s3 uri to artifact. s3://bucket/key/path Returns: artifact content """ bucket, key = parse_s3_uri(artifact_path) session = boto3.Session() s3_client = session.client("s3") logger = Logger() with io.BytesIO() as artifact_bytes_buf: with logger.bind(bucket=bucket, key=key): logger.info(event=LogEvent.ReadFromS3Start) s3_client.download_fileobj(bucket, key, artifact_bytes_buf) artifact_bytes_buf.flush() artifact_bytes_buf.seek(0) artifact_bytes = artifact_bytes_buf.read() logger.info(event=LogEvent.ReadFromS3End) artifact_str = artifact_bytes.decode("utf-8") artifact_dict = json.loads(artifact_str) return artifact_dict
def scan(self, account_scan_plans: List[AccountScanPlan]) -> List[AccountScanManifest]: """Scan accounts. Return a list of AccountScanManifest objects. Args: account_scan_plans: list of AccountScanPlan objects defining this scan op Returns: list of AccountScanManifest objects describing the output of the scan. """ account_scan_results: List[AccountScanManifest] = [] num_total_accounts = len(account_scan_plans) num_threads = min(num_total_accounts, self.max_threads) logger = Logger() with logger.bind( num_total_accounts=num_total_accounts, muxer=self.__class__.__name__, num_threads=num_threads, ): logger.info(event=AWSLogEvents.MuxerStart) with ThreadPoolExecutor(max_workers=num_threads) as executor: processed_accounts = 0 futures = [] for account_scan_plan in account_scan_plans: account_scan_future = self._schedule_account_scan(executor, account_scan_plan) futures.append(account_scan_future) logger.info( event=AWSLogEvents.MuxerQueueScan, account_id=account_scan_plan.account_id ) for future in as_completed(futures): scan_results_dict = future.result() account_id = scan_results_dict["account_id"] output_artifact = scan_results_dict["output_artifact"] account_errors = scan_results_dict["errors"] api_call_stats = scan_results_dict["api_call_stats"] artifacts = [output_artifact] if output_artifact else [] account_scan_result = AccountScanManifest( account_id=account_id, artifacts=artifacts, errors=account_errors, api_call_stats=api_call_stats, ) account_scan_results.append(account_scan_result) processed_accounts += 1 logger.info(event=AWSLogEvents.MuxerStat, num_scanned=processed_accounts) logger.info(event=AWSLogEvents.MuxerEnd) return account_scan_results
def enqueue_queries(jobs: List[schemas.Job], queue_url: str, execution_hash: str, region: str) -> None: """Enqueue querys by sending a message for each job key to queue_url""" sqs_client = boto3.client("sqs", region_name=region) logger = Logger() with logger.bind(queue_url=queue_url, execution_hash=execution_hash): for job in jobs: job_hash = hashlib.sha256() job_hash.update(json.dumps(job.json()).encode()) message_group_id = job_hash.hexdigest() job_hash.update(execution_hash.encode()) message_dedupe_id = job_hash.hexdigest() logger.info( QJLogEvents.ScheduleJob, job=job, message_group_id=message_group_id, message_dedupe_id=message_dedupe_id, ) sqs_client.send_message( QueueUrl=queue_url, MessageBody=job.json(), MessageGroupId=message_group_id, MessageDeduplicationId=message_dedupe_id, )
def lambda_handler(_: Dict[str, Any], __: Any) -> None: """Lambda entrypoint""" logger = Logger() config = PrunerConfig() logger.info(event=QJLogEvents.InitConfig, config=config) api_key = get_api_key(region_name=config.region) qj_client = QJAPIClient(host=config.api_host, port=config.api_port, api_key=api_key) logger.info(event=QJLogEvents.DeleteStart) result = qj_client.delete_expired_result_sets() logger.info(event=QJLogEvents.DeleteEnd, result=result)
def pruner() -> None: """Prune results according to Job config settings""" logger = Logger() pruner_config = PrunerConfig() logger.info(event=QJLogEvents.InitConfig, config=pruner_config) api_key = get_api_key(region_name=pruner_config.region) qj_client = QJAPIClient( host=pruner_config.api_host, port=pruner_config.api_port, api_key=api_key ) logger.info(event=QJLogEvents.DeleteStart) result = qj_client.delete_expired_result_sets() logger.info(event=QJLogEvents.DeleteEnd, result=result)
def _invoke_lambda( lambda_name: str, lambda_timeout: int, result: Result, ) -> Any: """Invoke a QJ's remediator function""" logger = Logger() with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, result=result): logger.info(event=QJLogEvents.InvokeResultRemediationLambdaStart) boto_config = botocore.config.Config( read_timeout=lambda_timeout + 10, retries={"max_attempts": 0}, ) session = boto3.Session() lambda_client = session.client("lambda", config=boto_config) event = result.json().encode("utf-8") try: resp = lambda_client.invoke( FunctionName=lambda_name, Payload=event, ) except Exception as invoke_ex: error = str(invoke_ex) logger.info(event=QJLogEvents.InvokeResultRemediationLambdaError, error=error) raise Exception( f"Error while invoking {lambda_name} with event: {str(event)}: {error}" ) from invoke_ex lambda_result: bytes = resp["Payload"].read() if resp.get("FunctionError", None): error = lambda_result.decode() logger.info(event=QJLogEvents.ResultRemediationLambdaRunError, error=error) raise Exception( f"Function error in {lambda_name} with event {str(event)}: {error}" ) logger.info(event=QJLogEvents.InvokeResultRemediationLambdaEnd) return json.loads(lambda_result)
def _drop_view(self, db_session: Session, view_name: str) -> None: """Drop a view by name""" logger = Logger() logger.info(event=QJLogEvents.DropView, view_name=view_name) drop_sql = f"DROP VIEW IF EXISTS {view_name};" db_session.execute(drop_sql)
def query(event: Dict[str, Any]) -> None: """Run the query portion of a QJ""" query_config = QueryConfig() logger = Logger() logger.info(event=QJLogEvents.InitConfig, config=query_config) records = event.get("Records", []) if not records: raise Exception("No records found") if len(records) > 1: raise Exception( f"More than one record. BatchSize is probably not 1. event: {event}" ) body = records[0].get("body") if body is None: raise Exception( f"No record body found. BatchSize is probably not 1. event: {event}" ) body = json.loads(body) job = schemas.Job(**body) logger.info(event=QJLogEvents.InitJob, job=job) logger.info(event=QJLogEvents.RunQueryStart) query_result = run_query(job=job, config=query_config) logger.info(event=QJLogEvents.RunQueryEnd, num_results=query_result.get_length()) results: List[schemas.Result] = [] if query_config.account_id_key not in query_result.query_result_set.fields: raise Exception( f"Query results must contain field '{query_config.account_id_key}'" ) for q_r in query_result.to_list(): account_id = q_r[query_config.account_id_key] result = schemas.Result( account_id=account_id, result={ key: val for key, val in q_r.items() if key != query_config.account_id_key }, ) results.append(result) graph_spec = schemas.ResultSetGraphSpec( graph_uris_load_times=query_result.graph_uris_load_times) result_set = schemas.ResultSetCreate(job=job, graph_spec=graph_spec, results=results) api_key = get_api_key(region_name=query_config.region) qj_client = QJAPIClient(host=query_config.api_host, port=query_config.api_port, api_key=api_key) logger.info(event=QJLogEvents.CreateResultSetStart) qj_client.create_result_set(result_set=result_set) logger.info(event=QJLogEvents.CreateResultSetEnd)
def lambda_handler(event, context): host = get_required_lambda_env_var("NEPTUNE_HOST") port = get_required_lambda_env_var("NEPTUNE_PORT") region = get_required_lambda_env_var("NEPTUNE_REGION") max_age_min = get_required_lambda_env_var("MAX_AGE_MIN") graph_name = get_required_lambda_env_var("GRAPH_NAME") try: max_age_min = int(max_age_min) except ValueError as ve: raise Exception(f"env var MAX_AGE_MIN must be an int: {ve}") now = int(datetime.now().timestamp()) oldest_acceptable_graph_epoch = now - max_age_min * 60 endpoint = NeptuneEndpoint(host=host, port=port, region=region) client = AltimeterNeptuneClient(max_age_min=max_age_min, neptune_endpoint=endpoint) logger = Logger() uncleared = [] # first prune metadata - if clears below are partial we want to make sure no clients # consider this a valid graph still. logger.info(event=LogEvent.PruneNeptuneMetadataGraphStart) client.clear_old_graph_metadata(name=graph_name, max_age_min=max_age_min) logger.info(event=LogEvent.PruneNeptuneMetadataGraphEnd) # now clear actual graphs with logger.bind(neptune_endpoint=endpoint): logger.info(event=LogEvent.PruneNeptuneGraphsStart) for graph_metadata in client.get_graph_metadatas(name=graph_name): assert graph_metadata.name == graph_name graph_epoch = graph_metadata.end_time with logger.bind(graph_uri=graph_metadata.uri, graph_epoch=graph_epoch): if graph_epoch < oldest_acceptable_graph_epoch: logger.info(event=LogEvent.PruneNeptuneGraphStart) try: client.clear_graph(graph_uri=graph_metadata.uri) logger.info(event=LogEvent.PruneNeptuneGraphEnd) except Exception as ex: logger.error( event=LogEvent.PruneNeptuneGraphError, msg= f"Error pruning graph {graph_metadata.uri}: {ex}", ) uncleared.append(graph_metadata.uri) continue else: logger.info(event=LogEvent.PruneNeptuneGraphSkip) logger.info(event=LogEvent.PruneNeptuneGraphsEnd) if uncleared: msg = f"Errors were found pruning {uncleared}." logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg) raise Exception(msg)