Пример #1
0
    def update_user(
        self,
        user_email,
        password: Optional[str] = None,
        groups: Optional[List[str]] = None,
    ):
        if not groups:
            groups = []

        user_ddb = self.users_table.query(
            KeyConditionExpression="username = :un",
            ExpressionAttributeValues={":un": user_email},
        )

        user = None

        if user_ddb and "Items" in user_ddb and len(user_ddb["Items"]) == 1:
            user = user_ddb["Items"][0]

        if not user:
            raise DataNotRetrievable(f"Unable to find user: {user_email}")

        timestamp = int(time.time())

        if password:
            pw = bytes(password, "utf-8")
            salt = bcrypt.gensalt()
            user["password"] = bcrypt.hashpw(pw, salt)

        if groups:
            user["groups"] = groups
        user["last_updated"] = timestamp

        user_entry = self.sign_request(user)
        try:
            self.users_table.put_item(
                Item=self._data_to_dynamo_replace(user_entry))
        except Exception as e:
            error = f"Unable to add user submission: {user_entry}: {str(e)}"
            log.error(error, exc_info=True)
            raise Exception(error)
        return user_entry
Пример #2
0
async def retrieve_json_data_from_redis_or_s3(
    redis_key: str = None,
    redis_data_type: str = "str",
    s3_bucket: str = None,
    s3_key: str = None,
    cache_to_redis_if_data_in_s3: bool = True,
    max_age: Optional[int] = None,
    default: Optional = None,
    json_object_hook: Optional = None,
    json_encoder: Optional = None,
):
    """
    Retrieve data from Redis as a priority. If data is unavailable in Redis, fall back to S3 and attempt to store
    data in Redis for quicker retrieval later

    :param redis_data_type: "str" or "hash", depending on how the data is stored in Redis
    :param redis_key: Redis Key to retrieve data from
    :param s3_bucket: S3 bucket to retrieve data from
    :param s3_key: S3 key to retrieve data from
    :param cache_to_redis_if_data_in_s3: Cache the data in Redis if the data is in S3 but not Redis
    :return:
    """
    function = f"{__name__}.{sys._getframe().f_code.co_name}"
    last_updated_redis_key = config.get(
        "store_json_results_in_redis_and_s3.last_updated_redis_key",
        "STORE_JSON_RESULTS_IN_REDIS_AND_S3_LAST_UPDATED",
    )
    stats.count(
        f"{function}.called",
        tags={
            "redis_key": redis_key,
            "s3_bucket": s3_bucket,
            "s3_key": s3_key
        },
    )
    data = None
    if redis_key:
        if redis_data_type == "str":
            data_s = red.get(redis_key)
            if data_s:
                data = json.loads(data_s, object_hook=json_object_hook)
        elif redis_data_type == "hash":
            data = red.hgetall(redis_key)
        else:
            raise UnsupportedRedisDataType(
                "Unsupported redis_data_type passed")
        if data and max_age:
            current_time = int(time.time())
            last_updated = int(red.hget(last_updated_redis_key, redis_key))
            if current_time - last_updated > max_age:
                raise ExpiredData(
                    f"Data in Redis is older than {max_age} seconds.")

    # Fall back to S3 if there's no data
    if not data and s3_bucket and s3_key:
        s3_object = get_object(Bucket=s3_bucket, Key=s3_key)
        s3_object_content = s3_object["Body"].read()
        data_object = json.loads(s3_object_content,
                                 object_hook=json_object_hook)
        data = data_object["data"]

        if data and max_age:
            current_time = int(time.time())
            last_updated = data_object["last_updated"]
            if current_time - last_updated > max_age:
                raise ExpiredData(
                    f"Data in S3 is older than {max_age} seconds.")
        if redis_key and cache_to_redis_if_data_in_s3:
            await store_json_results_in_redis_and_s3(
                data,
                redis_key=redis_key,
                redis_data_type=redis_data_type,
                json_encoder=json_encoder,
            )

    if data is not None:
        return data
    if default is not None:
        return default
    raise DataNotRetrievable("Unable to retrieve expected data.")
Пример #3
0
async def detect_cloudtrail_denies_and_update_cache(
    celery_app,
    event_ttl=config.get(
        "event_bridge.detect_cloudtrail_denies_and_update_cache.event_ttl",
        86400),
    max_num_messages_to_process=config.get(
        "event_bridge.detect_cloudtrail_denies_and_update_cache.max_num_messages_to_process",
        100,
    ),
) -> Dict[str, Any]:
    log_data = {"function": f"{__name__}.{sys._getframe().f_code.co_name}"}
    dynamo = UserDynamoHandler()
    queue_arn = config.get(
        "event_bridge.detect_cloudtrail_denies_and_update_cache.queue_arn",
        "").format(region=config.region)
    if not queue_arn:
        raise MissingConfigurationValue(
            "Unable to find required configuration value: "
            "`event_bridge.detect_cloudtrail_denies_and_update_cache.queue_arn`"
        )
    queue_name = queue_arn.split(":")[-1]
    queue_account_number = queue_arn.split(":")[4]
    queue_region = queue_arn.split(":")[3]
    # Optionally assume a role before receiving messages from the queue
    queue_assume_role = config.get(
        "event_bridge.detect_cloudtrail_denies_and_update_cache.assume_role")

    # Modify existing cloudtrail deny samples
    all_cloudtrail_denies_l = await dynamo.parallel_scan_table_async(
        dynamo.cloudtrail_table)
    all_cloudtrail_denies = {}
    for cloudtrail_deny in all_cloudtrail_denies_l:
        all_cloudtrail_denies[cloudtrail_deny["request_id"]] = cloudtrail_deny

    sqs_client = await sync_to_async(boto3_cached_conn)(
        "sqs",
        service_type="client",
        region=queue_region,
        retry_max_attempts=2,
        account_number=queue_account_number,
        assume_role=queue_assume_role,
        client_kwargs=config.get("boto3.client_kwargs", {}),
    )

    queue_url_res = await sync_to_async(sqs_client.get_queue_url
                                        )(QueueName=queue_name)
    queue_url = queue_url_res.get("QueueUrl")
    if not queue_url:
        raise DataNotRetrievable(
            f"Unable to retrieve Queue URL for {queue_arn}")
    messages_awaitable = await sync_to_async(sqs_client.receive_message
                                             )(QueueUrl=queue_url,
                                               MaxNumberOfMessages=10)
    new_events = 0
    messages = messages_awaitable.get("Messages", [])
    num_events = 0
    reached_limit_on_num_messages_to_process = False

    while messages:
        if num_events >= max_num_messages_to_process:
            reached_limit_on_num_messages_to_process = True
            break
        processed_messages = []
        for message in messages:
            try:
                message_body = json.loads(message["Body"])
                try:
                    if "Message" in message_body:
                        decoded_message = json.loads(
                            message_body["Message"])["detail"]
                    else:
                        decoded_message = message_body["detail"]
                except Exception as e:
                    log.error({
                        **log_data,
                        "message": "Unable to process Cloudtrail message",
                        "message_body": message_body,
                        "error": str(e),
                    })
                    sentry_sdk.capture_exception()
                    continue
                event_name = decoded_message.get("eventName")
                event_source = decoded_message.get("eventSource")
                for event_source_substitution in config.get(
                        "event_bridge.detect_cloudtrail_denies_and_update_cache.event_bridge_substitutions",
                    [".amazonaws.com"],
                ):
                    event_source = event_source.replace(
                        event_source_substitution, "")
                event_time = decoded_message.get("eventTime")
                utc_time = datetime.strptime(event_time, "%Y-%m-%dT%H:%M:%SZ")
                epoch_event_time = int(
                    (utc_time - datetime(1970, 1, 1)).total_seconds())
                # Skip entries older than a day
                if int(time.time()) - 86400 > epoch_event_time:
                    continue
                try:
                    session_name = decoded_message["userIdentity"][
                        "arn"].split("/")[-1]
                except (
                        IndexError,
                        KeyError,
                ):  # If IAM user, there won't be a session name
                    session_name = ""
                try:
                    principal_arn = decoded_message["userIdentity"][
                        "sessionContext"]["sessionIssuer"]["arn"]
                except KeyError:  # Skip events without a parsable ARN
                    continue

                event_call = f"{event_source}:{event_name}"

                ct_event = dict(
                    error_code=decoded_message.get("errorCode"),
                    error_message=decoded_message.get("errorMessage"),
                    arn=principal_arn,
                    # principal_owner=owner,
                    session_name=session_name,
                    source_ip=decoded_message["sourceIPAddress"],
                    event_call=event_call,
                    epoch_event_time=epoch_event_time,
                    ttl=epoch_event_time + event_ttl,
                    count=1,
                )
                resource = await get_resource_from_cloudtrail_deny(
                    ct_event, decoded_message)
                ct_event["resource"] = resource
                request_id = f"{principal_arn}-{session_name}-{event_call}-{resource}"
                ct_event["request_id"] = request_id
                generated_policy = await generate_policy_from_cloudtrail_deny(
                    ct_event)
                if generated_policy:
                    ct_event["generated_policy"] = generated_policy

                if all_cloudtrail_denies.get(request_id):
                    existing_count = all_cloudtrail_denies[request_id].get(
                        "count", 1)
                    ct_event["count"] += existing_count
                    all_cloudtrail_denies[request_id] = ct_event
                else:
                    all_cloudtrail_denies[request_id] = ct_event
                    new_events += 1
                num_events += 1
            except Exception as e:
                log.error({**log_data, "error": str(e)}, exc_info=True)
                sentry_sdk.capture_exception()
            processed_messages.append({
                "Id":
                message["MessageId"],
                "ReceiptHandle":
                message["ReceiptHandle"],
            })
        if processed_messages:
            await sync_to_async(sqs_client.delete_message_batch
                                )(QueueUrl=queue_url,
                                  Entries=processed_messages)

        await sync_to_async(dynamo.batch_write_cloudtrail_events
                            )(all_cloudtrail_denies.values())
        messages_awaitable = await sync_to_async(sqs_client.receive_message
                                                 )(QueueUrl=queue_url,
                                                   MaxNumberOfMessages=10)
        messages = messages_awaitable.get("Messages", [])
    if reached_limit_on_num_messages_to_process:
        # We hit our limit. Let's spawn another task immediately to process remaining messages
        celery_app.send_task(
            "consoleme.celery_tasks.celery_tasks.cache_cloudtrail_denies", )
    log_data["message"] = "Successfully cached Cloudtrail Access Denies"
    log_data["num_events"] = num_events
    log_data["new_events"] = new_events
    log.debug(log_data)

    return log_data
Пример #4
0
async def detect_cloudtrail_denies_and_update_cache():
    log_data = {"function": f"{__name__}.{sys._getframe().f_code.co_name}"}
    dynamo = UserDynamoHandler()
    queue_arn = config.get(
        "event_bridge.detect_cloudtrail_denies_and_update_cache.queue_arn",
        "").format(region=config.region)
    if not queue_arn:
        raise MissingConfigurationValue(
            "Unable to find required configuration value: "
            "`event_bridge.detect_cloudtrail_denies_and_update_cache.queue_arn`"
        )
    queue_name = queue_arn.split(":")[-1]
    queue_account_number = queue_arn.split(":")[4]
    queue_region = queue_arn.split(":")[3]
    # Optionally assume a role before receiving messages from the queue
    queue_assume_role = config.get(
        "event_bridge.detect_cloudtrail_denies_and_update_cache.assume_role")

    sqs_client = await sync_to_async(boto3_cached_conn)(
        "sqs",
        service_type="client",
        region=queue_region,
        retry_max_attempts=2,
        account_number=queue_account_number,
        assume_role=queue_assume_role,
    )

    queue_url_res = await sync_to_async(sqs_client.get_queue_url
                                        )(QueueName=queue_name)
    queue_url = queue_url_res.get("QueueUrl")
    if not queue_url:
        raise DataNotRetrievable(
            f"Unable to retrieve Queue URL for {queue_arn}")
    ct_events = []
    messages_awaitable = await sync_to_async(sqs_client.receive_message
                                             )(QueueUrl=queue_url,
                                               MaxNumberOfMessages=10)
    messages = messages_awaitable.get("Messages", [])
    while messages:
        processed_messages = []
        for message in messages:
            try:
                message_body = json.loads(message["Body"])
                decoded_message = json.loads(message_body["Message"])["detail"]
                event_name = decoded_message.get("eventName")
                event_source = decoded_message.get("eventSource")
                for event_source_substitution in config.get(
                        "event_bridge.detect_cloudtrail_denies_and_update_cache.event_bridge_substitutions",
                    [".amazonaws.com"],
                ):
                    event_source = event_source.replace(
                        event_source_substitution, "")
                event_time = decoded_message.get("eventTime")
                utc_time = datetime.strptime(event_time, "%Y-%m-%dT%H:%M:%SZ")
                epoch_event_time = int(
                    (utc_time - datetime(1970, 1, 1)).total_seconds() * 1000)
                try:
                    session_name = decoded_message["userIdentity"][
                        "arn"].split("/")[-1]
                except (
                        IndexError,
                        KeyError,
                ):  # If IAM user, there won't be a session name
                    session_name = ""
                try:
                    role_arn = decoded_message["userIdentity"][
                        "sessionContext"]["sessionIssuer"]["arn"]
                except KeyError:  # Skip events without a parsable ARN
                    continue

                ct_event = dict(
                    error_code=decoded_message.get("errorCode"),
                    error_message=decoded_message.get("errorMessage"),
                    arn=role_arn,
                    session_name=session_name,
                    request_id=decoded_message["requestID"],
                    event_call=f"{event_source}:{event_name}",
                    epoch_event_time=epoch_event_time,
                    ttl=(epoch_event_time + 86400000) / 1000,
                )
                ct_event["resource"] = await get_resource_from_cloudtrail_deny(
                    ct_event)
                generated_policy = await generate_policy_from_cloudtrail_deny(
                    ct_event)
                if generated_policy:
                    ct_event["generated_policy"] = generated_policy
                ct_events.append(ct_event)
            except Exception as e:
                log.error({**log_data, "error": str(e)}, exc_info=True)
                sentry_sdk.capture_exception()
            processed_messages.append({
                "Id":
                message["MessageId"],
                "ReceiptHandle":
                message["ReceiptHandle"],
            })
        await sync_to_async(sqs_client.delete_message_batch
                            )(QueueUrl=queue_url, Entries=processed_messages)
        await sync_to_async(dynamo.batch_write_cloudtrail_events)(ct_events)
        messages_awaitable = await sync_to_async(sqs_client.receive_message
                                                 )(QueueUrl=queue_url,
                                                   MaxNumberOfMessages=10)
        messages = messages_awaitable.get("Messages", [])
    log.debug({
        **log_data,
        "num_events": len(ct_events),
        "message": "Successfully cached Cloudtrail Access Denies",
    })

    return ct_events
Пример #5
0
async def retrieve_json_data_from_redis_or_s3(
    redis_key: str = None,
    redis_data_type: str = "str",
    s3_bucket: str = None,
    s3_key: str = None,
    cache_to_redis_if_data_in_s3: bool = True,
    max_age: Optional[int] = None,
    default: Optional = None,
    json_object_hook: Optional = None,
    json_encoder: Optional = None,
):
    """
    Retrieve data from Redis as a priority. If data is unavailable in Redis, fall back to S3 and attempt to store
    data in Redis for quicker retrieval later.

    :param redis_data_type: "str" or "hash", depending on how the data is stored in Redis
    :param redis_key: Redis Key to retrieve data from
    :param s3_bucket: S3 bucket to retrieve data from
    :param s3_key: S3 key to retrieve data from
    :param cache_to_redis_if_data_in_s3: Cache the data in Redis if the data is in S3 but not Redis
    :return:
    """
    function = f"{__name__}.{sys._getframe().f_code.co_name}"
    last_updated_redis_key = config.get(
        "store_json_results_in_redis_and_s3.last_updated_redis_key",
        "STORE_JSON_RESULTS_IN_REDIS_AND_S3_LAST_UPDATED",
    )
    stats.count(
        f"{function}.called",
        tags={"redis_key": redis_key, "s3_bucket": s3_bucket, "s3_key": s3_key},
    )

    # If we've defined an S3 key, but not a bucket, let's use the default bucket if it's defined in configuration.
    if s3_key and not s3_bucket:
        s3_bucket = config.get("consoleme_s3_bucket")

    data = None
    if redis_key:
        if redis_data_type == "str":
            data_s = red.get(redis_key)
            if data_s:
                data = json.loads(data_s, object_hook=json_object_hook)
        elif redis_data_type == "hash":
            data = red.hgetall(redis_key)
        else:
            raise UnsupportedRedisDataType("Unsupported redis_data_type passed")
        if data and max_age:
            current_time = int(time.time())
            last_updated = int(red.hget(last_updated_redis_key, redis_key))
            if current_time - last_updated > max_age:
                data = None
                # Fall back to S3 if expired.
                if not s3_bucket or not s3_key:
                    raise ExpiredData(f"Data in Redis is older than {max_age} seconds.")

    # Fall back to S3 if there's no data
    if not data and s3_bucket and s3_key:
        try:
            s3_object = get_object(Bucket=s3_bucket, Key=s3_key)
        except ClientError as e:
            if str(e) == (
                "An error occurred (NoSuchKey) when calling the GetObject operation: "
                "The specified key does not exist."
            ):
                if default is not None:
                    return default
            raise
        s3_object_content = await sync_to_async(s3_object["Body"].read)()
        if s3_key.endswith(".gz"):
            s3_object_content = gzip.decompress(s3_object_content)
        data_object = json.loads(s3_object_content, object_hook=json_object_hook)
        data = data_object["data"]

        if data and max_age:
            current_time = int(time.time())
            last_updated = data_object["last_updated"]
            if current_time - last_updated > max_age:
                raise ExpiredData(f"Data in S3 is older than {max_age} seconds.")
        if redis_key and cache_to_redis_if_data_in_s3:
            await store_json_results_in_redis_and_s3(
                data,
                redis_key=redis_key,
                redis_data_type=redis_data_type,
                json_encoder=json_encoder,
            )

    if data is not None:
        return data
    if default is not None:
        return default
    raise DataNotRetrievable("Unable to retrieve expected data.")
Пример #6
0
def detect_role_changes_and_update_cache(celery_app):
    """
    This function detects role changes through event bridge rules, and forces a refresh of the roles.
    """
    log_data = {"function": f"{__name__}.{sys._getframe().f_code.co_name}"}
    queue_arn = config.get(
        "event_bridge.detect_role_changes_and_update_cache.queue_arn", ""
    ).format(region=config.region)

    if not queue_arn:
        raise MissingConfigurationValue(
            "Unable to find required configuration value: "
            "`event_bridge.detect_role_changes_and_update_cache.queue_arn`"
        )
    queue_name = queue_arn.split(":")[-1]
    queue_account_number = queue_arn.split(":")[4]
    queue_region = queue_arn.split(":")[3]
    # Optionally assume a role before receiving messages from the queue
    queue_assume_role = config.get(
        "event_bridge.detect_role_changes_and_update_cache.assume_role"
    )

    sqs_client = boto3_cached_conn(
        "sqs",
        service_type="client",
        region=queue_region,
        retry_max_attempts=2,
        account_number=queue_account_number,
        assume_role=queue_assume_role,
    )

    queue_url_res = sqs_client.get_queue_url(QueueName=queue_name)
    queue_url = queue_url_res.get("QueueUrl")
    if not queue_url:
        raise DataNotRetrievable(f"Unable to retrieve Queue URL for {queue_arn}")
    roles_to_update = set()
    messages = sqs_client.receive_message(
        QueueUrl=queue_url, MaxNumberOfMessages=10
    ).get("Messages", [])

    while messages:
        processed_messages = []
        for message in messages:
            try:
                message_body = json.loads(message["Body"])
                decoded_message = json.loads(message_body["Message"])
                role_name = decoded_message["detail"]["requestParameters"]["roleName"]
                role_account_id = decoded_message["account"]
                role_arn = f"arn:aws:iam::{role_account_id}:role/{role_name}"

                if role_arn not in roles_to_update:
                    celery_app.send_task(
                        "consoleme.celery_tasks.celery_tasks.refresh_iam_role",
                        args=[role_arn],
                    )
                roles_to_update.add(role_arn)
            except Exception as e:
                log.error(
                    {**log_data, "error": str(e), "raw_message": message}, exc_info=True
                )
                sentry_sdk.capture_exception()
            processed_messages.append(
                {
                    "Id": message["MessageId"],
                    "ReceiptHandle": message["ReceiptHandle"],
                }
            )
        sqs_client.delete_message_batch(QueueUrl=queue_url, Entries=processed_messages)
        messages = sqs_client.receive_message(
            QueueUrl=queue_url, MaxNumberOfMessages=10
        ).get("Messages", [])
    log.debug(
        {
            **log_data,
            "num_roles": len(roles_to_update),
            "message": "Triggered role cache update for roles that were created or changed",
        }
    )

    return roles_to_update