Exemple #1
0
class DirectoryServiceBackend(BaseBackend):
    """Implementation of DirectoryService APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.directories = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "ds")

    @staticmethod
    def _verify_subnets(region, vpc_settings):
        """Verify subnets are valid, else raise an exception.

        If settings are valid, add AvailabilityZones to vpc_settings.
        """
        if len(vpc_settings["SubnetIds"]) != 2:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones.")

        # Subnet IDs are checked before the VPC ID.  The Subnet IDs must
        # be valid and in different availability zones.
        try:
            subnets = ec2_backends[region].get_all_subnets(
                subnet_ids=vpc_settings["SubnetIds"])
        except InvalidSubnetIdError as exc:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones.") from exc

        regions = [subnet.availability_zone for subnet in subnets]
        if regions[0] == regions[1]:
            raise ClientException(
                "Invalid subnet ID(s). The two subnets must be in "
                "different Availability Zones.")

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_settings["VpcId"] not in [x.id for x in vpcs]:
            raise ClientException("Invalid VPC ID.")
        vpc_settings["AvailabilityZones"] = regions

    def connect_directory(
        self,
        region,
        name,
        short_name,
        password,
        description,
        size,
        connect_settings,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Create a fake AD Connector."""
        if len(self.directories) > Directory.CONNECTED_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CONNECTED_DIRECTORIES_LIMIT} directories may be created"
            )

        validate_args([
            ("password", password),
            ("size", size),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            (
                "connectSettings.vpcSettings.subnetIds",
                connect_settings["SubnetIds"],
            ),
            (
                "connectSettings.customerUserName",
                connect_settings["CustomerUserName"],
            ),
            ("connectSettings.customerDnsIps",
             connect_settings["CustomerDnsIps"]),
        ])
        # ConnectSettings and VpcSettings both have a VpcId and Subnets.
        self._verify_subnets(region, connect_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "ADConnector",
            size=size,
            connect_settings=connect_settings,
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def create_directory(self, region, name, short_name, password, description,
                         size, vpc_settings, tags):  # pylint: disable=too-many-arguments
        """Create a fake Simple Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created"
            )

        # botocore doesn't look for missing vpc_settings, but boto3 does.
        if not vpc_settings:
            raise InvalidParameterException("VpcSettings must be specified.")
        validate_args([
            ("password", password),
            ("size", size),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]),
        ])
        self._verify_subnets(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "SimpleAD",
            size=size,
            vpc_settings=vpc_settings,
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def _validate_directory_id(self, directory_id):
        """Raise an exception if the directory id is invalid or unknown."""
        # Validation of ID takes precedence over a check for its existence.
        validate_args([("directoryId", directory_id)])
        if directory_id not in self.directories:
            raise EntityDoesNotExistException(
                f"Directory {directory_id} does not exist")

    def create_alias(self, directory_id, alias):
        """Create and assign an alias to a directory."""
        self._validate_directory_id(directory_id)

        # The default alias name is the same as the directory name.  Check
        # whether this directory was already given an alias.
        directory = self.directories[directory_id]
        if directory.alias != directory_id:
            raise InvalidParameterException(
                "The directory in the request already has an alias. That "
                "alias must be deleted before a new alias can be created.")

        # Is the alias already in use?
        if alias in [x.alias for x in self.directories.values()]:
            raise EntityAlreadyExistsException(
                f"Alias '{alias}' already exists.")
        validate_args([("alias", alias)])

        directory.update_alias(alias)
        return {"DirectoryId": directory_id, "Alias": alias}

    def create_microsoft_ad(
        self,
        region,
        name,
        short_name,
        password,
        description,
        vpc_settings,
        edition,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Create a fake Microsoft Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_MICROSOFT_AD_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_MICROSOFT_AD_LIMIT} directories may be created"
            )

        # boto3 looks for missing vpc_settings for create_microsoft_ad().
        validate_args([
            ("password", password),
            ("edition", edition),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]),
        ])
        self._verify_subnets(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "MicrosoftAD",
            vpc_settings=vpc_settings,
            short_name=short_name,
            description=description,
            edition=edition,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def delete_directory(self, directory_id):
        """Delete directory with the matching ID."""
        self._validate_directory_id(directory_id)
        self.tagger.delete_all_tags_for_resource(directory_id)
        self.directories.pop(directory_id)
        return directory_id

    def disable_sso(self, directory_id, username=None, password=None):
        """Disable single-sign on for a directory."""
        self._validate_directory_id(directory_id)
        validate_args([("ssoPassword", password), ("userName", username)])
        directory = self.directories[directory_id]
        directory.enable_sso(False)

    def enable_sso(self, directory_id, username=None, password=None):
        """Enable single-sign on for a directory."""
        self._validate_directory_id(directory_id)
        validate_args([("ssoPassword", password), ("userName", username)])

        directory = self.directories[directory_id]
        if directory.alias == directory_id:
            raise ClientException(
                f"An alias is required before enabling SSO. DomainId={directory_id}"
            )

        directory = self.directories[directory_id]
        directory.enable_sso(True)

    @paginate(pagination_model=PAGINATION_MODEL)
    def describe_directories(self,
                             directory_ids=None,
                             next_token=None,
                             limit=0):  # pylint: disable=unused-argument
        """Return info on all directories or directories with matching IDs."""
        for directory_id in directory_ids or self.directories:
            self._validate_directory_id(directory_id)

        directories = list(self.directories.values())
        if directory_ids:
            directories = [
                x for x in directories if x.directory_id in directory_ids
            ]
        return sorted(directories, key=lambda x: x.launch_time)

    def get_directory_limits(self):
        """Return hard-coded limits for the directories."""
        counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0}
        for directory in self.directories.values():
            if directory.directory_type == "SimpleAD":
                counts["SimpleAD"] += 1
            elif directory.directory_type in [
                    "MicrosoftAD", "SharedMicrosoftAD"
            ]:
                counts["MicrosoftAD"] += 1
            elif directory.directory_type == "ADConnector":
                counts["ConnectedAD"] += 1

        return {
            "CloudOnlyDirectoriesLimit":
            Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyDirectoriesCurrentCount":
            counts["SimpleAD"],
            "CloudOnlyDirectoriesLimitReached":
            counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyMicrosoftADLimit":
            Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "CloudOnlyMicrosoftADCurrentCount":
            counts["MicrosoftAD"],
            "CloudOnlyMicrosoftADLimitReached":
            counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "ConnectedDirectoriesLimit":
            Directory.CONNECTED_DIRECTORIES_LIMIT,
            "ConnectedDirectoriesCurrentCount":
            counts["ConnectedAD"],
            "ConnectedDirectoriesLimitReached":
            counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT,
        }

    def add_tags_to_resource(self, resource_id, tags):
        """Add or overwrite one or more tags for specified directory."""
        self._validate_directory_id(resource_id)
        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise TagLimitExceededException("Tag limit exceeded")
        self.tagger.tag_resource(resource_id, tags)

    def remove_tags_from_resource(self, resource_id, tag_keys):
        """Removes tags from a directory."""
        self._validate_directory_id(resource_id)
        self.tagger.untag_resource_using_names(resource_id, tag_keys)

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(
        self,
        resource_id,
        next_token=None,
        limit=None,
    ):  # pylint: disable=unused-argument
        """List all tags on a directory."""
        self._validate_directory_id(resource_id)
        return self.tagger.list_tags_for_resource(resource_id).get("Tags")
Exemple #2
0
class Route53ResolverBackend(BaseBackend):
    """Implementation of Route53Resolver APIs."""

    def __init__(self, region_name, account_id):
        super().__init__(region_name, account_id)
        self.resolver_endpoints = {}  # Key is self-generated ID (endpoint_id)
        self.resolver_rules = {}  # Key is self-generated ID (rule_id)
        self.resolver_rule_associations = {}  # Key is resolver_rule_association_id)
        self.tagger = TaggingService()

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "route53resolver"
        )

    def associate_resolver_rule(self, region, resolver_rule_id, name, vpc_id):
        validate_args(
            [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)]
        )

        associations = [
            x for x in self.resolver_rule_associations.values() if x.region == region
        ]
        if len(associations) > ResolverRuleAssociation.MAX_RULE_ASSOCIATIONS_PER_REGION:
            # This error message was not verified to be the same for AWS.
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-rule-association'"
            )

        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist."
            )

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_id not in [x.id for x in vpcs]:
            raise InvalidParameterException(f"The vpc ID '{vpc_id}' does not exist")

        # Can't duplicate resolver rule, vpc id associations.
        for association in self.resolver_rule_associations.values():
            if (
                resolver_rule_id == association.resolver_rule_id
                and vpc_id == association.vpc_id
            ):
                raise InvalidRequestException(
                    f"Cannot associate rules with same domain name with same "
                    f"VPC. Conflict with resolver rule '{resolver_rule_id}'"
                )

        rule_association_id = f"rslvr-rrassoc-{get_random_hex(17)}"
        rule_association = ResolverRuleAssociation(
            region, rule_association_id, resolver_rule_id, vpc_id, name
        )
        self.resolver_rule_associations[rule_association_id] = rule_association
        return rule_association

    @staticmethod
    def _verify_subnet_ips(region, ip_addresses, initial=True):
        """Perform additional checks on the IPAddresses.

        NOTE: This does not include IPv6 addresses.
        """
        # only initial endpoint creation requires atleast two ip addresses
        if initial:
            if len(ip_addresses) < 2:
                raise InvalidRequestException(
                    "Resolver endpoint needs to have at least 2 IP addresses"
                )

        subnets = defaultdict(set)
        for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]:
            try:
                subnet_info = ec2_backends[region].get_all_subnets(
                    subnet_ids=[subnet_id]
                )[0]
            except InvalidSubnetIdError as exc:
                raise InvalidParameterException(
                    f"The subnet ID '{subnet_id}' does not exist"
                ) from exc

            # IP in IPv4 CIDR range and not reserved?
            if ip_address(ip_addr) in subnet_info.reserved_ips or ip_address(
                ip_addr
            ) not in ip_network(subnet_info.cidr_block):
                raise InvalidRequestException(
                    f"IP address '{ip_addr}' is either not in subnet "
                    f"'{subnet_id}' CIDR range or is reserved"
                )

            if ip_addr in subnets[subnet_id]:
                raise ResourceExistsException(
                    f"The IP address '{ip_addr}' in subnet '{subnet_id}' is already in use"
                )
            subnets[subnet_id].add(ip_addr)

    @staticmethod
    def _verify_security_group_ids(region, security_group_ids):
        """Perform additional checks on the security groups."""
        if len(security_group_ids) > 10:
            raise InvalidParameterException("Maximum of 10 security groups are allowed")

        for group_id in security_group_ids:
            if not group_id.startswith("sg-"):
                raise InvalidParameterException(
                    f"Malformed security group ID: Invalid id: '{group_id}' "
                    f"(expecting 'sg-...')"
                )
            try:
                ec2_backends[region].describe_security_groups(group_ids=[group_id])
            except InvalidSecurityGroupNotFoundError as exc:
                raise ResourceNotFoundException(
                    f"The security group '{group_id}' does not exist"
                ) from exc

    def create_resolver_endpoint(
        self,
        region,
        creator_request_id,
        name,
        security_group_ids,
        direction,
        ip_addresses,
        tags,
    ):  # pylint: disable=too-many-arguments
        """
        Return description for a newly created resolver endpoint.

        NOTE:  IPv6 IPs are currently not being filtered when
        calculating the create_resolver_endpoint() IpAddresses.
        """
        validate_args(
            [
                ("creatorRequestId", creator_request_id),
                ("direction", direction),
                ("ipAddresses", ip_addresses),
                ("name", name),
                ("securityGroupIds", security_group_ids),
                ("ipAddresses.subnetId", ip_addresses),
            ]
        )
        errmsg = self.tagger.validate_tags(
            tags or [], limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
        )
        if errmsg:
            raise TagValidationException(errmsg)

        endpoints = [x for x in self.resolver_endpoints.values() if x.region == region]
        if len(endpoints) > ResolverEndpoint.MAX_ENDPOINTS_PER_REGION:
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-endpoints'"
            )

        for x in ip_addresses:
            if not x.get("Ip"):
                subnet_info = ec2_backends[region].get_all_subnets(
                    subnet_ids=[x["SubnetId"]]
                )[0]
                x["Ip"] = subnet_info.get_available_subnet_ip(self)

        self._verify_subnet_ips(region, ip_addresses)
        self._verify_security_group_ids(region, security_group_ids)
        if creator_request_id in [
            x.creator_request_id for x in self.resolver_endpoints.values()
        ]:
            raise ResourceExistsException(
                f"Resolver endpoint with creator request ID "
                f"'{creator_request_id}' already exists"
            )

        endpoint_id = (
            f"rslvr-{'in' if direction == 'INBOUND' else 'out'}-{get_random_hex(17)}"
        )
        resolver_endpoint = ResolverEndpoint(
            region,
            endpoint_id,
            creator_request_id,
            security_group_ids,
            direction,
            ip_addresses,
            name,
        )

        self.resolver_endpoints[endpoint_id] = resolver_endpoint
        self.tagger.tag_resource(resolver_endpoint.arn, tags or [])
        return resolver_endpoint

    def create_resolver_rule(
        self,
        region,
        creator_request_id,
        name,
        rule_type,
        domain_name,
        target_ips,
        resolver_endpoint_id,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Return description for a newly created resolver rule."""
        validate_args(
            [
                ("creatorRequestId", creator_request_id),
                ("ruleType", rule_type),
                ("domainName", domain_name),
                ("name", name),
                *[("targetIps.port", x) for x in target_ips],
                ("resolverEndpointId", resolver_endpoint_id),
            ]
        )
        errmsg = self.tagger.validate_tags(
            tags or [], limit=ResolverRule.MAX_TAGS_PER_RESOLVER_RULE
        )
        if errmsg:
            raise TagValidationException(errmsg)

        rules = [x for x in self.resolver_rules.values() if x.region == region]
        if len(rules) > ResolverRule.MAX_RULES_PER_REGION:
            # Did not verify that this is the actual error message.
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-rules'"
            )

        # Per the AWS documentation and as seen with the AWS console, target
        # ips are only relevant when the value of Rule is FORWARD.  However,
        # boto3 ignores this condition and so shall we.

        for ip_addr in [x["Ip"] for x in target_ips]:
            try:
                # boto3 fails with an InternalServiceException if IPv6
                # addresses are used, which isn't helpful.
                if not isinstance(ip_address(ip_addr), IPv4Address):
                    raise InvalidParameterException(
                        f"Only IPv4 addresses may be used: '{ip_addr}'"
                    )
            except ValueError as exc:
                raise InvalidParameterException(
                    f"Invalid IP address: '{ip_addr}'"
                ) from exc

        # The boto3 documentation indicates that ResolverEndpoint is
        # optional, as does the AWS documention.  But if resolver_endpoint_id
        # is set to None or an empty string, it results in boto3 raising
        # a ParamValidationError either regarding the type or len of string.
        if resolver_endpoint_id:
            if resolver_endpoint_id not in [
                x.id for x in self.resolver_endpoints.values()
            ]:
                raise ResourceNotFoundException(
                    f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist."
                )

            if rule_type == "SYSTEM":
                raise InvalidRequestException(
                    "Cannot specify resolver endpoint ID and target IP "
                    "for SYSTEM type resolver rule"
                )

        if creator_request_id in [
            x.creator_request_id for x in self.resolver_rules.values()
        ]:
            raise ResourceExistsException(
                f"Resolver rule with creator request ID "
                f"'{creator_request_id}' already exists"
            )

        rule_id = f"rslvr-rr-{get_random_hex(17)}"
        resolver_rule = ResolverRule(
            region,
            rule_id,
            creator_request_id,
            rule_type,
            domain_name,
            target_ips,
            resolver_endpoint_id,
            name,
        )

        self.resolver_rules[rule_id] = resolver_rule
        self.tagger.tag_resource(resolver_rule.arn, tags or [])
        return resolver_rule

    def _validate_resolver_endpoint_id(self, resolver_endpoint_id):
        """Raise an exception if the id is invalid or unknown."""
        validate_args([("resolverEndpointId", resolver_endpoint_id)])
        if resolver_endpoint_id not in self.resolver_endpoints:
            raise ResourceNotFoundException(
                f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist"
            )

    def delete_resolver_endpoint(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)

        # Can't delete an endpoint if there are rules associated with it.
        rules = [
            x.id
            for x in self.resolver_rules.values()
            if x.resolver_endpoint_id == resolver_endpoint_id
        ]
        if rules:
            raise InvalidRequestException(
                f"Cannot delete resolver endpoint unless its related resolver "
                f"rules are deleted.  The following rules still exist for "
                f"this resolver endpoint:  {','.join(rules)}"
            )

        self.tagger.delete_all_tags_for_resource(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints.pop(resolver_endpoint_id)
        resolver_endpoint.delete_eni()
        resolver_endpoint.status = "DELETING"
        resolver_endpoint.status_message = resolver_endpoint.status_message.replace(
            "Successfully created", "Deleting"
        )
        return resolver_endpoint

    def _validate_resolver_rule_id(self, resolver_rule_id):
        """Raise an exception if the id is invalid or unknown."""
        validate_args([("resolverRuleId", resolver_rule_id)])
        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist"
            )

    def delete_resolver_rule(self, resolver_rule_id):
        self._validate_resolver_rule_id(resolver_rule_id)

        # Can't delete an rule unless VPC's are disassociated.
        associations = [
            x.id
            for x in self.resolver_rule_associations.values()
            if x.resolver_rule_id == resolver_rule_id
        ]
        if associations:
            raise ResourceInUseException(
                "Please disassociate this resolver rule from VPC first "
                "before deleting"
            )

        self.tagger.delete_all_tags_for_resource(resolver_rule_id)
        resolver_rule = self.resolver_rules.pop(resolver_rule_id)
        resolver_rule.status = "DELETING"
        resolver_rule.status_message = resolver_rule.status_message.replace(
            "Successfully created", "Deleting"
        )
        return resolver_rule

    def disassociate_resolver_rule(self, resolver_rule_id, vpc_id):
        validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)])

        # Non-existent rule or vpc ids?
        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist"
            )

        # Find the matching association for this rule and vpc.
        rule_association_id = None
        for association in self.resolver_rule_associations.values():
            if (
                resolver_rule_id == association.resolver_rule_id
                and vpc_id == association.vpc_id
            ):
                rule_association_id = association.id
                break
        else:
            raise ResourceNotFoundException(
                f"Resolver Rule Association between Resolver Rule "
                f"'{resolver_rule_id}' and VPC '{vpc_id}' does not exist"
            )

        rule_association = self.resolver_rule_associations.pop(rule_association_id)
        rule_association.status = "DELETING"
        rule_association.status_message = "Deleting Association"
        return rule_association

    def get_resolver_endpoint(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        return self.resolver_endpoints[resolver_endpoint_id]

    def get_resolver_rule(self, resolver_rule_id):
        """Return info for specified resolver rule."""
        self._validate_resolver_rule_id(resolver_rule_id)
        return self.resolver_rules[resolver_rule_id]

    def get_resolver_rule_association(self, resolver_rule_association_id):
        validate_args([("resolverRuleAssociationId", resolver_rule_association_id)])
        if resolver_rule_association_id not in self.resolver_rule_associations:
            raise ResourceNotFoundException(
                f"ResolverRuleAssociation '{resolver_rule_association_id}' does not Exist"
            )
        return self.resolver_rule_associations[resolver_rule_association_id]

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        endpoint = self.resolver_endpoints[resolver_endpoint_id]
        return endpoint.ip_descriptions()

    @staticmethod
    def _add_field_name_to_filter(filters):
        """Convert both styles of filter names to lowercase snake format.

        "IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
        However, "HostVPCId" doesn't fit the pattern, so that's treated
        special.
        """
        for rr_filter in filters:
            filter_name = rr_filter["Name"]
            if "_" not in filter_name:
                if "Vpc" in filter_name:
                    filter_name = "WRONG"
                elif filter_name == "HostVPCId":
                    filter_name = "host_vpc_id"
                elif filter_name == "VPCId":
                    filter_name = "vpc_id"
                elif filter_name in ["Type", "TYPE"]:
                    filter_name = "rule_type"
                elif not filter_name.isupper():
                    filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name)
            rr_filter["Field"] = filter_name.lower()

    @staticmethod
    def _validate_filters(filters, allowed_filter_names):
        """Raise exception if filter names are not as expected."""
        for rr_filter in filters:
            if rr_filter["Field"] not in allowed_filter_names:
                raise InvalidParameterException(
                    f"The filter '{rr_filter['Name']}' is invalid"
                )
            if "Values" not in rr_filter:
                raise InvalidParameterException(
                    f"No values specified for filter {rr_filter['Name']}"
                )

    @staticmethod
    def _matches_all_filters(entity, filters):
        """Return True if this entity has fields matching all the filters."""
        for rr_filter in filters:
            field_value = getattr(entity, rr_filter["Field"])

            if isinstance(field_value, list):
                if not set(field_value).intersection(rr_filter["Values"]):
                    return False
            elif isinstance(field_value, int):
                if str(field_value) not in rr_filter["Values"]:
                    return False
            elif field_value not in rr_filter["Values"]:
                return False

        return True

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_endpoints(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES)

        endpoints = []
        for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name):
            if self._matches_all_filters(endpoint, filters):
                endpoints.append(endpoint)
        return endpoints

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_rules(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverRule.FILTER_NAMES)

        rules = []
        for rule in sorted(self.resolver_rules.values(), key=lambda x: x.name):
            if self._matches_all_filters(rule, filters):
                rules.append(rule)
        return rules

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_rule_associations(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverRuleAssociation.FILTER_NAMES)

        rules = []
        for rule in sorted(
            self.resolver_rule_associations.values(), key=lambda x: x.name
        ):
            if self._matches_all_filters(rule, filters):
                rules.append(rule)
        return rules

    def _matched_arn(self, resource_arn):
        """Given ARN, raise exception if there is no corresponding resource."""
        for resolver_endpoint in self.resolver_endpoints.values():
            if resolver_endpoint.arn == resource_arn:
                return
        for resolver_rule in self.resolver_rules.values():
            if resolver_rule.arn == resource_arn:
                return
        raise ResourceNotFoundException(
            f"Resolver endpoint with ID '{resource_arn}' does not exist"
        )

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(self, resource_arn):
        self._matched_arn(resource_arn)
        return self.tagger.list_tags_for_resource(resource_arn).get("Tags")

    def tag_resource(self, resource_arn, tags):
        self._matched_arn(resource_arn)
        errmsg = self.tagger.validate_tags(
            tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
        )
        if errmsg:
            raise TagValidationException(errmsg)
        self.tagger.tag_resource(resource_arn, tags)

    def untag_resource(self, resource_arn, tag_keys):
        self._matched_arn(resource_arn)
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)

    def update_resolver_endpoint(self, resolver_endpoint_id, name):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        validate_args([("name", name)])
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
        resolver_endpoint.update_name(name)
        return resolver_endpoint

    def associate_resolver_endpoint_ip_address(
        self, region, resolver_endpoint_id, ip_address
    ):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

        if not ip_address.get("Ip"):
            subnet_info = ec2_backends[region].get_all_subnets(
                subnet_ids=[ip_address.get("SubnetId")]
            )[0]
            ip_address["Ip"] = subnet_info.get_available_subnet_ip(self)
        self._verify_subnet_ips(region, [ip_address], False)

        resolver_endpoint.associate_ip_address(ip_address)
        return resolver_endpoint

    def disassociate_resolver_endpoint_ip_address(
        self, resolver_endpoint_id, ip_address
    ):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

        if not (ip_address.get("Ip") or ip_address.get("IpId")):
            raise InvalidRequestException(
                "[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address."
            )

        resolver_endpoint.disassociate_ip_address(ip_address)
        return resolver_endpoint
Exemple #3
0
class DirectoryServiceBackend(BaseBackend):
    """Implementation of DirectoryService APIs."""

    def __init__(self, region_name=None):
        self.region_name = region_name
        self.directories = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "ds"
        )

    @staticmethod
    def _validate_create_directory_args(
        name, passwd, size, vpc_settings, description, short_name,
    ):  # pylint: disable=too-many-arguments
        """Raise exception if create_directory() args don't meet constraints.

        The error messages are accumulated before the exception is raised.
        """
        error_tuples = []
        passwd_pattern = (
            r"(?=^.{8,64}$)((?=.*\d)(?=.*[A-Z])(?=.*[a-z])|"
            r"(?=.*\d)(?=.*[^A-Za-z0-9\s])(?=.*[a-z])|"
            r"(?=.*[^A-Za-z0-9\s])(?=.*[A-Z])(?=.*[a-z])|"
            r"(?=.*\d)(?=.*[A-Z])(?=.*[^A-Za-z0-9\s]))^.*"
        )
        if not re.match(passwd_pattern, passwd):
            # Can't have an odd number of backslashes in a literal.
            json_pattern = passwd_pattern.replace("\\", r"\\")
            error_tuples.append(
                (
                    "password",
                    passwd,
                    fr"satisfy regular expression pattern: {json_pattern}",
                )
            )

        if size.lower() not in ["small", "large"]:
            error_tuples.append(
                ("size", size, "satisfy enum value set: [Small, Large]")
            )

        name_pattern = r"^([a-zA-Z0-9]+[\\.-])+([a-zA-Z0-9])+$"
        if not re.match(name_pattern, name):
            error_tuples.append(
                ("name", name, fr"satisfy regular expression pattern: {name_pattern}")
            )

        subnet_id_pattern = r"^(subnet-[0-9a-f]{8}|subnet-[0-9a-f]{17})$"
        for subnet in vpc_settings["SubnetIds"]:
            if not re.match(subnet_id_pattern, subnet):
                error_tuples.append(
                    (
                        "vpcSettings.subnetIds",
                        subnet,
                        fr"satisfy regular expression pattern: {subnet_id_pattern}",
                    )
                )

        if description and len(description) > 128:
            error_tuples.append(
                ("description", description, "have length less than or equal to 128")
            )

        short_name_pattern = r'^[^\/:*?"<>|.]+[^\/:*?"<>|]*$'
        if short_name and not re.match(short_name_pattern, short_name):
            json_pattern = short_name_pattern.replace("\\", r"\\").replace('"', r"\"")
            error_tuples.append(
                (
                    "shortName",
                    short_name,
                    fr"satisfy regular expression pattern: {json_pattern}",
                )
            )

        if error_tuples:
            raise DsValidationException(error_tuples)

    @staticmethod
    def _validate_vpc_setting_values(region, vpc_settings):
        """Raise exception if vpc_settings are invalid.

        If settings are valid, add AvailabilityZones to vpc_settings.
        """
        if len(vpc_settings["SubnetIds"]) != 2:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones."
            )

        from moto.ec2 import ec2_backends  # pylint: disable=import-outside-toplevel

        # Subnet IDs are checked before the VPC ID.  The Subnet IDs must
        # be valid and in different availability zones.
        try:
            subnets = ec2_backends[region].get_all_subnets(
                subnet_ids=vpc_settings["SubnetIds"]
            )
        except InvalidSubnetIdError as exc:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones."
            ) from exc

        regions = [subnet.availability_zone for subnet in subnets]
        if regions[0] == regions[1]:
            raise ClientException(
                "Invalid subnet ID(s). The two subnets must be in "
                "different Availability Zones."
            )

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_settings["VpcId"] not in [x.id for x in vpcs]:
            raise ClientException("Invalid VPC ID.")

        vpc_settings["AvailabilityZones"] = regions

    def create_directory(
        self, region, name, short_name, password, description, size, vpc_settings, tags
    ):  # pylint: disable=too-many-arguments
        """Create a fake Simple Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created"
            )

        # botocore doesn't look for missing vpc_settings, but boto3 does.
        if not vpc_settings:
            raise InvalidParameterException("VpcSettings must be specified.")

        self._validate_create_directory_args(
            name, password, size, vpc_settings, description, short_name,
        )
        self._validate_vpc_setting_values(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)

        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            name,
            password,
            size,
            vpc_settings,
            directory_type="SimpleAD",
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def _validate_directory_id(self, directory_id):
        """Raise an exception if the directory id is invalid or unknown."""
        # Validation of ID takes precedence over a check for its existence.
        id_pattern = r"^d-[0-9a-f]{10}$"
        if not re.match(id_pattern, directory_id):
            raise DsValidationException(
                [
                    (
                        "directoryId",
                        directory_id,
                        fr"satisfy regular expression pattern: {id_pattern}",
                    )
                ]
            )

        if directory_id not in self.directories:
            raise EntityDoesNotExistException(
                f"Directory {directory_id} does not exist"
            )

    def delete_directory(self, directory_id):
        """Delete directory with the matching ID."""
        self._validate_directory_id(directory_id)
        self.tagger.delete_all_tags_for_resource(directory_id)
        self.directories.pop(directory_id)
        return directory_id

    @paginate(pagination_model=PAGINATION_MODEL)
    def describe_directories(
        self, directory_ids=None, next_token=None, limit=0
    ):  # pylint: disable=unused-argument
        """Return info on all directories or directories with matching IDs."""
        for directory_id in directory_ids or self.directories:
            self._validate_directory_id(directory_id)

        directories = list(self.directories.values())
        if directory_ids:
            directories = [x for x in directories if x.directory_id in directory_ids]
        return sorted(directories, key=lambda x: x.launch_time)

    def get_directory_limits(self):
        """Return hard-coded limits for the directories."""
        counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0}
        for directory in self.directories.values():
            if directory.directory_type == "SimpleAD":
                counts["SimpleAD"] += 1
            elif directory.directory_type in ["MicrosoftAD", "SharedMicrosoftAD"]:
                counts["MicrosoftAD"] += 1
            elif directory.directory_type == "ADConnector":
                counts["ConnectedAD"] += 1

        return {
            "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"],
            "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"]
            == Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"],
            "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"]
            == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT,
            "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"],
            "ConnectedDirectoriesLimitReached": counts["ConnectedAD"]
            == Directory.CONNECTED_DIRECTORIES_LIMIT,
        }

    def add_tags_to_resource(self, resource_id, tags):
        """Add or overwrite one or more tags for specified directory."""
        self._validate_directory_id(resource_id)
        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise TagLimitExceededException("Tag limit exceeded")
        self.tagger.tag_resource(resource_id, tags)

    def remove_tags_from_resource(self, resource_id, tag_keys):
        """Removes tags from a directory."""
        self._validate_directory_id(resource_id)
        self.tagger.untag_resource_using_names(resource_id, tag_keys)

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(
        self, resource_id, next_token=None, limit=None,
    ):  # pylint: disable=unused-argument
        """List all tags on a directory."""
        self._validate_directory_id(resource_id)
        return self.tagger.list_tags_for_resource(resource_id).get("Tags")
Exemple #4
0
class FirehoseBackend(BaseBackend):
    """Implementation of Firehose APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.delivery_streams = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initializes all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region,
            zones,
            "firehose",
            special_service_name="kinesis-firehose")

    def create_delivery_stream(
        self,
        region,
        delivery_stream_name,
        delivery_stream_type,
        kinesis_stream_source_configuration,
        delivery_stream_encryption_configuration_input,
        s3_destination_configuration,
        extended_s3_destination_configuration,
        redshift_destination_configuration,
        elasticsearch_destination_configuration,
        splunk_destination_configuration,
        http_endpoint_destination_configuration,
        tags,
    ):  # pylint: disable=too-many-arguments,too-many-locals,unused-argument
        """Create a Kinesis Data Firehose delivery stream."""
        (destination_name,
         destination_config) = find_destination_config_in_args(locals())

        if delivery_stream_name in self.delivery_streams:
            raise ResourceInUseException(
                f"Firehose {delivery_stream_name} under accountId {get_account_id()} "
                f"already exists")

        if len(self.delivery_streams) == DeliveryStream.MAX_STREAMS_PER_REGION:
            raise LimitExceededException(
                f"You have already consumed your firehose quota of "
                f"{DeliveryStream.MAX_STREAMS_PER_REGION} hoses. Firehose "
                f"names: {list(self.delivery_streams.keys())}")

        # Rule out situations that are not yet implemented.
        if delivery_stream_encryption_configuration_input:
            warnings.warn(
                "A delivery stream with server-side encryption enabled is not "
                "yet implemented")

        if destination_name == "Splunk":
            warnings.warn(
                "A Splunk destination delivery stream is not yet implemented")

        if (kinesis_stream_source_configuration
                and delivery_stream_type != "KinesisStreamAsSource"):
            raise InvalidArgumentException(
                "KinesisSourceStreamConfig is only applicable for "
                "KinesisStreamAsSource stream type")

        # Validate the tags before proceeding.
        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)

        if tags and len(tags) > MAX_TAGS_PER_DELIVERY_STREAM:
            raise ValidationException(
                f"1 validation error detected: Value '{tags}' at 'tags' "
                f"failed to satisify contstraint: Member must have length "
                f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}")

        # Create a DeliveryStream instance that will be stored and indexed
        # by delivery stream name.  This instance will update the state and
        # create the ARN.
        delivery_stream = DeliveryStream(
            region,
            delivery_stream_name,
            delivery_stream_type,
            kinesis_stream_source_configuration,
            destination_name,
            destination_config,
        )
        self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags
                                 or [])

        self.delivery_streams[delivery_stream_name] = delivery_stream
        return self.delivery_streams[delivery_stream_name].delivery_stream_arn

    def delete_delivery_stream(self,
                               delivery_stream_name,
                               allow_force_delete=False):  # pylint: disable=unused-argument
        """Delete a delivery stream and its data.

        AllowForceDelete option is ignored as we only superficially
        apply state.
        """
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        self.tagger.delete_all_tags_for_resource(
            delivery_stream.delivery_stream_arn)

        delivery_stream.delivery_stream_status = "DELETING"
        self.delivery_streams.pop(delivery_stream_name)

    def describe_delivery_stream(self, delivery_stream_name, limit,
                                 exclusive_start_destination_id):  # pylint: disable=unused-argument
        """Return description of specified delivery stream and its status.

        Note:  the 'limit' and 'exclusive_start_destination_id' parameters
        are not currently processed/implemented.
        """
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        result = {"DeliveryStreamDescription": {"HasMoreDestinations": False}}
        for attribute, attribute_value in vars(delivery_stream).items():
            if not attribute_value:
                continue

            # Convert from attribute's snake case to camel case for outgoing
            # JSON.
            name = "".join([x.capitalize() for x in attribute.split("_")])

            # Fooey ... always an exception to the rule:
            if name == "DeliveryStreamArn":
                name = "DeliveryStreamARN"

            if name != "Destinations":
                if name == "Source":
                    result["DeliveryStreamDescription"][name] = {
                        "KinesisStreamSourceDescription": attribute_value
                    }
                else:
                    result["DeliveryStreamDescription"][name] = attribute_value
                continue

            result["DeliveryStreamDescription"]["Destinations"] = []
            for destination in attribute_value:
                description = {}
                for key, value in destination.items():
                    if key == "destination_id":
                        description["DestinationId"] = value
                    else:
                        description[f"{key}DestinationDescription"] = value

                result["DeliveryStreamDescription"]["Destinations"].append(
                    description)

        return result

    def list_delivery_streams(self, limit, delivery_stream_type,
                              exclusive_start_delivery_stream_name):
        """Return list of delivery streams in alphabetic order of names."""
        result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False}
        if not self.delivery_streams:
            return result

        # If delivery_stream_type is specified, filter out any stream that's
        # not of that type.
        stream_list = self.delivery_streams.keys()
        if delivery_stream_type:
            stream_list = [
                x for x in stream_list
                if self.delivery_streams[x].delivery_stream_type ==
                delivery_stream_type
            ]

        # The list is sorted alphabetically, not alphanumerically.
        sorted_list = sorted(stream_list)

        # Determine the limit or number of names to return in the list.
        limit = limit or DeliveryStream.MAX_STREAMS_PER_REGION

        # If a starting delivery stream name is given, find the index into
        # the sorted list, then add one to get the name following it.  If the
        # exclusive_start_delivery_stream_name doesn't exist, it's ignored.
        start = 0
        if exclusive_start_delivery_stream_name:
            if self.delivery_streams.get(exclusive_start_delivery_stream_name):
                start = sorted_list.index(
                    exclusive_start_delivery_stream_name) + 1

        result["DeliveryStreamNames"] = sorted_list[start:start + limit]
        if len(sorted_list) > (start + limit):
            result["HasMoreDeliveryStreams"] = True
        return result

    def list_tags_for_delivery_stream(self, delivery_stream_name,
                                      exclusive_start_tag_key, limit):
        """Return list of tags."""
        result = {"Tags": [], "HasMoreTags": False}
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        tags = self.tagger.list_tags_for_resource(
            delivery_stream.delivery_stream_arn)["Tags"]
        keys = self.tagger.extract_tag_names(tags)

        # If a starting tag is given and can be found, find the index into
        # tags, then add one to get the tag following it.
        start = 0
        if exclusive_start_tag_key:
            if exclusive_start_tag_key in keys:
                start = keys.index(exclusive_start_tag_key) + 1

        limit = limit or MAX_TAGS_PER_DELIVERY_STREAM
        result["Tags"] = tags[start:start + limit]
        if len(tags) > (start + limit):
            result["HasMoreTags"] = True
        return result

    def put_record(self, delivery_stream_name, record):
        """Write a single data record into a Kinesis Data firehose stream."""
        result = self.put_record_batch(delivery_stream_name, [record])
        return {
            "RecordId": result["RequestResponses"][0]["RecordId"],
            "Encrypted": False,
        }

    @staticmethod
    def put_http_records(http_destination, records):
        """Put records to a HTTP destination."""
        # Mostly copied from localstack
        url = http_destination["EndpointConfiguration"]["Url"]
        headers = {"Content-Type": "application/json"}
        record_to_send = {
            "requestId": str(uuid4()),
            "timestamp": int(time()),
            "records": [{
                "data": record["Data"]
            } for record in records],
        }
        try:
            requests.post(url, json=record_to_send, headers=headers)
        except Exception as exc:
            # This could be better ...
            raise RuntimeError(
                "Firehose PutRecord(Batch) to HTTP destination failed"
            ) from exc
        return [{"RecordId": str(uuid4())} for _ in range(len(records))]

    @staticmethod
    def _format_s3_object_path(delivery_stream_name, version_id, prefix):
        """Return a S3 object path in the expected format."""
        # Taken from LocalStack's firehose logic, with minor changes.
        # See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name
        # Path prefix pattern: myApp/YYYY/MM/DD/HH/
        # Object name pattern:
        # DeliveryStreamName-DeliveryStreamVersion-YYYY-MM-DD-HH-MM-SS-RandomString
        prefix = f"{prefix}{'' if prefix.endswith('/') else '/'}"
        now = datetime.utcnow()
        return (f"{prefix}{now.strftime('%Y/%m/%d/%H')}/"
                f"{delivery_stream_name}-{version_id}-"
                f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(uuid4())}")

    def put_s3_records(self, delivery_stream_name, version_id, s3_destination,
                       records):
        """Put records to a ExtendedS3 or S3 destination."""
        # Taken from LocalStack's firehose logic, with minor changes.
        bucket_name = s3_destination["BucketARN"].split(":")[-1]
        prefix = s3_destination.get("Prefix", "")
        object_path = self._format_s3_object_path(delivery_stream_name,
                                                  version_id, prefix)

        batched_data = b"".join([b64decode(r["Data"]) for r in records])
        try:
            s3_backend.put_object(bucket_name, object_path, batched_data)
        except Exception as exc:
            # This could be better ...
            raise RuntimeError(
                "Firehose PutRecord(Batch to S3 destination failed") from exc
        return [{"RecordId": str(uuid4())} for _ in range(len(records))]

    def put_record_batch(self, delivery_stream_name, records):
        """Write multiple data records into a Kinesis Data firehose stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        request_responses = []
        for destination in delivery_stream.destinations:
            if "ExtendedS3" in destination:
                # ExtendedS3 will be handled like S3,but in the future
                # this will probably need to be revisited.  This destination
                # must be listed before S3 otherwise both destinations will
                # be processed instead of just ExtendedS3.
                request_responses = self.put_s3_records(
                    delivery_stream_name,
                    delivery_stream.version_id,
                    destination["ExtendedS3"],
                    records,
                )
            elif "S3" in destination:
                request_responses = self.put_s3_records(
                    delivery_stream_name,
                    delivery_stream.version_id,
                    destination["S3"],
                    records,
                )
            elif "HttpEndpoint" in destination:
                request_responses = self.put_http_records(
                    destination["HttpEndpoint"], records)
            elif "Elasticsearch" in destination or "Redshift" in destination:
                # This isn't implmented as these services aren't implemented,
                # so ignore the data, but return a "proper" response.
                request_responses = [{
                    "RecordId": str(uuid4())
                } for _ in range(len(records))]

        return {
            "FailedPutCount": 0,
            "Encrypted": False,
            "RequestResponses": request_responses,
        }

    def tag_delivery_stream(self, delivery_stream_name, tags):
        """Add/update tags for specified delivery stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        if len(tags) > MAX_TAGS_PER_DELIVERY_STREAM:
            raise ValidationException(
                f"1 validation error detected: Value '{tags}' at 'tags' "
                f"failed to satisify contstraint: Member must have length "
                f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}")

        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)

        self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags)

    def untag_delivery_stream(self, delivery_stream_name, tag_keys):
        """Removes tags from specified delivery stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        # If a tag key doesn't exist for the stream, boto3 ignores it.
        self.tagger.untag_resource_using_names(
            delivery_stream.delivery_stream_arn, tag_keys)

    def update_destination(
        self,
        delivery_stream_name,
        current_delivery_stream_version_id,
        destination_id,
        s3_destination_update,
        extended_s3_destination_update,
        s3_backup_mode,
        redshift_destination_update,
        elasticsearch_destination_update,
        splunk_destination_update,
        http_endpoint_destination_update,
    ):  # pylint: disable=unused-argument,too-many-arguments,too-many-locals
        """Updates specified destination of specified delivery stream."""
        (destination_name,
         destination_config) = find_destination_config_in_args(locals())

        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under accountId "
                f"{get_account_id()} not found.")

        if destination_name == "Splunk":
            warnings.warn(
                "A Splunk destination delivery stream is not yet implemented")

        if delivery_stream.version_id != current_delivery_stream_version_id:
            raise ConcurrentModificationException(
                f"Cannot update firehose: {delivery_stream_name} since the "
                f"current version id: {delivery_stream.version_id} and "
                f"specified version id: {current_delivery_stream_version_id} "
                f"do not match")

        destination = {}
        destination_idx = 0
        for destination in delivery_stream.destinations:
            if destination["destination_id"] == destination_id:
                break
            destination_idx += 1
        else:
            raise InvalidArgumentException(
                "Destination Id {destination_id} not found")

        # Switching between Amazon ES and other services is not supported.
        # For an Amazon ES destination, you can only update to another Amazon
        # ES destination.  Same with HTTP.  Didn't test Splunk.
        if (destination_name == "Elasticsearch" and "Elasticsearch"
                not in destination) or (destination_name == "HttpEndpoint"
                                        and "HttpEndpoint" not in destination):
            raise InvalidArgumentException(
                f"Changing the destination type to or from {destination_name} "
                f"is not supported at this time.")

        # If this is a different type of destination configuration,
        # the existing configuration is reset first.
        if destination_name in destination:
            delivery_stream.destinations[destination_idx][
                destination_name].update(destination_config)
        else:
            delivery_stream.destinations[destination_idx] = {
                "destination_id": destination_id,
                destination_name: destination_config,
            }

        # Once S3 is updated to an ExtendedS3 destination, both remain in
        # the destination.  That means when one is updated, the other needs
        # to be updated as well.  The problem is that they don't have the
        # same fields.
        if destination_name == "ExtendedS3":
            delivery_stream.destinations[destination_idx][
                "S3"] = create_s3_destination_config(destination_config)
        elif destination_name == "S3" and "ExtendedS3" in destination:
            destination["ExtendedS3"] = {
                k: v
                for k, v in destination["S3"].items()
                if k in destination["ExtendedS3"]
            }

        # Increment version number and update the timestamp.
        delivery_stream.version_id = str(
            int(current_delivery_stream_version_id) + 1)
        delivery_stream.last_update_timestamp = datetime.now(
            timezone.utc).isoformat()

        # Unimplemented: processing of the "S3BackupMode" parameter.  Per the
        # documentation:  "You can update a delivery stream to enable Amazon
        # S3 backup if it is disabled.  If backup is enabled, you can't update
        # the delivery stream to disable it."

    def lookup_name_from_arn(self, arn):
        """Given an ARN, return the associated delivery stream name."""
        return self.delivery_streams.get(arn.split("/")[-1])

    def send_log_event(
        self,
        delivery_stream_arn,
        filter_name,
        log_group_name,
        log_stream_name,
        log_events,
    ):  # pylint:  disable=too-many-arguments
        """Send log events to a S3 bucket after encoding and gzipping it."""
        data = {
            "logEvents": log_events,
            "logGroup": log_group_name,
            "logStream": log_stream_name,
            "messageType": "DATA_MESSAGE",
            "owner": get_account_id(),
            "subscriptionFilters": [filter_name],
        }

        output = io.BytesIO()
        with GzipFile(fileobj=output, mode="w") as fhandle:
            fhandle.write(
                json.dumps(data, separators=(",", ":")).encode("utf-8"))
        gzipped_payload = b64encode(output.getvalue()).decode("utf-8")

        delivery_stream = self.lookup_name_from_arn(delivery_stream_arn)
        self.put_s3_records(
            delivery_stream.delivery_stream_name,
            delivery_stream.version_id,
            delivery_stream.destinations[0]["S3"],
            [{
                "Data": gzipped_payload
            }],
        )
Exemple #5
0
def test_validate_tags():
    """Unit tests for validate_tags()."""
    svc = TaggingService()
    # Key with invalid characters.
    errors = svc.validate_tags([{"Key": "foo!", "Value": "bar"}])
    assert (
        "Value 'foo!' at 'tags.1.member.key' failed to satisfy constraint: "
        "Member must satisfy regular expression pattern") in errors

    # Value with invalid characters.
    errors = svc.validate_tags([{"Key": "foo", "Value": "bar!"}])
    assert (
        "Value 'bar!' at 'tags.1.member.value' failed to satisfy "
        "constraint: Member must satisfy regular expression pattern") in errors

    # Key too long.
    errors = svc.validate_tags([{
        "Key": ("12345678901234567890123456789012345678901234567890"
                "12345678901234567890123456789012345678901234567890"
                "123456789012345678901234567890"),
        "Value":
        "foo",
    }])
    assert ("at 'tags.1.member.key' failed to satisfy constraint: Member must "
            "have length less than or equal to 128") in errors

    # Value too long.
    errors = svc.validate_tags([
        {
            "Key":
            "foo",
            "Value": ("12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "1234567890"),
        },
    ])
    assert ("at 'tags.1.member.value' failed to satisfy constraint: Member "
            "must have length less than or equal to 256") in errors

    # Compound errors.
    errors = svc.validate_tags([
        {
            "Key": ("12345678901234567890123456789012345678901234567890"
                    "12345678901234567890123456789012345678901234567890"
                    "123456789012345678901234567890"),
            "Value": ("12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "12345678901234567890123456789012345678901234567890"
                      "1234567890"),
        },
        {
            "Key": "foo!",
            "Value": "bar!"
        },
    ])
    assert "4 validation errors detected" in errors
    assert ("at 'tags.1.member.key' failed to satisfy constraint: Member must "
            "have length less than or equal to 128") in errors
    assert ("at 'tags.1.member.value' failed to satisfy constraint: Member "
            "must have length less than or equal to 256") in errors
    assert (
        "Value 'foo!' at 'tags.2.member.key' failed to satisfy constraint: "
        "Member must satisfy regular expression pattern") in errors
    assert (
        "Value 'bar!' at 'tags.2.member.value' failed to satisfy "
        "constraint: Member must satisfy regular expression pattern") in errors