Esempio n. 1
0
    def read_json(self, path: str) -> Dict[str, Any]:
        """Read a json artifact

        Args:
            path: filesystem path to artifact

        Returns:
            artifact content
        """
        logger = Logger()
        with logger.bind(artifact_path=path):
            logger.info(event=LogEvent.ReadFromFSStart)
            with open(path, "r") as artifact_fp:
                data = json.load(artifact_fp)
            logger.info(event=LogEvent.ReadFromFSEnd)
            return data
 def get_version(self, db_session: Session, job_name: str,
                 created: datetime) -> Job:
     """Get a specific version of a Job by created timestamp"""
     logger = Logger()
     logger.info(event=QJLogEvents.GetJobVersion, job_name=job_name)
     query = db_session.query(Job).filter(Job.name == job_name).filter(
         Job.created == created)
     results = query.all()
     num_results = len(results)
     if num_results:
         if num_results > 1:
             raise Exception(
                 f"More than one job found for {job_name} with version {created}"
             )
         return results[0]
     raise JobVersionNotFound(f"Could not find job {job_name} / {created}")
Esempio n. 3
0
    def scan(
        self, account_scan_plan: AccountScanPlan
    ) -> Generator[AccountScanManifest, None, None]:
        """Scan accounts. Return a list of AccountScanManifest objects.

        Args:
            account_scan_plan: AccountScanPlan defining this scan op

        Yields:
            AccountScanManifest objects
        """
        num_total_accounts = len(account_scan_plan.account_ids)
        account_scan_plans = account_scan_plan.to_batches(
            max_accounts=self.config.concurrency.max_accounts_per_thread)
        num_account_batches = len(account_scan_plans)
        num_threads = min(num_account_batches,
                          self.config.concurrency.max_account_scan_threads)
        logger = Logger()
        with logger.bind(
                num_total_accounts=num_total_accounts,
                num_account_batches=num_account_batches,
                muxer=self.__class__.__name__,
                num_muxer_threads=num_threads,
        ):
            logger.info(event=AWSLogEvents.MuxerStart)
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                processed_accounts = 0
                futures = []
                for sub_account_scan_plan in account_scan_plans:
                    account_scan_future = self._schedule_account_scan(
                        executor, sub_account_scan_plan)
                    futures.append(account_scan_future)
                    logger.info(
                        event=AWSLogEvents.MuxerQueueScan,
                        account_ids=",".join(
                            sub_account_scan_plan.account_ids),
                    )
                for future in as_completed(futures):
                    scan_results_dicts = future.result()
                    for scan_results_dict in scan_results_dicts:
                        account_id = scan_results_dict["account_id"]
                        output_artifact = scan_results_dict["output_artifact"]
                        account_errors = scan_results_dict["errors"]
                        api_call_stats = scan_results_dict["api_call_stats"]
                        artifacts = [output_artifact
                                     ] if output_artifact else []
                        account_scan_result = AccountScanManifest(
                            account_id=account_id,
                            artifacts=artifacts,
                            errors=account_errors,
                            api_call_stats=api_call_stats,
                        )
                        yield account_scan_result
                        processed_accounts += 1
                    logger.info(event=AWSLogEvents.MuxerStat,
                                num_scanned=processed_accounts)
            logger.info(event=AWSLogEvents.MuxerEnd)
Esempio n. 4
0
def lambda_handler(event, context):
    json_bucket = event["Records"][0]["s3"]["bucket"]["name"]
    json_key = urllib.parse.unquote(event["Records"][0]["s3"]["object"]["key"])
    rdf_bucket = get_required_lambda_env_var("RDF_BUCKET")
    rdf_key = ".".join(json_key.split(".")[:-1]) + ".rdf.gz"
    session = boto3.Session()
    s3_client = session.client("s3")

    logger = Logger()
    with logger.bind(json_bucket=json_bucket, json_key=json_key):
        graph_pkg = graph_pkg_from_s3(s3_client=s3_client,
                                      json_bucket=json_bucket,
                                      json_key=json_key)

    with logger.bind(rdf_bucket=rdf_bucket, rdf_key=rdf_key):
        logger.info(event=LogEvent.WriteToS3Start)
        with io.BytesIO() as rdf_bytes_buf:
            with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz:
                graph_pkg.graph.serialize(gz)
            rdf_bytes_buf.flush()
            rdf_bytes_buf.seek(0)
            s3_client.upload_fileobj(rdf_bytes_buf, rdf_bucket, rdf_key)
            s3_client.put_object_tagging(
                Bucket=rdf_bucket,
                Key=rdf_key,
                Tagging={
                    "TagSet": [
                        {
                            "Key": "name",
                            "Value": graph_pkg.name
                        },
                        {
                            "Key": "version",
                            "Value": graph_pkg.version
                        },
                        {
                            "Key": "start_time",
                            "Value": str(graph_pkg.start_time)
                        },
                        {
                            "Key": "end_time",
                            "Value": str(graph_pkg.end_time)
                        },
                    ]
                },
            )
        logger.info(event=LogEvent.WriteToS3End)
Esempio n. 5
0
    def create(self, db_session: Session,
               obj_in: schemas.ResultSetCreate) -> ResultSet:
        """Create a ResultSet"""
        logger = Logger()
        num_results = len(obj_in.results)
        logger.info(
            event=QJLogEvents.CreateResultSet,
            job=obj_in.job,
            created=obj_in.created,
            graph_spec=obj_in.graph_spec,
            num_results=num_results,
        )
        job = self._job_crud.get_version(db_session=db_session,
                                         job_name=obj_in.job.name,
                                         created=obj_in.job.created)
        if not job:
            raise JobVersionNotFound(
                f"Could not find job {obj_in.job.name} / {obj_in.job.created}")

        # create result_set db object
        if num_results > self._max_result_set_results:
            raise ResultSetResultsLimitExceeded(
                f"Result set has {num_results} results, limit is {self._max_result_set_results}"
            )
        result_set = ResultSet(job=job,
                               created=obj_in.created,
                               graph_spec=json.loads(obj_in.graph_spec.json()))
        db_session.add(result_set)

        # create result db objects
        for obj_in_result in obj_in.results:
            result_size = len(json.dumps(obj_in_result.result))
            if result_size > self._max_result_size_bytes:
                raise ResultSizeExceeded((
                    f"Result size {result_size} exceeds max {self._max_result_size_bytes}: "
                    f"{json.dumps(obj_in_result.result)[:self._max_result_size_bytes]}..."
                ))
            result = Result(
                result_set=result_set,
                account_id=obj_in_result.account_id,
                result=obj_in_result.result,
            )
            db_session.add(result)

        db_session.commit()
        db_session.refresh(result_set)
        return result_set
Esempio n. 6
0
 def set_automated_backups(cls, client: BaseClient,
                           dbinstances: Dict[str, Dict[str, Any]]) -> None:
     logger = Logger()
     backup_paginator = client.get_paginator(
         "describe_db_instance_automated_backups")
     for resp in backup_paginator.paginate():
         for backup in resp.get("DBInstanceAutomatedBackups", []):
             if backup["DBInstanceArn"] in dbinstances:
                 dbinstances[backup["DBInstanceArn"]]["Backup"].append(
                     backup)
             else:
                 logger.info(
                     event=AWSLogEvents.ScanAWSResourcesNonFatalError,
                     msg=
                     (f'Unable to find matching DB Instance {backup["DBInstanceArn"]} '
                      "(Possible Deletion)"),
                 )
Esempio n. 7
0
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]:
    logger = Logger()
    with logger.bind(account_id=scan_unit.account_id,
                     region=scan_unit.region_name,
                     service=scan_unit.service):
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart)
        session = boto3.Session(
            aws_access_key_id=scan_unit.access_key,
            aws_secret_access_key=scan_unit.secret_key,
            aws_session_token=scan_unit.token,
            region_name=scan_unit.region_name,
        )
        scan_accessor = AWSAccessor(session=session,
                                    account_id=scan_unit.account_id,
                                    region_name=scan_unit.region_name)
        graph_spec = GraphSpec(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            resource_spec_classes=scan_unit.resource_spec_classes,
            scan_accessor=scan_accessor,
        )
        start_time = int(time.time())
        resources: List[Resource] = []
        errors = []
        try:
            resources = graph_spec.scan()
        except Exception as ex:
            error_str = str(ex)
            trace_back = traceback.format_exc()
            logger.error(event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back)
            error = f"{str(ex)}\n{trace_back}"
            errors.append(error)
        end_time = int(time.time())
        graph_set = GraphSet(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=scan_accessor.api_call_stats,
        )
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd)
        return (scan_unit.account_id, graph_set.to_dict())
def get_sub_account_ids(account_ids: Tuple[str, ...], accessor: Accessor) -> Tuple[str, ...]:
    logger = Logger()
    logger.info(event=AWSLogEvents.GetSubAccountsStart)
    sub_account_ids: Set[str] = set(account_ids)
    for master_account_id in account_ids:
        with logger.bind(master_account_id=master_account_id):
            account_session = accessor.get_session(master_account_id)
            orgs_client = account_session.client("organizations")
            resp = orgs_client.describe_organization()
            if resp["Organization"]["MasterAccountId"] == master_account_id:
                accounts_paginator = orgs_client.get_paginator("list_accounts")
                for accounts_resp in accounts_paginator.paginate():
                    for account_resp in accounts_resp["Accounts"]:
                        if account_resp["Status"].lower() == "active":
                            account_id = account_resp["Id"]
                            sub_account_ids.add(account_id)
    logger.info(event=AWSLogEvents.GetSubAccountsEnd)
    return tuple(sub_account_ids)
Esempio n. 9
0
def remediator(event: Dict[str, Any]) -> None:
    """Run the remediation lambda for a QJ result set"""
    config = RemediatorConfig()
    logger = Logger()
    remediation = Remediation(**event)
    with logger.bind(remediation=remediation):
        logger.info(event=QJLogEvents.RemediationInit)
        qj_api_client = QJAPIClient(host=config.qj_api_host)
        latest_result_set = qj_api_client.get_job_latest_result_set(
            job_name=remediation.job_name)
        if not latest_result_set:
            msg = f"No latest_result_set present for {remediation.job_name}"
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if latest_result_set.result_set_id != remediation.result_set_id:
            msg = (
                f"Remediation result_set_id {remediation.result_set_id} does not match the "
                f"latest result_set_id {latest_result_set.result_set_id}")
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if not latest_result_set.job.remediate_sqs_queue:
            msg = f"Job {latest_result_set.job.name} has no remediator"
            logger.error(QJLogEvents.JobHasNoRemediator, detail=msg)
            raise RemediationError(msg)
        num_threads = 10  # TODO env var
        errors = []
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for result in latest_result_set.results:
                logger.info(event=QJLogEvents.ProcessResult, result=result)
                future = _schedule_result_remediation(
                    executor=executor,
                    lambda_name=latest_result_set.job.remediate_sqs_queue,
                    lambda_timeout=300,  # TODO env var?
                    result=result,
                )
                futures.append(future)
            for future in as_completed(futures):
                try:
                    lambda_result = future.result()
                    logger.info(QJLogEvents.ResultRemediationSuccessful,
                                lambda_result=lambda_result)
                except Exception as ex:
                    logger.info(
                        event=QJLogEvents.ResultSetRemediationFailed,
                        error=str(ex),
                    )
                    errors.append(str(ex))
        if errors:
            logger.error(event=QJLogEvents.ResultSetRemediationFailed,
                         errors=errors)
            raise RemediationError(
                f"Errors encountered during remediation of {latest_result_set.job.name} "
                f"/ {latest_result_set.result_set_id}: {errors}")
Esempio n. 10
0
 def get(
     self,
     db_session: Session,
     result_set_id: str,
 ) -> ResultSet:
     """Get a ResultSet by id"""
     logger = Logger()
     logger.info(event=QJLogEvents.GetResultSet,
                 result_set_id=result_set_id)
     query = db_session.query(ResultSet).filter(
         ResultSet.result_set_id == result_set_id)
     result_sets = query.all()
     num_result_sets = len(result_sets)
     if num_result_sets:
         if num_result_sets > 1:
             raise Exception(
                 f"More than one result_set found for {result_set_id}")
         return result_sets[0]
     raise ResultSetNotFound(f"No result set {result_set_id} found")
Esempio n. 11
0
 def get_active(
     self,
     db_session: Session,
     job_name: str,
 ) -> Job:
     """Get the active version of a Job"""
     logger = Logger()
     query = db_session.query(Job).filter(
         Job.active).filter(Job.name == job_name)
     results = query.all()
     num_results = len(results)
     logger.info(event=QJLogEvents.GetActiveJob,
                 job_name=job_name,
                 num_results=num_results)
     if num_results:
         assert num_results == 1, f"More than one active job found for {job_name}"
         return results[0]
     raise ActiveJobVersionNotFound(
         f"No active job version found for {job_name}")
Esempio n. 12
0
    def write_json(self, name: str, data: BaseModel) -> str:
        """Write artifact data to self.output_dir/name.json

        Args:
            name: filename
            data: data

        Returns:
            Full filesystem path of artifact file
        """
        logger = Logger()
        os.makedirs(self.output_dir, exist_ok=True)
        artifact_path = os.path.join(self.output_dir, f"{name}.json")
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.WriteToFSStart)
            with open(artifact_path, "w") as artifact_fp:
                artifact_fp.write(data.json(exclude_unset=True))
            logger.info(event=LogEvent.WriteToFSEnd)
        return artifact_path
Esempio n. 13
0
 def create(self, db_session: Session,
            job_create_in: schemas.JobCreate) -> Job:
     """Create a Job"""
     logger = Logger()
     logger.info(event=QJLogEvents.CreateJob, job_create=job_create_in)
     try:
         query = rdflib.Graph().query(job_create_in.query)
     except Exception as ex:
         raise JobQueryInvalid(
             f"Invalid query {job_create_in.query}: {str(ex)}") from ex
     query_fields = [str(query_var) for query_var in query.vars]
     if self._account_id_key not in query_fields:
         raise JobQueryMissingAccountId(
             f"Query {job_create_in.query} missing '{self._account_id_key}' field"
         )
     if job_create_in.result_expiration_sec is None:
         job_create_in.result_expiration_sec = self._result_expiration_sec_default
     if job_create_in.result_expiration_sec > self._result_expiration_sec_limit:
         raise JobInvalid(
             f"Field result_expiration_sec value {job_create_in.result_expiration_sec} "
             f"must be <= {self._result_expiration_sec_limit}")
     if job_create_in.max_graph_age_sec is None:
         job_create_in.max_graph_age_sec = self._max_graph_age_sec_default
     else:
         if job_create_in.max_graph_age_sec > self._max_graph_age_sec_limit:
             raise JobInvalid(
                 f"Field max_graph_age_sec value {job_create_in.max_graph_age_sec} must be "
                 f"<= {self._max_graph_age_sec_limit}")
     if job_create_in.max_result_age_sec is None:
         job_create_in.max_result_age_sec = self._max_result_age_sec_default
     else:
         if job_create_in.max_result_age_sec > self._max_result_age_sec_limit:
             raise JobInvalid(
                 f"Field max_result_age_sec value {job_create_in.max_result_age_sec} must be "
                 f"<= {self._max_result_age_sec_limit}")
     obj_in_data = job_create_in.dict()
     obj_in_data["query_fields"] = query_fields
     job_create = schemas.Job(**obj_in_data)
     db_obj = Job(**job_create.dict())  # type: ignore
     db_session.add(db_obj)
     db_session.commit()
     db_session.refresh(db_obj)
     return db_obj
Esempio n. 14
0
    def write_artifact(self, name: str, data: Dict[str, Any]) -> str:
        """Write artifact data to self.output_dir/name.json

        Args:
            name: filename
            data: artifact data

        Returns:
            Full filesystem path of artifact file
        """
        logger = Logger()
        os.makedirs(self.output_dir, exist_ok=True)
        artifact_path = os.path.join(self.output_dir, f"{name}.json")
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.WriteToFSStart)
            with open(artifact_path, "w") as artifact_fp:
                json.dump(data, artifact_fp, default=json_encoder)
            logger.info(event=LogEvent.WriteToFSEnd)
        return artifact_path
 def update_version(
     self,
     db_session: Session,
     job_name: str,
     created: datetime,
     job_update: schemas.JobUpdate,
 ) -> Job:
     """Update a Job"""
     logger = Logger()
     logger.info(QJLogEvents.UpdateJob,
                 job_name=job_name,
                 created=created,
                 job_update=job_update)
     job_version = self.get_version(db_session=db_session,
                                    job_name=job_name,
                                    created=created)
     if job_update.description is not None:
         job_version.description = job_update.description
     if job_update.category is not None:
         job_version.category = job_update.category
     if job_update.result_expiration_sec is not None:
         job_version.result_expiration_sec = job_update.result_expiration_sec
     if job_update.max_graph_age_sec is not None:
         job_version.max_graph_age_sec = job_update.max_graph_age_sec
     if job_update.max_result_age_sec is not None:
         job_version.max_result_age_sec = job_update.max_result_age_sec
     if job_update.active is not None:
         if job_update.active:
             query = db_session.query(Job).filter(Job.name == job_name)
             job_versions = query.all()
             for _job_version in job_versions:
                 if _job_version != job_version:
                     _job_version.active = False
         job_version.active = job_update.active
         if job_version.active:
             self._create_views(db_session=db_session,
                                job_version=job_version)
     if job_update.notify_if_results is not None:
         job_version.notify_if_results = job_update.notify_if_results
     db_session.commit()
     db_session.refresh(job_version)
     return job_version
Esempio n. 16
0
def invoke_lambda(lambda_name: str, lambda_timeout: int,
                  event: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Invoke an AWS Lambda function

    Args:
        lambda_name: name of lambda
        lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait
                        at least this long for a response before timing out.
        event: event data to send to the lambda

    Returns:
        lambda response payload

    Raises:
        Exception if there was an error invoking the lambda.
    """
    logger = Logger()
    account_ids = [account_id for account_id in event["account_ids"]]
    with logger.bind(lambda_name=lambda_name,
                     lambda_timeout=lambda_timeout,
                     account_ids=account_ids):
        logger.info(event=AWSLogEvents.RunAccountScanLambdaStart)
        boto_config = botocore.config.Config(
            read_timeout=lambda_timeout + 10,
            retries={"max_attempts": 0},
        )
        session = boto3.Session()
        lambda_client = session.client("lambda", config=boto_config)
        try:
            resp = lambda_client.invoke(
                FunctionName=lambda_name,
                Payload=json.dumps(event).encode("utf-8"))
        except Exception as invoke_ex:
            error = str(invoke_ex)
            logger.info(event=AWSLogEvents.RunAccountScanLambdaError,
                        error=error)
            raise Exception(
                f"Error while invoking {lambda_name} with event {event}: {error}"
            )
        payload: bytes = resp["Payload"].read()
        if resp.get("FunctionError", None):
            function_error = payload.decode()
            logger.info(event=AWSLogEvents.RunAccountScanLambdaError,
                        error=function_error)
            raise Exception(
                f"Function error in {lambda_name} with event {event}: {function_error}"
            )
        payload_dict = json.loads(payload)
        logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd)
        return payload_dict
Esempio n. 17
0
def lambda_handler(event, context):
    rdf_bucket = event["Records"][0]["s3"]["bucket"]["name"]
    rdf_key = urllib.parse.unquote(event["Records"][0]["s3"]["object"]["key"])

    neptune_host = get_required_lambda_env_var("NEPTUNE_HOST")
    neptune_port = get_required_lambda_env_var("NEPTUNE_PORT")
    neptune_region = get_required_lambda_env_var("NEPTUNE_REGION")
    neptune_load_iam_role_arn = get_required_lambda_env_var(
        "NEPTUNE_LOAD_IAM_ROLE_ARN")
    on_success_sns_topic_arn = get_required_lambda_env_var(
        "ON_SUCCESS_SNS_TOPIC_ARN")

    endpoint = NeptuneEndpoint(host=neptune_host,
                               port=neptune_port,
                               region=neptune_region)
    neptune_client = AltimeterNeptuneClient(max_age_min=1440,
                                            neptune_endpoint=endpoint)
    graph_metadata = neptune_client.load_graph(
        bucket=rdf_bucket,
        key=rdf_key,
        load_iam_role_arn=neptune_load_iam_role_arn)

    logger = Logger()
    logger.info(event=LogEvent.GraphLoadedSNSNotificationStart)
    sns_client = boto3.client("sns")
    message_dict = {
        "uri": graph_metadata.uri,
        "name": graph_metadata.name,
        "version": graph_metadata.version,
        "start_time": graph_metadata.start_time,
        "end_time": graph_metadata.end_time,
        "neptune_endpoint": endpoint.get_endpoint_str(),
    }
    message_dict["default"] = json.dumps(message_dict)
    sns_client.publish(TopicArn=on_success_sns_topic_arn,
                       MessageStructure="json",
                       Message=json.dumps(message_dict))
    logger.info(event=LogEvent.GraphLoadedSNSNotificationEnd)
Esempio n. 18
0
    def write_artifact(self, name: str, data: Dict[str, Any]) -> str:
        """Write artifact data to s3://self.bucket/self.key_prefix/name.json

        Args:
            name: s3 key name
            data: artifact data

        Returns:
            S3 uri (s3://bucket/key/path) to artifact
        """

        output_key = "/".join((self.key_prefix, f"{name}.json"))
        logger = Logger()
        with logger.bind(bucket=self.bucket, key=output_key):
            logger.info(event=LogEvent.WriteToS3Start)
            s3_client = boto3.client("s3")
            results_str = json.dumps(data, default=json_encoder)
            results_bytes = results_str.encode("utf-8")
            with io.BytesIO(results_bytes) as results_bytes_stream:
                s3_client.upload_fileobj(results_bytes_stream, self.bucket,
                                         output_key)
            logger.info(event=LogEvent.WriteToS3End)
        return f"s3://{self.bucket}/{output_key}"
Esempio n. 19
0
    def write_json(self, name: str, data: BaseModel) -> str:
        """Write artifact data to s3://self.bucket/self.key_prefix/name.json

        Args:
            name: s3 key name
            data: data

        Returns:
            S3 uri (s3://bucket/key/path) to artifact
        """

        output_key = "/".join((self.key_prefix, f"{name}.json"))
        logger = Logger()
        with logger.bind(bucket=self.bucket, key=output_key):
            logger.info(event=LogEvent.WriteToS3Start)
            s3_client = boto3.Session().client("s3")
            results_str = data.json(exclude_unset=True)
            results_bytes = results_str.encode("utf-8")
            with io.BytesIO(results_bytes) as results_bytes_stream:
                s3_client.upload_fileobj(results_bytes_stream, self.bucket,
                                         output_key)
            logger.info(event=LogEvent.WriteToS3End)
        return f"s3://{self.bucket}/{output_key}"
def invoke_lambda(
    lambda_name: str, lambda_timeout: int, account_scan_lambda_event: AccountScanLambdaEvent
) -> AccountScanResult:
    """Invoke the AccountScan AWS Lambda function

    Args:
        lambda_name: name of lambda
        lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait
                        at least this long for a response before timing out.
        account_scan_lambda_event: AccountScanLambdaEvent object to serialize to json and send to the lambda

    Returns:
        AccountScanResult

    Raises:
        Exception if there was an error invoking the lambda.
    """
    logger = Logger()
    account_id = account_scan_lambda_event.account_scan_plan.account_id
    with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, account_id=account_id):
        logger.info(event=AWSLogEvents.RunAccountScanLambdaStart)
        boto_config = botocore.config.Config(
            read_timeout=lambda_timeout + 10, retries={"max_attempts": 0},
        )
        session = boto3.Session()
        lambda_client = session.client("lambda", config=boto_config)
        try:
            resp = lambda_client.invoke(
                FunctionName=lambda_name, Payload=account_scan_lambda_event.json().encode("utf-8")
            )
        except Exception as invoke_ex:
            error = str(invoke_ex)
            logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=error)
            raise Exception(
                f"Error while invoking {lambda_name} with event {account_scan_lambda_event.json()}: {error}"
            ) from invoke_ex
        payload: bytes = resp["Payload"].read()
        if resp.get("FunctionError", None):
            function_error = payload.decode()
            logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=function_error)
            raise Exception(
                f"Function error in {lambda_name} with event {account_scan_lambda_event.json()}: {function_error}"
            )
        payload_dict = json.loads(payload)
        account_scan_result = AccountScanResult(**payload_dict)
        logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd)
        return account_scan_result
Esempio n. 21
0
 def _create_latest_view(self, db_session: Session,
                         job_version: Job) -> None:
     """Create the _latest view for a Job"""
     logger = Logger()
     latest_view_name = self._get_latest_view_name(
         job_name=job_version.name)
     self._drop_view(db_session=db_session, view_name=latest_view_name)
     create_sql = (
         f"CREATE VIEW {latest_view_name} AS\n"
         f"SELECT result_created, {', '.join(job_version.query_fields)}\n"
         f"FROM\n"
         "(\n"
         f"    SELECT\n"
         "        rs.created as result_created,\n"
         f"        lpad(r.account_id::text, 12, '0') as {self._account_id_key},\n"
     )
     for query_field in job_version.query_fields:
         if query_field != self._account_id_key:
             create_sql += f"        result->>'{query_field}' as {query_field},\n"
     create_sql += (
         f"    RANK () OVER (PARTITION BY r.account_id ORDER BY rs.created DESC) as rank_number\n"
         "    FROM\n"
         "        result r\n"
         "    INNER JOIN result_set rs ON r.result_set_id = rs.id\n"
         "    INNER JOIN job j ON rs.job_id = j.id\n"
         "    WHERE\n"
         f"        j.name = '{job_version.name}' AND\n"
         "        j.active = true AND\n"
         f"        rs.created > CURRENT_TIMESTAMP - INTERVAL '{job_version.max_result_age_sec} seconds'\n"
         ") ranked_query\n"
         "WHERE rank_number = 1\n"
         f"ORDER BY {self._account_id_key};\n")
     logger.info(event=QJLogEvents.CreateView, view_name=latest_view_name)
     db_session.execute(create_sql)
     grant_sql = f"GRANT SELECT ON {latest_view_name} TO {self._db_ro_user};\n"
     db_session.execute(grant_sql)
Esempio n. 22
0
    def read_artifact(self, artifact_path: str) -> Dict[str, Any]:
        """Read an artifact

        Args:
            artifact_path: s3 uri to artifact. s3://bucket/key/path

        Returns:
            artifact content
        """
        bucket, key = parse_s3_uri(artifact_path)
        session = boto3.Session()
        s3_client = session.client("s3")
        logger = Logger()
        with io.BytesIO() as artifact_bytes_buf:
            with logger.bind(bucket=bucket, key=key):
                logger.info(event=LogEvent.ReadFromS3Start)
                s3_client.download_fileobj(bucket, key, artifact_bytes_buf)
                artifact_bytes_buf.flush()
                artifact_bytes_buf.seek(0)
                artifact_bytes = artifact_bytes_buf.read()
                logger.info(event=LogEvent.ReadFromS3End)
                artifact_str = artifact_bytes.decode("utf-8")
                artifact_dict = json.loads(artifact_str)
                return artifact_dict
Esempio n. 23
0
    def scan(self, account_scan_plans: List[AccountScanPlan]) -> List[AccountScanManifest]:
        """Scan accounts. Return a list of AccountScanManifest objects.

        Args:
            account_scan_plans: list of AccountScanPlan objects defining this scan op

        Returns:
            list of AccountScanManifest objects describing the output of the scan.
        """
        account_scan_results: List[AccountScanManifest] = []
        num_total_accounts = len(account_scan_plans)
        num_threads = min(num_total_accounts, self.max_threads)
        logger = Logger()
        with logger.bind(
            num_total_accounts=num_total_accounts,
            muxer=self.__class__.__name__,
            num_threads=num_threads,
        ):
            logger.info(event=AWSLogEvents.MuxerStart)
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                processed_accounts = 0
                futures = []
                for account_scan_plan in account_scan_plans:
                    account_scan_future = self._schedule_account_scan(executor, account_scan_plan)
                    futures.append(account_scan_future)
                    logger.info(
                        event=AWSLogEvents.MuxerQueueScan, account_id=account_scan_plan.account_id
                    )
                for future in as_completed(futures):
                    scan_results_dict = future.result()
                    account_id = scan_results_dict["account_id"]
                    output_artifact = scan_results_dict["output_artifact"]
                    account_errors = scan_results_dict["errors"]
                    api_call_stats = scan_results_dict["api_call_stats"]
                    artifacts = [output_artifact] if output_artifact else []
                    account_scan_result = AccountScanManifest(
                        account_id=account_id,
                        artifacts=artifacts,
                        errors=account_errors,
                        api_call_stats=api_call_stats,
                    )
                    account_scan_results.append(account_scan_result)
                    processed_accounts += 1
                    logger.info(event=AWSLogEvents.MuxerStat, num_scanned=processed_accounts)
            logger.info(event=AWSLogEvents.MuxerEnd)
        return account_scan_results
def enqueue_queries(jobs: List[schemas.Job], queue_url: str,
                    execution_hash: str, region: str) -> None:
    """Enqueue querys by sending a message for each job key to queue_url"""
    sqs_client = boto3.client("sqs", region_name=region)
    logger = Logger()
    with logger.bind(queue_url=queue_url, execution_hash=execution_hash):
        for job in jobs:
            job_hash = hashlib.sha256()
            job_hash.update(json.dumps(job.json()).encode())
            message_group_id = job_hash.hexdigest()
            job_hash.update(execution_hash.encode())
            message_dedupe_id = job_hash.hexdigest()
            logger.info(
                QJLogEvents.ScheduleJob,
                job=job,
                message_group_id=message_group_id,
                message_dedupe_id=message_dedupe_id,
            )
            sqs_client.send_message(
                QueueUrl=queue_url,
                MessageBody=job.json(),
                MessageGroupId=message_group_id,
                MessageDeduplicationId=message_dedupe_id,
            )
Esempio n. 25
0
def lambda_handler(_: Dict[str, Any], __: Any) -> None:
    """Lambda entrypoint"""
    logger = Logger()
    config = PrunerConfig()
    logger.info(event=QJLogEvents.InitConfig, config=config)
    api_key = get_api_key(region_name=config.region)
    qj_client = QJAPIClient(host=config.api_host,
                            port=config.api_port,
                            api_key=api_key)
    logger.info(event=QJLogEvents.DeleteStart)
    result = qj_client.delete_expired_result_sets()
    logger.info(event=QJLogEvents.DeleteEnd, result=result)
def pruner() -> None:
    """Prune results according to Job config settings"""
    logger = Logger()
    pruner_config = PrunerConfig()
    logger.info(event=QJLogEvents.InitConfig, config=pruner_config)
    api_key = get_api_key(region_name=pruner_config.region)
    qj_client = QJAPIClient(
        host=pruner_config.api_host, port=pruner_config.api_port, api_key=api_key
    )
    logger.info(event=QJLogEvents.DeleteStart)
    result = qj_client.delete_expired_result_sets()
    logger.info(event=QJLogEvents.DeleteEnd, result=result)
Esempio n. 27
0
def _invoke_lambda(
    lambda_name: str,
    lambda_timeout: int,
    result: Result,
) -> Any:
    """Invoke a QJ's remediator function"""
    logger = Logger()
    with logger.bind(lambda_name=lambda_name,
                     lambda_timeout=lambda_timeout,
                     result=result):
        logger.info(event=QJLogEvents.InvokeResultRemediationLambdaStart)
        boto_config = botocore.config.Config(
            read_timeout=lambda_timeout + 10,
            retries={"max_attempts": 0},
        )
        session = boto3.Session()
        lambda_client = session.client("lambda", config=boto_config)
        event = result.json().encode("utf-8")
        try:
            resp = lambda_client.invoke(
                FunctionName=lambda_name,
                Payload=event,
            )
        except Exception as invoke_ex:
            error = str(invoke_ex)
            logger.info(event=QJLogEvents.InvokeResultRemediationLambdaError,
                        error=error)
            raise Exception(
                f"Error while invoking {lambda_name} with event: {str(event)}: {error}"
            ) from invoke_ex
        lambda_result: bytes = resp["Payload"].read()
        if resp.get("FunctionError", None):
            error = lambda_result.decode()
            logger.info(event=QJLogEvents.ResultRemediationLambdaRunError,
                        error=error)
            raise Exception(
                f"Function error in {lambda_name} with event {str(event)}: {error}"
            )
        logger.info(event=QJLogEvents.InvokeResultRemediationLambdaEnd)
        return json.loads(lambda_result)
Esempio n. 28
0
 def _drop_view(self, db_session: Session, view_name: str) -> None:
     """Drop a view by name"""
     logger = Logger()
     logger.info(event=QJLogEvents.DropView, view_name=view_name)
     drop_sql = f"DROP VIEW IF EXISTS {view_name};"
     db_session.execute(drop_sql)
Esempio n. 29
0
def query(event: Dict[str, Any]) -> None:
    """Run the query portion of a QJ"""
    query_config = QueryConfig()
    logger = Logger()
    logger.info(event=QJLogEvents.InitConfig, config=query_config)

    records = event.get("Records", [])
    if not records:
        raise Exception("No records found")
    if len(records) > 1:
        raise Exception(
            f"More than one record. BatchSize is probably not 1. event: {event}"
        )
    body = records[0].get("body")
    if body is None:
        raise Exception(
            f"No record body found. BatchSize is probably not 1. event: {event}"
        )
    body = json.loads(body)
    job = schemas.Job(**body)
    logger.info(event=QJLogEvents.InitJob, job=job)

    logger.info(event=QJLogEvents.RunQueryStart)
    query_result = run_query(job=job, config=query_config)
    logger.info(event=QJLogEvents.RunQueryEnd,
                num_results=query_result.get_length())

    results: List[schemas.Result] = []
    if query_config.account_id_key not in query_result.query_result_set.fields:
        raise Exception(
            f"Query results must contain field '{query_config.account_id_key}'"
        )
    for q_r in query_result.to_list():
        account_id = q_r[query_config.account_id_key]
        result = schemas.Result(
            account_id=account_id,
            result={
                key: val
                for key, val in q_r.items()
                if key != query_config.account_id_key
            },
        )
        results.append(result)

    graph_spec = schemas.ResultSetGraphSpec(
        graph_uris_load_times=query_result.graph_uris_load_times)
    result_set = schemas.ResultSetCreate(job=job,
                                         graph_spec=graph_spec,
                                         results=results)

    api_key = get_api_key(region_name=query_config.region)
    qj_client = QJAPIClient(host=query_config.api_host,
                            port=query_config.api_port,
                            api_key=api_key)
    logger.info(event=QJLogEvents.CreateResultSetStart)
    qj_client.create_result_set(result_set=result_set)
    logger.info(event=QJLogEvents.CreateResultSetEnd)
Esempio n. 30
0
def lambda_handler(event, context):
    host = get_required_lambda_env_var("NEPTUNE_HOST")
    port = get_required_lambda_env_var("NEPTUNE_PORT")
    region = get_required_lambda_env_var("NEPTUNE_REGION")
    max_age_min = get_required_lambda_env_var("MAX_AGE_MIN")
    graph_name = get_required_lambda_env_var("GRAPH_NAME")
    try:
        max_age_min = int(max_age_min)
    except ValueError as ve:
        raise Exception(f"env var MAX_AGE_MIN must be an int: {ve}")
    now = int(datetime.now().timestamp())
    oldest_acceptable_graph_epoch = now - max_age_min * 60

    endpoint = NeptuneEndpoint(host=host, port=port, region=region)
    client = AltimeterNeptuneClient(max_age_min=max_age_min,
                                    neptune_endpoint=endpoint)
    logger = Logger()

    uncleared = []

    # first prune metadata - if clears below are partial we want to make sure no clients
    # consider this a valid graph still.
    logger.info(event=LogEvent.PruneNeptuneMetadataGraphStart)
    client.clear_old_graph_metadata(name=graph_name, max_age_min=max_age_min)
    logger.info(event=LogEvent.PruneNeptuneMetadataGraphEnd)
    # now clear actual graphs
    with logger.bind(neptune_endpoint=endpoint):
        logger.info(event=LogEvent.PruneNeptuneGraphsStart)
        for graph_metadata in client.get_graph_metadatas(name=graph_name):
            assert graph_metadata.name == graph_name
            graph_epoch = graph_metadata.end_time
            with logger.bind(graph_uri=graph_metadata.uri,
                             graph_epoch=graph_epoch):
                if graph_epoch < oldest_acceptable_graph_epoch:
                    logger.info(event=LogEvent.PruneNeptuneGraphStart)
                    try:
                        client.clear_graph(graph_uri=graph_metadata.uri)
                        logger.info(event=LogEvent.PruneNeptuneGraphEnd)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=
                            f"Error pruning graph {graph_metadata.uri}: {ex}",
                        )
                        uncleared.append(graph_metadata.uri)
                        continue
                else:
                    logger.info(event=LogEvent.PruneNeptuneGraphSkip)
        logger.info(event=LogEvent.PruneNeptuneGraphsEnd)
        if uncleared:
            msg = f"Errors were found pruning {uncleared}."
            logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg)
            raise Exception(msg)