def executor(event: Dict[str, Any]) -> None:
    """Execute all known QJs. If this was triggered by an sns message, use that message as part of
    the deduplication id for each sqs message. Otherwise generate a unique id so that repeated
    manual runs of executor will not be dedupe'd"""
    sns_message = event.get("Records", [{}])[0].get("Sns", {}).get("Message")
    if sns_message:
        execution_hash = hashlib.sha256(sns_message.encode()).hexdigest()
    else:
        execution_hash = hashlib.sha256(str(uuid.uuid4()).encode()).hexdigest()
    exec_config = ExecutorConfig()
    logger = Logger()
    logger.info(
        event=QJLogEvents.InitConfig,
        sns_triggered=bool(sns_message),
        execution_hash=execution_hash,
    )
    qj_client = QJAPIClient(host=exec_config.api_host,
                            port=exec_config.api_port)
    jobs = qj_client.get_jobs(active_only=True)
    logger.info(event=QJLogEvents.GetJobs, num_jobs=len(jobs))
    enqueue_queries(
        jobs=jobs,
        queue_url=exec_config.query_queue_url,
        execution_hash=execution_hash,
        region=exec_config.region,
    )
Esempio n. 2
0
def query(event: Dict[str, Any]) -> None:
    """Run the query portion of a QJ"""
    query_config = QueryConfig()
    logger = Logger()
    logger.info(event=QJLogEvents.InitConfig, config=query_config)

    records = event.get("Records", [])
    if not records:
        raise Exception("No records found")
    if len(records) > 1:
        raise Exception(
            f"More than one record. BatchSize is probably not 1. event: {event}"
        )
    body = records[0].get("body")
    if body is None:
        raise Exception(
            f"No record body found. BatchSize is probably not 1. event: {event}"
        )
    body = json.loads(body)
    job = schemas.Job(**body)
    logger.info(event=QJLogEvents.InitJob, job=job)

    logger.info(event=QJLogEvents.RunQueryStart)
    query_result = run_query(job=job, config=query_config)
    logger.info(event=QJLogEvents.RunQueryEnd,
                num_results=query_result.get_length())

    results: List[schemas.Result] = []
    if query_config.account_id_key not in query_result.query_result_set.fields:
        raise Exception(
            f"Query results must contain field '{query_config.account_id_key}'"
        )
    for q_r in query_result.to_list():
        account_id = q_r[query_config.account_id_key]
        result = schemas.Result(
            account_id=account_id,
            result={
                key: val
                for key, val in q_r.items()
                if key != query_config.account_id_key
            },
        )
        results.append(result)

    graph_spec = schemas.ResultSetGraphSpec(
        graph_uris_load_times=query_result.graph_uris_load_times)
    result_set = schemas.ResultSetCreate(job=job,
                                         graph_spec=graph_spec,
                                         results=results)

    api_key = get_api_key(region_name=query_config.region)
    qj_client = QJAPIClient(host=query_config.api_host,
                            port=query_config.api_port,
                            api_key=api_key)
    logger.info(event=QJLogEvents.CreateResultSetStart)
    qj_client.create_result_set(result_set=result_set)
    logger.info(event=QJLogEvents.CreateResultSetEnd)
Esempio n. 3
0
def remediator(event: Dict[str, Any]) -> None:
    """Run the remediation lambda for a QJ result set"""
    config = RemediatorConfig()
    logger = Logger()
    remediation = Remediation(**event)
    with logger.bind(remediation=remediation):
        logger.info(event=QJLogEvents.RemediationInit)
        qj_api_client = QJAPIClient(host=config.qj_api_host)
        latest_result_set = qj_api_client.get_job_latest_result_set(
            job_name=remediation.job_name)
        if not latest_result_set:
            msg = f"No latest_result_set present for {remediation.job_name}"
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if latest_result_set.result_set_id != remediation.result_set_id:
            msg = (
                f"Remediation result_set_id {remediation.result_set_id} does not match the "
                f"latest result_set_id {latest_result_set.result_set_id}")
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if not latest_result_set.job.remediate_sqs_queue:
            msg = f"Job {latest_result_set.job.name} has no remediator"
            logger.error(QJLogEvents.JobHasNoRemediator, detail=msg)
            raise RemediationError(msg)
        num_threads = 10  # TODO env var
        errors = []
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for result in latest_result_set.results:
                logger.info(event=QJLogEvents.ProcessResult, result=result)
                future = _schedule_result_remediation(
                    executor=executor,
                    lambda_name=latest_result_set.job.remediate_sqs_queue,
                    lambda_timeout=300,  # TODO env var?
                    result=result,
                )
                futures.append(future)
            for future in as_completed(futures):
                try:
                    lambda_result = future.result()
                    logger.info(QJLogEvents.ResultRemediationSuccessful,
                                lambda_result=lambda_result)
                except Exception as ex:
                    logger.info(
                        event=QJLogEvents.ResultSetRemediationFailed,
                        error=str(ex),
                    )
                    errors.append(str(ex))
        if errors:
            logger.error(event=QJLogEvents.ResultSetRemediationFailed,
                         errors=errors)
            raise RemediationError(
                f"Errors encountered during remediation of {latest_result_set.job.name} "
                f"/ {latest_result_set.result_set_id}: {errors}")
Esempio n. 4
0
def lambda_handler(_: Dict[str, Any], __: Any) -> None:
    """Lambda entrypoint"""
    logger = Logger()
    config = PrunerConfig()
    logger.info(event=QJLogEvents.InitConfig, config=config)
    api_key = get_api_key(region_name=config.region)
    qj_client = QJAPIClient(host=config.api_host,
                            port=config.api_port,
                            api_key=api_key)
    logger.info(event=QJLogEvents.DeleteStart)
    result = qj_client.delete_expired_result_sets()
    logger.info(event=QJLogEvents.DeleteEnd, result=result)
def pruner() -> None:
    """Prune results according to Job config settings"""
    logger = Logger()
    pruner_config = PrunerConfig()
    logger.info(event=QJLogEvents.InitConfig, config=pruner_config)
    api_key = get_api_key(region_name=pruner_config.region)
    qj_client = QJAPIClient(
        host=pruner_config.api_host, port=pruner_config.api_port, api_key=api_key
    )
    logger.info(event=QJLogEvents.DeleteStart)
    result = qj_client.delete_expired_result_sets()
    logger.info(event=QJLogEvents.DeleteEnd, result=result)