def executor(event: Dict[str, Any]) -> None: """Execute all known QJs. If this was triggered by an sns message, use that message as part of the deduplication id for each sqs message. Otherwise generate a unique id so that repeated manual runs of executor will not be dedupe'd""" sns_message = event.get("Records", [{}])[0].get("Sns", {}).get("Message") if sns_message: execution_hash = hashlib.sha256(sns_message.encode()).hexdigest() else: execution_hash = hashlib.sha256(str(uuid.uuid4()).encode()).hexdigest() exec_config = ExecutorConfig() logger = Logger() logger.info( event=QJLogEvents.InitConfig, sns_triggered=bool(sns_message), execution_hash=execution_hash, ) qj_client = QJAPIClient(host=exec_config.api_host, port=exec_config.api_port) jobs = qj_client.get_jobs(active_only=True) logger.info(event=QJLogEvents.GetJobs, num_jobs=len(jobs)) enqueue_queries( jobs=jobs, queue_url=exec_config.query_queue_url, execution_hash=execution_hash, region=exec_config.region, )
def query(event: Dict[str, Any]) -> None: """Run the query portion of a QJ""" query_config = QueryConfig() logger = Logger() logger.info(event=QJLogEvents.InitConfig, config=query_config) records = event.get("Records", []) if not records: raise Exception("No records found") if len(records) > 1: raise Exception( f"More than one record. BatchSize is probably not 1. event: {event}" ) body = records[0].get("body") if body is None: raise Exception( f"No record body found. BatchSize is probably not 1. event: {event}" ) body = json.loads(body) job = schemas.Job(**body) logger.info(event=QJLogEvents.InitJob, job=job) logger.info(event=QJLogEvents.RunQueryStart) query_result = run_query(job=job, config=query_config) logger.info(event=QJLogEvents.RunQueryEnd, num_results=query_result.get_length()) results: List[schemas.Result] = [] if query_config.account_id_key not in query_result.query_result_set.fields: raise Exception( f"Query results must contain field '{query_config.account_id_key}'" ) for q_r in query_result.to_list(): account_id = q_r[query_config.account_id_key] result = schemas.Result( account_id=account_id, result={ key: val for key, val in q_r.items() if key != query_config.account_id_key }, ) results.append(result) graph_spec = schemas.ResultSetGraphSpec( graph_uris_load_times=query_result.graph_uris_load_times) result_set = schemas.ResultSetCreate(job=job, graph_spec=graph_spec, results=results) api_key = get_api_key(region_name=query_config.region) qj_client = QJAPIClient(host=query_config.api_host, port=query_config.api_port, api_key=api_key) logger.info(event=QJLogEvents.CreateResultSetStart) qj_client.create_result_set(result_set=result_set) logger.info(event=QJLogEvents.CreateResultSetEnd)
def remediator(event: Dict[str, Any]) -> None: """Run the remediation lambda for a QJ result set""" config = RemediatorConfig() logger = Logger() remediation = Remediation(**event) with logger.bind(remediation=remediation): logger.info(event=QJLogEvents.RemediationInit) qj_api_client = QJAPIClient(host=config.qj_api_host) latest_result_set = qj_api_client.get_job_latest_result_set( job_name=remediation.job_name) if not latest_result_set: msg = f"No latest_result_set present for {remediation.job_name}" logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if latest_result_set.result_set_id != remediation.result_set_id: msg = ( f"Remediation result_set_id {remediation.result_set_id} does not match the " f"latest result_set_id {latest_result_set.result_set_id}") logger.error(QJLogEvents.StaleResultSet, detail=msg) raise RemediationError(msg) if not latest_result_set.job.remediate_sqs_queue: msg = f"Job {latest_result_set.job.name} has no remediator" logger.error(QJLogEvents.JobHasNoRemediator, detail=msg) raise RemediationError(msg) num_threads = 10 # TODO env var errors = [] with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] for result in latest_result_set.results: logger.info(event=QJLogEvents.ProcessResult, result=result) future = _schedule_result_remediation( executor=executor, lambda_name=latest_result_set.job.remediate_sqs_queue, lambda_timeout=300, # TODO env var? result=result, ) futures.append(future) for future in as_completed(futures): try: lambda_result = future.result() logger.info(QJLogEvents.ResultRemediationSuccessful, lambda_result=lambda_result) except Exception as ex: logger.info( event=QJLogEvents.ResultSetRemediationFailed, error=str(ex), ) errors.append(str(ex)) if errors: logger.error(event=QJLogEvents.ResultSetRemediationFailed, errors=errors) raise RemediationError( f"Errors encountered during remediation of {latest_result_set.job.name} " f"/ {latest_result_set.result_set_id}: {errors}")
def lambda_handler(_: Dict[str, Any], __: Any) -> None: """Lambda entrypoint""" logger = Logger() config = PrunerConfig() logger.info(event=QJLogEvents.InitConfig, config=config) api_key = get_api_key(region_name=config.region) qj_client = QJAPIClient(host=config.api_host, port=config.api_port, api_key=api_key) logger.info(event=QJLogEvents.DeleteStart) result = qj_client.delete_expired_result_sets() logger.info(event=QJLogEvents.DeleteEnd, result=result)
def pruner() -> None: """Prune results according to Job config settings""" logger = Logger() pruner_config = PrunerConfig() logger.info(event=QJLogEvents.InitConfig, config=pruner_config) api_key = get_api_key(region_name=pruner_config.region) qj_client = QJAPIClient( host=pruner_config.api_host, port=pruner_config.api_port, api_key=api_key ) logger.info(event=QJLogEvents.DeleteStart) result = qj_client.delete_expired_result_sets() logger.info(event=QJLogEvents.DeleteEnd, result=result)