def read_json(self, path: str) -> Dict[str, Any]:
        """Read a json artifact

        Args:
            path: s3 uri to artifact. s3://bucket/key/path

        Returns:
            artifact content
        """
        bucket, key = parse_s3_uri(path)
        if key is None:
            raise ValueError(f"Unable to read from s3 uri missing key: {path}")
        session = boto3.Session()
        s3_client = session.client("s3")
        logger = Logger()
        with logger.bind(bucket=bucket, key=key):
            with io.BytesIO() as artifact_bytes_buf:
                logger.info(event=LogEvent.ReadFromS3Start)
                s3_client.download_fileobj(bucket, key, artifact_bytes_buf)
                artifact_bytes_buf.flush()
                artifact_bytes_buf.seek(0)
                artifact_bytes = artifact_bytes_buf.read()
                logger.info(event=LogEvent.ReadFromS3End)
                artifact_str = artifact_bytes.decode("utf-8")
                artifact_dict = json.loads(artifact_str)
                return artifact_dict
示例#2
0
    def scan(
        self, account_scan_plan: AccountScanPlan
    ) -> Generator[AccountScanManifest, None, None]:
        """Scan accounts. Return a list of AccountScanManifest objects.

        Args:
            account_scan_plan: AccountScanPlan defining this scan op

        Yields:
            AccountScanManifest objects
        """
        num_total_accounts = len(account_scan_plan.account_ids)
        account_scan_plans = account_scan_plan.to_batches(
            max_accounts=self.config.concurrency.max_accounts_per_thread)
        num_account_batches = len(account_scan_plans)
        num_threads = min(num_account_batches,
                          self.config.concurrency.max_account_scan_threads)
        logger = Logger()
        with logger.bind(
                num_total_accounts=num_total_accounts,
                num_account_batches=num_account_batches,
                muxer=self.__class__.__name__,
                num_muxer_threads=num_threads,
        ):
            logger.info(event=AWSLogEvents.MuxerStart)
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                processed_accounts = 0
                futures = []
                for sub_account_scan_plan in account_scan_plans:
                    account_scan_future = self._schedule_account_scan(
                        executor, sub_account_scan_plan)
                    futures.append(account_scan_future)
                    logger.info(
                        event=AWSLogEvents.MuxerQueueScan,
                        account_ids=",".join(
                            sub_account_scan_plan.account_ids),
                    )
                for future in as_completed(futures):
                    scan_results_dicts = future.result()
                    for scan_results_dict in scan_results_dicts:
                        account_id = scan_results_dict["account_id"]
                        output_artifact = scan_results_dict["output_artifact"]
                        account_errors = scan_results_dict["errors"]
                        api_call_stats = scan_results_dict["api_call_stats"]
                        artifacts = [output_artifact
                                     ] if output_artifact else []
                        account_scan_result = AccountScanManifest(
                            account_id=account_id,
                            artifacts=artifacts,
                            errors=account_errors,
                            api_call_stats=api_call_stats,
                        )
                        yield account_scan_result
                        processed_accounts += 1
                    logger.info(event=AWSLogEvents.MuxerStat,
                                num_scanned=processed_accounts)
            logger.info(event=AWSLogEvents.MuxerEnd)
示例#3
0
def remediator(event: Dict[str, Any]) -> None:
    """Run the remediation lambda for a QJ result set"""
    config = RemediatorConfig()
    logger = Logger()
    remediation = Remediation(**event)
    with logger.bind(remediation=remediation):
        logger.info(event=QJLogEvents.RemediationInit)
        qj_api_client = QJAPIClient(host=config.qj_api_host)
        latest_result_set = qj_api_client.get_job_latest_result_set(
            job_name=remediation.job_name)
        if not latest_result_set:
            msg = f"No latest_result_set present for {remediation.job_name}"
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if latest_result_set.result_set_id != remediation.result_set_id:
            msg = (
                f"Remediation result_set_id {remediation.result_set_id} does not match the "
                f"latest result_set_id {latest_result_set.result_set_id}")
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if not latest_result_set.job.remediate_sqs_queue:
            msg = f"Job {latest_result_set.job.name} has no remediator"
            logger.error(QJLogEvents.JobHasNoRemediator, detail=msg)
            raise RemediationError(msg)
        num_threads = 10  # TODO env var
        errors = []
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for result in latest_result_set.results:
                logger.info(event=QJLogEvents.ProcessResult, result=result)
                future = _schedule_result_remediation(
                    executor=executor,
                    lambda_name=latest_result_set.job.remediate_sqs_queue,
                    lambda_timeout=300,  # TODO env var?
                    result=result,
                )
                futures.append(future)
            for future in as_completed(futures):
                try:
                    lambda_result = future.result()
                    logger.info(QJLogEvents.ResultRemediationSuccessful,
                                lambda_result=lambda_result)
                except Exception as ex:
                    logger.info(
                        event=QJLogEvents.ResultSetRemediationFailed,
                        error=str(ex),
                    )
                    errors.append(str(ex))
        if errors:
            logger.error(event=QJLogEvents.ResultSetRemediationFailed,
                         errors=errors)
            raise RemediationError(
                f"Errors encountered during remediation of {latest_result_set.job.name} "
                f"/ {latest_result_set.result_set_id}: {errors}")
示例#4
0
    def list_from_aws(cls: Type["S3BucketResourceSpec"], client: BaseClient,
                      account_id: str, region: str) -> ListFromAWSResult:
        """Return a dict of dicts of the format:

            {'bucket_1_arn': {bucket_1_dict},
             'bucket_2_arn': {bucket_2_dict},
             ...}

        Where the dicts represent results from list_buckets."""
        logger = Logger()
        buckets = {}
        buckets_resp = client.list_buckets()
        for bucket in buckets_resp.get("Buckets", []):
            bucket_name = bucket["Name"]
            try:
                try:
                    bucket_region = get_s3_bucket_region(client, bucket_name)
                except S3BucketAccessDeniedException as s3ade:
                    logger.warn(
                        event=AWSLogEvents.ScanAWSResourcesNonFatalError,
                        msg=
                        f"Unable to determine region for {bucket_name}: {s3ade}",
                    )
                    continue
                try:
                    bucket["Tags"] = get_s3_bucket_tags(client, bucket_name)
                except S3BucketAccessDeniedException as s3ade:
                    bucket["Tags"] = []
                    logger.warn(
                        event=AWSLogEvents.ScanAWSResourcesNonFatalError,
                        msg=
                        f"Unable to determine tags for {bucket_name}: {s3ade}",
                    )
                try:
                    bucket["ServerSideEncryption"] = get_s3_bucket_encryption(
                        client, bucket_name)
                except S3BucketAccessDeniedException as s3ade:
                    bucket["ServerSideEncryption"] = {"Rules": []}
                    logger.warn(
                        event=AWSLogEvents.ScanAWSResourcesNonFatalError,
                        msg=
                        f"Unable to determine encryption status for {bucket_name}: {s3ade}",
                    )
                resource_arn = cls.generate_arn(account_id=account_id,
                                                region=bucket_region,
                                                resource_id=bucket_name)
                buckets[resource_arn] = bucket
            except S3BucketDoesNotExistException as s3bdnee:
                logger.warn(
                    event=AWSLogEvents.ScanAWSResourcesNonFatalError,
                    msg=f"{bucket_name}: No longer exists: {s3bdnee}",
                )
        return ListFromAWSResult(resources=buckets)
示例#5
0
def graph_set_from_s3(s3_client: BaseClient, json_bucket: str,
                      json_key: str) -> GraphSet:
    """Load a GraphSet from json located in an s3 object."""
    logger = Logger()
    logger.info(event=LogEvent.ReadFromS3Start)
    with io.BytesIO() as json_bytes_buf:
        s3_client.download_fileobj(json_bucket, json_key, json_bytes_buf)
        json_bytes_buf.flush()
        json_bytes_buf.seek(0)
        graph_set_bytes = json_bytes_buf.read()
        logger.info(event=LogEvent.ReadFromS3End)
    graph_set_str = graph_set_bytes.decode("utf-8")
    graph_set_dict = json.loads(graph_set_str)
    return GraphSet.from_dict(graph_set_dict)