Beispiel #1
0
    def get_session(self,
                    account_id: str,
                    region: Optional[str] = None) -> boto3.Session:
        """Get a boto3 session for a given account.

        Args:
            account_id: target account id
            region: session region

        Returns:
            boto3.Session object
        """
        logger = Logger()
        with logger.bind(auth_account_id=account_id):
            if self.multi_hop_accessors:
                for mha in self.multi_hop_accessors:  # pylint: disable=not-an-iterable
                    with logger.bind(auth_accessor=str(mha)):
                        try:
                            session = mha.get_session(account_id=account_id,
                                                      region=region)
                            return session
                        except Exception as ex:
                            logger.debug(event=LogEvent.AuthToAccountFailure,
                                         exception=str(ex))

                raise AccountAuthException(
                    f"Unable to access {account_id} using {str(self)}")
            # local run mode
            session = boto3.Session(region_name=region)
            sts_client = session.client("sts")
            sts_account_id = sts_client.get_caller_identity()["Account"]
            if sts_account_id != account_id:
                raise ValueError(
                    f"BUG: sts_account_id {sts_account_id} != {account_id}")
            return session
Beispiel #2
0
def lambda_handler(event, context):
    host = get_required_lambda_env_var("NEPTUNE_HOST")
    port = get_required_lambda_env_var("NEPTUNE_PORT")
    region = get_required_lambda_env_var("NEPTUNE_REGION")
    max_age_min = get_required_lambda_env_var("MAX_AGE_MIN")
    graph_name = get_required_lambda_env_var("GRAPH_NAME")
    try:
        max_age_min = int(max_age_min)
    except ValueError as ve:
        raise Exception(f"env var MAX_AGE_MIN must be an int: {ve}")
    now = int(datetime.now().timestamp())
    oldest_acceptable_graph_epoch = now - max_age_min * 60

    endpoint = NeptuneEndpoint(host=host, port=port, region=region)
    client = AltimeterNeptuneClient(max_age_min=max_age_min,
                                    neptune_endpoint=endpoint)
    logger = Logger()

    uncleared = []

    # first prune metadata - if clears below are partial we want to make sure no clients
    # consider this a valid graph still.
    logger.info(event=LogEvent.PruneNeptuneMetadataGraphStart)
    client.clear_old_graph_metadata(name=graph_name, max_age_min=max_age_min)
    logger.info(event=LogEvent.PruneNeptuneMetadataGraphEnd)
    # now clear actual graphs
    with logger.bind(neptune_endpoint=endpoint):
        logger.info(event=LogEvent.PruneNeptuneGraphsStart)
        for graph_metadata in client.get_graph_metadatas(name=graph_name):
            assert graph_metadata.name == graph_name
            graph_epoch = graph_metadata.end_time
            with logger.bind(graph_uri=graph_metadata.uri,
                             graph_epoch=graph_epoch):
                if graph_epoch < oldest_acceptable_graph_epoch:
                    logger.info(event=LogEvent.PruneNeptuneGraphStart)
                    try:
                        client.clear_graph(graph_uri=graph_metadata.uri)
                        logger.info(event=LogEvent.PruneNeptuneGraphEnd)
                    except Exception as ex:
                        logger.error(
                            event=LogEvent.PruneNeptuneGraphError,
                            msg=
                            f"Error pruning graph {graph_metadata.uri}: {ex}",
                        )
                        uncleared.append(graph_metadata.uri)
                        continue
                else:
                    logger.info(event=LogEvent.PruneNeptuneGraphSkip)
        logger.info(event=LogEvent.PruneNeptuneGraphsEnd)
        if uncleared:
            msg = f"Errors were found pruning {uncleared}."
            logger.error(event=LogEvent.PruneNeptuneGraphsError, msg=msg)
            raise Exception(msg)
Beispiel #3
0
    def write_graph_set(self,
                        name: str,
                        graph_set: GraphSet,
                        compression: Optional[str] = None) -> str:
        """Write a graph artifact

        Args:
            name: name
            graph_set: GraphSet object to write

        Returns:
            path to written artifact
        """
        logger = Logger()
        os.makedirs(self.output_dir, exist_ok=True)
        if compression is None:
            artifact_path = os.path.join(self.output_dir, f"{name}.rdf")
        elif compression == GZIP:
            artifact_path = os.path.join(self.output_dir, f"{name}.rdf.gz")
        else:
            raise ValueError(f"Unknown compression arg {compression}")
        graph = graph_set.to_rdf()
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.WriteToFSStart)
            with open(artifact_path, "wb") as fp:
                if compression is None:
                    graph.serialize(fp)
                elif compression == GZIP:
                    with gzip.GzipFile(fileobj=fp, mode="wb") as gz:
                        graph.serialize(gz)
                else:
                    raise ValueError(f"Unknown compression arg {compression}")
            logger.info(event=LogEvent.WriteToFSEnd)
        return artifact_path
Beispiel #4
0
def scan_services(
    graph_name: str,
    graph_version: str,
    account_id: str,
    region: str,
    service: str,
    access_key: str,
    secret_key: str,
    token: str,
    resource_spec_classes: Tuple[Type[AWSResourceSpec], ...],
) -> Dict[str, Any]:
    logger = Logger()
    with logger.bind(region=region, service=service):
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart)
        session = boto3.Session(
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
            aws_session_token=token,
            region_name=region,
        )
        aws_accessor = AWSAccessor(session=session,
                                   account_id=account_id,
                                   region_name=region)
        graph_spec = GraphSpec(
            name=graph_name,
            version=graph_version,
            resource_spec_classes=resource_spec_classes,
            scan_accessor=aws_accessor,
        )
        graph_set = graph_spec.scan()
        graph_set_dict = graph_set.to_dict()
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd)
        return graph_set_dict
Beispiel #5
0
 def lambda_handler(cls, event: Dict[str, Any], _: Any) -> None:
     """lambda entrypoint"""
     config = Config()
     result = Result(**event)
     logger = Logger()
     errors: List[str] = []
     with logger.bind(result=result):
         logger.info(event=QJLogEvents.ResultRemediationStart)
         try:
             session = get_assumed_session(
                 account_id=result.account_id,
                 role_name=config.remediator_target_role_name,
                 external_id=config.remediator_target_role_external_id,
             )
             cls.remediate(session=session,
                           result=result.result,
                           dry_run=config.dry_run)
             logger.info(event=QJLogEvents.ResultRemediationSuccessful)
         except Exception as ex:
             logger.error(event=QJLogEvents.ResultRemediationFailed,
                          error=str(ex))
             errors.append(str(ex))
     if errors:
         raise RemediationError(
             f"Errors found during remediation: {errors}")
Beispiel #6
0
    def scan(self) -> GraphSet:
        """Perform a scan on all of the resource classes in this GraphSpec and return
        a GraphSet containing the scanned data.

        Returns:
            GraphSet representing results of scanning this GraphSpec's resource_spec_classes.
        """

        resources: List[Resource] = []
        errors: List[str] = []
        stats = MultilevelCounter()
        start_time = int(time.time())
        logger = Logger()
        for resource_spec_class in self.resource_spec_classes:
            with logger.bind(resource_type=str(resource_spec_class.type_name)):
                logger.debug(event=LogEvent.ScanResourceTypeStart)
                resource_scan_result = resource_spec_class.scan(
                    scan_accessor=self.scan_accessor)
                resources += resource_scan_result.resources
                errors += resource_scan_result.errors
                stats.merge(resource_scan_result.stats)
                logger.debug(event=LogEvent.ScanResourceTypeEnd)
        end_time = int(time.time())
        return GraphSet(
            name=self.name,
            version=self.version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=stats,
        )
Beispiel #7
0
def invoke_lambda(lambda_name: str, lambda_timeout: int, event: Dict[str, Any]) -> Dict[str, Any]:
    """Invoke an AWS Lambda function

    Args:
        lambda_name: name of lambda
        lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait
                        at least this long for a response before timing out.
        event: event data to send to the lambda

    Returns:
        lambda response payload

    Raises:
        Exception if there was an error invoking the lambda.
    """
    logger = Logger()
    with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, event=event):
        logger.info(event=AWSLogEvents.RunAccountScanLambdaStart)
        boto_config = botocore.config.Config(
            read_timeout=lambda_timeout + 10, retries={"max_attempts": 0}
        )
        session = boto3.Session()
        lambda_client = session.client("lambda", config=boto_config)
        resp = lambda_client.invoke(
            FunctionName=lambda_name, Payload=json.dumps(event).encode("utf-8")
        )
        payload: bytes = resp["Payload"].read()
        if resp.get("FunctionError", None):
            raise Exception(f"Error invoking {lambda_name} with event {event}: {payload}")
        payload_dict = json.loads(payload)
        logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd)
        return payload_dict
    def read_json(self, path: str) -> Dict[str, Any]:
        """Read a json artifact

        Args:
            path: s3 uri to artifact. s3://bucket/key/path

        Returns:
            artifact content
        """
        bucket, key = parse_s3_uri(path)
        if key is None:
            raise ValueError(f"Unable to read from s3 uri missing key: {path}")
        session = boto3.Session()
        s3_client = session.client("s3")
        logger = Logger()
        with logger.bind(bucket=bucket, key=key):
            with io.BytesIO() as artifact_bytes_buf:
                logger.info(event=LogEvent.ReadFromS3Start)
                s3_client.download_fileobj(bucket, key, artifact_bytes_buf)
                artifact_bytes_buf.flush()
                artifact_bytes_buf.seek(0)
                artifact_bytes = artifact_bytes_buf.read()
                logger.info(event=LogEvent.ReadFromS3End)
                artifact_str = artifact_bytes.decode("utf-8")
                artifact_dict = json.loads(artifact_str)
                return artifact_dict
Beispiel #9
0
def lambda_handler(event, context):
    json_bucket = event["Records"][0]["s3"]["bucket"]["name"]
    json_key = urllib.parse.unquote(event["Records"][0]["s3"]["object"]["key"])
    rdf_bucket = get_required_lambda_env_var("RDF_BUCKET")
    rdf_key = ".".join(json_key.split(".")[:-1]) + ".rdf.gz"
    session = boto3.Session()
    s3_client = session.client("s3")

    logger = Logger()
    with logger.bind(json_bucket=json_bucket, json_key=json_key):
        graph_pkg = graph_pkg_from_s3(s3_client=s3_client,
                                      json_bucket=json_bucket,
                                      json_key=json_key)

    with logger.bind(rdf_bucket=rdf_bucket, rdf_key=rdf_key):
        logger.info(event=LogEvent.WriteToS3Start)
        with io.BytesIO() as rdf_bytes_buf:
            with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz:
                graph_pkg.graph.serialize(gz)
            rdf_bytes_buf.flush()
            rdf_bytes_buf.seek(0)
            s3_client.upload_fileobj(rdf_bytes_buf, rdf_bucket, rdf_key)
            s3_client.put_object_tagging(
                Bucket=rdf_bucket,
                Key=rdf_key,
                Tagging={
                    "TagSet": [
                        {
                            "Key": "name",
                            "Value": graph_pkg.name
                        },
                        {
                            "Key": "version",
                            "Value": graph_pkg.version
                        },
                        {
                            "Key": "start_time",
                            "Value": str(graph_pkg.start_time)
                        },
                        {
                            "Key": "end_time",
                            "Value": str(graph_pkg.end_time)
                        },
                    ]
                },
            )
        logger.info(event=LogEvent.WriteToS3End)
Beispiel #10
0
    def scan(
        self, account_scan_plan: AccountScanPlan
    ) -> Generator[AccountScanManifest, None, None]:
        """Scan accounts. Return a list of AccountScanManifest objects.

        Args:
            account_scan_plan: AccountScanPlan defining this scan op

        Yields:
            AccountScanManifest objects
        """
        num_total_accounts = len(account_scan_plan.account_ids)
        account_scan_plans = account_scan_plan.to_batches(
            max_accounts=self.config.concurrency.max_accounts_per_thread)
        num_account_batches = len(account_scan_plans)
        num_threads = min(num_account_batches,
                          self.config.concurrency.max_account_scan_threads)
        logger = Logger()
        with logger.bind(
                num_total_accounts=num_total_accounts,
                num_account_batches=num_account_batches,
                muxer=self.__class__.__name__,
                num_muxer_threads=num_threads,
        ):
            logger.info(event=AWSLogEvents.MuxerStart)
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                processed_accounts = 0
                futures = []
                for sub_account_scan_plan in account_scan_plans:
                    account_scan_future = self._schedule_account_scan(
                        executor, sub_account_scan_plan)
                    futures.append(account_scan_future)
                    logger.info(
                        event=AWSLogEvents.MuxerQueueScan,
                        account_ids=",".join(
                            sub_account_scan_plan.account_ids),
                    )
                for future in as_completed(futures):
                    scan_results_dicts = future.result()
                    for scan_results_dict in scan_results_dicts:
                        account_id = scan_results_dict["account_id"]
                        output_artifact = scan_results_dict["output_artifact"]
                        account_errors = scan_results_dict["errors"]
                        api_call_stats = scan_results_dict["api_call_stats"]
                        artifacts = [output_artifact
                                     ] if output_artifact else []
                        account_scan_result = AccountScanManifest(
                            account_id=account_id,
                            artifacts=artifacts,
                            errors=account_errors,
                            api_call_stats=api_call_stats,
                        )
                        yield account_scan_result
                        processed_accounts += 1
                    logger.info(event=AWSLogEvents.MuxerStat,
                                num_scanned=processed_accounts)
            logger.info(event=AWSLogEvents.MuxerEnd)
Beispiel #11
0
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]:
    logger = Logger()
    with logger.bind(
            account_id=scan_unit.account_id,
            region=scan_unit.region_name,
            service=scan_unit.service,
            resource_classes=sorted([
                resource_spec_class.__name__
                for resource_spec_class in scan_unit.resource_spec_classes
            ]),
    ):
        start_t = time.time()
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart)
        session = boto3.Session(
            aws_access_key_id=scan_unit.access_key,
            aws_secret_access_key=scan_unit.secret_key,
            aws_session_token=scan_unit.token,
            region_name=scan_unit.region_name,
        )
        scan_accessor = AWSAccessor(session=session,
                                    account_id=scan_unit.account_id,
                                    region_name=scan_unit.region_name)
        graph_spec = GraphSpec(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            resource_spec_classes=scan_unit.resource_spec_classes,
            scan_accessor=scan_accessor,
        )
        start_time = int(time.time())
        resources: List[Resource] = []
        errors = []
        try:
            resources = graph_spec.scan()
        except Exception as ex:
            error_str = str(ex)
            trace_back = traceback.format_exc()
            logger.error(event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back)
            error = f"{str(ex)}\n{trace_back}"
            errors.append(error)
        end_time = int(time.time())
        graph_set = GraphSet(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=scan_accessor.api_call_stats,
        )
        end_t = time.time()
        elapsed_sec = end_t - start_t
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd,
                    elapsed_sec=elapsed_sec)
        return (scan_unit.account_id, graph_set.to_dict())
Beispiel #12
0
def remediator(event: Dict[str, Any]) -> None:
    """Run the remediation lambda for a QJ result set"""
    config = RemediatorConfig()
    logger = Logger()
    remediation = Remediation(**event)
    with logger.bind(remediation=remediation):
        logger.info(event=QJLogEvents.RemediationInit)
        qj_api_client = QJAPIClient(host=config.qj_api_host)
        latest_result_set = qj_api_client.get_job_latest_result_set(
            job_name=remediation.job_name)
        if not latest_result_set:
            msg = f"No latest_result_set present for {remediation.job_name}"
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if latest_result_set.result_set_id != remediation.result_set_id:
            msg = (
                f"Remediation result_set_id {remediation.result_set_id} does not match the "
                f"latest result_set_id {latest_result_set.result_set_id}")
            logger.error(QJLogEvents.StaleResultSet, detail=msg)
            raise RemediationError(msg)
        if not latest_result_set.job.remediate_sqs_queue:
            msg = f"Job {latest_result_set.job.name} has no remediator"
            logger.error(QJLogEvents.JobHasNoRemediator, detail=msg)
            raise RemediationError(msg)
        num_threads = 10  # TODO env var
        errors = []
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for result in latest_result_set.results:
                logger.info(event=QJLogEvents.ProcessResult, result=result)
                future = _schedule_result_remediation(
                    executor=executor,
                    lambda_name=latest_result_set.job.remediate_sqs_queue,
                    lambda_timeout=300,  # TODO env var?
                    result=result,
                )
                futures.append(future)
            for future in as_completed(futures):
                try:
                    lambda_result = future.result()
                    logger.info(QJLogEvents.ResultRemediationSuccessful,
                                lambda_result=lambda_result)
                except Exception as ex:
                    logger.info(
                        event=QJLogEvents.ResultSetRemediationFailed,
                        error=str(ex),
                    )
                    errors.append(str(ex))
        if errors:
            logger.error(event=QJLogEvents.ResultSetRemediationFailed,
                         errors=errors)
            raise RemediationError(
                f"Errors encountered during remediation of {latest_result_set.job.name} "
                f"/ {latest_result_set.result_set_id}: {errors}")
 def notify(self, notification: schemas.ResultSetNotification) -> None:
     logger = Logger()
     with logger.bind(notification=notification):
         logger.info(event=QJLogEvents.NotifyNewResultsStart)
         session = boto3.Session(region_name=self.region_name)
         sns_client = session.client("sns", region_name=self.region_name)
         sns_client.publish(
             TopicArn=self.sns_topic_arn,
             Message=json.dumps({"default": notification.json()}),
             MessageStructure="json",
         )
         logger.info(event=QJLogEvents.NotifyNewResultsEnd)
Beispiel #14
0
    def read_artifact(self, artifact_path: str) -> Dict[str, Any]:
        """Read an artifact

        Args:
            artifact_path: filesystem path to artifact

        Returns:
            artifact content
        """
        logger = Logger()
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.ReadFromFSStart)
            with open(artifact_path, "r") as artifact_fp:
                data = json.load(artifact_fp)
            logger.info(event=LogEvent.ReadFromFSEnd)
            return data
def invoke_lambda(
    lambda_name: str, lambda_timeout: int, account_scan_lambda_event: AccountScanLambdaEvent
) -> AccountScanResult:
    """Invoke the AccountScan AWS Lambda function

    Args:
        lambda_name: name of lambda
        lambda_timeout: timeout of the lambda. Used to tell the boto3 lambda client to wait
                        at least this long for a response before timing out.
        account_scan_lambda_event: AccountScanLambdaEvent object to serialize to json and send to the lambda

    Returns:
        AccountScanResult

    Raises:
        Exception if there was an error invoking the lambda.
    """
    logger = Logger()
    account_id = account_scan_lambda_event.account_scan_plan.account_id
    with logger.bind(lambda_name=lambda_name, lambda_timeout=lambda_timeout, account_id=account_id):
        logger.info(event=AWSLogEvents.RunAccountScanLambdaStart)
        boto_config = botocore.config.Config(
            read_timeout=lambda_timeout + 10, retries={"max_attempts": 0},
        )
        session = boto3.Session()
        lambda_client = session.client("lambda", config=boto_config)
        try:
            resp = lambda_client.invoke(
                FunctionName=lambda_name, Payload=account_scan_lambda_event.json().encode("utf-8")
            )
        except Exception as invoke_ex:
            error = str(invoke_ex)
            logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=error)
            raise Exception(
                f"Error while invoking {lambda_name} with event {account_scan_lambda_event.json()}: {error}"
            ) from invoke_ex
        payload: bytes = resp["Payload"].read()
        if resp.get("FunctionError", None):
            function_error = payload.decode()
            logger.info(event=AWSLogEvents.RunAccountScanLambdaError, error=function_error)
            raise Exception(
                f"Function error in {lambda_name} with event {account_scan_lambda_event.json()}: {function_error}"
            )
        payload_dict = json.loads(payload)
        account_scan_result = AccountScanResult(**payload_dict)
        logger.info(event=AWSLogEvents.RunAccountScanLambdaEnd)
        return account_scan_result
Beispiel #16
0
    def scan(self, account_scan_plans: List[AccountScanPlan]) -> List[AccountScanManifest]:
        """Scan accounts. Return a list of AccountScanManifest objects.

        Args:
            account_scan_plans: list of AccountScanPlan objects defining this scan op

        Returns:
            list of AccountScanManifest objects describing the output of the scan.
        """
        account_scan_results: List[AccountScanManifest] = []
        num_total_accounts = len(account_scan_plans)
        num_threads = min(num_total_accounts, self.max_threads)
        logger = Logger()
        with logger.bind(
            num_total_accounts=num_total_accounts,
            muxer=self.__class__.__name__,
            num_threads=num_threads,
        ):
            logger.info(event=AWSLogEvents.MuxerStart)
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                processed_accounts = 0
                futures = []
                for account_scan_plan in account_scan_plans:
                    account_scan_future = self._schedule_account_scan(executor, account_scan_plan)
                    futures.append(account_scan_future)
                    logger.info(
                        event=AWSLogEvents.MuxerQueueScan, account_id=account_scan_plan.account_id
                    )
                for future in as_completed(futures):
                    scan_results_dict = future.result()
                    account_id = scan_results_dict["account_id"]
                    output_artifact = scan_results_dict["output_artifact"]
                    account_errors = scan_results_dict["errors"]
                    api_call_stats = scan_results_dict["api_call_stats"]
                    artifacts = [output_artifact] if output_artifact else []
                    account_scan_result = AccountScanManifest(
                        account_id=account_id,
                        artifacts=artifacts,
                        errors=account_errors,
                        api_call_stats=api_call_stats,
                    )
                    account_scan_results.append(account_scan_result)
                    processed_accounts += 1
                    logger.info(event=AWSLogEvents.MuxerStat, num_scanned=processed_accounts)
            logger.info(event=AWSLogEvents.MuxerEnd)
        return account_scan_results
    def scan(self) -> List[Resource]:
        """Perform a scan on all of the resource classes in this GraphSpec and return
        a list of Resource objects.

        Returns:
            List of Resource objects
        """
        resources: List[Resource] = []
        logger = Logger()
        for resource_spec_class in self.resource_spec_classes:
            with logger.bind(resource_type=str(resource_spec_class.type_name)):
                logger.debug(event=LogEvent.ScanResourceTypeStart)
                scanned_resources = resource_spec_class.scan(
                    scan_accessor=self.scan_accessor)
                resources += scanned_resources
                logger.debug(event=LogEvent.ScanResourceTypeEnd)
        return resources
def get_sub_account_ids(account_ids: Tuple[str, ...], accessor: Accessor) -> Tuple[str, ...]:
    logger = Logger()
    logger.info(event=AWSLogEvents.GetSubAccountsStart)
    sub_account_ids: Set[str] = set(account_ids)
    for master_account_id in account_ids:
        with logger.bind(master_account_id=master_account_id):
            account_session = accessor.get_session(master_account_id)
            orgs_client = account_session.client("organizations")
            resp = orgs_client.describe_organization()
            if resp["Organization"]["MasterAccountId"] == master_account_id:
                accounts_paginator = orgs_client.get_paginator("list_accounts")
                for accounts_resp in accounts_paginator.paginate():
                    for account_resp in accounts_resp["Accounts"]:
                        if account_resp["Status"].lower() == "active":
                            account_id = account_resp["Id"]
                            sub_account_ids.add(account_id)
    logger.info(event=AWSLogEvents.GetSubAccountsEnd)
    return tuple(sub_account_ids)
Beispiel #19
0
    def write_json(self, name: str, data: BaseModel) -> str:
        """Write artifact data to self.output_dir/name.json

        Args:
            name: filename
            data: data

        Returns:
            Full filesystem path of artifact file
        """
        logger = Logger()
        os.makedirs(self.output_dir, exist_ok=True)
        artifact_path = os.path.join(self.output_dir, f"{name}.json")
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.WriteToFSStart)
            with open(artifact_path, "w") as artifact_fp:
                artifact_fp.write(data.json(exclude_unset=True))
            logger.info(event=LogEvent.WriteToFSEnd)
        return artifact_path
Beispiel #20
0
    def write_artifact(self, name: str, data: Dict[str, Any]) -> str:
        """Write artifact data to self.output_dir/name.json

        Args:
            name: filename
            data: artifact data

        Returns:
            Full filesystem path of artifact file
        """
        logger = Logger()
        os.makedirs(self.output_dir, exist_ok=True)
        artifact_path = os.path.join(self.output_dir, f"{name}.json")
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.WriteToFSStart)
            with open(artifact_path, "w") as artifact_fp:
                json.dump(data, artifact_fp, default=json_encoder)
            logger.info(event=LogEvent.WriteToFSEnd)
        return artifact_path
Beispiel #21
0
def _invoke_lambda(
    lambda_name: str,
    lambda_timeout: int,
    result: Result,
) -> Any:
    """Invoke a QJ's remediator function"""
    logger = Logger()
    with logger.bind(lambda_name=lambda_name,
                     lambda_timeout=lambda_timeout,
                     result=result):
        logger.info(event=QJLogEvents.InvokeResultRemediationLambdaStart)
        boto_config = botocore.config.Config(
            read_timeout=lambda_timeout + 10,
            retries={"max_attempts": 0},
        )
        session = boto3.Session()
        lambda_client = session.client("lambda", config=boto_config)
        event = result.json().encode("utf-8")
        try:
            resp = lambda_client.invoke(
                FunctionName=lambda_name,
                Payload=event,
            )
        except Exception as invoke_ex:
            error = str(invoke_ex)
            logger.info(event=QJLogEvents.InvokeResultRemediationLambdaError,
                        error=error)
            raise Exception(
                f"Error while invoking {lambda_name} with event: {str(event)}: {error}"
            ) from invoke_ex
        lambda_result: bytes = resp["Payload"].read()
        if resp.get("FunctionError", None):
            error = lambda_result.decode()
            logger.info(event=QJLogEvents.ResultRemediationLambdaRunError,
                        error=error)
            raise Exception(
                f"Function error in {lambda_name} with event {str(event)}: {error}"
            )
        logger.info(event=QJLogEvents.InvokeResultRemediationLambdaEnd)
        return json.loads(lambda_result)
Beispiel #22
0
    def write_json(self, name: str, data: BaseModel) -> str:
        """Write artifact data to s3://self.bucket/self.key_prefix/name.json

        Args:
            name: s3 key name
            data: data

        Returns:
            S3 uri (s3://bucket/key/path) to artifact
        """

        output_key = "/".join((self.key_prefix, f"{name}.json"))
        logger = Logger()
        with logger.bind(bucket=self.bucket, key=output_key):
            logger.info(event=LogEvent.WriteToS3Start)
            s3_client = boto3.Session().client("s3")
            results_str = data.json(exclude_unset=True)
            results_bytes = results_str.encode("utf-8")
            with io.BytesIO(results_bytes) as results_bytes_stream:
                s3_client.upload_fileobj(results_bytes_stream, self.bucket,
                                         output_key)
            logger.info(event=LogEvent.WriteToS3End)
        return f"s3://{self.bucket}/{output_key}"
Beispiel #23
0
    def write_artifact(self, name: str, data: Dict[str, Any]) -> str:
        """Write artifact data to s3://self.bucket/self.key_prefix/name.json

        Args:
            name: s3 key name
            data: artifact data

        Returns:
            S3 uri (s3://bucket/key/path) to artifact
        """

        output_key = "/".join((self.key_prefix, f"{name}.json"))
        logger = Logger()
        with logger.bind(bucket=self.bucket, key=output_key):
            logger.info(event=LogEvent.WriteToS3Start)
            s3_client = boto3.client("s3")
            results_str = json.dumps(data, default=json_encoder)
            results_bytes = results_str.encode("utf-8")
            with io.BytesIO(results_bytes) as results_bytes_stream:
                s3_client.upload_fileobj(results_bytes_stream, self.bucket,
                                         output_key)
            logger.info(event=LogEvent.WriteToS3End)
        return f"s3://{self.bucket}/{output_key}"
Beispiel #24
0
    def read_artifact(self, artifact_path: str) -> Dict[str, Any]:
        """Read an artifact

        Args:
            artifact_path: s3 uri to artifact. s3://bucket/key/path

        Returns:
            artifact content
        """
        bucket, key = parse_s3_uri(artifact_path)
        session = boto3.Session()
        s3_client = session.client("s3")
        logger = Logger()
        with io.BytesIO() as artifact_bytes_buf:
            with logger.bind(bucket=bucket, key=key):
                logger.info(event=LogEvent.ReadFromS3Start)
                s3_client.download_fileobj(bucket, key, artifact_bytes_buf)
                artifact_bytes_buf.flush()
                artifact_bytes_buf.seek(0)
                artifact_bytes = artifact_bytes_buf.read()
                logger.info(event=LogEvent.ReadFromS3End)
                artifact_str = artifact_bytes.decode("utf-8")
                artifact_dict = json.loads(artifact_str)
                return artifact_dict
def enqueue_queries(jobs: List[schemas.Job], queue_url: str,
                    execution_hash: str, region: str) -> None:
    """Enqueue querys by sending a message for each job key to queue_url"""
    sqs_client = boto3.client("sqs", region_name=region)
    logger = Logger()
    with logger.bind(queue_url=queue_url, execution_hash=execution_hash):
        for job in jobs:
            job_hash = hashlib.sha256()
            job_hash.update(json.dumps(job.json()).encode())
            message_group_id = job_hash.hexdigest()
            job_hash.update(execution_hash.encode())
            message_dedupe_id = job_hash.hexdigest()
            logger.info(
                QJLogEvents.ScheduleJob,
                job=job,
                message_group_id=message_group_id,
                message_dedupe_id=message_dedupe_id,
            )
            sqs_client.send_message(
                QueueUrl=queue_url,
                MessageBody=job.json(),
                MessageGroupId=message_group_id,
                MessageDeduplicationId=message_dedupe_id,
            )
Beispiel #26
0
    def scan(self) -> Dict[str, Any]:
        """Scan an account and return a dict containing keys:

            * account_id: str
            * output_artifact: str
            * api_call_stats: Dict[str, Any]
            * errors: List[str]

        If errors is non-empty the results are incomplete for this account.
        output_artifact is a pointer to the actual scan data - either on the local fs or in s3.

        To scan an account we create a set of GraphSpecs, one for each region.  Any ACCOUNT
        level granularity resources are only scanned in a single region (e.g. IAM Users)

        Returns:
            Dict of scan result, see above for details.
        """
        logger = Logger()
        with logger.bind(account_id=self.account_id):
            logger.info(event=AWSLogEvents.ScanAWSAccountStart)
            output_artifact = None
            stats = MultilevelCounter()
            errors: List[str] = []
            now = int(time.time())
            try:
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[],
                    errors=[],
                    stats=stats,
                )
                # sanity check
                session = self.get_session()
                sts_client = session.client("sts")
                sts_account_id = sts_client.get_caller_identity()["Account"]
                if sts_account_id != self.account_id:
                    raise ValueError(
                        f"BUG: sts detected account_id {sts_account_id} != {self.account_id}"
                    )
                if self.regions:
                    scan_regions = tuple(self.regions)
                else:
                    scan_regions = get_all_enabled_regions(session=session)
                # build graph specs.
                # build a dict of regions -> services -> List[AWSResourceSpec]
                regions_services_resource_spec_classes: DefaultDict[
                    str,
                    DefaultDict[str,
                                List[Type[AWSResourceSpec]]]] = defaultdict(
                                    lambda: defaultdict(list))
                resource_spec_class: Type[AWSResourceSpec]
                for resource_spec_class in self.resource_spec_classes:
                    client_name = resource_spec_class.get_client_name()
                    resource_class_scan_granularity = resource_spec_class.scan_granularity
                    if resource_class_scan_granularity == ScanGranularity.ACCOUNT:
                        regions_services_resource_spec_classes[scan_regions[
                            0]][client_name].append(resource_spec_class)
                    elif resource_class_scan_granularity == ScanGranularity.REGION:
                        for region in scan_regions:
                            regions_services_resource_spec_classes[region][
                                client_name].append(resource_spec_class)
                    else:
                        raise NotImplementedError(
                            f"ScanGranularity {resource_class_scan_granularity} not implemented"
                        )
                with ThreadPoolExecutor(
                        max_workers=self.max_svc_threads) as executor:
                    futures = []
                    for (
                            region,
                            services_resource_spec_classes,
                    ) in regions_services_resource_spec_classes.items():
                        for (
                                service,
                                resource_spec_classes,
                        ) in services_resource_spec_classes.items():
                            region_session = self.get_session(region=region)
                            region_creds = region_session.get_credentials()
                            scan_future = schedule_scan_services(
                                executor=executor,
                                graph_name=self.graph_name,
                                graph_version=self.graph_version,
                                account_id=self.account_id,
                                region=region,
                                service=service,
                                access_key=region_creds.access_key,
                                secret_key=region_creds.secret_key,
                                token=region_creds.token,
                                resource_spec_classes=tuple(
                                    resource_spec_classes),
                            )
                            futures.append(scan_future)
                    for future in as_completed(futures):
                        graph_set_dict = future.result()
                        graph_set = GraphSet.from_dict(graph_set_dict)
                        errors += graph_set.errors
                        account_graph_set.merge(graph_set)
                    account_graph_set.validate()
            except Exception as ex:
                error_str = str(ex)
                trace_back = traceback.format_exc()
                logger.error(event=AWSLogEvents.ScanAWSAccountError,
                             error=error_str,
                             trace_back=trace_back)
                errors.append(" : ".join((error_str, trace_back)))
                unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                    account_id=self.account_id, errors=errors)
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[unscanned_account_resource],
                    errors=errors,
                    stats=stats,
                )
                account_graph_set.validate()
        output_artifact = self.artifact_writer.write_artifact(
            name=self.account_id, data=account_graph_set.to_dict())
        logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
        api_call_stats = account_graph_set.stats.to_dict()
        return {
            "account_id": self.account_id,
            "output_artifact": output_artifact,
            "errors": errors,
            "api_call_stats": api_call_stats,
        }
Beispiel #27
0
 def scan(self) -> List[Dict[str, Any]]:
     logger = Logger()
     scan_result_dicts = []
     now = int(time.time())
     prescan_account_ids_errors: DefaultDict[str,
                                             List[str]] = defaultdict(list)
     futures = []
     with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
         shuffled_account_ids = random.sample(
             self.account_scan_plan.account_ids,
             k=len(self.account_scan_plan.account_ids))
         for account_id in shuffled_account_ids:
             with logger.bind(account_id=account_id):
                 logger.info(event=AWSLogEvents.ScanAWSAccountStart)
                 try:
                     session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id)
                     # sanity check
                     sts_client = session.client("sts")
                     sts_account_id = sts_client.get_caller_identity(
                     )["Account"]
                     if sts_account_id != account_id:
                         raise ValueError(
                             f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                         )
                     if self.account_scan_plan.regions:
                         scan_regions = tuple(
                             self.account_scan_plan.regions)
                     else:
                         scan_regions = get_all_enabled_regions(
                             session=session)
                     account_gran_scan_region = random.choice(
                         self.preferred_account_scan_regions)
                     # build a dict of regions -> services -> List[AWSResourceSpec]
                     regions_services_resource_spec_classes: DefaultDict[
                         str, DefaultDict[
                             str,
                             List[Type[AWSResourceSpec]]]] = defaultdict(
                                 lambda: defaultdict(list))
                     resource_spec_class: Type[AWSResourceSpec]
                     for resource_spec_class in self.resource_spec_classes:
                         client_name = resource_spec_class.get_client_name()
                         if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT:
                             if resource_spec_class.region_whitelist:
                                 account_resource_scan_region = resource_spec_class.region_whitelist[
                                     0]
                             else:
                                 account_resource_scan_region = account_gran_scan_region
                             regions_services_resource_spec_classes[
                                 account_resource_scan_region][
                                     client_name].append(
                                         resource_spec_class)
                         elif resource_spec_class.scan_granularity == ScanGranularity.REGION:
                             if resource_spec_class.region_whitelist:
                                 resource_scan_regions = tuple(
                                     region for region in scan_regions
                                     if region in
                                     resource_spec_class.region_whitelist)
                                 if not resource_scan_regions:
                                     resource_scan_regions = resource_spec_class.region_whitelist
                             else:
                                 resource_scan_regions = scan_regions
                             for region in resource_scan_regions:
                                 regions_services_resource_spec_classes[
                                     region][client_name].append(
                                         resource_spec_class)
                         else:
                             raise NotImplementedError(
                                 f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented"
                             )
                     # Build and submit ScanUnits
                     shuffed_regions_services_resource_spec_classes = random.sample(
                         regions_services_resource_spec_classes.items(),
                         len(regions_services_resource_spec_classes),
                     )
                     for (
                             region,
                             services_resource_spec_classes,
                     ) in shuffed_regions_services_resource_spec_classes:
                         region_session = self.account_scan_plan.accessor.get_session(
                             account_id=account_id, region_name=region)
                         region_creds = region_session.get_credentials()
                         shuffled_services_resource_spec_classes = random.sample(
                             services_resource_spec_classes.items(),
                             len(services_resource_spec_classes),
                         )
                         for (
                                 service,
                                 svc_resource_spec_classes,
                         ) in shuffled_services_resource_spec_classes:
                             future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=tuple(
                                     svc_resource_spec_classes),
                             )
                             futures.append(future)
                 except Exception as ex:
                     error_str = str(ex)
                     trace_back = traceback.format_exc()
                     logger.error(
                         event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back,
                     )
                     prescan_account_ids_errors[account_id].append(
                         f"{error_str}\n{trace_back}")
     account_ids_graph_set_dicts: Dict[str,
                                       List[Dict[str,
                                                 Any]]] = defaultdict(list)
     for future in as_completed(futures):
         account_id, graph_set_dict = future.result()
         account_ids_graph_set_dicts[account_id].append(graph_set_dict)
     # first make sure no account id appears both in account_ids_graph_set_dicts
     # and prescan_account_ids_errors - this should never happen
     doubled_accounts = set(
         account_ids_graph_set_dicts.keys()).intersection(
             set(prescan_account_ids_errors.keys()))
     if doubled_accounts:
         raise Exception((
             f"BUG: Account(s) {doubled_accounts} in both "
             "account_ids_graph_set_dicts and prescan_account_ids_errors."))
     # graph prescan error accounts
     for account_id, errors in prescan_account_ids_errors.items():
         with logger.bind(account_id=account_id):
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
                 stats=MultilevelCounter(),
             )
             account_graph_set.validate()
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     # graph rest
     for account_id, graph_set_dicts in account_ids_graph_set_dicts.items():
         with logger.bind(account_id=account_id):
             # if there are any errors whatsoever we generate an empty graph with
             # errors only
             errors = []
             for graph_set_dict in graph_set_dicts:
                 errors += graph_set_dict["errors"]
             if errors:
                 unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                     account_id=account_id, errors=errors)
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[unscanned_account_resource],
                     errors=errors,
                     stats=MultilevelCounter(
                     ),  # ENHANCHMENT: could technically get partial stats.
                 )
                 account_graph_set.validate()
             else:
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[],
                     errors=[],
                     stats=MultilevelCounter(),
                 )
                 for graph_set_dict in graph_set_dicts:
                     graph_set = GraphSet.from_dict(graph_set_dict)
                     account_graph_set.merge(graph_set)
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     return scan_result_dicts
Beispiel #28
0
 def scan(self) -> AccountScanResult:
     logger = Logger()
     now = int(time.time())
     prescan_errors: List[str] = []
     futures: List[Future] = []
     account_id = self.account_scan_plan.account_id
     with logger.bind(account_id=account_id):
         with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
             logger.info(event=AWSLogEvents.ScanAWSAccountStart)
             try:
                 session = self.account_scan_plan.accessor.get_session(
                     account_id=account_id)
                 # sanity check
                 sts_client = session.client("sts")
                 sts_account_id = sts_client.get_caller_identity(
                 )["Account"]
                 if sts_account_id != account_id:
                     raise ValueError(
                         f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                     )
                 if self.account_scan_plan.regions:
                     account_scan_regions = tuple(
                         self.account_scan_plan.regions)
                 else:
                     account_scan_regions = get_all_enabled_regions(
                         session=session)
                 # build a dict of regions -> services -> List[AWSResourceSpec]
                 regions_services_resource_spec_classes: DefaultDict[
                     str, DefaultDict[
                         str, List[Type[AWSResourceSpec]]]] = defaultdict(
                             lambda: defaultdict(list))
                 for resource_spec_class in self.resource_spec_classes:
                     resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions(
                         resource_spec_class=resource_spec_class,
                         region_whitelist=account_scan_regions,
                     )
                     for region in resource_regions:
                         regions_services_resource_spec_classes[region][
                             resource_spec_class.service_name].append(
                                 resource_spec_class)
                 # Build and submit ScanUnits
                 shuffed_regions_services_resource_spec_classes = random.sample(
                     regions_services_resource_spec_classes.items(),
                     len(regions_services_resource_spec_classes),
                 )
                 for (
                         region,
                         services_resource_spec_classes,
                 ) in shuffed_regions_services_resource_spec_classes:
                     region_session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id, region_name=region)
                     region_creds = region_session.get_credentials()
                     shuffled_services_resource_spec_classes = random.sample(
                         services_resource_spec_classes.items(),
                         len(services_resource_spec_classes),
                     )
                     for (
                             service,
                             svc_resource_spec_classes,
                     ) in shuffled_services_resource_spec_classes:
                         parallel_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if svc_resource_spec_class.parallel_scan
                         ]
                         serial_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if not svc_resource_spec_class.parallel_scan
                         ]
                         for (parallel_svc_resource_spec_class
                              ) in parallel_svc_resource_spec_classes:
                             parallel_future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=(
                                     parallel_svc_resource_spec_class, ),
                             )
                             futures.append(parallel_future)
                         serial_future = schedule_scan(
                             executor=executor,
                             graph_name=self.graph_name,
                             graph_version=self.graph_version,
                             account_id=account_id,
                             region_name=region,
                             service=service,
                             access_key=region_creds.access_key,
                             secret_key=region_creds.secret_key,
                             token=region_creds.token,
                             resource_spec_classes=tuple(
                                 serial_svc_resource_spec_classes),
                         )
                         futures.append(serial_future)
             except Exception as ex:
                 error_str = str(ex)
                 trace_back = traceback.format_exc()
                 logger.error(
                     event=AWSLogEvents.ScanAWSAccountError,
                     error=error_str,
                     trace_back=trace_back,
                 )
                 prescan_errors.append(f"{error_str}\n{trace_back}")
         graph_sets: List[GraphSet] = []
         for future in as_completed(futures):
             graph_set = future.result()
             graph_sets.append(graph_set)
         # if there was a prescan error graph it and return the result
         if prescan_errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=prescan_errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=prescan_errors,
             )
             output_artifact = self.artifact_writer.write_json(
                 name=account_id,
                 data=account_graph_set,
             )
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             return AccountScanResult(
                 account_id=account_id,
                 artifacts=[output_artifact],
                 errors=prescan_errors,
             )
         # if there are any errors whatsoever we generate an empty graph with errors only
         errors = []
         for graph_set in graph_sets:
             errors += graph_set.errors
         if errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
             )
         else:
             account_graph_set = GraphSet.from_graph_sets(graph_sets)
         output_artifact = self.artifact_writer.write_json(
             name=account_id,
             data=account_graph_set,
         )
         logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
         return AccountScanResult(
             account_id=account_id,
             artifacts=[output_artifact],
             errors=errors,
         )
Beispiel #29
0
    def load_graph(self, bucket: str, key: str,
                   load_iam_role_arn: str) -> GraphMetadata:
        """Load a graph into Neptune.
        Args:
             bucket: s3 bucket of graph rdf
             key: s3 key of graph rdf
             load_iam_role_arn: arn of iam role used to load the graph

        Returns:
            GraphMetadata object describing loaded graph

        Raises:
            NeptuneLoadGraphException if errors occur during graph load
        """
        session = boto3.Session(region_name=self._neptune_endpoint.region)
        s3_client = session.client("s3")
        rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket,
                                                          Key=key)
        tag_set = rdf_object_tagging["TagSet"]
        graph_name = get_required_tag_value(tag_set, "name")
        graph_version = get_required_tag_value(tag_set, "version")
        graph_start_time = int(get_required_tag_value(tag_set, "start_time"))
        graph_end_time = int(get_required_tag_value(tag_set, "end_time"))
        graph_metadata = GraphMetadata(
            uri=
            f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}",
            name=graph_name,
            version=graph_version,
            start_time=graph_start_time,
            end_time=graph_end_time,
        )
        logger = Logger()
        with logger.bind(
                rdf_bucket=bucket,
                rdf_key=key,
                graph_uri=graph_metadata.uri,
                neptune_endpoint=self._neptune_endpoint.get_endpoint_str(),
        ):
            session = boto3.Session(region_name=self._neptune_endpoint.region)
            credentials = session.get_credentials()
            auth = AWSRequestsAuth(
                aws_access_key=credentials.access_key,
                aws_secret_access_key=credentials.secret_key,
                aws_token=credentials.token,
                aws_host=self._neptune_endpoint.get_endpoint_str(),
                aws_region=self._neptune_endpoint.region,
                aws_service="neptune-db",
            )
            post_body = {
                "source": f"s3://{bucket}/{key}",
                "format": "rdfxml",
                "iamRoleArn": load_iam_role_arn,
                "region": self._neptune_endpoint.region,
                "failOnError": "TRUE",
                "parallelism": "MEDIUM",
                "parserConfiguration": {
                    "baseUri": GRAPH_BASE_URI,
                    "namedGraphUri": graph_metadata.uri,
                },
            }
            logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body)
            submit_resp = requests.post(
                self._neptune_endpoint.get_loader_endpoint(),
                json=post_body,
                auth=auth)
            if submit_resp.status_code != 200:
                raise NeptuneLoadGraphException(
                    f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}"
                )
            submit_resp_json = submit_resp.json()
            load_id = submit_resp_json["payload"]["loadId"]
            with logger.bind(load_id=load_id):
                logger.info(event=LogEvent.NeptuneLoadPolling)
                while True:
                    time.sleep(10)
                    status_resp = requests.get(
                        f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}",
                        params={
                            "details": "true",
                            "errors": "true"
                        },
                        auth=auth,
                    )
                    if status_resp.status_code != 200:
                        raise NeptuneLoadGraphException(
                            f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}"
                        )
                    status_resp_json = status_resp.json()
                    status = status_resp_json["payload"]["overallStatus"][
                        "status"]
                    logger.info(event=LogEvent.NeptuneLoadPolling,
                                status=status)
                    if status == "LOAD_COMPLETED":
                        break
                    if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"):
                        logger.error(event=LogEvent.NeptuneLoadError,
                                     status=status)
                        raise NeptuneLoadGraphException(
                            f"Error loading graph: {status_resp_json}")
                logger.info(event=LogEvent.NeptuneLoadEnd)

                logger.info(event=LogEvent.MetadataGraphUpdateStart)
                self._register_graph(graph_metadata=graph_metadata)
                logger.info(event=LogEvent.MetadataGraphUpdateEnd)

                return graph_metadata
Beispiel #30
0
    def write_graph_set(self,
                        name: str,
                        graph_set: GraphSet,
                        compression: Optional[str] = None) -> str:
        """Write a graph artifact

        Args:
            name: name
            graph_set: GraphSet to write

        Returns:
            path to written artifact
        """
        logger = Logger()
        if compression is None:
            key = f"{name}.rdf"
        elif compression == GZIP:
            key = f"{name}.rdf.gz"
        else:
            raise ValueError(f"Unknown compression arg {compression}")
        output_key = "/".join((self.key_prefix, key))
        graph = graph_set.to_rdf()
        with logger.bind(bucket=self.bucket,
                         key_prefix=self.key_prefix,
                         key=key):
            logger.info(event=LogEvent.WriteToS3Start)
            with io.BytesIO() as rdf_bytes_buf:
                if compression is None:
                    graph.serialize(rdf_bytes_buf)
                elif compression == GZIP:
                    with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz:
                        graph.serialize(gz)
                else:
                    raise ValueError(f"Unknown compression arg {compression}")
                rdf_bytes_buf.flush()
                rdf_bytes_buf.seek(0)
                session = boto3.Session()
                s3_client = session.client("s3")
                s3_client.upload_fileobj(rdf_bytes_buf, self.bucket,
                                         output_key)
            s3_client.put_object_tagging(
                Bucket=self.bucket,
                Key=output_key,
                Tagging={
                    "TagSet": [
                        {
                            "Key": "name",
                            "Value": graph_set.name
                        },
                        {
                            "Key": "version",
                            "Value": graph_set.version
                        },
                        {
                            "Key": "start_time",
                            "Value": str(graph_set.start_time)
                        },
                        {
                            "Key": "end_time",
                            "Value": str(graph_set.end_time)
                        },
                    ]
                },
            )
            logger.info(event=LogEvent.WriteToS3End)
        return f"s3://{self.bucket}/{output_key}"